File size: 4,906 Bytes
bec06d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import sys
import asyncio
from typing import List, Dict, Any
import tiktoken
from openai import AsyncOpenAI

# Add the current directory to the path so we can import config
sys.path.insert(0, os.path.dirname(__file__))
from config import OPENAI_API_KEY, OPENAI_BASE_URL, EMBEDDING_MODEL

import logging

logger = logging.getLogger(__name__)

class Embedder:
    """
    A class to handle document embedding using OpenAI's embedding API.
    """

    def __init__(self):
        # Configure OpenAI client for OpenRouter with required headers
        self.client = AsyncOpenAI(
            api_key=OPENAI_API_KEY,
            base_url=OPENAI_BASE_URL,
            default_headers={
                "HTTP-Referer": os.getenv("APP_URL", "http://localhost:3000"),
                "X-Title": os.getenv("APP_NAME", "Physical AI Textbook")
            }
        )

        # Use cl100k_base encoding which is used by text-embedding-ada-002
        self.encoding = tiktoken.get_encoding("cl100k_base")

    def count_tokens(self, text: str) -> int:
        """Count the number of tokens in a text."""
        return len(self.encoding.encode(text))

    async def create_embedding(self, text: str) -> List[float]:
        """Create an embedding for a single text."""
        try:
            # Truncate text if it's too long
            if self.count_tokens(text) > 8192:  # OpenAI's limit for most models
                logger.warning(f"Text too long ({self.count_tokens(text)} tokens), truncating...")
                tokens = self.encoding.encode(text)
                tokens = tokens[:8000]  # Leave some room for potential processing
                text = self.encoding.decode(tokens)

            response = await self.client.embeddings.create(
                input=text,
                model=EMBEDDING_MODEL
            )
            return response.data[0].embedding
        except Exception as e:
            logger.error(f"Error creating embedding: {str(e)}")
            raise

    async def create_embeddings_batch(self, texts: List[str], batch_size: int = 100) -> List[List[float]]:
        """Create embeddings for a batch of texts."""
        all_embeddings = []

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]

            try:
                # Truncate any texts that are too long
                processed_batch = []
                for text in batch:
                    if self.count_tokens(text) > 8192:
                        logger.warning(f"Text in batch too long, truncating...")
                        tokens = self.encoding.encode(text)
                        tokens = tokens[:8000]  # Leave some room for potential processing
                        text = self.encoding.decode(tokens)
                    processed_batch.append(text)

                response = await self.client.embeddings.create(
                    input=processed_batch,
                    model=EMBEDDING_MODEL
                )

                batch_embeddings = [item.embedding for item in response.data]
                all_embeddings.extend(batch_embeddings)

            except Exception as e:
                logger.error(f"Error creating batch embeddings: {str(e)}")
                # If the whole batch failed, try each text individually
                for text in batch:
                    try:
                        embedding = await self.create_embedding(text)
                        all_embeddings.append(embedding)
                    except Exception as individual_error:
                        logger.error(f"Failed to embed individual text: {str(individual_error)}")
                        all_embeddings.append([])  # Placeholder for failed embedding

        return all_embeddings

    def chunk_text_by_tokens(self, text: str, max_tokens: int = 512) -> List[str]:
        """Split a long text into chunks of specified token length."""
        tokens = self.encoding.encode(text)
        chunks = []

        for i in range(0, len(tokens), max_tokens):
            chunk_tokens = tokens[i:i + max_tokens]
            chunk_text = self.encoding.decode(chunk_tokens)
            chunks.append(chunk_text)

        return chunks

    async def embed_documents(self, documents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """Embed a list of documents with their content and metadata."""
        if not documents:
            return []

        # Extract just the content for embedding
        texts = [doc['content'] for doc in documents]

        # Create embeddings
        embeddings = await self.create_embeddings_batch(texts)

        # Combine documents with embeddings
        embedded_docs = []
        for i, doc in enumerate(documents):
            embedded_doc = doc.copy()
            embedded_doc['embedding'] = embeddings[i]
            embedded_docs.append(embedded_doc)

        return embedded_docs