File size: 5,290 Bytes
6a4bd6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254ca68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a4bd6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254ca68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a4bd6f
 
254ca68
 
6a4bd6f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import time
import numpy as np
from tqdm import tqdm
from openai import OpenAI
from typing import List, Dict, Any, Optional

from src.utils.config import EMBEDDING_MODEL, EMBEDDING_BATCH_SIZE, OPENAI_API_KEY

class TextEmbedder:
    """Class for generating embeddings for document chunks using OpenAI's embeddings API."""
    
    def __init__(self, model: str = EMBEDDING_MODEL, batch_size: int = EMBEDDING_BATCH_SIZE):
        """
        Initialize the TextEmbedder with the specified embedding model and batch size.
        
        Args:
            model: The OpenAI embedding model to use
            batch_size: Number of chunks to embed per API call
        """
        self.model = model
        self.batch_size = batch_size
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        
        # Default dimension for different models
        self.embedding_dim = self._get_model_dimension(model)
        print(f"Initialized TextEmbedder with model {model}, dimension {self.embedding_dim}")
    
    def _get_model_dimension(self, model_name: str) -> int:
        """Get the embedding dimension for a given model."""
        # Mapping of model names to dimensions
        dimensions = {
            "text-embedding-3-small": 1536,
            "text-embedding-3-large": 3072,
            "text-embedding-ada-002": 1536,
            # Add other models if needed
        }
        
        # Return the dimension for the model or default to 1536 (most common)
        return dimensions.get(model_name, 1536)
    
    def set_dimension(self, dimension: int) -> None:
        """
        Set the embedding dimension explicitly. 
        Use this to ensure compatibility with existing FAISS indices.
        """
        self.embedding_dim = dimension
        print(f"Explicitly set embedding dimension to {dimension}")
    
    def get_embedding_for_text(self, text: str) -> List[float]:
        """Generate embedding for a single text."""
        try:
            response = self.client.embeddings.create(
                input=[text],
                model=self.model
            )
            return response.data[0].embedding
        except Exception as e:
            print(f"Error generating embedding: {e}")
            return [0.0] * self.embedding_dim
    
    def get_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]:
        """
        Compute embeddings for a list of texts using batched API calls.
        
        Args:
            texts: List of text chunks to embed
            
        Returns:
            List of embedding vectors
        """
        embeddings = []
        for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding chunks"):
            batch = texts[i:i + self.batch_size]
            try:
                response = self.client.embeddings.create(
                    input=batch,
                    model=self.model
                )
                # Extract embeddings from the response
                for item in response.data:
                    embeddings.append(item.embedding)
            except Exception as e:
                print(f"Error embedding batch starting at index {i}: {e}")
                # Append placeholder zero vectors for failed texts
                for _ in batch:
                    embeddings.append([0.0] * self.embedding_dim)
            # Brief pause to avoid rate limits
            time.sleep(0.2)
        
        return embeddings
    
    def get_query_embedding(self, query: str) -> np.ndarray:
        """
        Generate embedding for a query string and return as numpy array.
        
        Args:
            query: The query text to embed
            
        Returns:
            Numpy array of the embedding
        """
        try:
            q_response = self.client.embeddings.create(
                input=[query], 
                model=self.model
            )
            embedding = np.array(q_response.data[0].embedding, dtype='float32')
            
            # Check and log the actual dimension
            actual_dim = embedding.shape[0]
            if actual_dim != self.embedding_dim:
                print(f"Warning: OpenAI returned embedding of dimension {actual_dim}, expected {self.embedding_dim}")
                
                # Handle dimension mismatch
                if actual_dim > self.embedding_dim:
                    # Truncate the embedding to match expected dimension
                    print(f"Truncating embedding from {actual_dim} to {self.embedding_dim}")
                    embedding = embedding[:self.embedding_dim]
                elif actual_dim < self.embedding_dim:
                    # Pad the embedding to match expected dimension
                    print(f"Padding embedding from {actual_dim} to {self.embedding_dim}")
                    padding = np.zeros(self.embedding_dim - actual_dim, dtype='float32')
                    embedding = np.concatenate([embedding, padding])
            
            # Return the embedding as a 2D array
            return embedding.reshape(1, -1)
        except Exception as e:
            print(f"Error creating embedding for query: {e}")
            import traceback
            traceback.print_exc()
            return np.zeros((1, self.embedding_dim), dtype='float32')