Spaces:
Runtime error
Runtime error
| """ | |
| Parallel Inference Integration for DittoTalkingHead | |
| Integrates parallel processing into the inference pipeline | |
| """ | |
| import asyncio | |
| import time | |
| from typing import Dict, Any, Tuple, Optional | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| from .parallel_processing import ParallelProcessor, PipelineProcessor | |
| class ParallelInference: | |
| """ | |
| Parallel inference wrapper for DittoTalkingHead | |
| """ | |
| def __init__(self, sdk, parallel_processor: Optional[ParallelProcessor] = None): | |
| """ | |
| Initialize parallel inference | |
| Args: | |
| sdk: StreamSDK instance | |
| parallel_processor: ParallelProcessor instance | |
| """ | |
| self.sdk = sdk | |
| self.parallel_processor = parallel_processor or ParallelProcessor(num_threads=4) | |
| # Setup pipeline stages | |
| self.pipeline_stages = { | |
| 'load': self._load_files, | |
| 'preprocess': self._preprocess, | |
| 'inference': self._inference, | |
| 'postprocess': self._postprocess | |
| } | |
| def _load_files(self, paths: Dict[str, str]) -> Dict[str, Any]: | |
| """Load audio and image files""" | |
| audio_path = paths['audio'] | |
| image_path = paths['image'] | |
| # Parallel loading | |
| audio_data, image_data = self.parallel_processor.preprocess_parallel_sync( | |
| audio_path, image_path | |
| ) | |
| return { | |
| 'audio_data': audio_data, | |
| 'image_data': image_data, | |
| 'paths': paths | |
| } | |
| def _preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Preprocess loaded data""" | |
| # Extract audio features | |
| audio = data['audio_data']['audio'] | |
| sr = data['audio_data']['sample_rate'] | |
| # Prepare for SDK | |
| import librosa | |
| import math | |
| # Calculate number of frames | |
| num_frames = math.ceil(len(audio) / 16000 * 25) | |
| # Prepare image | |
| image = data['image_data']['image'] | |
| return { | |
| 'audio': audio, | |
| 'image': image, | |
| 'num_frames': num_frames, | |
| 'paths': data['paths'] | |
| } | |
| def _inference(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Run inference""" | |
| # This would integrate with the actual SDK inference | |
| # For now, placeholder | |
| return { | |
| 'result': 'inference_result', | |
| 'paths': data['paths'] | |
| } | |
| def _postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Postprocess results""" | |
| return data | |
| async def process_parallel_async( | |
| self, | |
| audio_path: str, | |
| image_path: str, | |
| output_path: str, | |
| **kwargs | |
| ) -> Tuple[str, float]: | |
| """ | |
| Process with full parallelization (async) | |
| Args: | |
| audio_path: Path to audio file | |
| image_path: Path to image file | |
| output_path: Output video path | |
| **kwargs: Additional parameters | |
| Returns: | |
| Tuple of (output_path, process_time) | |
| """ | |
| start_time = time.time() | |
| # Parallel preprocessing | |
| audio_data, image_data = await self.parallel_processor.preprocess_parallel_async( | |
| audio_path, image_path, kwargs.get('target_size', 320) | |
| ) | |
| # Run inference (simplified for integration) | |
| # In real implementation, this would call SDK methods | |
| process_time = time.time() - start_time | |
| return output_path, process_time | |
| def process_parallel_sync( | |
| self, | |
| audio_path: str, | |
| image_path: str, | |
| output_path: str, | |
| **kwargs | |
| ) -> Tuple[str, float]: | |
| """ | |
| Process with parallelization (sync) | |
| Args: | |
| audio_path: Path to audio file | |
| image_path: Path to image file | |
| output_path: Output video path | |
| **kwargs: Additional parameters | |
| Returns: | |
| Tuple of (output_path, process_time) | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Parallel preprocessing | |
| print("🔄 Starting parallel preprocessing...") | |
| preprocess_start = time.time() | |
| audio_data, image_data = self.parallel_processor.preprocess_parallel_sync( | |
| audio_path, image_path, kwargs.get('target_size', 320) | |
| ) | |
| preprocess_time = time.time() - preprocess_start | |
| print(f"✅ Parallel preprocessing completed in {preprocess_time:.2f}s") | |
| # Run actual SDK inference | |
| # This integrates with the existing SDK | |
| from inference import run, seed_everything | |
| seed_everything(kwargs.get('seed', 1024)) | |
| inference_start = time.time() | |
| run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {})) | |
| inference_time = time.time() - inference_start | |
| print(f"✅ Inference completed in {inference_time:.2f}s") | |
| total_time = time.time() - start_time | |
| # Performance breakdown | |
| print(f""" | |
| 🎯 Performance Breakdown: | |
| - Preprocessing (parallel): {preprocess_time:.2f}s | |
| - Inference: {inference_time:.2f}s | |
| - Total: {total_time:.2f}s | |
| """) | |
| return output_path, total_time | |
| except Exception as e: | |
| print(f"❌ Error in parallel processing: {e}") | |
| raise | |
| def get_performance_stats(self) -> Dict[str, Any]: | |
| """Get performance statistics""" | |
| return { | |
| 'num_threads': self.parallel_processor.num_threads, | |
| 'num_processes': self.parallel_processor.num_processes, | |
| 'cuda_streams_enabled': self.parallel_processor.use_cuda_streams | |
| } | |
| class OptimizedInferenceWrapper: | |
| """ | |
| Wrapper that combines all optimizations | |
| """ | |
| def __init__( | |
| self, | |
| sdk, | |
| use_parallel: bool = True, | |
| use_cache: bool = True, | |
| use_gpu_opt: bool = True | |
| ): | |
| """ | |
| Initialize optimized inference wrapper | |
| Args: | |
| sdk: StreamSDK instance | |
| use_parallel: Enable parallel processing | |
| use_cache: Enable caching | |
| use_gpu_opt: Enable GPU optimizations | |
| """ | |
| self.sdk = sdk | |
| self.use_parallel = use_parallel | |
| self.use_cache = use_cache | |
| self.use_gpu_opt = use_gpu_opt | |
| # Initialize components | |
| if use_parallel: | |
| self.parallel_processor = ParallelProcessor(num_threads=4) | |
| self.parallel_inference = ParallelInference(sdk, self.parallel_processor) | |
| else: | |
| self.parallel_processor = None | |
| self.parallel_inference = None | |
| def process( | |
| self, | |
| audio_path: str, | |
| image_path: str, | |
| output_path: str, | |
| **kwargs | |
| ) -> Tuple[str, float, Dict[str, Any]]: | |
| """ | |
| Process with all optimizations | |
| Returns: | |
| Tuple of (output_path, process_time, stats) | |
| """ | |
| stats = { | |
| 'parallel_enabled': self.use_parallel, | |
| 'cache_enabled': self.use_cache, | |
| 'gpu_opt_enabled': self.use_gpu_opt | |
| } | |
| if self.use_parallel and self.parallel_inference: | |
| output_path, process_time = self.parallel_inference.process_parallel_sync( | |
| audio_path, image_path, output_path, **kwargs | |
| ) | |
| stats['preprocessing'] = 'parallel' | |
| else: | |
| # Fallback to sequential | |
| from inference import run, seed_everything | |
| start_time = time.time() | |
| seed_everything(kwargs.get('seed', 1024)) | |
| run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {})) | |
| process_time = time.time() - start_time | |
| stats['preprocessing'] = 'sequential' | |
| stats['process_time'] = process_time | |
| return output_path, process_time, stats | |
| def shutdown(self): | |
| """Cleanup resources""" | |
| if self.parallel_processor: | |
| self.parallel_processor.shutdown() |