File size: 11,296 Bytes
1dab660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5411262
1dab660
 
 
 
5411262
1dab660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787b7ff
 
 
 
 
 
 
 
 
1dab660
 
 
5fb63e2
 
 
 
 
 
 
 
 
 
 
1dab660
 
 
 
 
 
 
 
 
 
 
 
 
5411262
1dab660
 
 
 
 
5411262
1dab660
 
 
 
5411262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dab660
5411262
1dab660
 
 
 
 
 
 
 
 
 
 
 
 
5411262
1dab660
 
 
 
 
 
 
 
 
 
 
 
 
 
5411262
1dab660
 
 
 
 
 
 
 
5411262
1dab660
5411262
 
1dab660
 
 
 
 
 
5411262
1dab660
 
 
 
5411262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dab660
5411262
1dab660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5411262
1dab660
 
 
5411262
1dab660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5411262
1dab660
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5411262
1dab660
 
 
 
 
 
 
 
 
 
 
 
 
 
5411262
1dab660
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
"""
OpenAI embeddings client with caching and retry logic.

This module provides an optimized client for generating text embeddings
using OpenAI's API with built-in caching and automatic retry on failures.
"""

import os
import hashlib
from typing import List
from functools import wraps
from openai import OpenAI, RateLimitError
from config import get_settings
from loguru import logger

try:
    from diskcache import Cache
    CACHE_AVAILABLE = True
except ImportError:
    CACHE_AVAILABLE = False
    logger.warning("diskcache not available - embeddings caching disabled")


