Spaces:
Build error
Build error
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from typing import List, Dict, Any | |
| import torch | |
| import gc | |
| class FAQEmbedder: | |
| def __init__(self, model_name: str = "all-MiniLM-L6-v2"): | |
| """ | |
| Initialize the FAQ embedder with a sentence transformer model | |
| Optimized for memory efficiency | |
| """ | |
| print(f"Initializing FAQ Embedder with model: {model_name}") | |
| # Use CPU for embedding model to save GPU memory for LLM | |
| self.device = "cpu" | |
| self.model = SentenceTransformer(model_name, device=self.device) | |
| self.index = None | |
| self.faqs = None | |
| self.embeddings = None | |
| def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = 32) -> None: | |
| """ | |
| Create embeddings for all FAQs and build FAISS index | |
| Using batching for memory efficiency | |
| """ | |
| self.faqs = faqs | |
| print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...") | |
| # Extract questions for embedding | |
| questions = [faq['question'] for faq in faqs] | |
| # Process in batches to reduce memory usage | |
| all_embeddings = [] | |
| for i in range(0, len(questions), batch_size): | |
| batch = questions[i:i+batch_size] | |
| print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}") | |
| # Create embeddings for this batch | |
| batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True) | |
| all_embeddings.append(batch_embeddings) | |
| # Combine all batches | |
| self.embeddings = np.vstack(all_embeddings).astype('float32') | |
| # Clear memory explicitly | |
| all_embeddings = None | |
| gc.collect() | |
| # Create FAISS index | |
| dimension = self.embeddings.shape[1] | |
| self.index = faiss.IndexFlatL2(dimension) | |
| self.index.add(self.embeddings) | |
| print(f"Created embeddings of shape {self.embeddings.shape}") | |
| print(f"FAISS index contains {self.index.ntotal} vectors") | |
| def retrieve_relevant_faqs(self, query: str, k: int = 3) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve top-k relevant FAQs for a given query | |
| """ | |
| if self.index is None or self.faqs is None or self.embeddings is None: | |
| raise ValueError("Embeddings not created yet. Call create_embeddings first.") | |
| # Embed the query | |
| query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32') | |
| # Search in FAISS | |
| distances, indices = self.index.search(query_embedding, k) | |
| # Get the relevant FAQs with their similarity scores | |
| relevant_faqs = [] | |
| for i, idx in enumerate(indices[0]): | |
| if idx < len(self.faqs): # Ensure index is valid | |
| faq = self.faqs[idx].copy() | |
| # Convert L2 distance to similarity score (higher is better) | |
| similarity = 1.0 / (1.0 + distances[0][i]) | |
| faq['similarity'] = similarity | |
| relevant_faqs.append(faq) | |
| return relevant_faqs |