agentic-defensor / embedder.py
vichudo's picture
Fix module import issues and data loading bugs
8abf329
Raw
History Blame Contribute Delete
3.28 kB
import time
import numpy as np
from tqdm import tqdm
from openai import OpenAI
from typing import List, Dict, Any, Optional
import os
# Get API key from environment variable
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_BATCH_SIZE = 10
class TextEmbedder:
"""Class for generating embeddings for document chunks using OpenAI's embeddings API."""
def __init__(self, model: str = EMBEDDING_MODEL, batch_size: int = EMBEDDING_BATCH_SIZE):
"""
Initialize the TextEmbedder with the specified embedding model and batch size.
Args:
model: The OpenAI embedding model to use
batch_size: Number of chunks to embed per API call
"""
self.model = model
self.batch_size = batch_size
self.client = OpenAI(api_key=OPENAI_API_KEY)
self.embedding_dim = 1536 # Default dimension for text-embedding-3-small
def get_embedding_for_text(self, text: str) -> List[float]:
"""Generate embedding for a single text."""
try:
response = self.client.embeddings.create(
input=[text],
model=self.model
)
return response.data[0].embedding
except Exception as e:
print(f"Error generating embedding: {e}")
return [0.0] * self.embedding_dim
def get_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]:
"""
Compute embeddings for a list of texts using batched API calls.
Args:
texts: List of text chunks to embed
Returns:
List of embedding vectors
"""
embeddings = []
for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding chunks"):
batch = texts[i:i + self.batch_size]
try:
response = self.client.embeddings.create(
input=batch,
model=self.model
)
# Extract embeddings from the response
for item in response.data:
embeddings.append(item.embedding)
except Exception as e:
print(f"Error embedding batch starting at index {i}: {e}")
# Append placeholder zero vectors for failed texts
for _ in batch:
embeddings.append([0.0] * self.embedding_dim)
# Brief pause to avoid rate limits
time.sleep(0.2)
return embeddings
def get_query_embedding(self, query: str) -> np.ndarray:
"""
Generate embedding for a query string and return as numpy array.
Args:
query: The query text to embed
Returns:
Numpy array of the embedding
"""
try:
q_response = self.client.embeddings.create(
input=[query],
model=self.model
)
return np.array(q_response.data[0].embedding, dtype='float32').reshape(1, -1)
except Exception as e:
print(f"Error creating embedding for query: {e}")
return np.zeros((1, self.embedding_dim), dtype='float32')