class EmbeddingsClient:
    """
    Client for generating embeddings with caching and retry logic.
    
    Features:
    - Persistent disk cache for embeddings (using diskcache)
    - Automatic retry with exponential backoff
    - Batch processing support
    - Text truncation for long inputs
    
    Attributes:
        client: OpenAI client instance
        model: Embedding model name
        settings: Application settings
        cache: Disk cache for embeddings (if available)
        
    Examples:
        >>> client = EmbeddingsClient()
        >>> embedding = client.get_embedding("Hello world")
        >>> len(embedding)
        3072
        
        >>> # Second call uses cache
        >>> embedding2 = client.get_embedding("Hello world")
        >>> embedding == embedding2
        True
    """
    
    def __init__(self, cache_dir: str = "./cache/embeddings"):
        """
        Initialize OpenAI client with configured endpoint and cache.
        
        Args:
            cache_dir: Directory for persistent embedding cache
        """
        self.settings = get_settings()
        
        logger.info(f"Initializing EmbeddingsClient with {self.settings.llm_base_url}")
        self.client = OpenAI(
            api_key=self.settings.openai_api_key,
            base_url=self.settings.llm_base_url
        )
        
        self.model = self.settings.embedding_model
        
        # Initialize cache if available
        if CACHE_AVAILABLE:
            try:
                import os
                os.makedirs(cache_dir, exist_ok=True)
                self.cache = Cache(cache_dir)
                logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: enabled)")
            except (OSError, PermissionError) as e:
                logger.warning(f"Could not create cache directory (read-only?): {e}")
                self.cache = None
                logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)")
        else:
            self.cache = None
            logger.info(f"✅ EmbeddingsClient initialized (model: {self.model}, cache: disabled)")
    
    def close(self):
        """Close cache and clean up resources."""
        try:
            if self.cache is not None:
                self.cache.close()
                logger.info("EmbeddingsClient cache closed")
            if hasattr(self.client, 'close'):
                self.client.close()
        except Exception as e:
            logger.warning(f"Error closing EmbeddingsClient: {e}")
    
    def _get_cache_key(self, text: str) -> str:
        """
        Generate cache key for text.
        
        Args:
            text: Input text
            
        Returns:
            MD5 hash of text as cache key
        """
        return hashlib.md5(f"{self.model}:{text}".encode()).hexdigest()
    
    def _get_embedding_uncached(self, text: str, retry_count: int = 5) -> List[float]:
        """
        Generate embedding without cache (internal method).
        
        Args:
            text: Input text (already truncated)
            retry_count: Number of retries on rate limit
            
        Returns:
            Embedding vector
        """
        for attempt in range(retry_count):
            try:
                response = self.client.embeddings.create(
                    model=self.model,
                    input=text
                )
                embedding = response.data[0].embedding
                logger.debug(f"Generated embedding (dim={len(embedding)})")
                return embedding
            except RateLimitError as e:
                if attempt == retry_count - 1:
                    raise
                wait_time = (2 ** attempt) * 2
                logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})")
                import time
                time.sleep(wait_time)
        raise RuntimeError(f"Failed after {retry_count} retries")
    
    def get_embedding(self, text: str) -> List[float]:
        """
        Generate embedding for text with caching.
        
        Uses disk cache to avoid regenerating embeddings for the same text.
        Automatically truncates long texts to 8191 characters.
        
        Args:
            text: Input text to embed
            
        Returns:
            List of float values representing the embedding
            
        Examples:
            >>> embedding = client.get_embedding("Hello world")
            >>> len(embedding)
            3072
        """
        # Truncate text if too long (max token limit for embeddings)
        text = text[:8191]
        
        # Check cache
        if self.cache is not None:
            cache_key = self._get_cache_key(text)
            if cache_key in self.cache:
                logger.debug("Cache hit for embedding")
                return self.cache[cache_key]
        
        # Generate embedding
        embedding = self._get_embedding_uncached(text)
        
        # Store in cache
        if self.cache is not None:
            cache_key = self._get_cache_key(text)
            self.cache[cache_key] = embedding
        
        return embedding
    
    def _get_embeddings_batch_uncached(
        self, 
        texts: List[str],
        retry_count: int = 3
    ) -> List[List[float]]:
        """
        Generate embeddings for batch without cache (internal method).
        
        Args:
            texts: List of texts (already truncated)
            retry_count: Number of retries on rate limit
            
        Returns:
            List of embedding vectors
        """
        for attempt in range(retry_count):
            try:
                response = self.client.embeddings.create(
                    model=self.model,
                    input=texts
                )
                # Sort by index to maintain order
                batch_embeddings = sorted(response.data, key=lambda x: x.index)
                return [e.embedding for e in batch_embeddings]
            except RateLimitError as e:
                if attempt == retry_count - 1:
                    raise
                wait_time = (2 ** attempt) * 2
                logger.warning(f"Rate limited. Retrying in {wait_time}s (attempt {attempt + 1}/{retry_count})")
                import time
                time.sleep(wait_time)
        raise RuntimeError(f"Failed after {retry_count} retries")
    
    def get_embeddings_batch(
        self, 
        texts: List[str], 
        batch_size: int = 100
    ) -> List[List[float]]:
        """
        Generate embeddings for multiple texts efficiently with caching.
        
        Processes texts in batches and uses cache when available.
        Each text is checked against cache individually.
        
        Args:
            texts: List of texts to embed
            batch_size: Number of texts per API call
            
        Returns:
            List of embeddings (one per input text)
            
        Examples:
            >>> texts = ["Hello", "World", "!"]
            >>> embeddings = client.get_embeddings_batch(texts)
            >>> len(embeddings)
            3
        """
        import time
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            
            num_batches = (len(texts) + batch_size - 1) // batch_size
            current_batch_num = i // batch_size + 1
            
            logger.info(
                f"      - Processing batch {current_batch_num}/{num_batches} "
                f"({len(batch_texts)} texts)"
            )
            
            # Check cache for each text in batch
            batch_embeddings = []
            texts_to_generate = []
            indices_to_generate = []
            
            if self.cache is not None:
                for idx, text in enumerate(batch_texts):
                    text_truncated = text[:8191]
                    cache_key = self._get_cache_key(text_truncated)
                    
                    if cache_key in self.cache:
                        batch_embeddings.append((idx, self.cache[cache_key]))
                    else:
                        texts_to_generate.append(text_truncated)
                        indices_to_generate.append(idx)
                
                if texts_to_generate:
                    logger.debug(
                        f"Cache: {len(batch_embeddings)} hits, "
                        f"{len(texts_to_generate)} misses"
                    )
            else:
                # No cache - generate all
                texts_to_generate = [t[:8191] for t in batch_texts]
                indices_to_generate = list(range(len(batch_texts)))
            
            # Generate embeddings for cache misses
            if texts_to_generate:
                try:
                    generated = self._get_embeddings_batch_uncached(
                        texts_to_generate
                    )
                    
                    # Store in cache and add to results
                    for idx, text, embedding in zip(
                        indices_to_generate, 
                        texts_to_generate, 
                        generated
                    ):
                        batch_embeddings.append((idx, embedding))
                        
                        if self.cache is not None:
                            cache_key = self._get_cache_key(text)
                            self.cache[cache_key] = embedding
                    
                except Exception as e:
                    logger.error(f"      - Batch embedding failed: {e}")
                    raise
            
            # Sort by original index and extract embeddings
            batch_embeddings.sort(key=lambda x: x[0])
            all_embeddings.extend([emb for _, emb in batch_embeddings])
            
            # Small delay between batches to avoid rate limiting
            if num_batches > 1 and current_batch_num < num_batches:
                time.sleep(0.5)
        
        logger.success(f"   🧠 Generated {len(all_embeddings)} embeddings total.")
        return all_embeddings


def get_embeddings_client() -> EmbeddingsClient:
    """
    Get or create embeddings client (singleton pattern).
    
    Returns:
        Shared EmbeddingsClient instance
        
    Examples:
        >>> client = get_embeddings_client()
        >>> embedding = client.get_embedding("test")
    """
    if not hasattr(get_embeddings_client, '_instance'):
        get_embeddings_client._instance = EmbeddingsClient()
    return get_embeddings_client._instance