Spaces:
Sleeping
Sleeping
| """ | |
| Configuration loader for the AI Assistant | |
| """ | |
| import yaml | |
| import os | |
| import logging | |
| from typing import Dict, Any, Optional, List | |
| from pathlib import Path | |
| logger = logging.getLogger(__name__) | |
| class ConfigLoader: | |
| """Loads and manages configuration from YAML file""" | |
| def __init__(self, config_path: Optional[str] = None): | |
| """Initialize config loader with optional custom path""" | |
| if config_path is None: | |
| config_path = Path(__file__).parent / "config.yaml" | |
| self.config_path = Path(config_path) | |
| self.config = self._load_config() | |
| # Apply environment variable overrides | |
| self._apply_env_overrides() | |
| def _load_config(self) -> Dict[str, Any]: | |
| """Load configuration from YAML file""" | |
| try: | |
| with open(self.config_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| logger.info(f"Loaded configuration from {self.config_path}") | |
| return config | |
| except FileNotFoundError: | |
| logger.warning(f"Config file not found at {self.config_path}, using defaults") | |
| return self._get_default_config() | |
| except Exception as e: | |
| logger.error(f"Error loading config: {e}, using defaults") | |
| return self._get_default_config() | |
| def _get_default_config(self) -> Dict[str, Any]: | |
| """Return default configuration if file not found""" | |
| return { | |
| "models": { | |
| "available": [ | |
| {"model_id": "gpt-4o-mini", "display_name": "GPT-4 Omni Mini", "max_tokens": 16384, "default": True}, | |
| {"model_id": "gpt-4o", "display_name": "GPT-4 Omni", "max_tokens": 128000}, | |
| {"model_id": "gpt-4", "display_name": "GPT-4", "max_tokens": 8192}, | |
| {"model_id": "gpt-3.5-turbo", "display_name": "GPT-3.5 Turbo", "max_tokens": 16384} | |
| ], | |
| "default_temperature": 0.0, | |
| "default_max_tokens": 4000 | |
| }, | |
| "products": { | |
| "available": [ | |
| {"id": "harmony", "display_name": "Harmony", "versions": ["1.8", "1.6", "1.5", "1.2"], "default_version": "1.8"}, | |
| {"id": "chorus", "display_name": "Chorus", "versions": ["1.1"], "default_version": "1.1"} | |
| ], | |
| "default_product": "harmony" | |
| }, | |
| "rag": { | |
| "default_k": 5, | |
| "max_k": 10 | |
| }, | |
| "tools": { | |
| "max_tool_calls": 5, | |
| "tool_timeout": 30 | |
| }, | |
| "agents": { | |
| "default_agent": "document_reader" | |
| } | |
| } | |
| def _apply_env_overrides(self): | |
| """Apply environment variable overrides to config""" | |
| # Example: ASSISTANT_RAG_DEFAULT_K=10 overrides rag.default_k | |
| prefix = "ASSISTANT_" | |
| for key, value in os.environ.items(): | |
| if key.startswith(prefix): | |
| # Convert ASSISTANT_RAG_DEFAULT_K to rag.default_k | |
| config_path = key[len(prefix):].lower().replace('_', '.') | |
| self._set_nested_value(config_path, value) | |
| def _set_nested_value(self, path: str, value: str): | |
| """Set a nested configuration value using dot notation""" | |
| keys = path.split('.') | |
| current = self.config | |
| for key in keys[:-1]: | |
| if key not in current: | |
| current[key] = {} | |
| current = current[key] | |
| # Try to convert value to appropriate type | |
| try: | |
| if value.lower() in ['true', 'false']: | |
| value = value.lower() == 'true' | |
| elif value.isdigit(): | |
| value = int(value) | |
| elif '.' in value and value.replace('.', '').isdigit(): | |
| value = float(value) | |
| except: | |
| pass # Keep as string | |
| current[keys[-1]] = value | |
| logger.debug(f"Override config {path} = {value}") | |
| def get(self, path: str, default: Any = None) -> Any: | |
| """Get configuration value using dot notation""" | |
| keys = path.split('.') | |
| current = self.config | |
| for key in keys: | |
| if isinstance(current, dict) and key in current: | |
| current = current[key] | |
| else: | |
| return default | |
| return current | |
| def get_available_models(self) -> List[Dict[str, Any]]: | |
| """Get list of available models""" | |
| return self.get("models.available", []) | |
| def get_model_ids(self) -> List[str]: | |
| """Get list of model IDs""" | |
| return [m["model_id"] for m in self.get_available_models()] | |
| def get_default_model(self) -> str: | |
| """Get default model ID""" | |
| models = self.get_available_models() | |
| for model in models: | |
| if model.get("default", False): | |
| return model["model_id"] | |
| return models[0]["model_id"] if models else "gpt-4o-mini" | |
| def get_available_products(self) -> List[Dict[str, Any]]: | |
| """Get list of available products""" | |
| return self.get("products.available", []) | |
| def get_product_ids(self) -> List[str]: | |
| """Get list of product IDs""" | |
| return [p["id"] for p in self.get_available_products()] | |
| def get_product_versions(self, product_id: str) -> List[str]: | |
| """Get versions for a specific product""" | |
| products = self.get_available_products() | |
| for product in products: | |
| if product["id"] == product_id: | |
| return product.get("versions", []) | |
| return [] | |
| def get_default_product(self) -> str: | |
| """Get default product ID""" | |
| return self.get("products.default_product", "harmony") | |
| def get_default_version(self, product_id: str) -> str: | |
| """Get default version for a product""" | |
| products = self.get_available_products() | |
| for product in products: | |
| if product["id"] == product_id: | |
| return product.get("default_version", product["versions"][0]) | |
| return "1.0" | |
| def get_rag_k(self) -> int: | |
| """Get default number of RAG results""" | |
| return self.get("rag.default_k", 5) | |
| def get_max_tool_calls(self) -> int: | |
| """Get maximum number of tool calls""" | |
| return self.get("tools.max_tool_calls", 5) | |
| def get_agent_config(self, agent_id: str) -> Dict[str, Any]: | |
| """Get configuration for a specific agent""" | |
| return self.get(f"agents.{agent_id}", {}) | |
| def is_agent_enabled(self, agent_id: str) -> bool: | |
| """Check if an agent is enabled""" | |
| return self.get(f"agents.{agent_id}.enabled", True) | |
| def get_ui_config(self) -> Dict[str, Any]: | |
| """Get UI configuration""" | |
| return self.get("ui", {}) | |
| def get_logging_config(self) -> Dict[str, Any]: | |
| """Get logging configuration""" | |
| return self.get("logging", {}) | |
| def reload(self): | |
| """Reload configuration from file""" | |
| self.config = self._load_config() | |
| self._apply_env_overrides() | |
| logger.info("Configuration reloaded") | |
| def save(self, config_path: Optional[str] = None): | |
| """Save current configuration to file""" | |
| save_path = config_path or self.config_path | |
| try: | |
| with open(save_path, 'w') as f: | |
| yaml.dump(self.config, f, default_flow_style=False, sort_keys=False) | |
| logger.info(f"Configuration saved to {save_path}") | |
| except Exception as e: | |
| logger.error(f"Error saving configuration: {e}") | |
| raise | |
| # Global config instance | |
| _config = None | |
| def get_config() -> ConfigLoader: | |
| """Get global configuration instance""" | |
| global _config | |
| if _config is None: | |
| _config = ConfigLoader() | |
| return _config |