Spaces:
Runtime error
Runtime error
| """ | |
| Cold Start Optimization for DittoTalkingHead | |
| Reduces model loading time and I/O overhead | |
| """ | |
| import os | |
| import shutil | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, Any, Optional | |
| import pickle | |
| import torch | |
| class ColdStartOptimizer: | |
| """ | |
| Optimizes cold start time by using persistent storage and efficient loading | |
| """ | |
| def __init__(self, persistent_dir: str = "/tmp/persistent_model_cache"): | |
| """ | |
| Initialize cold start optimizer | |
| Args: | |
| persistent_dir: Directory for persistent storage (survives restarts) | |
| """ | |
| self.persistent_dir = Path(persistent_dir) | |
| self.persistent_dir.mkdir(parents=True, exist_ok=True) | |
| # Hugging Face Spaces persistent paths | |
| self.hf_persistent_paths = [ | |
| "/data", # Primary persistent storage | |
| "/tmp/persistent", # Fallback | |
| ] | |
| # Model cache settings | |
| self.model_cache = {} | |
| self.load_times = {} | |
| def get_persistent_path(self) -> Path: | |
| """ | |
| Get the best available persistent path | |
| Returns: | |
| Path to persistent storage | |
| """ | |
| # Check Hugging Face Spaces persistent directories | |
| for path in self.hf_persistent_paths: | |
| if os.path.exists(path) and os.access(path, os.W_OK): | |
| return Path(path) / "model_cache" | |
| # Fallback to configured directory | |
| return self.persistent_dir | |
| def setup_persistent_model_cache(self, source_dir: str) -> bool: | |
| """ | |
| Set up persistent model cache | |
| Args: | |
| source_dir: Source directory containing models | |
| Returns: | |
| True if successful | |
| """ | |
| persistent_path = self.get_persistent_path() | |
| persistent_path.mkdir(parents=True, exist_ok=True) | |
| source_path = Path(source_dir) | |
| if not source_path.exists(): | |
| print(f"Source directory {source_dir} not found") | |
| return False | |
| # Copy models to persistent storage if not already there | |
| model_files = list(source_path.glob("**/*.pth")) + \ | |
| list(source_path.glob("**/*.pkl")) + \ | |
| list(source_path.glob("**/*.onnx")) + \ | |
| list(source_path.glob("**/*.trt")) | |
| copied = 0 | |
| for model_file in model_files: | |
| relative_path = model_file.relative_to(source_path) | |
| target_path = persistent_path / relative_path | |
| if not target_path.exists(): | |
| target_path.parent.mkdir(parents=True, exist_ok=True) | |
| shutil.copy2(model_file, target_path) | |
| copied += 1 | |
| print(f"Copied {relative_path} to persistent storage") | |
| print(f"Persistent cache setup complete. Copied {copied} new files.") | |
| return True | |
| def load_model_cached( | |
| self, | |
| model_path: str, | |
| load_func: callable, | |
| cache_key: Optional[str] = None | |
| ) -> Any: | |
| """ | |
| Load model with caching | |
| Args: | |
| model_path: Path to model file | |
| load_func: Function to load the model | |
| cache_key: Optional cache key (defaults to model_path) | |
| Returns: | |
| Loaded model | |
| """ | |
| cache_key = cache_key or model_path | |
| # Check in-memory cache first | |
| if cache_key in self.model_cache: | |
| print(f"✅ Loaded {cache_key} from memory cache") | |
| return self.model_cache[cache_key] | |
| # Check persistent storage | |
| persistent_path = self.get_persistent_path() | |
| model_name = Path(model_path).name | |
| persistent_model_path = persistent_path / model_name | |
| start_time = time.time() | |
| if persistent_model_path.exists(): | |
| # Load from persistent storage | |
| print(f"Loading {model_name} from persistent storage...") | |
| model = load_func(str(persistent_model_path)) | |
| else: | |
| # Load from original path | |
| print(f"Loading {model_name} from original location...") | |
| model = load_func(model_path) | |
| # Try to copy to persistent storage | |
| try: | |
| shutil.copy2(model_path, persistent_model_path) | |
| print(f"Cached {model_name} to persistent storage") | |
| except Exception as e: | |
| print(f"Warning: Could not cache to persistent storage: {e}") | |
| load_time = time.time() - start_time | |
| self.load_times[cache_key] = load_time | |
| # Cache in memory | |
| self.model_cache[cache_key] = model | |
| print(f"✅ Loaded {cache_key} in {load_time:.2f}s") | |
| return model | |
| def preload_models(self, model_configs: Dict[str, Dict[str, Any]]): | |
| """ | |
| Preload multiple models in parallel | |
| Args: | |
| model_configs: Dictionary of model configurations | |
| { | |
| 'model_name': { | |
| 'path': 'path/to/model', | |
| 'load_func': callable, | |
| 'priority': int (0-10) | |
| } | |
| } | |
| """ | |
| # Sort by priority | |
| sorted_models = sorted( | |
| model_configs.items(), | |
| key=lambda x: x[1].get('priority', 5), | |
| reverse=True | |
| ) | |
| for model_name, config in sorted_models: | |
| try: | |
| self.load_model_cached( | |
| config['path'], | |
| config['load_func'], | |
| cache_key=model_name | |
| ) | |
| except Exception as e: | |
| print(f"Error preloading {model_name}: {e}") | |
| def optimize_gradio_settings(self) -> Dict[str, Any]: | |
| """ | |
| Get optimized Gradio settings for faster response | |
| Returns: | |
| Gradio launch parameters | |
| """ | |
| return { | |
| 'max_threads': 40, # Increase parallel processing | |
| 'show_error': True, | |
| 'server_name': '0.0.0.0', | |
| 'server_port': 7860, | |
| 'share': False, # Disable share link for faster startup | |
| } | |
| def get_optimization_stats(self) -> Dict[str, Any]: | |
| """ | |
| Get cold start optimization statistics | |
| Returns: | |
| Optimization statistics | |
| """ | |
| persistent_path = self.get_persistent_path() | |
| # Count cached files | |
| cached_files = 0 | |
| total_size = 0 | |
| if persistent_path.exists(): | |
| for file in persistent_path.rglob("*"): | |
| if file.is_file(): | |
| cached_files += 1 | |
| total_size += file.stat().st_size | |
| return { | |
| 'persistent_path': str(persistent_path), | |
| 'cached_models': len(self.model_cache), | |
| 'cached_files': cached_files, | |
| 'total_cache_size_mb': total_size / (1024 * 1024), | |
| 'load_times': self.load_times, | |
| 'average_load_time': sum(self.load_times.values()) / len(self.load_times) if self.load_times else 0 | |
| } | |
| def clear_memory_cache(self): | |
| """Clear in-memory model cache""" | |
| self.model_cache.clear() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print("Memory cache cleared") | |
| def setup_streaming_response(self) -> Dict[str, Any]: | |
| """ | |
| Set up configuration for streaming responses | |
| Returns: | |
| Streaming configuration | |
| """ | |
| return { | |
| 'stream_output': True, | |
| 'buffer_size': 8192, # 8KB buffer | |
| 'chunk_size': 1024, # 1KB chunks | |
| 'enable_compression': True, | |
| 'compression_level': 6 # Balanced compression | |
| } |