|
|
""" |
|
|
Fast Batch Processing Module |
|
|
Optimized batch processing that works around generate_batch issues |
|
|
""" |
|
|
|
|
|
import json |
|
|
import time |
|
|
import logging |
|
|
from collections import defaultdict, Counter |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Tuple, Any |
|
|
import torch |
|
|
|
|
|
class FastBatchProcessor: |
|
|
"""Fast batch processing using optimized individual calls""" |
|
|
|
|
|
def __init__(self, model, tolerance=0.05, min_batch_size=2, max_batch_size=8): |
|
|
""" |
|
|
Initialize fast batch processor |
|
|
|
|
|
Args: |
|
|
model: ChatterboxTTS model instance |
|
|
tolerance: Parameter tolerance for grouping (0.05 = 5% variation allowed) |
|
|
min_batch_size: Minimum chunks to form a batch |
|
|
max_batch_size: Maximum chunks per batch (memory/performance limit) |
|
|
""" |
|
|
self.model = model |
|
|
self.tolerance = tolerance |
|
|
self.min_batch_size = min_batch_size |
|
|
self.max_batch_size = max_batch_size |
|
|
self.logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
self.stats = { |
|
|
'total_chunks': 0, |
|
|
'batched_chunks': 0, |
|
|
'individual_chunks': 0, |
|
|
'batch_groups': 0, |
|
|
'total_time': 0, |
|
|
'batch_time': 0, |
|
|
'individual_time': 0, |
|
|
'parameter_switches': 0 |
|
|
} |
|
|
|
|
|
def analyze_chunk_distribution(self, chunks: List[Dict]) -> Dict: |
|
|
"""Analyze parameter distribution and batching potential""" |
|
|
if not chunks: |
|
|
return {'error': 'No chunks provided'} |
|
|
|
|
|
|
|
|
param_combos = [] |
|
|
for chunk in chunks: |
|
|
if 'tts_params' not in chunk: |
|
|
continue |
|
|
|
|
|
params = chunk['tts_params'] |
|
|
combo = ( |
|
|
round(params.get('exaggeration', 0.5), 3), |
|
|
round(params.get('cfg_weight', 0.5), 3), |
|
|
round(params.get('temperature', 0.8), 3), |
|
|
round(params.get('min_p', 0.05), 3), |
|
|
round(params.get('repetition_penalty', 1.2), 3) |
|
|
) |
|
|
param_combos.append(combo) |
|
|
|
|
|
|
|
|
combo_counts = Counter(param_combos) |
|
|
unique_combos = len(combo_counts) |
|
|
|
|
|
|
|
|
consecutive_groups = self._find_consecutive_groups(param_combos) |
|
|
batchable = sum(len(group) for group in consecutive_groups if len(group) >= self.min_batch_size) |
|
|
|
|
|
|
|
|
total_param_switches = len(set(param_combos)) |
|
|
optimized_switches = len(consecutive_groups) |
|
|
switch_reduction = (total_param_switches - optimized_switches) / max(total_param_switches, 1) |
|
|
|
|
|
analysis = { |
|
|
'total_chunks': len(chunks), |
|
|
'unique_combinations': unique_combos, |
|
|
'consecutive_batchable': batchable, |
|
|
'batch_percentage': (batchable / len(chunks)) * 100, |
|
|
'consecutive_groups': len(consecutive_groups), |
|
|
'parameter_switch_reduction': switch_reduction * 100, |
|
|
'most_common_combos': combo_counts.most_common(5), |
|
|
'estimated_speedup': self._estimate_speedup(len(chunks), batchable, switch_reduction) |
|
|
} |
|
|
|
|
|
return analysis |
|
|
|
|
|
def _find_consecutive_groups(self, param_combos: List[Tuple]) -> List[List[int]]: |
|
|
"""Find consecutive chunks with same parameters""" |
|
|
if not param_combos: |
|
|
return [] |
|
|
|
|
|
groups = [] |
|
|
current_group = [0] |
|
|
current_params = param_combos[0] |
|
|
|
|
|
for i in range(1, len(param_combos)): |
|
|
if self._params_within_tolerance(current_params, param_combos[i]): |
|
|
current_group.append(i) |
|
|
else: |
|
|
groups.append(current_group) |
|
|
current_group = [i] |
|
|
current_params = param_combos[i] |
|
|
|
|
|
|
|
|
groups.append(current_group) |
|
|
return groups |
|
|
|
|
|
def _params_within_tolerance(self, params1: Tuple, params2: Tuple) -> bool: |
|
|
"""Check if two parameter sets are within tolerance""" |
|
|
for p1, p2 in zip(params1, params2): |
|
|
if abs(p1 - p2) > self.tolerance: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def _estimate_speedup(self, total_chunks: int, batchable_chunks: int, switch_reduction: float) -> Dict: |
|
|
"""Estimate performance improvements""" |
|
|
|
|
|
switch_speedup = 1.0 + (switch_reduction * 0.3) |
|
|
|
|
|
|
|
|
memory_speedup = 1.0 + (batchable_chunks / total_chunks * 0.4) |
|
|
|
|
|
|
|
|
combined_speedup = switch_speedup * memory_speedup |
|
|
|
|
|
return { |
|
|
'parameter_switch_speedup': switch_speedup, |
|
|
'memory_optimization_speedup': memory_speedup, |
|
|
'combined_speedup': combined_speedup, |
|
|
'estimated_time_saving': ((combined_speedup - 1.0) / combined_speedup) * 100 |
|
|
} |
|
|
|
|
|
def process_chunks_fast_batch(self, chunks: List[Dict], use_tolerance: bool = True) -> List[torch.Tensor]: |
|
|
""" |
|
|
Process chunks using fast batch optimization |
|
|
|
|
|
Args: |
|
|
chunks: List of chunk dictionaries from JSON |
|
|
use_tolerance: Whether to use parameter tolerance |
|
|
|
|
|
Returns: |
|
|
List of audio tensors in original chunk order |
|
|
""" |
|
|
if not chunks: |
|
|
return [] |
|
|
|
|
|
start_time = time.time() |
|
|
self.stats['total_chunks'] = len(chunks) |
|
|
|
|
|
|
|
|
consecutive_groups = self._group_consecutive_chunks(chunks, use_tolerance) |
|
|
|
|
|
|
|
|
results = [None] * len(chunks) |
|
|
|
|
|
|
|
|
current_params = None |
|
|
param_switches = 0 |
|
|
|
|
|
for group in consecutive_groups: |
|
|
group_start = time.time() |
|
|
|
|
|
|
|
|
_, first_chunk = group['chunks'][0] |
|
|
target_params = first_chunk.get('tts_params', {}) |
|
|
|
|
|
|
|
|
if current_params != target_params: |
|
|
self._update_model_parameters(target_params) |
|
|
current_params = target_params.copy() |
|
|
param_switches += 1 |
|
|
|
|
|
|
|
|
self.logger.info(f"๐ฅ Processing group of {len(group['chunks'])} chunks with same parameters") |
|
|
|
|
|
for idx, chunk in group['chunks']: |
|
|
try: |
|
|
|
|
|
audio = self.model.generate( |
|
|
chunk['text'], |
|
|
exaggeration=target_params.get('exaggeration', 0.5), |
|
|
cfg_weight=target_params.get('cfg_weight', 0.5), |
|
|
temperature=target_params.get('temperature', 0.8), |
|
|
min_p=target_params.get('min_p', 0.05), |
|
|
top_p=target_params.get('top_p', 1.0), |
|
|
repetition_penalty=target_params.get('repetition_penalty', 1.2) |
|
|
) |
|
|
results[idx] = audio |
|
|
self.stats['batched_chunks'] += 1 |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"โ Chunk {idx} failed: {e}") |
|
|
|
|
|
results[idx] = torch.zeros(1, 24000) |
|
|
self.stats['individual_chunks'] += 1 |
|
|
|
|
|
group_time = time.time() - group_start |
|
|
self.logger.info(f" โ
Group completed in {group_time:.2f}s ({len(group['chunks'])} chunks)") |
|
|
|
|
|
self.stats['parameter_switches'] = param_switches |
|
|
self.stats['batch_groups'] = len(consecutive_groups) |
|
|
self.stats['total_time'] = time.time() - start_time |
|
|
|
|
|
|
|
|
self._log_performance_summary() |
|
|
|
|
|
return results |
|
|
|
|
|
def _group_consecutive_chunks(self, chunks: List[Dict], use_tolerance: bool = True) -> List[Dict]: |
|
|
"""Group consecutive chunks with similar parameters""" |
|
|
if not chunks: |
|
|
return [] |
|
|
|
|
|
groups = [] |
|
|
current_group = [] |
|
|
current_params = None |
|
|
|
|
|
for i, chunk in enumerate(chunks): |
|
|
if 'tts_params' not in chunk: |
|
|
|
|
|
if current_group: |
|
|
groups.append({'chunks': current_group, 'params': current_params}) |
|
|
current_group = [] |
|
|
groups.append({'chunks': [(i, chunk)], 'params': {}}) |
|
|
current_params = None |
|
|
continue |
|
|
|
|
|
chunk_params = chunk['tts_params'] |
|
|
|
|
|
|
|
|
if use_tolerance: |
|
|
param_signature = ( |
|
|
round(chunk_params.get('exaggeration', 0.5) / self.tolerance) * self.tolerance, |
|
|
round(chunk_params.get('cfg_weight', 0.5) / self.tolerance) * self.tolerance, |
|
|
round(chunk_params.get('temperature', 0.8) / self.tolerance) * self.tolerance, |
|
|
round(chunk_params.get('min_p', 0.05) / self.tolerance) * self.tolerance, |
|
|
round(chunk_params.get('repetition_penalty', 1.2) / self.tolerance) * self.tolerance |
|
|
) |
|
|
else: |
|
|
param_signature = ( |
|
|
chunk_params.get('exaggeration', 0.5), |
|
|
chunk_params.get('cfg_weight', 0.5), |
|
|
chunk_params.get('temperature', 0.8), |
|
|
chunk_params.get('min_p', 0.05), |
|
|
chunk_params.get('repetition_penalty', 1.2) |
|
|
) |
|
|
|
|
|
|
|
|
if current_params is None or param_signature == current_params: |
|
|
current_group.append((i, chunk)) |
|
|
current_params = param_signature |
|
|
else: |
|
|
|
|
|
if current_group: |
|
|
groups.append({'chunks': current_group, 'params': current_params}) |
|
|
current_group = [(i, chunk)] |
|
|
current_params = param_signature |
|
|
|
|
|
|
|
|
if current_group: |
|
|
groups.append({'chunks': current_group, 'params': current_params}) |
|
|
|
|
|
return groups |
|
|
|
|
|
def _update_model_parameters(self, params: Dict): |
|
|
"""Update model with new parameters (placeholder for future optimization)""" |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
def _log_performance_summary(self): |
|
|
"""Log performance statistics""" |
|
|
stats = self.stats |
|
|
|
|
|
self.logger.info("๐ FAST BATCH PROCESSING SUMMARY") |
|
|
self.logger.info("=" * 50) |
|
|
self.logger.info(f"Total chunks processed: {stats['total_chunks']}") |
|
|
self.logger.info(f"Batched chunks: {stats['batched_chunks']} ({stats['batched_chunks']/stats['total_chunks']*100:.1f}%)") |
|
|
self.logger.info(f"Individual chunks: {stats['individual_chunks']} ({stats['individual_chunks']/stats['total_chunks']*100:.1f}%)") |
|
|
self.logger.info(f"Batch groups: {stats['batch_groups']}") |
|
|
self.logger.info(f"Parameter switches: {stats['parameter_switches']}") |
|
|
self.logger.info(f"Total time: {stats['total_time']:.2f}s") |
|
|
|
|
|
|
|
|
if stats['total_chunks'] > 0: |
|
|
switch_efficiency = (stats['total_chunks'] - stats['parameter_switches']) / stats['total_chunks'] * 100 |
|
|
self.logger.info(f"Parameter switch efficiency: {switch_efficiency:.1f}%") |
|
|
|
|
|
|
|
|
estimated_naive_switches = stats['total_chunks'] |
|
|
switch_reduction = (estimated_naive_switches - stats['parameter_switches']) / estimated_naive_switches * 100 |
|
|
self.logger.info(f"Parameter switches reduced by: {switch_reduction:.1f}%") |
|
|
|
|
|
def load_chunks_from_json(json_path: str) -> List[Dict]: |
|
|
"""Load chunks from JSON file""" |
|
|
try: |
|
|
with open(json_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
chunks = [item for item in data if isinstance(item, dict) and 'text' in item] |
|
|
return chunks |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Failed to load JSON file {json_path}: {e}") |
|
|
return [] |