File size: 3,282 Bytes
8abf329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import time
import numpy as np
from tqdm import tqdm
from openai import OpenAI
from typing import List, Dict, Any, Optional
import os

# Get API key from environment variable
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
EMBEDDING_MODEL = "text-embedding-3-small"
EMBEDDING_BATCH_SIZE = 10

class TextEmbedder:
    """Class for generating embeddings for document chunks using OpenAI's embeddings API."""
    
    def __init__(self, model: str = EMBEDDING_MODEL, batch_size: int = EMBEDDING_BATCH_SIZE):
        """
        Initialize the TextEmbedder with the specified embedding model and batch size.
        
        Args:
            model: The OpenAI embedding model to use
            batch_size: Number of chunks to embed per API call
        """
        self.model = model
        self.batch_size = batch_size
        self.client = OpenAI(api_key=OPENAI_API_KEY)
        self.embedding_dim = 1536  # Default dimension for text-embedding-3-small
    
    def get_embedding_for_text(self, text: str) -> List[float]:
        """Generate embedding for a single text."""
        try:
            response = self.client.embeddings.create(
                input=[text],
                model=self.model
            )
            return response.data[0].embedding
        except Exception as e:
            print(f"Error generating embedding: {e}")
            return [0.0] * self.embedding_dim
    
    def get_embeddings_for_texts(self, texts: List[str]) -> List[List[float]]:
        """
        Compute embeddings for a list of texts using batched API calls.
        
        Args:
            texts: List of text chunks to embed
            
        Returns:
            List of embedding vectors
        """
        embeddings = []
        for i in tqdm(range(0, len(texts), self.batch_size), desc="Embedding chunks"):
            batch = texts[i:i + self.batch_size]
            try:
                response = self.client.embeddings.create(
                    input=batch,
                    model=self.model
                )
                # Extract embeddings from the response
                for item in response.data:
                    embeddings.append(item.embedding)
            except Exception as e:
                print(f"Error embedding batch starting at index {i}: {e}")
                # Append placeholder zero vectors for failed texts
                for _ in batch:
                    embeddings.append([0.0] * self.embedding_dim)
            # Brief pause to avoid rate limits
            time.sleep(0.2)
        
        return embeddings
    
    def get_query_embedding(self, query: str) -> np.ndarray:
        """
        Generate embedding for a query string and return as numpy array.
        
        Args:
            query: The query text to embed
            
        Returns:
            Numpy array of the embedding
        """
        try:
            q_response = self.client.embeddings.create(
                input=[query], 
                model=self.model
            )
            return np.array(q_response.data[0].embedding, dtype='float32').reshape(1, -1)
        except Exception as e:
            print(f"Error creating embedding for query: {e}")
            return np.zeros((1, self.embedding_dim), dtype='float32')