Spaces:
Sleeping
Sleeping
| """ | |
| ColBERT embeddings cache for test set documents. | |
| Provides O(1) lookup for ColBERT embeddings during late interaction. | |
| """ | |
| import json | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import Dict, Optional, Any | |
| class ColBERTCache: | |
| """Cache for ColBERT embeddings of test set documents.""" | |
| def __init__(self, cache_file: str = "test_set_colbert_cache.json"): | |
| self.cache_file = Path("outputs/caches") / cache_file | |
| self.embeddings_cache: Dict[str, np.ndarray] = {} | |
| self._load_cache() | |
| def _load_cache(self): | |
| """Load embeddings from cache file.""" | |
| if not self.cache_file.exists(): | |
| print(f"β οΈ ColBERT cache not found: {self.cache_file}") | |
| print("π‘ Run 'python precalculate_test_set_colbert.py' to create cache") | |
| return | |
| print(f"π Loading ColBERT cache from {self.cache_file}...") | |
| try: | |
| with open(self.cache_file, 'r') as f: | |
| cache_data = json.load(f) | |
| # Reconstruct embeddings from compressed format | |
| for doc_id, data in cache_data.items(): | |
| embedding_min = data['min'] | |
| embedding_max = data['max'] | |
| quantized_embedding = np.array(data['embedding'], dtype=np.uint8) | |
| # Reconstruct original embedding | |
| reconstructed = (quantized_embedding.astype(np.float32) / 255.0) * (embedding_max - embedding_min) + embedding_min | |
| self.embeddings_cache[doc_id] = reconstructed.reshape(data['shape']) | |
| print(f"β Loaded {len(self.embeddings_cache)} ColBERT embeddings from cache") | |
| except Exception as e: | |
| print(f"β Error loading ColBERT cache: {e}") | |
| self.embeddings_cache = {} | |
| def get_embedding(self, document_text: str) -> Optional[np.ndarray]: | |
| """Get ColBERT embedding for a document (O(1) lookup).""" | |
| return self.embeddings_cache.get(document_text) | |
| def has_embedding(self, document_text: str) -> bool: | |
| """Check if embedding exists for document.""" | |
| return document_text in self.embeddings_cache | |
| def get_cache_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics.""" | |
| return { | |
| 'total_embeddings': len(self.embeddings_cache), | |
| 'cache_file': str(self.cache_file), | |
| 'cache_exists': self.cache_file.exists() | |
| } | |
| # Global cache instance | |
| _colbert_cache = None | |
| def get_colbert_cache() -> ColBERTCache: | |
| """Get global ColBERT cache instance.""" | |
| global _colbert_cache | |
| if _colbert_cache is None: | |
| _colbert_cache = ColBERTCache() | |
| return _colbert_cache |