Spaces:
Runtime error
Runtime error
| """ | |
| Configuration Module for TTS Pipeline | |
| ===================================== | |
| Centralized configuration management for all pipeline components. | |
| """ | |
| import os | |
| from dataclasses import dataclass | |
| from typing import Optional, Dict, Any | |
| import torch | |
| class TextProcessingConfig: | |
| """Configuration for text processing components.""" | |
| max_chunk_length: int = 200 | |
| overlap_words: int = 5 | |
| translation_timeout: int = 10 | |
| enable_caching: bool = True | |
| cache_size: int = 1000 | |
| class ModelConfig: | |
| """Configuration for TTS model components.""" | |
| checkpoint: str = "Edmon02/TTS_NB_2" | |
| vocoder_checkpoint: str = "microsoft/speecht5_hifigan" | |
| device: Optional[str] = None | |
| use_mixed_precision: bool = True | |
| cache_embeddings: bool = True | |
| max_text_positions: int = 600 | |
| class AudioProcessingConfig: | |
| """Configuration for audio processing components.""" | |
| crossfade_duration: float = 0.1 | |
| sample_rate: int = 16000 | |
| apply_noise_gate: bool = True | |
| normalize_audio: bool = True | |
| noise_gate_threshold_db: float = -40.0 | |
| target_peak: float = 0.95 | |
| class PipelineConfig: | |
| """Main pipeline configuration.""" | |
| enable_chunking: bool = True | |
| apply_audio_processing: bool = True | |
| enable_performance_tracking: bool = True | |
| max_concurrent_requests: int = 5 | |
| warmup_on_init: bool = True | |
| class DeploymentConfig: | |
| """Deployment-specific configuration.""" | |
| environment: str = "production" # development, staging, production | |
| log_level: str = "INFO" | |
| enable_health_checks: bool = True | |
| max_memory_mb: int = 2000 | |
| gpu_memory_fraction: float = 0.8 | |
| class ConfigManager: | |
| """Centralized configuration manager.""" | |
| def __init__(self, environment: str = "production"): | |
| self.environment = environment | |
| self._load_environment_config() | |
| def _load_environment_config(self): | |
| """Load configuration based on environment.""" | |
| if self.environment == "development": | |
| self._load_dev_config() | |
| elif self.environment == "staging": | |
| self._load_staging_config() | |
| else: | |
| self._load_production_config() | |
| def _load_production_config(self): | |
| """Production environment configuration.""" | |
| self.text_processing = TextProcessingConfig( | |
| max_chunk_length=200, | |
| overlap_words=5, | |
| translation_timeout=10, | |
| enable_caching=True, | |
| cache_size=1000 | |
| ) | |
| self.model = ModelConfig( | |
| device=self._auto_detect_device(), | |
| use_mixed_precision=torch.cuda.is_available(), | |
| cache_embeddings=True | |
| ) | |
| self.audio_processing = AudioProcessingConfig( | |
| crossfade_duration=0.1, | |
| apply_noise_gate=True, | |
| normalize_audio=True | |
| ) | |
| self.pipeline = PipelineConfig( | |
| enable_chunking=True, | |
| apply_audio_processing=True, | |
| enable_performance_tracking=True, | |
| max_concurrent_requests=5 | |
| ) | |
| self.deployment = DeploymentConfig( | |
| environment="production", | |
| log_level="INFO", | |
| enable_health_checks=True, | |
| max_memory_mb=2000 | |
| ) | |
| def _load_dev_config(self): | |
| """Development environment configuration.""" | |
| self.text_processing = TextProcessingConfig( | |
| max_chunk_length=100, # Smaller chunks for testing | |
| translation_timeout=5, # Shorter timeout for dev | |
| cache_size=100 | |
| ) | |
| self.model = ModelConfig( | |
| device="cpu", # Force CPU for consistent dev testing | |
| use_mixed_precision=False | |
| ) | |
| self.audio_processing = AudioProcessingConfig( | |
| crossfade_duration=0.05 # Shorter for faster testing | |
| ) | |
| self.pipeline = PipelineConfig( | |
| max_concurrent_requests=2 # Limited for dev | |
| ) | |
| self.deployment = DeploymentConfig( | |
| environment="development", | |
| log_level="DEBUG", | |
| max_memory_mb=1000 | |
| ) | |
| def _load_staging_config(self): | |
| """Staging environment configuration.""" | |
| # Similar to production but with more logging and smaller limits | |
| self._load_production_config() | |
| self.deployment.log_level = "DEBUG" | |
| self.deployment.max_memory_mb = 1500 | |
| self.pipeline.max_concurrent_requests = 3 | |
| def _auto_detect_device(self) -> str: | |
| """Auto-detect optimal device for deployment.""" | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
| return "mps" # Apple Silicon | |
| else: | |
| return "cpu" | |
| def get_all_config(self) -> Dict[str, Any]: | |
| """Get all configuration as dictionary.""" | |
| return { | |
| "text_processing": self.text_processing.__dict__, | |
| "model": self.model.__dict__, | |
| "audio_processing": self.audio_processing.__dict__, | |
| "pipeline": self.pipeline.__dict__, | |
| "deployment": self.deployment.__dict__ | |
| } | |
| def update_from_env(self): | |
| """Update configuration from environment variables.""" | |
| # Text processing | |
| if os.getenv("TTS_MAX_CHUNK_LENGTH"): | |
| self.text_processing.max_chunk_length = int(os.getenv("TTS_MAX_CHUNK_LENGTH")) | |
| if os.getenv("TTS_TRANSLATION_TIMEOUT"): | |
| self.text_processing.translation_timeout = int(os.getenv("TTS_TRANSLATION_TIMEOUT")) | |
| # Model | |
| if os.getenv("TTS_MODEL_CHECKPOINT"): | |
| self.model.checkpoint = os.getenv("TTS_MODEL_CHECKPOINT") | |
| if os.getenv("TTS_DEVICE"): | |
| self.model.device = os.getenv("TTS_DEVICE") | |
| if os.getenv("TTS_USE_MIXED_PRECISION"): | |
| self.model.use_mixed_precision = os.getenv("TTS_USE_MIXED_PRECISION").lower() == "true" | |
| # Audio processing | |
| if os.getenv("TTS_CROSSFADE_DURATION"): | |
| self.audio_processing.crossfade_duration = float(os.getenv("TTS_CROSSFADE_DURATION")) | |
| # Pipeline | |
| if os.getenv("TTS_MAX_CONCURRENT"): | |
| self.pipeline.max_concurrent_requests = int(os.getenv("TTS_MAX_CONCURRENT")) | |
| # Deployment | |
| if os.getenv("TTS_LOG_LEVEL"): | |
| self.deployment.log_level = os.getenv("TTS_LOG_LEVEL") | |
| if os.getenv("TTS_MAX_MEMORY_MB"): | |
| self.deployment.max_memory_mb = int(os.getenv("TTS_MAX_MEMORY_MB")) | |
| # Global config instance | |
| config = ConfigManager() | |
| # Environment variable overrides | |
| config.update_from_env() | |
| def get_config() -> ConfigManager: | |
| """Get the global configuration instance.""" | |
| return config | |
| def update_config(environment: str): | |
| """Update configuration for specific environment.""" | |
| global config | |
| config = ConfigManager(environment) | |
| config.update_from_env() | |
| return config | |