File size: 3,847 Bytes
8bf4d58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""Embedding generation using OpenAI."""

import logging
from typing import List, Optional
from functools import lru_cache
from openai import OpenAI
from src.core.config import get_settings

logger = logging.getLogger(__name__)


class EmbeddingGenerator:
    """Generate embeddings using OpenAI."""

    def __init__(self, client: Optional[OpenAI] = None):
        """Initialize the embedding generator."""
        self.settings = get_settings()
        self.client = client or OpenAI(**self.settings.get_openai_client_kwargs())
        self.model = self.settings.openai_embedding_model
        self._cache: dict = {}

    def generate_embedding(self, text: str, use_cache: bool = True) -> List[float]:
        """
        Generate embedding for a single text.

        Args:
            text: Input text to embed
            use_cache: Whether to use caching

        Returns:
            Embedding vector as a list of floats
        """
        if use_cache and text in self._cache:
            return self._cache[text]

        try:
            response = self.client.embeddings.create(
                model=self.model,
                input=text,
            )
            embedding = response.data[0].embedding
            if use_cache:
                self._cache[text] = embedding
            return embedding
        except Exception as e:
            logger.error(f"Error generating embedding: {e}")
            raise

    def generate_embeddings_batch(
        self, texts: List[str], use_cache: bool = True
    ) -> List[List[float]]:
        """
        Generate embeddings for multiple texts in batch.

        Args:
            texts: List of input texts to embed
            use_cache: Whether to use caching

        Returns:
            List of embedding vectors
        """
        # Check cache first
        cached_embeddings = {}
        texts_to_embed = []
        indices = []

        for i, text in enumerate(texts):
            if use_cache and text in self._cache:
                cached_embeddings[i] = self._cache[text]
            else:
                texts_to_embed.append(text)
                indices.append(i)

        if not texts_to_embed:
            # All embeddings were cached
            return [cached_embeddings[i] for i in range(len(texts))]

        # Generate embeddings for uncached texts
        embeddings = []
        try:
            # OpenAI supports batch processing
            response = self.client.embeddings.create(
                model=self.model,
                input=texts_to_embed,
            )
            new_embeddings = {indices[i]: item.embedding for i, item in enumerate(response.data)}
            
            # Update cache
            if use_cache:
                for idx, text in zip(indices, texts_to_embed):
                    self._cache[text] = new_embeddings[idx]

            # Combine cached and new embeddings
            for i in range(len(texts)):
                if i in cached_embeddings:
                    embeddings.append(cached_embeddings[i])
                else:
                    embeddings.append(new_embeddings[i])

            return embeddings
        except Exception as e:
            logger.error(f"Error generating batch embeddings: {e}")
            raise

    def clear_cache(self) -> None:
        """Clear the embedding cache."""
        self._cache.clear()

    def get_cache_size(self) -> int:
        """Get the number of cached embeddings."""
        return len(self._cache)


# Global instance
_embedding_generator: Optional[EmbeddingGenerator] = None


def get_embedding_generator() -> EmbeddingGenerator:
    """Get or create the global embedding generator instance."""
    global _embedding_generator
    if _embedding_generator is None:
        _embedding_generator = EmbeddingGenerator()
    return _embedding_generator