Spaces:
Sleeping
Sleeping
File size: 3,282 Bytes
8abf329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | import time
import numpy as np
from tqdm import tqdm
from openai import OpenAI
from typing import List, Dict, Any, Optional
import os
# Get API key from environment variable
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_BATCH_SIZE = 10
class TextEmbedder:
"""Class for generating embeddings for document chunks using OpenAI's embeddings API."""
def __init__(self, model: str = EMBEDDING_MODEL, batch_size: int = EMBEDDING_BATCH_SIZE):
"""
Initialize the TextEmbedder with the specified embedding model and batch size.
Args:
model: The OpenAI embedding model to use
batch_size: Number of chunks to embed per API call
"""
self.model = model
self.batch_size = batch_size
self.client = OpenAI(api_key=OPENAI_API_KEY)
self.embedding_dim = 1536 # Default dimension for text-embedding-3-small
def get_embedding_for_text(self, text: str) -> List[float]:
"""Generate embedding for a single text."""
try:
response = self.client.embeddings.create(
input=[text],
model=self.model
)
return response.data[0].embedding
except Exception as e:
print(f"Error generating embedding: {e}")
return [0.0] * self.embedding_dim
def get_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]:
"""
Compute embeddings for a list of texts using batched API calls.
Args:
texts: List of text chunks to embed
Returns:
List of embedding vectors
"""
embeddings = []
for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding chunks"):
batch = texts[i:i + self.batch_size]
try:
response = self.client.embeddings.create(
input=batch,
model=self.model
)
# Extract embeddings from the response
for item in response.data:
embeddings.append(item.embedding)
except Exception as e:
print(f"Error embedding batch starting at index {i}: {e}")
# Append placeholder zero vectors for failed texts
for _ in batch:
embeddings.append([0.0] * self.embedding_dim)
# Brief pause to avoid rate limits
time.sleep(0.2)
return embeddings
def get_query_embedding(self, query: str) -> np.ndarray:
"""
Generate embedding for a query string and return as numpy array.
Args:
query: The query text to embed
Returns:
Numpy array of the embedding
"""
try:
q_response = self.client.embeddings.create(
input=[query],
model=self.model
)
return np.array(q_response.data[0].embedding, dtype='float32').reshape(1, -1)
except Exception as e:
print(f"Error creating embedding for query: {e}")
return np.zeros((1, self.embedding_dim), dtype='float32') |