vichudo's picture
fix
254ca68
Raw
History Blame Contribute Delete
5.29 kB
import time
import numpy as np
from tqdm import tqdm
from openai import OpenAI
from typing import List, Dict, Any, Optional
from src.utils.config import EMBEDDING_MODEL, EMBEDDING_BATCH_SIZE, OPENAI_API_KEY
class TextEmbedder:
"""Class for generating embeddings for document chunks using OpenAI's embeddings API."""
def __init__(self, model: str = EMBEDDING_MODEL, batch_size: int = EMBEDDING_BATCH_SIZE):
"""
Initialize the TextEmbedder with the specified embedding model and batch size.
Args:
model: The OpenAI embedding model to use
batch_size: Number of chunks to embed per API call
"""
self.model = model
self.batch_size = batch_size
self.client = OpenAI(api_key=OPENAI_API_KEY)
# Default dimension for different models
self.embedding_dim = self._get_model_dimension(model)
print(f"Initialized TextEmbedder with model {model}, dimension {self.embedding_dim}")
def _get_model_dimension(self, model_name: str) -> int:
"""Get the embedding dimension for a given model."""
# Mapping of model names to dimensions
dimensions = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
# Add other models if needed
}
# Return the dimension for the model or default to 1536 (most common)
return dimensions.get(model_name, 1536)
def set_dimension(self, dimension: int) -> None:
"""
Set the embedding dimension explicitly.
Use this to ensure compatibility with existing FAISS indices.
"""
self.embedding_dim = dimension
print(f"Explicitly set embedding dimension to {dimension}")
def get_embedding_for_text(self, text: str) -> List[float]:
"""Generate embedding for a single text."""
try:
response = self.client.embeddings.create(
input=[text],
model=self.model
)
return response.data[0].embedding
except Exception as e:
print(f"Error generating embedding: {e}")
return [0.0] * self.embedding_dim
def get_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]:
"""
Compute embeddings for a list of texts using batched API calls.
Args:
texts: List of text chunks to embed
Returns:
List of embedding vectors
"""
embeddings = []
for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding chunks"):
batch = texts[i:i + self.batch_size]
try:
response = self.client.embeddings.create(
input=batch,
model=self.model
)
# Extract embeddings from the response
for item in response.data:
embeddings.append(item.embedding)
except Exception as e:
print(f"Error embedding batch starting at index {i}: {e}")
# Append placeholder zero vectors for failed texts
for _ in batch:
embeddings.append([0.0] * self.embedding_dim)
# Brief pause to avoid rate limits
time.sleep(0.2)
return embeddings
def get_query_embedding(self, query: str) -> np.ndarray:
"""
Generate embedding for a query string and return as numpy array.
Args:
query: The query text to embed
Returns:
Numpy array of the embedding
"""
try:
q_response = self.client.embeddings.create(
input=[query],
model=self.model
)
embedding = np.array(q_response.data[0].embedding, dtype='float32')
# Check and log the actual dimension
actual_dim = embedding.shape[0]
if actual_dim != self.embedding_dim:
print(f"Warning: OpenAI returned embedding of dimension {actual_dim}, expected {self.embedding_dim}")
# Handle dimension mismatch
if actual_dim > self.embedding_dim:
# Truncate the embedding to match expected dimension
print(f"Truncating embedding from {actual_dim} to {self.embedding_dim}")
embedding = embedding[:self.embedding_dim]
elif actual_dim < self.embedding_dim:
# Pad the embedding to match expected dimension
print(f"Padding embedding from {actual_dim} to {self.embedding_dim}")
padding = np.zeros(self.embedding_dim - actual_dim, dtype='float32')
embedding = np.concatenate([embedding, padding])
# Return the embedding as a 2D array
return embedding.reshape(1, -1)
except Exception as e:
print(f"Error creating embedding for query: {e}")
import traceback
traceback.print_exc()
return np.zeros((1, self.embedding_dim), dtype='float32')