""" Word-Level Token Selection for Saliency Maps This module provides utilities for selecting specific words from a query to focus saliency map computation on those words only. """ import logging from typing import List, Dict, Tuple, Optional import torch import numpy as np logger = logging.getLogger(__name__) def tokenize_query_with_word_mapping( processor, query_text: str ) -> Tuple[List[int], Dict[int, Tuple[int, int]]]: """ Tokenize query and create mapping from word indices to token indices. Args: processor: ColPali processor with tokenizer query_text: Query text string Returns: Tuple of: - token_ids: List of token IDs - word_to_tokens: Dict mapping word_index -> (start_token_idx, end_token_idx) """ tokenizer = processor.processor.tokenizer # Tokenize the query tokens = tokenizer(query_text, return_offsets_mapping=True, add_special_tokens=False) token_ids = tokens['input_ids'] offsets = tokens['offset_mapping'] # Split query into words (simple whitespace-based splitting) words = query_text.split() word_to_tokens = {} word_idx = 0 char_pos = 0 for token_idx, (start_char, end_char) in enumerate(offsets): # Check if this token starts a new word if start_char >= char_pos and word_idx < len(words): # Find which word this token belongs to word_start = query_text.find(words[word_idx], char_pos) if word_start != -1 and start_char >= word_start: # This token belongs to the current word if word_idx not in word_to_tokens: word_to_tokens[word_idx] = (token_idx, token_idx) else: # Extend the token range for this word _, end_token = word_to_tokens[word_idx] word_to_tokens[word_idx] = (word_to_tokens[word_idx][0], token_idx) char_pos = word_start + len(words[word_idx]) # Move to next word if we've passed it if end_char >= char_pos: word_idx += 1 return token_ids, word_to_tokens def get_token_indices_for_words( processor, query_text: str, selected_word_indices: List[int], input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None ) -> List[int]: """ Get token indices corresponding to selected words. SIMPLE APPROACH: Tokenize the query, map words to tokens, return token indices. These indices can be used directly to filter the query embedding. Args: processor: ColPali processor query_text: Query text string selected_word_indices: List of word indices (0-based) to select input_ids: Optional (currently unused, for future compatibility) attention_mask: Optional (currently unused) Returns: List of token indices corresponding to selected words """ try: tokenizer = processor.processor.tokenizer except AttributeError: try: tokenizer = processor.tokenizer except AttributeError: logger.error("Could not access tokenizer from processor") return [] # Split query into words words = query_text.split() logger.info(f"🔤 Word selection: query has {len(words)} words") logger.info(f" Selected word indices: {selected_word_indices}") logger.info(f" Selected words: {[words[i] for i in selected_word_indices if i < len(words)]}") # Simple approach: tokenize and map words to tokens using offsets try: tokens = tokenizer(query_text, return_offsets_mapping=True, add_special_tokens=False) offsets = tokens['offset_mapping'] token_ids = tokens['input_ids'] logger.info(f" Tokenization: {len(token_ids)} tokens") # Build word-to-token mapping word_to_tokens = {} current_char = 0 for word_idx, word in enumerate(words): # Find where this word starts in the text word_start = query_text.find(word, current_char) if word_start == -1: continue word_end = word_start + len(word) # Find tokens that overlap with this word word_tokens = [] for tok_idx, (tok_start, tok_end) in enumerate(offsets): # Check if token overlaps with word if tok_end > word_start and tok_start < word_end: word_tokens.append(tok_idx) if word_tokens: word_to_tokens[word_idx] = word_tokens current_char = word_end logger.info(f" Word-to-token mapping: {word_to_tokens}") # Collect token indices for selected words result_indices = [] for word_idx in selected_word_indices: if word_idx in word_to_tokens: result_indices.extend(word_to_tokens[word_idx]) result_indices = sorted(set(result_indices)) logger.info(f" Result token indices: {result_indices}") return result_indices except Exception as e: logger.error(f"Error in word-to-token mapping: {e}") import traceback logger.debug(traceback.format_exc()) return [] def get_word_boundaries(query_text: str) -> List[Tuple[int, int, str]]: """ Get word boundaries (start, end, word) for a query. Args: query_text: Query text string Returns: List of tuples (start_char, end_char, word) """ words = query_text.split() boundaries = [] char_pos = 0 for word in words: start = query_text.find(word, char_pos) if start != -1: end = start + len(word) boundaries.append((start, end, word)) char_pos = end return boundaries