UAP-Data-Analysis-Tool / utils /embedding_cache.py
Ashoka74's picture
Deploy current work to HF Space (slim)
a1aef88
Raw
History Blame Contribute Delete
11.2 kB
"""
Embedding cache management for UAP Data Analysis Tool
Caches embeddings per column to avoid recomputation
"""
import pandas as pd
import numpy as np
import streamlit as st
import hashlib
import pickle
import os
import json
from typing import Optional, Dict, Any, Tuple
import logging
from pathlib import Path
import time
logger = logging.getLogger(__name__)
class EmbeddingCacheManager:
"""Manages caching of embeddings for text columns"""
# Cache directory
CACHE_DIR = Path(".embedding_cache")
def __init__(self, cache_dir: Optional[Path] = None):
"""Initialize the embedding cache manager"""
self.cache_dir = cache_dir or self.CACHE_DIR
self.cache_dir.mkdir(exist_ok=True)
# In-memory cache for current session
self._memory_cache: Dict[str, np.ndarray] = {}
# Metadata file to track cached embeddings
self.metadata_file = self.cache_dir / "metadata.json"
self.metadata = self._load_metadata()
def _load_metadata(self) -> Dict[str, Any]:
"""Load cache metadata"""
if self.metadata_file.exists():
try:
with open(self.metadata_file, 'r') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata: {e}")
return {}
return {}
def _save_metadata(self) -> None:
"""Save cache metadata"""
try:
with open(self.metadata_file, 'w') as f:
json.dump(self.metadata, f, indent=2)
except Exception as e:
logger.error(f"Error saving metadata: {e}")
@staticmethod
def generate_cache_key(data: pd.Series, column_name: str, model_name: str = "default") -> str:
"""
Generate a unique cache key for column data
Args:
data: The column data
column_name: Name of the column
model_name: Embedding model name
Returns:
Unique cache key
"""
# Create a hash of the data content
data_str = ''.join(str(x) for x in data.dropna().unique())
data_hash = hashlib.md5(data_str.encode()).hexdigest()[:16]
# Include column name and model in the key
cache_key = f"{column_name}_{model_name}_{data_hash}"
return cache_key
def get_embeddings(self,
data: pd.Series,
column_name: str,
model_name: str = "microsoft/harrier-oss-v1-0.6b",
compute_func: Optional[callable] = None) -> Optional[np.ndarray]:
"""
Get embeddings from cache or compute if not exists
Args:
data: The column data
column_name: Name of the column
model_name: Embedding model name
compute_func: Function to compute embeddings if not cached
Returns:
Embeddings array or None
"""
cache_key = self.generate_cache_key(data, column_name, model_name)
# Check in-memory cache first
if cache_key in self._memory_cache:
logger.info(f"Using in-memory cached embeddings for {column_name}")
return self._memory_cache[cache_key]
# Check disk cache
cache_file = self.cache_dir / f"{cache_key}.pkl"
if cache_file.exists():
try:
logger.info(f"Loading cached embeddings from disk for {column_name}")
with open(cache_file, 'rb') as f:
embeddings = pickle.load(f)
# Store in memory cache
self._memory_cache[cache_key] = embeddings
return embeddings
except Exception as e:
logger.error(f"Error loading cached embeddings: {e}")
# Continue to recompute if loading fails
# Compute embeddings if not cached and compute function provided
if compute_func is not None:
logger.info(f"Computing new embeddings for {column_name}")
start_time = time.time()
embeddings = compute_func(data)
compute_time = time.time() - start_time
logger.info(f"Computed embeddings in {compute_time:.2f}s")
# Cache the embeddings
self.cache_embeddings(embeddings, data, column_name, model_name)
return embeddings
return None
def cache_embeddings(self,
embeddings: np.ndarray,
data: pd.Series,
column_name: str,
model_name: str = "microsoft/harrier-oss-v1-0.6b") -> None:
"""
Cache embeddings to disk and memory
Args:
embeddings: The embeddings array
data: The original column data
column_name: Name of the column
model_name: Embedding model name
"""
cache_key = self.generate_cache_key(data, column_name, model_name)
# Store in memory cache
self._memory_cache[cache_key] = embeddings
# Store on disk
cache_file = self.cache_dir / f"{cache_key}.pkl"
try:
with open(cache_file, 'wb') as f:
pickle.dump(embeddings, f)
# Update metadata
self.metadata[cache_key] = {
'column_name': column_name,
'model_name': model_name,
'shape': embeddings.shape,
'cached_at': time.time(),
'data_size': len(data),
'unique_values': int(data.nunique())
}
self._save_metadata()
logger.info(f"Cached embeddings for {column_name} to disk")
except Exception as e:
logger.error(f"Error caching embeddings: {e}")
def clear_cache(self, column_name: Optional[str] = None) -> None:
"""
Clear embedding cache
Args:
column_name: If provided, only clear cache for this column
"""
if column_name:
# Clear specific column
keys_to_remove = [k for k in self.metadata.keys()
if k.startswith(f"{column_name}_")]
for key in keys_to_remove:
# Remove from memory
self._memory_cache.pop(key, None)
# Remove from disk
cache_file = self.cache_dir / f"{key}.pkl"
if cache_file.exists():
cache_file.unlink()
# Remove from metadata
self.metadata.pop(key, None)
self._save_metadata()
logger.info(f"Cleared cache for column: {column_name}")
else:
# Clear all cache
self._memory_cache.clear()
# Remove all cache files
for cache_file in self.cache_dir.glob("*.pkl"):
cache_file.unlink()
# Clear metadata
self.metadata.clear()
self._save_metadata()
logger.info("Cleared all embedding cache")
def get_cache_info(self) -> Dict[str, Any]:
"""Get information about cached embeddings"""
info = {
'total_cached': len(self.metadata),
'memory_cached': len(self._memory_cache),
'disk_size_mb': sum(
(self.cache_dir / f"{k}.pkl").stat().st_size / 1024 / 1024
for k in self.metadata.keys()
if (self.cache_dir / f"{k}.pkl").exists()
),
'columns': {}
}
# Group by column name
for key, meta in self.metadata.items():
col_name = meta['column_name']
if col_name not in info['columns']:
info['columns'][col_name] = []
info['columns'][col_name].append({
'model': meta['model_name'],
'shape': meta['shape'],
'cached_at': time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(meta['cached_at']))
})
return info
# Global instance for easy access
_embedding_cache = None
def get_embedding_cache() -> EmbeddingCacheManager:
"""Get or create global embedding cache instance"""
global _embedding_cache
if _embedding_cache is None:
_embedding_cache = EmbeddingCacheManager()
return _embedding_cache
# Streamlit-specific caching decorator
@st.cache_data(persist="disk", show_spinner=False)
def get_cached_embeddings(data_hash: str,
column_name: str,
model_name: str,
compute_func: callable,
data_values: list) -> np.ndarray:
"""
Streamlit-cached wrapper for embedding computation
Args:
data_hash: Hash of the data
column_name: Column name
model_name: Model name
compute_func: Function to compute embeddings
data_values: The actual data values
Returns:
Embeddings array
"""
# Convert back to Series
data = pd.Series(data_values)
# Use the cache manager
cache_manager = get_embedding_cache()
embeddings = cache_manager.get_embeddings(
data,
column_name,
model_name,
lambda x: compute_func(x.tolist())
)
return embeddings
# Helper function for the UAPAnalyzer integration
def compute_embeddings_with_cache(data: pd.Series,
column_name: str,
model_name: str = "microsoft/harrier-oss-v1-0.6b",
encoder_func: callable = None) -> np.ndarray:
"""
Compute embeddings with caching support
Args:
data: Column data
column_name: Name of the column
model_name: Embedding model name
encoder_func: Function that takes a list of texts and returns embeddings
Returns:
Embeddings array
"""
if encoder_func is None:
from uap_analyzer import get_embed_model
from sentence_transformers import SentenceTransformer
if model_name == "microsoft/harrier-oss-v1-0.6b":
encoder_func = get_embed_model().encode
else:
model = SentenceTransformer(model_name, model_kwargs={"dtype": "auto"})
encoder_func = model.encode
# Generate hash for Streamlit caching
data_str = ''.join(str(x) for x in data.dropna().unique())
data_hash = hashlib.md5(data_str.encode()).hexdigest()
# Use Streamlit caching
return get_cached_embeddings(
data_hash,
column_name,
model_name,
encoder_func,
data.tolist()
)