File size: 13,449 Bytes
0a4529c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
# DEPENDENCIES
import numpy as np
from typing import List
from typing import Optional
from numpy.typing import NDArray
from config.settings import get_settings
from config.logging_config import get_logger
from utils.error_handler import handle_errors
from utils.error_handler import EmbeddingError
from chunking.token_counter import get_token_counter
from sentence_transformers import SentenceTransformer
from utils.helpers import BatchProcessor as BaseBatchProcessor


# Setup Settings and Logging
settings = get_settings()
logger   = get_logger(__name__)


class BatchProcessor:
    """
    Efficient batch processing for embeddings: Handles large batches with memory optimization and progress tracking
    """
    def __init__(self):
        self.logger           = logger
        self.base_processor   = BaseBatchProcessor()
        
        # Batch processing statistics
        self.total_batches    = 0
        self.total_texts      = 0
        self.failed_batches   = 0
    

    @handle_errors(error_type = EmbeddingError, log_error = True, reraise = True)
    def process_embeddings_batch(self, model: SentenceTransformer, texts: List[str], batch_size: Optional[int] = None, normalize: bool = True, **kwargs) -> List[NDArray]:
        """
        Process embeddings in optimized batches
        
        Arguments:
        ----------
            model      { SentenceTransformer } : Embedding model

            texts             { list }         : List of texts to embed
            
            batch_size        { int }          : Batch size (default from settings)
            
            normalize         { bool }         : Normalize embeddings
            
            **kwargs                           : Additional model.encode parameters
        
        Returns:
        --------
                      { list }                 : List of embedding vectors
        """
        if not texts:
            return []
        
        batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE
        
        self.logger.debug(f"Processing {len(texts)} texts in batches of {batch_size}")
        
        try:
            # Use model's built-in batching with optimization
            embeddings          = model.encode(texts,
                                               batch_size           = batch_size,
                                               normalize_embeddings = normalize,
                                               show_progress_bar    = False,
                                               convert_to_numpy     = True,
                                               **kwargs
                                              )
            
            # Update statistics
            self.total_batches += ((len(texts) + batch_size - 1) // batch_size)
            self.total_texts   += len(texts)
            
            self.logger.debug(f"Successfully generated {len(embeddings)} embeddings")
            
            # Convert to list of arrays
            return list(embeddings)  
            
        except Exception as e:
            self.failed_batches += 1
            self.logger.error(f"Batch embedding failed: {repr(e)}")
            raise EmbeddingError(f"Batch processing failed: {repr(e)}")
    

    def process_embeddings_with_fallback(self, model: SentenceTransformer, texts: List[str], batch_size: Optional[int] = None, normalize: bool = True) -> List[NDArray]:
        """
        Process embeddings with automatic batch size reduction on failure
        
        Arguments:
        ----------
            model      { SentenceTransformer } : Embedding model

            texts      { list }                : List of texts
            
            batch_size { int }                 : Initial batch size
            
            normalize  { bool }                : Normalize embeddings
        
        Returns:
        --------
                 { list }                      : List of embeddings
        """
        batch_size = batch_size or settings.EMBEDDING_BATCH_SIZE
        
        try:
            return self.process_embeddings_batch(model      = model,
                                                 texts      = texts,
                                                 batch_size = batch_size,
                                                 normalize  = normalize,
                                                )
        
        except (MemoryError, RuntimeError) as e:
            self.logger.warning(f"Batch size {batch_size} failed, reducing to {batch_size // 2}")
            
            # Reduce batch size and retry
            return self.process_embeddings_batch(model      = model,
                                                 texts      = texts,
                                                 batch_size = batch_size // 2,
                                                 normalize  = normalize,
                                                )
    

    def split_into_optimal_batches(self, texts: List[str], target_batch_size: int, max_batch_size: int = 1000) -> List[List[str]]:
        """
        Split texts into optimal batches considering token counts
        
        Arguments:
        ----------
            texts            { list } : List of texts

            target_batch_size { int } : Target batch size in texts
            
            max_batch_size    { int } : Maximum batch size to allow
        
        Returns:
        --------
                       { list }       : List of text batches
        """
        if not texts:
            return []
        
        token_counter  = get_token_counter()
        batches        = list()
        current_batch  = list()
        current_tokens = 0
        
        # Estimate tokens per text (average of first 10 or all if less)
        sample_size    = min(10, len(texts))
        sample_tokens  = [token_counter.count_tokens(text) for text in texts[:sample_size]]
        avg_tokens     = sum(sample_tokens) / len(sample_tokens) if sample_tokens else 100
        
        # Target tokens per batch (approximate)
        target_tokens  = target_batch_size * avg_tokens
        
        for text in texts:
            text_tokens = token_counter.count_tokens(text)
            
            # If single text is too large, put it in its own batch
            if (text_tokens > (target_tokens * 0.8)):
                if current_batch:
                    batches.append(current_batch)
                    current_batch  = list()
                    current_tokens = 0
                
                batches.append([text])
                continue
            
            # Check if adding this text would exceed limits
            if (((current_tokens + text_tokens) > target_tokens) and current_batch) or (len(current_batch) >= max_batch_size):
                batches.append(current_batch)
                current_batch  = list()
                current_tokens = 0
            
            current_batch.append(text)
            current_tokens += text_tokens
        
        # Add final batch
        if current_batch:
            batches.append(current_batch)
        
        self.logger.debug(f"Split {len(texts)} texts into {len(batches)} optimal batches")
        
        return batches
    

    def process_batches_with_progress(self, model: SentenceTransformer, texts: List[str], batch_size: Optional[int] = None, progress_callback: Optional[callable] = None, **kwargs) -> List[NDArray]:
        """
        Process batches with progress reporting
        
        Arguments:
        ----------
            model            { SentenceTransformer } : Embedding model

            texts            { list }                : List of texts
            
            batch_size       { int }                 : Batch size
            
            progress_callback { callable }           : Callback for progress updates
            
            **kwargs                                 : Additional parameters
        
        Returns:
        --------
                         { list }                    : List of embeddings
        """
        if not texts:
            return []
        
        batch_size     = batch_size or settings.EMBEDDING_BATCH_SIZE
        
        # Split into batches
        batches        = self.split_into_optimal_batches(texts             = texts, 
                                                         target_batch_size = batch_size,
                                                        )
        
        all_embeddings = list()
        
        for i, batch_texts in enumerate(batches):
            if progress_callback:
                progress = (i / len(batches)) * 100
                progress_callback(progress, f"Processing batch {i + 1}/{len(batches)}")
            
            try:
                batch_embeddings = self.process_embeddings_batch(model      = model,
                                                                 texts      = batch_texts,
                                                                 batch_size = len(batch_texts),
                                                                 **kwargs
                                                                )
                
                all_embeddings.extend(batch_embeddings)
                
                self.logger.debug(f"Processed batch {i + 1}/{len(batches)}: {len(batch_texts)} texts")
            
            except Exception as e:
                self.logger.error(f"Failed to process batch {i + 1}: {repr(e)}")
                
                # Add None placeholders for failed batch
                all_embeddings.extend([None] * len(batch_texts))
        
        if progress_callback:
            progress_callback(100, "Embedding complete")
        
        return all_embeddings
    

    def validate_embeddings_batch(self, embeddings: List[NDArray], expected_count: int) -> bool:
        """
        Validate a batch of embeddings
        
        Arguments:
        ----------
            embeddings     { list } : List of embedding vectors

            expected_count { int }  : Expected number of embeddings
        
        Returns:
        --------
                   { bool }         : True if valid
        """
        if (len(embeddings) != expected_count):
            self.logger.error(f"Embedding count mismatch: expected {expected_count}, got {len(embeddings)}")
            return False
        
        valid_count = 0
        
        for i, emb in enumerate(embeddings):
            if emb is None:
                self.logger.warning(f"None embedding at index {i}")
                continue
            
            if not isinstance(emb, np.ndarray):
                self.logger.warning(f"Invalid embedding type at index {i}: {type(emb)}")
                continue
            
            if (emb.ndim != 1):
                self.logger.warning(f"Invalid embedding dimension at index {i}: {emb.ndim}")
                continue
            
            if np.any(np.isnan(emb)):
                self.logger.warning(f"NaN values in embedding at index {i}")
                continue
            
            valid_count += 1
        
        validity_ratio = valid_count / expected_count
        
        if (validity_ratio < 0.9):
            self.logger.warning(f"Low embedding validity: {valid_count}/{expected_count} ({validity_ratio:.1%})")
            return False
        
        return True
    

    def get_processing_stats(self) -> dict:
        """
        Get batch processing statistics
        
        Returns:
        --------
            { dict }    : Statistics dictionary
        """
        success_rate = ((self.total_batches - self.failed_batches) / self.total_batches * 100) if (self.total_batches > 0) else 100
        
        stats        = {"total_batches"    : self.total_batches,
                        "total_texts"      : self.total_texts,
                        "failed_batches"   : self.failed_batches,
                        "success_rate"     : success_rate,
                        "avg_batch_size"   : self.total_texts / self.total_batches if (self.total_batches > 0) else 0,
                       }
        
        return stats
    

    def reset_stats(self):
        """
        Reset processing statistics
        """
        self.total_batches  = 0
        self.total_texts    = 0
        self.failed_batches = 0
        
        self.logger.debug("Reset batch processing statistics")


# Global batch processor instance
_batch_processor = None


def get_batch_processor() -> BatchProcessor:
    """
    Get global batch processor instance
    
    Returns:
    --------
        { BatchProcessor } : BatchProcessor instance
    """
    global _batch_processor

    if _batch_processor is None:
        _batch_processor = BatchProcessor()
    
    return _batch_processor


def process_embeddings_batch(model: SentenceTransformer, texts: List[str], **kwargs) -> List[NDArray]:
    """
    Convenience function for batch embedding
    
    Arguments:
    ----------
        model { SentenceTransformer } : Embedding model

        texts { list }                : List of texts
        
        **kwargs                      : Additional arguments
    
    Returns:
    --------
             { list }                 : List of embeddings
    """
    processor = get_batch_processor()

    return processor.process_embeddings_batch(model, texts, **kwargs)