Spaces:
Sleeping
Sleeping
| """ | |
| Word-Level Token Selection for Saliency Maps | |
| This module provides utilities for selecting specific words from a query | |
| to focus saliency map computation on those words only. | |
| """ | |
| import logging | |
| from typing import List, Dict, Tuple, Optional | |
| import torch | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| def tokenize_query_with_word_mapping( | |
| processor, | |
| query_text: str | |
| ) -> Tuple[List[int], Dict[int, Tuple[int, int]]]: | |
| """ | |
| Tokenize query and create mapping from word indices to token indices. | |
| Args: | |
| processor: ColPali processor with tokenizer | |
| query_text: Query text string | |
| Returns: | |
| Tuple of: | |
| - token_ids: List of token IDs | |
| - word_to_tokens: Dict mapping word_index -> (start_token_idx, end_token_idx) | |
| """ | |
| tokenizer = processor.processor.tokenizer | |
| # Tokenize the query | |
| tokens = tokenizer(query_text, return_offsets_mapping=True, add_special_tokens=False) | |
| token_ids = tokens['input_ids'] | |
| offsets = tokens['offset_mapping'] | |
| # Split query into words (simple whitespace-based splitting) | |
| words = query_text.split() | |
| word_to_tokens = {} | |
| word_idx = 0 | |
| char_pos = 0 | |
| for token_idx, (start_char, end_char) in enumerate(offsets): | |
| # Check if this token starts a new word | |
| if start_char >= char_pos and word_idx < len(words): | |
| # Find which word this token belongs to | |
| word_start = query_text.find(words[word_idx], char_pos) | |
| if word_start != -1 and start_char >= word_start: | |
| # This token belongs to the current word | |
| if word_idx not in word_to_tokens: | |
| word_to_tokens[word_idx] = (token_idx, token_idx) | |
| else: | |
| # Extend the token range for this word | |
| _, end_token = word_to_tokens[word_idx] | |
| word_to_tokens[word_idx] = (word_to_tokens[word_idx][0], token_idx) | |
| char_pos = word_start + len(words[word_idx]) | |
| # Move to next word if we've passed it | |
| if end_char >= char_pos: | |
| word_idx += 1 | |
| return token_ids, word_to_tokens | |
| def get_token_indices_for_words( | |
| processor, | |
| query_text: str, | |
| selected_word_indices: List[int], | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None | |
| ) -> List[int]: | |
| """ | |
| Get token indices corresponding to selected words. | |
| SIMPLE APPROACH: Tokenize the query, map words to tokens, return token indices. | |
| These indices can be used directly to filter the query embedding. | |
| Args: | |
| processor: ColPali processor | |
| query_text: Query text string | |
| selected_word_indices: List of word indices (0-based) to select | |
| input_ids: Optional (currently unused, for future compatibility) | |
| attention_mask: Optional (currently unused) | |
| Returns: | |
| List of token indices corresponding to selected words | |
| """ | |
| try: | |
| tokenizer = processor.processor.tokenizer | |
| except AttributeError: | |
| try: | |
| tokenizer = processor.tokenizer | |
| except AttributeError: | |
| logger.error("Could not access tokenizer from processor") | |
| return [] | |
| # Split query into words | |
| words = query_text.split() | |
| logger.info(f"🔤 Word selection: query has {len(words)} words") | |
| logger.info(f" Selected word indices: {selected_word_indices}") | |
| logger.info(f" Selected words: {[words[i] for i in selected_word_indices if i < len(words)]}") | |
| # Simple approach: tokenize and map words to tokens using offsets | |
| try: | |
| tokens = tokenizer(query_text, return_offsets_mapping=True, add_special_tokens=False) | |
| offsets = tokens['offset_mapping'] | |
| token_ids = tokens['input_ids'] | |
| logger.info(f" Tokenization: {len(token_ids)} tokens") | |
| # Build word-to-token mapping | |
| word_to_tokens = {} | |
| current_char = 0 | |
| for word_idx, word in enumerate(words): | |
| # Find where this word starts in the text | |
| word_start = query_text.find(word, current_char) | |
| if word_start == -1: | |
| continue | |
| word_end = word_start + len(word) | |
| # Find tokens that overlap with this word | |
| word_tokens = [] | |
| for tok_idx, (tok_start, tok_end) in enumerate(offsets): | |
| # Check if token overlaps with word | |
| if tok_end > word_start and tok_start < word_end: | |
| word_tokens.append(tok_idx) | |
| if word_tokens: | |
| word_to_tokens[word_idx] = word_tokens | |
| current_char = word_end | |
| logger.info(f" Word-to-token mapping: {word_to_tokens}") | |
| # Collect token indices for selected words | |
| result_indices = [] | |
| for word_idx in selected_word_indices: | |
| if word_idx in word_to_tokens: | |
| result_indices.extend(word_to_tokens[word_idx]) | |
| result_indices = sorted(set(result_indices)) | |
| logger.info(f" Result token indices: {result_indices}") | |
| return result_indices | |
| except Exception as e: | |
| logger.error(f"Error in word-to-token mapping: {e}") | |
| import traceback | |
| logger.debug(traceback.format_exc()) | |
| return [] | |
| def get_word_boundaries(query_text: str) -> List[Tuple[int, int, str]]: | |
| """ | |
| Get word boundaries (start, end, word) for a query. | |
| Args: | |
| query_text: Query text string | |
| Returns: | |
| List of tuples (start_char, end_char, word) | |
| """ | |
| words = query_text.split() | |
| boundaries = [] | |
| char_pos = 0 | |
| for word in words: | |
| start = query_text.find(word, char_pos) | |
| if start != -1: | |
| end = start + len(word) | |
| boundaries.append((start, end, word)) | |
| char_pos = end | |
| return boundaries | |