File size: 3,241 Bytes
26d1a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from typing import List, Dict, Any
import torch
import gc

class FAQEmbedder:
    def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
        """
        Initialize the FAQ embedder with a sentence transformer model
        Optimized for memory efficiency
        """
        print(f"Initializing FAQ Embedder with model: {model_name}")
        # Use CPU for embedding model to save GPU memory for LLM
        self.device = "cpu"
        self.model = SentenceTransformer(model_name, device=self.device)
        self.index = None
        self.faqs = None
        self.embeddings = None
    
    def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = 32) -> None:
        """
        Create embeddings for all FAQs and build FAISS index
        Using batching for memory efficiency
        """
        self.faqs = faqs
        print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...")
        
        # Extract questions for embedding
        questions = [faq['question'] for faq in faqs]
        
        # Process in batches to reduce memory usage
        all_embeddings = []
        for i in range(0, len(questions), batch_size):
            batch = questions[i:i+batch_size]
            print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}")
            
            # Create embeddings for this batch
            batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
            all_embeddings.append(batch_embeddings)
            
        # Combine all batches
        self.embeddings = np.vstack(all_embeddings).astype('float32')
        
        # Clear memory explicitly
        all_embeddings = None
        gc.collect()
        
        # Create FAISS index
        dimension = self.embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(self.embeddings)
        
        print(f"Created embeddings of shape {self.embeddings.shape}")
        print(f"FAISS index contains {self.index.ntotal} vectors")
    
    def retrieve_relevant_faqs(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
        """
        Retrieve top-k relevant FAQs for a given query
        """
        if self.index is None or self.faqs is None or self.embeddings is None:
            raise ValueError("Embeddings not created yet. Call create_embeddings first.")
        
        # Embed the query
        query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32')
        
        # Search in FAISS
        distances, indices = self.index.search(query_embedding, k)
        
        # Get the relevant FAQs with their similarity scores
        relevant_faqs = []
        for i, idx in enumerate(indices[0]):
            if idx < len(self.faqs):  # Ensure index is valid
                faq = self.faqs[idx].copy()
                # Convert L2 distance to similarity score (higher is better)
                similarity = 1.0 / (1.0 + distances[0][i])
                faq['similarity'] = similarity
                relevant_faqs.append(faq)
        
        return relevant_faqs