import os import sys import asyncio from typing import List, Dict, Any import tiktoken from openai import AsyncOpenAI # Add the current directory to the path so we can import config sys.path.insert(0, os.path.dirname(__file__)) from config import OPENAI_API_KEY, OPENAI_BASE_URL, EMBEDDING_MODEL import logging logger = logging.getLogger(__name__) class Embedder: """ A class to handle document embedding using OpenAI's embedding API. """ def __init__(self): # Configure OpenAI client for OpenRouter with required headers self.client = AsyncOpenAI( api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL, default_headers={ "HTTP-Referer": os.getenv("APP_URL", "http://localhost:3000"), "X-Title": os.getenv("APP_NAME", "Physical AI Textbook") } ) # Use cl100k_base encoding which is used by text-embedding-ada-002 self.encoding = tiktoken.get_encoding("cl100k_base") def count_tokens(self, text: str) -> int: """Count the number of tokens in a text.""" return len(self.encoding.encode(text)) async def create_embedding(self, text: str) -> List[float]: """Create an embedding for a single text.""" try: # Truncate text if it's too long if self.count_tokens(text) > 8192: # OpenAI's limit for most models logger.warning(f"Text too long ({self.count_tokens(text)} tokens), truncating...") tokens = self.encoding.encode(text) tokens = tokens[:8000] # Leave some room for potential processing text = self.encoding.decode(tokens) response = await self.client.embeddings.create( input=text, model=EMBEDDING_MODEL ) return response.data[0].embedding except Exception as e: logger.error(f"Error creating embedding: {str(e)}") raise async def create_embeddings_batch(self, texts: List[str], batch_size: int = 100) -> List[List[float]]: """Create embeddings for a batch of texts.""" all_embeddings = [] for i in range(0, len(texts), batch_size): batch = texts[i:i + batch_size] try: # Truncate any texts that are too long processed_batch = [] for text in batch: if self.count_tokens(text) > 8192: logger.warning(f"Text in batch too long, truncating...") tokens = self.encoding.encode(text) tokens = tokens[:8000] # Leave some room for potential processing text = self.encoding.decode(tokens) processed_batch.append(text) response = await self.client.embeddings.create( input=processed_batch, model=EMBEDDING_MODEL ) batch_embeddings = [item.embedding for item in response.data] all_embeddings.extend(batch_embeddings) except Exception as e: logger.error(f"Error creating batch embeddings: {str(e)}") # If the whole batch failed, try each text individually for text in batch: try: embedding = await self.create_embedding(text) all_embeddings.append(embedding) except Exception as individual_error: logger.error(f"Failed to embed individual text: {str(individual_error)}") all_embeddings.append([]) # Placeholder for failed embedding return all_embeddings def chunk_text_by_tokens(self, text: str, max_tokens: int = 512) -> List[str]: """Split a long text into chunks of specified token length.""" tokens = self.encoding.encode(text) chunks = [] for i in range(0, len(tokens), max_tokens): chunk_tokens = tokens[i:i + max_tokens] chunk_text = self.encoding.decode(chunk_tokens) chunks.append(chunk_text) return chunks async def embed_documents(self, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Embed a list of documents with their content and metadata.""" if not documents: return [] # Extract just the content for embedding texts = [doc['content'] for doc in documents] # Create embeddings embeddings = await self.create_embeddings_batch(texts) # Combine documents with embeddings embedded_docs = [] for i, doc in enumerate(documents): embedded_doc = doc.copy() embedded_doc['embedding'] = embeddings[i] embedded_docs.append(embedded_doc) return embedded_docs