ChatterboxTTS-DNXS-Spokenwordv1 / modules /fast_batch_processor.py
danneauxs
update
3cb0dc4
"""
Fast Batch Processing Module
Optimized batch processing that works around generate_batch issues
"""
import json
import time
import logging
from collections import defaultdict, Counter
from pathlib import Path
from typing import List, Dict, Tuple, Any
import torch
class FastBatchProcessor:
"""Fast batch processing using optimized individual calls"""
def __init__(self, model, tolerance=0.05, min_batch_size=2, max_batch_size=8):
"""
Initialize fast batch processor
Args:
model: ChatterboxTTS model instance
tolerance: Parameter tolerance for grouping (0.05 = 5% variation allowed)
min_batch_size: Minimum chunks to form a batch
max_batch_size: Maximum chunks per batch (memory/performance limit)
"""
self.model = model
self.tolerance = tolerance
self.min_batch_size = min_batch_size
self.max_batch_size = max_batch_size
self.logger = logging.getLogger(__name__)
# Performance tracking
self.stats = {
'total_chunks': 0,
'batched_chunks': 0,
'individual_chunks': 0,
'batch_groups': 0,
'total_time': 0,
'batch_time': 0,
'individual_time': 0,
'parameter_switches': 0
}
def analyze_chunk_distribution(self, chunks: List[Dict]) -> Dict:
"""Analyze parameter distribution and batching potential"""
if not chunks:
return {'error': 'No chunks provided'}
# Extract parameters
param_combos = []
for chunk in chunks:
if 'tts_params' not in chunk:
continue
params = chunk['tts_params']
combo = (
round(params.get('exaggeration', 0.5), 3),
round(params.get('cfg_weight', 0.5), 3),
round(params.get('temperature', 0.8), 3),
round(params.get('min_p', 0.05), 3),
round(params.get('repetition_penalty', 1.2), 3)
)
param_combos.append(combo)
# Count combinations
combo_counts = Counter(param_combos)
unique_combos = len(combo_counts)
# Calculate batching potential (consecutive chunks with same params)
consecutive_groups = self._find_consecutive_groups(param_combos)
batchable = sum(len(group) for group in consecutive_groups if len(group) >= self.min_batch_size)
# Estimate speedup from reduced parameter switches
total_param_switches = len(set(param_combos))
optimized_switches = len(consecutive_groups)
switch_reduction = (total_param_switches - optimized_switches) / max(total_param_switches, 1)
analysis = {
'total_chunks': len(chunks),
'unique_combinations': unique_combos,
'consecutive_batchable': batchable,
'batch_percentage': (batchable / len(chunks)) * 100,
'consecutive_groups': len(consecutive_groups),
'parameter_switch_reduction': switch_reduction * 100,
'most_common_combos': combo_counts.most_common(5),
'estimated_speedup': self._estimate_speedup(len(chunks), batchable, switch_reduction)
}
return analysis
def _find_consecutive_groups(self, param_combos: List[Tuple]) -> List[List[int]]:
"""Find consecutive chunks with same parameters"""
if not param_combos:
return []
groups = []
current_group = [0]
current_params = param_combos[0]
for i in range(1, len(param_combos)):
if self._params_within_tolerance(current_params, param_combos[i]):
current_group.append(i)
else:
groups.append(current_group)
current_group = [i]
current_params = param_combos[i]
# Add the last group
groups.append(current_group)
return groups
def _params_within_tolerance(self, params1: Tuple, params2: Tuple) -> bool:
"""Check if two parameter sets are within tolerance"""
for p1, p2 in zip(params1, params2):
if abs(p1 - p2) > self.tolerance:
return False
return True
def _estimate_speedup(self, total_chunks: int, batchable_chunks: int, switch_reduction: float) -> Dict:
"""Estimate performance improvements"""
# Parameter switching overhead reduction
switch_speedup = 1.0 + (switch_reduction * 0.3) # 30% speedup from fewer switches
# Memory optimization speedup (fewer allocations/deallocations)
memory_speedup = 1.0 + (batchable_chunks / total_chunks * 0.4) # Up to 40% from memory optimization
# Combined speedup
combined_speedup = switch_speedup * memory_speedup
return {
'parameter_switch_speedup': switch_speedup,
'memory_optimization_speedup': memory_speedup,
'combined_speedup': combined_speedup,
'estimated_time_saving': ((combined_speedup - 1.0) / combined_speedup) * 100
}
def process_chunks_fast_batch(self, chunks: List[Dict], use_tolerance: bool = True) -> List[torch.Tensor]:
"""
Process chunks using fast batch optimization
Args:
chunks: List of chunk dictionaries from JSON
use_tolerance: Whether to use parameter tolerance
Returns:
List of audio tensors in original chunk order
"""
if not chunks:
return []
start_time = time.time()
self.stats['total_chunks'] = len(chunks)
# Group consecutive chunks with similar parameters
consecutive_groups = self._group_consecutive_chunks(chunks, use_tolerance)
# Initialize results array
results = [None] * len(chunks)
# Process each group with optimized parameter handling
current_params = None
param_switches = 0
for group in consecutive_groups:
group_start = time.time()
# Extract parameters for this group (use first chunk's params as representative)
_, first_chunk = group['chunks'][0]
target_params = first_chunk.get('tts_params', {})
# Check if we need to update model parameters
if current_params != target_params:
self._update_model_parameters(target_params)
current_params = target_params.copy()
param_switches += 1
# Process all chunks in this group with same parameters
self.logger.info(f"๐Ÿ”ฅ Processing group of {len(group['chunks'])} chunks with same parameters")
for idx, chunk in group['chunks']:
try:
# Generate without parameter overhead (params already set)
audio = self.model.generate(
chunk['text'],
exaggeration=target_params.get('exaggeration', 0.5),
cfg_weight=target_params.get('cfg_weight', 0.5),
temperature=target_params.get('temperature', 0.8),
min_p=target_params.get('min_p', 0.05),
top_p=target_params.get('top_p', 1.0),
repetition_penalty=target_params.get('repetition_penalty', 1.2)
)
results[idx] = audio
self.stats['batched_chunks'] += 1
except Exception as e:
self.logger.error(f"โŒ Chunk {idx} failed: {e}")
# Create silent audio as fallback
results[idx] = torch.zeros(1, 24000)
self.stats['individual_chunks'] += 1
group_time = time.time() - group_start
self.logger.info(f" โœ… Group completed in {group_time:.2f}s ({len(group['chunks'])} chunks)")
self.stats['parameter_switches'] = param_switches
self.stats['batch_groups'] = len(consecutive_groups)
self.stats['total_time'] = time.time() - start_time
# Log performance summary
self._log_performance_summary()
return results
def _group_consecutive_chunks(self, chunks: List[Dict], use_tolerance: bool = True) -> List[Dict]:
"""Group consecutive chunks with similar parameters"""
if not chunks:
return []
groups = []
current_group = []
current_params = None
for i, chunk in enumerate(chunks):
if 'tts_params' not in chunk:
# Handle chunks without parameters as individual
if current_group:
groups.append({'chunks': current_group, 'params': current_params})
current_group = []
groups.append({'chunks': [(i, chunk)], 'params': {}})
current_params = None
continue
chunk_params = chunk['tts_params']
# Convert to comparable format
if use_tolerance:
param_signature = (
round(chunk_params.get('exaggeration', 0.5) / self.tolerance) * self.tolerance,
round(chunk_params.get('cfg_weight', 0.5) / self.tolerance) * self.tolerance,
round(chunk_params.get('temperature', 0.8) / self.tolerance) * self.tolerance,
round(chunk_params.get('min_p', 0.05) / self.tolerance) * self.tolerance,
round(chunk_params.get('repetition_penalty', 1.2) / self.tolerance) * self.tolerance
)
else:
param_signature = (
chunk_params.get('exaggeration', 0.5),
chunk_params.get('cfg_weight', 0.5),
chunk_params.get('temperature', 0.8),
chunk_params.get('min_p', 0.05),
chunk_params.get('repetition_penalty', 1.2)
)
# Check if this chunk can be grouped with current group
if current_params is None or param_signature == current_params:
current_group.append((i, chunk))
current_params = param_signature
else:
# Start new group
if current_group:
groups.append({'chunks': current_group, 'params': current_params})
current_group = [(i, chunk)]
current_params = param_signature
# Add the last group
if current_group:
groups.append({'chunks': current_group, 'params': current_params})
return groups
def _update_model_parameters(self, params: Dict):
"""Update model with new parameters (placeholder for future optimization)"""
# For now, parameters are passed directly to generate()
# Future optimization: pre-configure model components with parameters
pass
def _log_performance_summary(self):
"""Log performance statistics"""
stats = self.stats
self.logger.info("๐Ÿ“Š FAST BATCH PROCESSING SUMMARY")
self.logger.info("=" * 50)
self.logger.info(f"Total chunks processed: {stats['total_chunks']}")
self.logger.info(f"Batched chunks: {stats['batched_chunks']} ({stats['batched_chunks']/stats['total_chunks']*100:.1f}%)")
self.logger.info(f"Individual chunks: {stats['individual_chunks']} ({stats['individual_chunks']/stats['total_chunks']*100:.1f}%)")
self.logger.info(f"Batch groups: {stats['batch_groups']}")
self.logger.info(f"Parameter switches: {stats['parameter_switches']}")
self.logger.info(f"Total time: {stats['total_time']:.2f}s")
# Calculate optimization efficiency
if stats['total_chunks'] > 0:
switch_efficiency = (stats['total_chunks'] - stats['parameter_switches']) / stats['total_chunks'] * 100
self.logger.info(f"Parameter switch efficiency: {switch_efficiency:.1f}%")
# Estimate what naive processing would have taken (one switch per chunk)
estimated_naive_switches = stats['total_chunks']
switch_reduction = (estimated_naive_switches - stats['parameter_switches']) / estimated_naive_switches * 100
self.logger.info(f"Parameter switches reduced by: {switch_reduction:.1f}%")
def load_chunks_from_json(json_path: str) -> List[Dict]:
"""Load chunks from JSON file"""
try:
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Filter out metadata entries
chunks = [item for item in data if isinstance(item, dict) and 'text' in item]
return chunks
except Exception as e:
logging.error(f"Failed to load JSON file {json_path}: {e}")
return []