Docker_Deploy / src /python /embedder.py
Shaheryar Shah
Add backend files for RAG Chatbot Docker deployment
bec06d9
import os
import sys
import asyncio
from typing import List, Dict, Any
import tiktoken
from openai import AsyncOpenAI
# Add the current directory to the path so we can import config
sys.path.insert(0, os.path.dirname(__file__))
from config import OPENAI_API_KEY, OPENAI_BASE_URL, EMBEDDING_MODEL
import logging
logger = logging.getLogger(__name__)
class Embedder:
"""
A class to handle document embedding using OpenAI's embedding API.
"""
def __init__(self):
# Configure OpenAI client for OpenRouter with required headers
self.client = AsyncOpenAI(
api_key=OPENAI_API_KEY,
base_url=OPENAI_BASE_URL,
default_headers={
"HTTP-Referer": os.getenv("APP_URL", "http://localhost:3000"),
"X-Title": os.getenv("APP_NAME", "Physical AI Textbook")
}
)
# Use cl100k_base encoding which is used by text-embedding-ada-002
self.encoding = tiktoken.get_encoding("cl100k_base")
def count_tokens(self, text: str) -> int:
"""Count the number of tokens in a text."""
return len(self.encoding.encode(text))
async def create_embedding(self, text: str) -> List[float]:
"""Create an embedding for a single text."""
try:
# Truncate text if it's too long
if self.count_tokens(text) > 8192: # OpenAI's limit for most models
logger.warning(f"Text too long ({self.count_tokens(text)} tokens), truncating...")
tokens = self.encoding.encode(text)
tokens = tokens[:8000] # Leave some room for potential processing
text = self.encoding.decode(tokens)
response = await self.client.embeddings.create(
input=text,
model=EMBEDDING_MODEL
)
return response.data[0].embedding
except Exception as e:
logger.error(f"Error creating embedding: {str(e)}")
raise
async def create_embeddings_batch(self, texts: List[str], batch_size: int = 100) -> List[List[float]]:
"""Create embeddings for a batch of texts."""
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
try:
# Truncate any texts that are too long
processed_batch = []
for text in batch:
if self.count_tokens(text) > 8192:
logger.warning(f"Text in batch too long, truncating...")
tokens = self.encoding.encode(text)
tokens = tokens[:8000] # Leave some room for potential processing
text = self.encoding.decode(tokens)
processed_batch.append(text)
response = await self.client.embeddings.create(
input=processed_batch,
model=EMBEDDING_MODEL
)
batch_embeddings = [item.embedding for item in response.data]
all_embeddings.extend(batch_embeddings)
except Exception as e:
logger.error(f"Error creating batch embeddings: {str(e)}")
# If the whole batch failed, try each text individually
for text in batch:
try:
embedding = await self.create_embedding(text)
all_embeddings.append(embedding)
except Exception as individual_error:
logger.error(f"Failed to embed individual text: {str(individual_error)}")
all_embeddings.append([]) # Placeholder for failed embedding
return all_embeddings
def chunk_text_by_tokens(self, text: str, max_tokens: int = 512) -> List[str]:
"""Split a long text into chunks of specified token length."""
tokens = self.encoding.encode(text)
chunks = []
for i in range(0, len(tokens), max_tokens):
chunk_tokens = tokens[i:i + max_tokens]
chunk_text = self.encoding.decode(chunk_tokens)
chunks.append(chunk_text)
return chunks
async def embed_documents(self, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Embed a list of documents with their content and metadata."""
if not documents:
return []
# Extract just the content for embedding
texts = [doc['content'] for doc in documents]
# Create embeddings
embeddings = await self.create_embeddings_batch(texts)
# Combine documents with embeddings
embedded_docs = []
for i, doc in enumerate(documents):
embedded_doc = doc.copy()
embedded_doc['embedding'] = embeddings[i]
embedded_docs.append(embedded_doc)
return embedded_docs