Spaces:
Sleeping
Sleeping
| """ | |
| Embedding cache management for UAP Data Analysis Tool | |
| Caches embeddings per column to avoid recomputation | |
| """ | |
| import pandas as pd | |
| import numpy as np | |
| import streamlit as st | |
| import hashlib | |
| import pickle | |
| import os | |
| import json | |
| from typing import Optional, Dict, Any, Tuple | |
| import logging | |
| from pathlib import Path | |
| import time | |
| logger = logging.getLogger(__name__) | |
| class EmbeddingCacheManager: | |
| """Manages caching of embeddings for text columns""" | |
| # Cache directory | |
| CACHE_DIR = Path(".embedding_cache") | |
| def __init__(self, cache_dir: Optional[Path] = None): | |
| """Initialize the embedding cache manager""" | |
| self.cache_dir = cache_dir or self.CACHE_DIR | |
| self.cache_dir.mkdir(exist_ok=True) | |
| # In-memory cache for current session | |
| self._memory_cache: Dict[str, np.ndarray] = {} | |
| # Metadata file to track cached embeddings | |
| self.metadata_file = self.cache_dir / "metadata.json" | |
| self.metadata = self._load_metadata() | |
| def _load_metadata(self) -> Dict[str, Any]: | |
| """Load cache metadata""" | |
| if self.metadata_file.exists(): | |
| try: | |
| with open(self.metadata_file, 'r') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| logger.error(f"Error loading metadata: {e}") | |
| return {} | |
| return {} | |
| def _save_metadata(self) -> None: | |
| """Save cache metadata""" | |
| try: | |
| with open(self.metadata_file, 'w') as f: | |
| json.dump(self.metadata, f, indent=2) | |
| except Exception as e: | |
| logger.error(f"Error saving metadata: {e}") | |
| def generate_cache_key(data: pd.Series, column_name: str, model_name: str = "default") -> str: | |
| """ | |
| Generate a unique cache key for column data | |
| Args: | |
| data: The column data | |
| column_name: Name of the column | |
| model_name: Embedding model name | |
| Returns: | |
| Unique cache key | |
| """ | |
| # Create a hash of the data content | |
| data_str = ''.join(str(x) for x in data.dropna().unique()) | |
| data_hash = hashlib.md5(data_str.encode()).hexdigest()[:16] | |
| # Include column name and model in the key | |
| cache_key = f"{column_name}_{model_name}_{data_hash}" | |
| return cache_key | |
| def get_embeddings(self, | |
| data: pd.Series, | |
| column_name: str, | |
| model_name: str = "microsoft/harrier-oss-v1-0.6b", | |
| compute_func: Optional[callable] = None) -> Optional[np.ndarray]: | |
| """ | |
| Get embeddings from cache or compute if not exists | |
| Args: | |
| data: The column data | |
| column_name: Name of the column | |
| model_name: Embedding model name | |
| compute_func: Function to compute embeddings if not cached | |
| Returns: | |
| Embeddings array or None | |
| """ | |
| cache_key = self.generate_cache_key(data, column_name, model_name) | |
| # Check in-memory cache first | |
| if cache_key in self._memory_cache: | |
| logger.info(f"Using in-memory cached embeddings for {column_name}") | |
| return self._memory_cache[cache_key] | |
| # Check disk cache | |
| cache_file = self.cache_dir / f"{cache_key}.pkl" | |
| if cache_file.exists(): | |
| try: | |
| logger.info(f"Loading cached embeddings from disk for {column_name}") | |
| with open(cache_file, 'rb') as f: | |
| embeddings = pickle.load(f) | |
| # Store in memory cache | |
| self._memory_cache[cache_key] = embeddings | |
| return embeddings | |
| except Exception as e: | |
| logger.error(f"Error loading cached embeddings: {e}") | |
| # Continue to recompute if loading fails | |
| # Compute embeddings if not cached and compute function provided | |
| if compute_func is not None: | |
| logger.info(f"Computing new embeddings for {column_name}") | |
| start_time = time.time() | |
| embeddings = compute_func(data) | |
| compute_time = time.time() - start_time | |
| logger.info(f"Computed embeddings in {compute_time:.2f}s") | |
| # Cache the embeddings | |
| self.cache_embeddings(embeddings, data, column_name, model_name) | |
| return embeddings | |
| return None | |
| def cache_embeddings(self, | |
| embeddings: np.ndarray, | |
| data: pd.Series, | |
| column_name: str, | |
| model_name: str = "microsoft/harrier-oss-v1-0.6b") -> None: | |
| """ | |
| Cache embeddings to disk and memory | |
| Args: | |
| embeddings: The embeddings array | |
| data: The original column data | |
| column_name: Name of the column | |
| model_name: Embedding model name | |
| """ | |
| cache_key = self.generate_cache_key(data, column_name, model_name) | |
| # Store in memory cache | |
| self._memory_cache[cache_key] = embeddings | |
| # Store on disk | |
| cache_file = self.cache_dir / f"{cache_key}.pkl" | |
| try: | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump(embeddings, f) | |
| # Update metadata | |
| self.metadata[cache_key] = { | |
| 'column_name': column_name, | |
| 'model_name': model_name, | |
| 'shape': embeddings.shape, | |
| 'cached_at': time.time(), | |
| 'data_size': len(data), | |
| 'unique_values': int(data.nunique()) | |
| } | |
| self._save_metadata() | |
| logger.info(f"Cached embeddings for {column_name} to disk") | |
| except Exception as e: | |
| logger.error(f"Error caching embeddings: {e}") | |
| def clear_cache(self, column_name: Optional[str] = None) -> None: | |
| """ | |
| Clear embedding cache | |
| Args: | |
| column_name: If provided, only clear cache for this column | |
| """ | |
| if column_name: | |
| # Clear specific column | |
| keys_to_remove = [k for k in self.metadata.keys() | |
| if k.startswith(f"{column_name}_")] | |
| for key in keys_to_remove: | |
| # Remove from memory | |
| self._memory_cache.pop(key, None) | |
| # Remove from disk | |
| cache_file = self.cache_dir / f"{key}.pkl" | |
| if cache_file.exists(): | |
| cache_file.unlink() | |
| # Remove from metadata | |
| self.metadata.pop(key, None) | |
| self._save_metadata() | |
| logger.info(f"Cleared cache for column: {column_name}") | |
| else: | |
| # Clear all cache | |
| self._memory_cache.clear() | |
| # Remove all cache files | |
| for cache_file in self.cache_dir.glob("*.pkl"): | |
| cache_file.unlink() | |
| # Clear metadata | |
| self.metadata.clear() | |
| self._save_metadata() | |
| logger.info("Cleared all embedding cache") | |
| def get_cache_info(self) -> Dict[str, Any]: | |
| """Get information about cached embeddings""" | |
| info = { | |
| 'total_cached': len(self.metadata), | |
| 'memory_cached': len(self._memory_cache), | |
| 'disk_size_mb': sum( | |
| (self.cache_dir / f"{k}.pkl").stat().st_size / 1024 / 1024 | |
| for k in self.metadata.keys() | |
| if (self.cache_dir / f"{k}.pkl").exists() | |
| ), | |
| 'columns': {} | |
| } | |
| # Group by column name | |
| for key, meta in self.metadata.items(): | |
| col_name = meta['column_name'] | |
| if col_name not in info['columns']: | |
| info['columns'][col_name] = [] | |
| info['columns'][col_name].append({ | |
| 'model': meta['model_name'], | |
| 'shape': meta['shape'], | |
| 'cached_at': time.strftime('%Y-%m-%d %H:%M:%S', | |
| time.localtime(meta['cached_at'])) | |
| }) | |
| return info | |
| # Global instance for easy access | |
| _embedding_cache = None | |
| def get_embedding_cache() -> EmbeddingCacheManager: | |
| """Get or create global embedding cache instance""" | |
| global _embedding_cache | |
| if _embedding_cache is None: | |
| _embedding_cache = EmbeddingCacheManager() | |
| return _embedding_cache | |
| # Streamlit-specific caching decorator | |
| def get_cached_embeddings(data_hash: str, | |
| column_name: str, | |
| model_name: str, | |
| compute_func: callable, | |
| data_values: list) -> np.ndarray: | |
| """ | |
| Streamlit-cached wrapper for embedding computation | |
| Args: | |
| data_hash: Hash of the data | |
| column_name: Column name | |
| model_name: Model name | |
| compute_func: Function to compute embeddings | |
| data_values: The actual data values | |
| Returns: | |
| Embeddings array | |
| """ | |
| # Convert back to Series | |
| data = pd.Series(data_values) | |
| # Use the cache manager | |
| cache_manager = get_embedding_cache() | |
| embeddings = cache_manager.get_embeddings( | |
| data, | |
| column_name, | |
| model_name, | |
| lambda x: compute_func(x.tolist()) | |
| ) | |
| return embeddings | |
| # Helper function for the UAPAnalyzer integration | |
| def compute_embeddings_with_cache(data: pd.Series, | |
| column_name: str, | |
| model_name: str = "microsoft/harrier-oss-v1-0.6b", | |
| encoder_func: callable = None) -> np.ndarray: | |
| """ | |
| Compute embeddings with caching support | |
| Args: | |
| data: Column data | |
| column_name: Name of the column | |
| model_name: Embedding model name | |
| encoder_func: Function that takes a list of texts and returns embeddings | |
| Returns: | |
| Embeddings array | |
| """ | |
| if encoder_func is None: | |
| from uap_analyzer import get_embed_model | |
| from sentence_transformers import SentenceTransformer | |
| if model_name == "microsoft/harrier-oss-v1-0.6b": | |
| encoder_func = get_embed_model().encode | |
| else: | |
| model = SentenceTransformer(model_name, model_kwargs={"dtype": "auto"}) | |
| encoder_func = model.encode | |
| # Generate hash for Streamlit caching | |
| data_str = ''.join(str(x) for x in data.dropna().unique()) | |
| data_hash = hashlib.md5(data_str.encode()).hexdigest() | |
| # Use Streamlit caching | |
| return get_cached_embeddings( | |
| data_hash, | |
| column_name, | |
| model_name, | |
| encoder_func, | |
| data.tolist() | |
| ) |