Spaces:
Paused
Paused
| """Embedding Manager for Starfish | |
| This module provides embedding functionality using FAISS and SentenceTransformers | |
| for semantic similarity search and data deduplication. | |
| """ | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional, Tuple, Union | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| import pickle | |
| import os | |
| from pathlib import Path | |
| from starfish.common.logger import get_logger | |
| logger = get_logger(__name__) | |
| class EmbeddingManager: | |
| """ | |
| Manages embeddings using SentenceTransformers and FAISS for efficient similarity search. | |
| Features: | |
| - Text embedding using pre-trained SentenceTransformers models | |
| - Fast similarity search using FAISS indexing | |
| - Persistent storage and loading of embeddings | |
| - Configurable similarity thresholds | |
| - Support for both exact and approximate nearest neighbor search | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str = "all-MiniLM-L6-v2", | |
| index_type: str = "flat", | |
| similarity_threshold: float = 0.85, | |
| cache_dir: Optional[str] = None, | |
| device: Optional[str] = None, | |
| ): | |
| """ | |
| Initialize the EmbeddingManager. | |
| Args: | |
| model_name: SentenceTransformers model name or path | |
| index_type: Type of FAISS index ('flat', 'ivf', 'hnsw') | |
| similarity_threshold: Threshold for determining similar items (0-1) | |
| cache_dir: Directory to cache embeddings and models | |
| device: Device to run model on ('cpu', 'cuda', 'mps') | |
| """ | |
| self.model_name = model_name | |
| self.index_type = index_type | |
| self.similarity_threshold = similarity_threshold | |
| self.cache_dir = Path(cache_dir) if cache_dir else Path.home() / ".starfish" / "embeddings" | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| # Initialize SentenceTransformer model | |
| logger.info(f"Loading SentenceTransformer model: {model_name}") | |
| self.model = SentenceTransformer(model_name, device=device) | |
| self.embedding_dim = self.model.get_sentence_embedding_dimension() | |
| # Initialize FAISS index | |
| self.index = None | |
| self.metadata = [] # Store original texts and metadata | |
| self.id_to_index = {} # Map custom IDs to FAISS indices | |
| logger.info(f"EmbeddingManager initialized with {model_name}, dim={self.embedding_dim}") | |
| def _create_index(self, dimension: int) -> faiss.Index: | |
| """Create a FAISS index based on the specified type.""" | |
| if self.index_type == "flat": | |
| # L2 distance (Euclidean) | |
| index = faiss.IndexFlatL2(dimension) | |
| elif self.index_type == "ivf": | |
| # Inverted file index for faster approximate search | |
| quantizer = faiss.IndexFlatL2(dimension) | |
| index = faiss.IndexIVFFlat(quantizer, dimension, 100) # 100 clusters | |
| elif self.index_type == "hnsw": | |
| # Hierarchical Navigable Small World for very fast approximate search | |
| index = faiss.IndexHNSWFlat(dimension, 32) | |
| else: | |
| raise ValueError(f"Unsupported index type: {self.index_type}") | |
| return index | |
| def embed_texts(self, texts: List[str], show_progress: bool = True) -> np.ndarray: | |
| """ | |
| Embed a list of texts using SentenceTransformers. | |
| Args: | |
| texts: List of texts to embed | |
| show_progress: Whether to show progress bar | |
| Returns: | |
| numpy array of embeddings with shape (len(texts), embedding_dim) | |
| """ | |
| if not texts: | |
| return np.array([]).reshape(0, self.embedding_dim) | |
| logger.info(f"Embedding {len(texts)} texts...") | |
| embeddings = self.model.encode( | |
| texts, | |
| convert_to_numpy=True, | |
| show_progress_bar=show_progress, | |
| normalize_embeddings=True, # Normalize for cosine similarity | |
| ) | |
| return embeddings.astype(np.float32) | |
| def add_texts(self, texts: List[str], metadata: Optional[List[Dict[str, Any]]] = None, ids: Optional[List[str]] = None) -> List[int]: | |
| """ | |
| Add texts to the embedding index. | |
| Args: | |
| texts: List of texts to add | |
| metadata: Optional metadata for each text | |
| ids: Optional custom IDs for each text | |
| Returns: | |
| List of internal indices assigned to the texts | |
| """ | |
| if not texts: | |
| return [] | |
| # Generate embeddings | |
| embeddings = self.embed_texts(texts) | |
| # Initialize index if needed | |
| if self.index is None: | |
| self.index = self._create_index(self.embedding_dim) | |
| if self.index_type == "ivf": | |
| # Train the IVF index | |
| if len(embeddings) >= 100: # Need at least as many points as clusters | |
| self.index.train(embeddings) | |
| else: | |
| logger.warning("Not enough data to train IVF index, using flat index instead") | |
| self.index = faiss.IndexFlatL2(self.embedding_dim) | |
| # Add to index | |
| start_idx = self.index.ntotal | |
| self.index.add(embeddings) | |
| # Store metadata | |
| if metadata is None: | |
| metadata = [{"text": text} for text in texts] | |
| else: | |
| # Ensure metadata includes the original text | |
| for i, meta in enumerate(metadata): | |
| if "text" not in meta: | |
| meta["text"] = texts[i] | |
| self.metadata.extend(metadata) | |
| # Handle custom IDs | |
| indices = list(range(start_idx, start_idx + len(texts))) | |
| if ids: | |
| for i, custom_id in enumerate(ids): | |
| self.id_to_index[custom_id] = indices[i] | |
| logger.info(f"Added {len(texts)} texts to index. Total: {self.index.ntotal}") | |
| return indices | |
| def search_similar(self, query_text: str, k: int = 5, threshold: Optional[float] = None) -> List[Dict[str, Any]]: | |
| """ | |
| Search for similar texts in the index. | |
| Args: | |
| query_text: Text to search for | |
| k: Number of similar items to return | |
| threshold: Similarity threshold (overrides default) | |
| Returns: | |
| List of dictionaries containing similar items with scores and metadata | |
| """ | |
| if self.index is None or self.index.ntotal == 0: | |
| logger.warning("Index is empty or not initialized") | |
| return [] | |
| # Embed query | |
| query_embedding = self.embed_texts([query_text], show_progress=False) | |
| # Search | |
| if self.index_type == "ivf" and hasattr(self.index, "nprobe"): | |
| self.index.nprobe = min(10, self.index.nlist) # Search in 10 clusters | |
| scores, indices = self.index.search(query_embedding, k) | |
| # Convert L2 distances to cosine similarities | |
| # Since embeddings are normalized, L2 distance relates to cosine similarity | |
| similarities = 1 - (scores[0] / 2) # Convert L2 to cosine similarity | |
| # Filter by threshold | |
| threshold = threshold or self.similarity_threshold | |
| results = [] | |
| for idx, similarity in zip(indices[0], similarities): | |
| if idx != -1 and similarity >= threshold: # -1 indicates no match found | |
| result = { | |
| "index": int(idx), | |
| "similarity": float(similarity), | |
| "metadata": self.metadata[idx].copy() if idx < len(self.metadata) else {}, | |
| "text": self.metadata[idx].get("text", "") if idx < len(self.metadata) else "", | |
| } | |
| results.append(result) | |
| logger.debug(f"Found {len(results)} similar items for query (threshold={threshold})") | |
| return results | |
| def find_duplicates(self, texts: List[str], threshold: Optional[float] = None) -> List[List[int]]: | |
| """ | |
| Find groups of duplicate/similar texts. | |
| Args: | |
| texts: List of texts to check for duplicates | |
| threshold: Similarity threshold for considering items duplicates | |
| Returns: | |
| List of lists, where each inner list contains indices of similar texts | |
| """ | |
| threshold = threshold or self.similarity_threshold | |
| if not texts: | |
| return [] | |
| # Embed all texts | |
| embeddings = self.embed_texts(texts, show_progress=True) | |
| # Create temporary index for comparison | |
| temp_index = faiss.IndexFlatL2(self.embedding_dim) | |
| temp_index.add(embeddings) | |
| # Find similar items | |
| duplicate_groups = [] | |
| processed = set() | |
| for i, embedding in enumerate(embeddings): | |
| if i in processed: | |
| continue | |
| # Search for similar items | |
| query_embedding = embedding.reshape(1, -1) | |
| scores, indices = temp_index.search(query_embedding, len(texts)) | |
| # Convert to similarities and filter | |
| similarities = 1 - (scores[0] / 2) | |
| similar_indices = [] | |
| for idx, similarity in zip(indices[0], similarities): | |
| if similarity >= threshold and idx not in processed: | |
| similar_indices.append(idx) | |
| processed.add(idx) | |
| if len(similar_indices) > 1: | |
| duplicate_groups.append(similar_indices) | |
| logger.info(f"Found {len(duplicate_groups)} groups of duplicates") | |
| return duplicate_groups | |
| def save_index(self, filepath: str) -> None: | |
| """Save the FAISS index and metadata to disk.""" | |
| if self.index is None: | |
| logger.warning("No index to save") | |
| return | |
| filepath = Path(filepath) | |
| filepath.parent.mkdir(parents=True, exist_ok=True) | |
| # Save FAISS index | |
| faiss.write_index(self.index, str(filepath.with_suffix(".faiss"))) | |
| # Save metadata and configuration | |
| metadata_file = filepath.with_suffix(".pkl") | |
| with open(metadata_file, "wb") as f: | |
| pickle.dump( | |
| { | |
| "metadata": self.metadata, | |
| "id_to_index": self.id_to_index, | |
| "model_name": self.model_name, | |
| "index_type": self.index_type, | |
| "similarity_threshold": self.similarity_threshold, | |
| "embedding_dim": self.embedding_dim, | |
| }, | |
| f, | |
| ) | |
| logger.info(f"Saved index to {filepath}") | |
| def load_index(self, filepath: str) -> None: | |
| """Load a FAISS index and metadata from disk.""" | |
| filepath = Path(filepath) | |
| # Load FAISS index | |
| index_file = filepath.with_suffix(".faiss") | |
| if not index_file.exists(): | |
| raise FileNotFoundError(f"Index file not found: {index_file}") | |
| self.index = faiss.read_index(str(index_file)) | |
| # Load metadata and configuration | |
| metadata_file = filepath.with_suffix(".pkl") | |
| if metadata_file.exists(): | |
| with open(metadata_file, "rb") as f: | |
| data = pickle.load(f) | |
| self.metadata = data.get("metadata", []) | |
| self.id_to_index = data.get("id_to_index", {}) | |
| # Verify model compatibility | |
| saved_model = data.get("model_name", self.model_name) | |
| if saved_model != self.model_name: | |
| logger.warning(f"Model mismatch: saved={saved_model}, current={self.model_name}") | |
| logger.info(f"Loaded index from {filepath} ({self.index.ntotal} items)") | |
| def get_embedding_by_id(self, custom_id: str) -> Optional[np.ndarray]: | |
| """Get embedding vector by custom ID.""" | |
| if custom_id not in self.id_to_index: | |
| return None | |
| idx = self.id_to_index[custom_id] | |
| if self.index is None or idx >= self.index.ntotal: | |
| return None | |
| return self.index.reconstruct(idx) | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get statistics about the current index.""" | |
| return { | |
| "model_name": self.model_name, | |
| "index_type": self.index_type, | |
| "embedding_dimension": self.embedding_dim, | |
| "total_items": self.index.ntotal if self.index else 0, | |
| "similarity_threshold": self.similarity_threshold, | |
| "metadata_count": len(self.metadata), | |
| "custom_ids_count": len(self.id_to_index), | |
| } | |