"""Parallel audio processing for generating multiple audio chunks concurrently.""" import asyncio import concurrent.futures from typing import List, Tuple, Optional, Callable import numpy as np import gradio as gr class ParallelAudioProcessor: """Handles parallel processing of multiple audio chunks.""" def __init__(self, max_workers: int = 4): """ Initialize the parallel processor. Args: max_workers: Maximum number of concurrent workers for audio generation """ self.max_workers = max_workers def process_chunks_parallel( self, text_chunks: List[str], audio_generator_func: Callable, progress_callback: Optional[Callable] = None ) -> List[Tuple[int, np.ndarray]]: """ Process multiple text chunks in parallel to generate audio. Args: text_chunks: List of text chunks to process audio_generator_func: Function to generate audio from text progress_callback: Optional callback for progress updates Returns: List of tuples containing (sample_rate, audio_data) for each chunk """ if not text_chunks: return [] total_chunks = len(text_chunks) completed_chunks = 0 results = [None] * total_chunks def update_progress(chunk_index: int, desc: str = ""): nonlocal completed_chunks if progress_callback: progress = completed_chunks / total_chunks progress_callback(progress, desc=f"Processing chunk {completed_chunks + 1}/{total_chunks}{': ' + desc if desc else ''}") def process_single_chunk(chunk_index: int, text_chunk: str) -> Tuple[int, Tuple[int, np.ndarray]]: """Process a single chunk and return the result with its index.""" try: # Create a local progress callback for this chunk def chunk_progress(progress: float, desc: str = ""): update_progress(chunk_index, f"Chunk {chunk_index + 1}: {desc}") # Generate audio for this chunk audio_result = audio_generator_func(text_chunk, None, progress=chunk_progress) return chunk_index, audio_result except Exception as e: raise Exception(f"Error processing chunk {chunk_index + 1}: {str(e)}") # Use ThreadPoolExecutor for parallel processing with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: # Submit all chunks for processing future_to_index = { executor.submit(process_single_chunk, i, chunk): i for i, chunk in enumerate(text_chunks) } # Collect results as they complete for future in concurrent.futures.as_completed(future_to_index): chunk_index = future_to_index[future] try: index, audio_result = future.result() results[index] = audio_result completed_chunks += 1 if progress_callback: progress = completed_chunks / total_chunks progress_callback( progress, desc=f"Completed {completed_chunks}/{total_chunks} audio chunks" ) except Exception as e: raise gr.Error(f"Failed to process chunk {chunk_index + 1}: {str(e)}") # Filter out any None results (shouldn't happen, but just in case) valid_results = [result for result in results if result is not None] if len(valid_results) != total_chunks: raise gr.Error(f"Only {len(valid_results)} out of {total_chunks} chunks processed successfully") return valid_results async def process_chunks_async( self, text_chunks: List[str], audio_generator_func: Callable, progress_callback: Optional[Callable] = None ) -> List[Tuple[int, np.ndarray]]: """ Async version of parallel chunk processing. Args: text_chunks: List of text chunks to process audio_generator_func: Function to generate audio from text progress_callback: Optional callback for progress updates Returns: List of tuples containing (sample_rate, audio_data) for each chunk """ if not text_chunks: return [] async def process_chunk_async(chunk_index: int, text_chunk: str): """Process a single chunk asynchronously.""" loop = asyncio.get_event_loop() def chunk_progress(progress: float, desc: str = ""): if progress_callback: progress_callback( (chunk_index + progress) / len(text_chunks), desc=f"Chunk {chunk_index + 1}: {desc}" ) # Run the audio generation in a thread pool audio_result = await loop.run_in_executor( None, lambda: audio_generator_func(text_chunk, None, progress=chunk_progress) ) return chunk_index, audio_result # Create tasks for all chunks tasks = [ process_chunk_async(i, chunk) for i, chunk in enumerate(text_chunks) ] # Process all chunks concurrently try: results = await asyncio.gather(*tasks) # Sort results by chunk index to maintain order results.sort(key=lambda x: x[0]) return [result[1] for result in results] except Exception as e: raise gr.Error(f"Error in async processing: {str(e)}") def estimate_processing_time(self, text_chunks: List[str], avg_time_per_char: float = 0.1) -> float: """ Estimate total processing time for all chunks. Args: text_chunks: List of text chunks avg_time_per_char: Average processing time per character (seconds) Returns: Estimated processing time in seconds """ total_chars = sum(len(chunk) for chunk in text_chunks) sequential_time = total_chars * avg_time_per_char # Account for parallelization parallel_efficiency = min(len(text_chunks), self.max_workers) / len(text_chunks) if text_chunks else 1 estimated_time = sequential_time * parallel_efficiency return estimated_time