audit_assistant / src /ui_components /word_selection.py
akryldigital's picture
add Word Level Saliency and Experimental techniques
000885e verified
"""
Word-Level Token Selection for Saliency Maps
This module provides utilities for selecting specific words from a query
to focus saliency map computation on those words only.
"""
import logging
from typing import List, Dict, Tuple, Optional
import torch
import numpy as np
logger = logging.getLogger(__name__)
def tokenize_query_with_word_mapping(
processor,
query_text: str
) -> Tuple[List[int], Dict[int, Tuple[int, int]]]:
"""
Tokenize query and create mapping from word indices to token indices.
Args:
processor: ColPali processor with tokenizer
query_text: Query text string
Returns:
Tuple of:
- token_ids: List of token IDs
- word_to_tokens: Dict mapping word_index -> (start_token_idx, end_token_idx)
"""
tokenizer = processor.processor.tokenizer
# Tokenize the query
tokens = tokenizer(query_text, return_offsets_mapping=True, add_special_tokens=False)
token_ids = tokens['input_ids']
offsets = tokens['offset_mapping']
# Split query into words (simple whitespace-based splitting)
words = query_text.split()
word_to_tokens = {}
word_idx = 0
char_pos = 0
for token_idx, (start_char, end_char) in enumerate(offsets):
# Check if this token starts a new word
if start_char >= char_pos and word_idx < len(words):
# Find which word this token belongs to
word_start = query_text.find(words[word_idx], char_pos)
if word_start != -1 and start_char >= word_start:
# This token belongs to the current word
if word_idx not in word_to_tokens:
word_to_tokens[word_idx] = (token_idx, token_idx)
else:
# Extend the token range for this word
_, end_token = word_to_tokens[word_idx]
word_to_tokens[word_idx] = (word_to_tokens[word_idx][0], token_idx)
char_pos = word_start + len(words[word_idx])
# Move to next word if we've passed it
if end_char >= char_pos:
word_idx += 1
return token_ids, word_to_tokens
def get_token_indices_for_words(
processor,
query_text: str,
selected_word_indices: List[int],
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None
) -> List[int]:
"""
Get token indices corresponding to selected words.
SIMPLE APPROACH: Tokenize the query, map words to tokens, return token indices.
These indices can be used directly to filter the query embedding.
Args:
processor: ColPali processor
query_text: Query text string
selected_word_indices: List of word indices (0-based) to select
input_ids: Optional (currently unused, for future compatibility)
attention_mask: Optional (currently unused)
Returns:
List of token indices corresponding to selected words
"""
try:
tokenizer = processor.processor.tokenizer
except AttributeError:
try:
tokenizer = processor.tokenizer
except AttributeError:
logger.error("Could not access tokenizer from processor")
return []
# Split query into words
words = query_text.split()
logger.info(f"🔤 Word selection: query has {len(words)} words")
logger.info(f" Selected word indices: {selected_word_indices}")
logger.info(f" Selected words: {[words[i] for i in selected_word_indices if i < len(words)]}")
# Simple approach: tokenize and map words to tokens using offsets
try:
tokens = tokenizer(query_text, return_offsets_mapping=True, add_special_tokens=False)
offsets = tokens['offset_mapping']
token_ids = tokens['input_ids']
logger.info(f" Tokenization: {len(token_ids)} tokens")
# Build word-to-token mapping
word_to_tokens = {}
current_char = 0
for word_idx, word in enumerate(words):
# Find where this word starts in the text
word_start = query_text.find(word, current_char)
if word_start == -1:
continue
word_end = word_start + len(word)
# Find tokens that overlap with this word
word_tokens = []
for tok_idx, (tok_start, tok_end) in enumerate(offsets):
# Check if token overlaps with word
if tok_end > word_start and tok_start < word_end:
word_tokens.append(tok_idx)
if word_tokens:
word_to_tokens[word_idx] = word_tokens
current_char = word_end
logger.info(f" Word-to-token mapping: {word_to_tokens}")
# Collect token indices for selected words
result_indices = []
for word_idx in selected_word_indices:
if word_idx in word_to_tokens:
result_indices.extend(word_to_tokens[word_idx])
result_indices = sorted(set(result_indices))
logger.info(f" Result token indices: {result_indices}")
return result_indices
except Exception as e:
logger.error(f"Error in word-to-token mapping: {e}")
import traceback
logger.debug(traceback.format_exc())
return []
def get_word_boundaries(query_text: str) -> List[Tuple[int, int, str]]:
"""
Get word boundaries (start, end, word) for a query.
Args:
query_text: Query text string
Returns:
List of tuples (start_char, end_char, word)
"""
words = query_text.split()
boundaries = []
char_pos = 0
for word in words:
start = query_text.find(word, char_pos)
if start != -1:
end = start + len(word)
boundaries.append((start, end, word))
char_pos = end
return boundaries