File size: 13,242 Bytes
3cb0dc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""
Fast Batch Processing Module
Optimized batch processing that works around generate_batch issues
"""

import json
import time
import logging
from collections import defaultdict, Counter
from pathlib import Path
from typing import List, Dict, Tuple, Any
import torch

class FastBatchProcessor:
    """Fast batch processing using optimized individual calls"""
    
    def __init__(self, model, tolerance=0.05, min_batch_size=2, max_batch_size=8):
        """
        Initialize fast batch processor
        
        Args:
            model: ChatterboxTTS model instance
            tolerance: Parameter tolerance for grouping (0.05 = 5% variation allowed)
            min_batch_size: Minimum chunks to form a batch
            max_batch_size: Maximum chunks per batch (memory/performance limit)
        """
        self.model = model
        self.tolerance = tolerance
        self.min_batch_size = min_batch_size
        self.max_batch_size = max_batch_size
        self.logger = logging.getLogger(__name__)
        
        # Performance tracking
        self.stats = {
            'total_chunks': 0,
            'batched_chunks': 0,
            'individual_chunks': 0,
            'batch_groups': 0,
            'total_time': 0,
            'batch_time': 0,
            'individual_time': 0,
            'parameter_switches': 0
        }
    
    def analyze_chunk_distribution(self, chunks: List[Dict]) -> Dict:
        """Analyze parameter distribution and batching potential"""
        if not chunks:
            return {'error': 'No chunks provided'}
        
        # Extract parameters
        param_combos = []
        for chunk in chunks:
            if 'tts_params' not in chunk:
                continue
                
            params = chunk['tts_params']
            combo = (
                round(params.get('exaggeration', 0.5), 3),
                round(params.get('cfg_weight', 0.5), 3),
                round(params.get('temperature', 0.8), 3),
                round(params.get('min_p', 0.05), 3),
                round(params.get('repetition_penalty', 1.2), 3)
            )
            param_combos.append(combo)
        
        # Count combinations
        combo_counts = Counter(param_combos)
        unique_combos = len(combo_counts)
        
        # Calculate batching potential (consecutive chunks with same params)
        consecutive_groups = self._find_consecutive_groups(param_combos)
        batchable = sum(len(group) for group in consecutive_groups if len(group) >= self.min_batch_size)
        
        # Estimate speedup from reduced parameter switches
        total_param_switches = len(set(param_combos))
        optimized_switches = len(consecutive_groups)
        switch_reduction = (total_param_switches - optimized_switches) / max(total_param_switches, 1)
        
        analysis = {
            'total_chunks': len(chunks),
            'unique_combinations': unique_combos,
            'consecutive_batchable': batchable,
            'batch_percentage': (batchable / len(chunks)) * 100,
            'consecutive_groups': len(consecutive_groups),
            'parameter_switch_reduction': switch_reduction * 100,
            'most_common_combos': combo_counts.most_common(5),
            'estimated_speedup': self._estimate_speedup(len(chunks), batchable, switch_reduction)
        }
        
        return analysis
    
    def _find_consecutive_groups(self, param_combos: List[Tuple]) -> List[List[int]]:
        """Find consecutive chunks with same parameters"""
        if not param_combos:
            return []
        
        groups = []
        current_group = [0]
        current_params = param_combos[0]
        
        for i in range(1, len(param_combos)):
            if self._params_within_tolerance(current_params, param_combos[i]):
                current_group.append(i)
            else:
                groups.append(current_group)
                current_group = [i]
                current_params = param_combos[i]
        
        # Add the last group
        groups.append(current_group)
        return groups
    
    def _params_within_tolerance(self, params1: Tuple, params2: Tuple) -> bool:
        """Check if two parameter sets are within tolerance"""
        for p1, p2 in zip(params1, params2):
            if abs(p1 - p2) > self.tolerance:
                return False
        return True
    
    def _estimate_speedup(self, total_chunks: int, batchable_chunks: int, switch_reduction: float) -> Dict:
        """Estimate performance improvements"""
        # Parameter switching overhead reduction
        switch_speedup = 1.0 + (switch_reduction * 0.3)  # 30% speedup from fewer switches
        
        # Memory optimization speedup (fewer allocations/deallocations)
        memory_speedup = 1.0 + (batchable_chunks / total_chunks * 0.4)  # Up to 40% from memory optimization
        
        # Combined speedup
        combined_speedup = switch_speedup * memory_speedup
        
        return {
            'parameter_switch_speedup': switch_speedup,
            'memory_optimization_speedup': memory_speedup, 
            'combined_speedup': combined_speedup,
            'estimated_time_saving': ((combined_speedup - 1.0) / combined_speedup) * 100
        }
    
    def process_chunks_fast_batch(self, chunks: List[Dict], use_tolerance: bool = True) -> List[torch.Tensor]:
        """
        Process chunks using fast batch optimization
        
        Args:
            chunks: List of chunk dictionaries from JSON
            use_tolerance: Whether to use parameter tolerance
            
        Returns:
            List of audio tensors in original chunk order
        """
        if not chunks:
            return []
        
        start_time = time.time()
        self.stats['total_chunks'] = len(chunks)
        
        # Group consecutive chunks with similar parameters
        consecutive_groups = self._group_consecutive_chunks(chunks, use_tolerance)
        
        # Initialize results array
        results = [None] * len(chunks)
        
        # Process each group with optimized parameter handling
        current_params = None
        param_switches = 0
        
        for group in consecutive_groups:
            group_start = time.time()
            
            # Extract parameters for this group (use first chunk's params as representative)
            _, first_chunk = group['chunks'][0]
            target_params = first_chunk.get('tts_params', {})
            
            # Check if we need to update model parameters
            if current_params != target_params:
                self._update_model_parameters(target_params)
                current_params = target_params.copy()
                param_switches += 1
            
            # Process all chunks in this group with same parameters
            self.logger.info(f"πŸ”₯ Processing group of {len(group['chunks'])} chunks with same parameters")
            
            for idx, chunk in group['chunks']:
                try:
                    # Generate without parameter overhead (params already set)
                    audio = self.model.generate(
                        chunk['text'],
                        exaggeration=target_params.get('exaggeration', 0.5),
                        cfg_weight=target_params.get('cfg_weight', 0.5), 
                        temperature=target_params.get('temperature', 0.8),
                        min_p=target_params.get('min_p', 0.05),
                        top_p=target_params.get('top_p', 1.0),
                        repetition_penalty=target_params.get('repetition_penalty', 1.2)
                    )
                    results[idx] = audio
                    self.stats['batched_chunks'] += 1
                    
                except Exception as e:
                    self.logger.error(f"❌ Chunk {idx} failed: {e}")
                    # Create silent audio as fallback
                    results[idx] = torch.zeros(1, 24000)
                    self.stats['individual_chunks'] += 1
            
            group_time = time.time() - group_start
            self.logger.info(f"   βœ… Group completed in {group_time:.2f}s ({len(group['chunks'])} chunks)")
        
        self.stats['parameter_switches'] = param_switches
        self.stats['batch_groups'] = len(consecutive_groups)
        self.stats['total_time'] = time.time() - start_time
        
        # Log performance summary
        self._log_performance_summary()
        
        return results
    
    def _group_consecutive_chunks(self, chunks: List[Dict], use_tolerance: bool = True) -> List[Dict]:
        """Group consecutive chunks with similar parameters"""
        if not chunks:
            return []
        
        groups = []
        current_group = []
        current_params = None
        
        for i, chunk in enumerate(chunks):
            if 'tts_params' not in chunk:
                # Handle chunks without parameters as individual
                if current_group:
                    groups.append({'chunks': current_group, 'params': current_params})
                    current_group = []
                groups.append({'chunks': [(i, chunk)], 'params': {}})
                current_params = None
                continue
            
            chunk_params = chunk['tts_params']
            
            # Convert to comparable format
            if use_tolerance:
                param_signature = (
                    round(chunk_params.get('exaggeration', 0.5) / self.tolerance) * self.tolerance,
                    round(chunk_params.get('cfg_weight', 0.5) / self.tolerance) * self.tolerance,
                    round(chunk_params.get('temperature', 0.8) / self.tolerance) * self.tolerance,
                    round(chunk_params.get('min_p', 0.05) / self.tolerance) * self.tolerance,
                    round(chunk_params.get('repetition_penalty', 1.2) / self.tolerance) * self.tolerance
                )
            else:
                param_signature = (
                    chunk_params.get('exaggeration', 0.5),
                    chunk_params.get('cfg_weight', 0.5),
                    chunk_params.get('temperature', 0.8),
                    chunk_params.get('min_p', 0.05),
                    chunk_params.get('repetition_penalty', 1.2)
                )
            
            # Check if this chunk can be grouped with current group
            if current_params is None or param_signature == current_params:
                current_group.append((i, chunk))
                current_params = param_signature
            else:
                # Start new group
                if current_group:
                    groups.append({'chunks': current_group, 'params': current_params})
                current_group = [(i, chunk)]
                current_params = param_signature
        
        # Add the last group
        if current_group:
            groups.append({'chunks': current_group, 'params': current_params})
        
        return groups
    
    def _update_model_parameters(self, params: Dict):
        """Update model with new parameters (placeholder for future optimization)"""
        # For now, parameters are passed directly to generate()
        # Future optimization: pre-configure model components with parameters
        pass
    
    def _log_performance_summary(self):
        """Log performance statistics"""
        stats = self.stats
        
        self.logger.info("πŸ“Š FAST BATCH PROCESSING SUMMARY")
        self.logger.info("=" * 50)
        self.logger.info(f"Total chunks processed: {stats['total_chunks']}")
        self.logger.info(f"Batched chunks: {stats['batched_chunks']} ({stats['batched_chunks']/stats['total_chunks']*100:.1f}%)")
        self.logger.info(f"Individual chunks: {stats['individual_chunks']} ({stats['individual_chunks']/stats['total_chunks']*100:.1f}%)")
        self.logger.info(f"Batch groups: {stats['batch_groups']}")
        self.logger.info(f"Parameter switches: {stats['parameter_switches']}")
        self.logger.info(f"Total time: {stats['total_time']:.2f}s")
        
        # Calculate optimization efficiency
        if stats['total_chunks'] > 0:
            switch_efficiency = (stats['total_chunks'] - stats['parameter_switches']) / stats['total_chunks'] * 100
            self.logger.info(f"Parameter switch efficiency: {switch_efficiency:.1f}%")
            
            # Estimate what naive processing would have taken (one switch per chunk)
            estimated_naive_switches = stats['total_chunks']
            switch_reduction = (estimated_naive_switches - stats['parameter_switches']) / estimated_naive_switches * 100
            self.logger.info(f"Parameter switches reduced by: {switch_reduction:.1f}%")

def load_chunks_from_json(json_path: str) -> List[Dict]:
    """Load chunks from JSON file"""
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # Filter out metadata entries
        chunks = [item for item in data if isinstance(item, dict) and 'text' in item]
        return chunks
        
    except Exception as e:
        logging.error(f"Failed to load JSON file {json_path}: {e}")
        return []