Spaces:
Running
Running
| # config.py | |
| # Configuration management for Dia TTS server | |
| import os | |
| import logging | |
| from dotenv import load_dotenv, find_dotenv, set_key | |
| from typing import Dict, Any, Optional | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| # Default configuration values (used if not found in .env or environment) | |
| DEFAULT_CONFIG = { | |
| # Server Settings | |
| "HOST": "0.0.0.0", | |
| "PORT": "8003", | |
| # Model Source Settings | |
| "DIA_MODEL_REPO_ID": "ttj/dia-1.6b-safetensors", # Default to safetensors repo | |
| "DIA_MODEL_CONFIG_FILENAME": "config.json", # Standard config filename | |
| "DIA_MODEL_WEIGHTS_FILENAME": "dia-v0_1_bf16.safetensors", # Default to BF16 weights | |
| # Path Settings | |
| "DIA_MODEL_CACHE_PATH": "./model_cache", | |
| "REFERENCE_AUDIO_PATH": "./reference_audio", | |
| "OUTPUT_PATH": "./outputs", | |
| # Default Generation Parameters (can be overridden by user in UI/API) | |
| # These are saved to .env via the UI's "Save Generation Defaults" button | |
| "GEN_DEFAULT_SPEED_FACTOR": "0.90", # Default speed slightly slower | |
| "GEN_DEFAULT_CFG_SCALE": "3.0", | |
| "GEN_DEFAULT_TEMPERATURE": "1.3", | |
| "GEN_DEFAULT_TOP_P": "0.95", | |
| "GEN_DEFAULT_CFG_FILTER_TOP_K": "35", | |
| } | |
| class ConfigManager: | |
| """Manages configuration for the TTS server with .env file support.""" | |
| def __init__(self): | |
| """Initialize the configuration manager.""" | |
| self.config = {} | |
| self.env_file = find_dotenv() | |
| if not self.env_file: | |
| self.env_file = os.path.join(os.getcwd(), ".env") | |
| logger.info( | |
| f"No .env file found, creating one with defaults at {self.env_file}" | |
| ) | |
| self._create_default_env_file() | |
| else: | |
| logger.info(f"Loading configuration from: {self.env_file}") | |
| self.reload() | |
| def _create_default_env_file(self): | |
| """Create a default .env file with default values.""" | |
| try: | |
| with open(self.env_file, "w") as f: | |
| for key, value in DEFAULT_CONFIG.items(): | |
| f.write(f"{key}={value}\n") | |
| logger.info("Created default .env file") | |
| except Exception as e: | |
| logger.error(f"Failed to create default .env file: {e}") | |
| def reload(self): | |
| """Reload configuration from .env file and environment variables.""" | |
| load_dotenv(self.env_file, override=True) | |
| loaded_config = {} | |
| for key, default_value in DEFAULT_CONFIG.items(): | |
| loaded_config[key] = os.environ.get(key, default_value) | |
| self.config = loaded_config | |
| logger.info("Configuration loaded/reloaded.") | |
| logger.debug(f"Current config: {self.config}") | |
| return self.config | |
| def get(self, key: str, default: Any = None) -> Any: | |
| """Get a configuration value by key.""" | |
| return self.config.get(key, default) | |
| def set(self, key: str, value: Any) -> None: | |
| """Set a configuration value in memory (does not save automatically).""" | |
| self.config[key] = value | |
| logger.debug(f"Configuration value set in memory: {key}={value}") | |
| def save(self) -> bool: | |
| """Save the current in-memory configuration to the .env file.""" | |
| if not self.env_file: | |
| logger.error("Cannot save configuration, .env file path not set.") | |
| return False | |
| try: | |
| for key in DEFAULT_CONFIG.keys(): | |
| if key not in self.config: | |
| logger.warning( | |
| f"Key '{key}' missing from current config, adding default value before saving." | |
| ) | |
| self.config[key] = DEFAULT_CONFIG[key] | |
| for key, value in self.config.items(): | |
| if key in DEFAULT_CONFIG: | |
| set_key(self.env_file, key, str(value)) | |
| logger.info(f"Configuration saved to {self.env_file}") | |
| return True | |
| except Exception as e: | |
| logger.error( | |
| f"Failed to save configuration to {self.env_file}: {e}", exc_info=True | |
| ) | |
| return False | |
| def get_all(self) -> Dict[str, Any]: | |
| """Get all current configuration values.""" | |
| return self.config.copy() | |
| def update(self, new_config: Dict[str, Any]) -> None: | |
| """Update multiple configuration values in memory from a dictionary.""" | |
| updated_keys = [] | |
| for key, value in new_config.items(): | |
| if key in DEFAULT_CONFIG: | |
| self.config[key] = value | |
| updated_keys.append(key) | |
| else: | |
| logger.warning( | |
| f"Attempted to update unknown config key: {key}. Ignoring." | |
| ) | |
| if updated_keys: | |
| logger.debug( | |
| f"Configuration values updated in memory for keys: {updated_keys}" | |
| ) | |
| def get_int(self, key: str, default: Optional[int] = None) -> int: | |
| """Get a configuration value as an integer, with error handling.""" | |
| value_str = self.get(key) # Get value which might be from env (str) or default | |
| if value_str is None: # Key not found at all | |
| if default is not None: | |
| logger.warning( | |
| f"Config key '{key}' not found, using provided default: {default}" | |
| ) | |
| return default | |
| else: | |
| logger.error( | |
| f"Mandatory config key '{key}' not found and no default provided. Returning 0." | |
| ) | |
| return 0 # Or raise error | |
| try: | |
| return int(value_str) | |
| except (ValueError, TypeError): | |
| logger.warning( | |
| f"Invalid integer value '{value_str}' for config key '{key}', using default: {default}" | |
| ) | |
| if isinstance(default, int): | |
| return default | |
| elif default is None: | |
| logger.error( | |
| f"Cannot parse '{value_str}' as int for key '{key}' and no valid default. Returning 0." | |
| ) | |
| return 0 | |
| else: # Default was provided but not an int | |
| logger.error( | |
| f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0." | |
| ) | |
| return 0 | |
| def get_float(self, key: str, default: Optional[float] = None) -> float: | |
| """Get a configuration value as a float, with error handling.""" | |
| value_str = self.get(key) | |
| if value_str is None: | |
| if default is not None: | |
| logger.warning( | |
| f"Config key '{key}' not found, using provided default: {default}" | |
| ) | |
| return default | |
| else: | |
| logger.error( | |
| f"Mandatory config key '{key}' not found and no default provided. Returning 0.0." | |
| ) | |
| return 0.0 | |
| try: | |
| return float(value_str) | |
| except (ValueError, TypeError): | |
| logger.warning( | |
| f"Invalid float value '{value_str}' for config key '{key}', using default: {default}" | |
| ) | |
| if isinstance(default, float): | |
| return default | |
| elif default is None: | |
| logger.error( | |
| f"Cannot parse '{value_str}' as float for key '{key}' and no valid default. Returning 0.0." | |
| ) | |
| return 0.0 | |
| else: | |
| logger.error( | |
| f"Invalid default value type for key '{key}'. Cannot parse '{value_str}'. Returning 0.0." | |
| ) | |
| return 0.0 | |
| # --- Create a singleton instance for global access --- | |
| config_manager = ConfigManager() | |
| # --- Export common getters for easy access --- | |
| # Server Settings | |
| def get_host() -> str: | |
| """Gets the host address for the server.""" | |
| return config_manager.get("HOST", DEFAULT_CONFIG["HOST"]) | |
| def get_port() -> int: | |
| """Gets the port number for the server.""" | |
| # Ensure default is parsed correctly if get_int fails on env var | |
| return config_manager.get_int("PORT", int(DEFAULT_CONFIG["PORT"])) | |
| # Model Source Settings | |
| def get_model_repo_id() -> str: | |
| """Gets the Hugging Face repository ID for the model.""" | |
| return config_manager.get("DIA_MODEL_REPO_ID", DEFAULT_CONFIG["DIA_MODEL_REPO_ID"]) | |
| def get_model_config_filename() -> str: | |
| """Gets the filename for the model's configuration file within the repo.""" | |
| return config_manager.get( | |
| "DIA_MODEL_CONFIG_FILENAME", DEFAULT_CONFIG["DIA_MODEL_CONFIG_FILENAME"] | |
| ) | |
| def get_model_weights_filename() -> str: | |
| """Gets the filename for the model's weights file within the repo.""" | |
| return config_manager.get( | |
| "DIA_MODEL_WEIGHTS_FILENAME", DEFAULT_CONFIG["DIA_MODEL_WEIGHTS_FILENAME"] | |
| ) | |
| # Path Settings | |
| def get_model_cache_path() -> str: | |
| """Gets the local directory path for caching downloaded models.""" | |
| return os.path.abspath( | |
| config_manager.get( | |
| "DIA_MODEL_CACHE_PATH", DEFAULT_CONFIG["DIA_MODEL_CACHE_PATH"] | |
| ) | |
| ) | |
| def get_reference_audio_path() -> str: | |
| """Gets the local directory path for storing reference audio files for cloning.""" | |
| return os.path.abspath( | |
| config_manager.get( | |
| "REFERENCE_AUDIO_PATH", DEFAULT_CONFIG["REFERENCE_AUDIO_PATH"] | |
| ) | |
| ) | |
| def get_output_path() -> str: | |
| """Gets the local directory path for saving generated audio outputs.""" | |
| return os.path.abspath( | |
| config_manager.get("OUTPUT_PATH", DEFAULT_CONFIG["OUTPUT_PATH"]) | |
| ) | |
| # Default Generation Parameter Getters | |
| def get_gen_default_speed_factor() -> float: | |
| """Gets the default speed factor for generation.""" | |
| return config_manager.get_float( | |
| "GEN_DEFAULT_SPEED_FACTOR", float(DEFAULT_CONFIG["GEN_DEFAULT_SPEED_FACTOR"]) | |
| ) | |
| def get_gen_default_cfg_scale() -> float: | |
| """Gets the default CFG scale for generation.""" | |
| return config_manager.get_float( | |
| "GEN_DEFAULT_CFG_SCALE", float(DEFAULT_CONFIG["GEN_DEFAULT_CFG_SCALE"]) | |
| ) | |
| def get_gen_default_temperature() -> float: | |
| """Gets the default temperature for generation.""" | |
| return config_manager.get_float( | |
| "GEN_DEFAULT_TEMPERATURE", float(DEFAULT_CONFIG["GEN_DEFAULT_TEMPERATURE"]) | |
| ) | |
| def get_gen_default_top_p() -> float: | |
| """Gets the default top_p for generation.""" | |
| return config_manager.get_float( | |
| "GEN_DEFAULT_TOP_P", float(DEFAULT_CONFIG["GEN_DEFAULT_TOP_P"]) | |
| ) | |
| def get_gen_default_cfg_filter_top_k() -> int: | |
| """Gets the default CFG filter top_k for generation.""" | |
| return config_manager.get_int( | |
| "GEN_DEFAULT_CFG_FILTER_TOP_K", | |
| int(DEFAULT_CONFIG["GEN_DEFAULT_CFG_FILTER_TOP_K"]), | |
| ) | |