v3_ai_assistant / py /config_loader.py
Julian Vanecek
Initial commit: AI Assistant Multi-Agent System for HuggingFace Spaces
bb80caa
"""
Configuration loader for the AI Assistant
"""
import yaml
import os
import logging
from typing import Dict, Any, Optional, List
from pathlib import Path
logger = logging.getLogger(__name__)
class ConfigLoader:
"""Loads and manages configuration from YAML file"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize config loader with optional custom path"""
if config_path is None:
config_path = Path(__file__).parent / "config.yaml"
self.config_path = Path(config_path)
self.config = self._load_config()
# Apply environment variable overrides
self._apply_env_overrides()
def _load_config(self) -> Dict[str, Any]:
"""Load configuration from YAML file"""
try:
with open(self.config_path, 'r') as f:
config = yaml.safe_load(f)
logger.info(f"Loaded configuration from {self.config_path}")
return config
except FileNotFoundError:
logger.warning(f"Config file not found at {self.config_path}, using defaults")
return self._get_default_config()
except Exception as e:
logger.error(f"Error loading config: {e}, using defaults")
return self._get_default_config()
def _get_default_config(self) -> Dict[str, Any]:
"""Return default configuration if file not found"""
return {
"models": {
"available": [
{"model_id": "gpt-4o-mini", "display_name": "GPT-4 Omni Mini", "max_tokens": 16384, "default": True},
{"model_id": "gpt-4o", "display_name": "GPT-4 Omni", "max_tokens": 128000},
{"model_id": "gpt-4", "display_name": "GPT-4", "max_tokens": 8192},
{"model_id": "gpt-3.5-turbo", "display_name": "GPT-3.5 Turbo", "max_tokens": 16384}
],
"default_temperature": 0.0,
"default_max_tokens": 4000
},
"products": {
"available": [
{"id": "harmony", "display_name": "Harmony", "versions": ["1.8", "1.6", "1.5", "1.2"], "default_version": "1.8"},
{"id": "chorus", "display_name": "Chorus", "versions": ["1.1"], "default_version": "1.1"}
],
"default_product": "harmony"
},
"rag": {
"default_k": 5,
"max_k": 10
},
"tools": {
"max_tool_calls": 5,
"tool_timeout": 30
},
"agents": {
"default_agent": "document_reader"
}
}
def _apply_env_overrides(self):
"""Apply environment variable overrides to config"""
# Example: ASSISTANT_RAG_DEFAULT_K=10 overrides rag.default_k
prefix = "ASSISTANT_"
for key, value in os.environ.items():
if key.startswith(prefix):
# Convert ASSISTANT_RAG_DEFAULT_K to rag.default_k
config_path = key[len(prefix):].lower().replace('_', '.')
self._set_nested_value(config_path, value)
def _set_nested_value(self, path: str, value: str):
"""Set a nested configuration value using dot notation"""
keys = path.split('.')
current = self.config
for key in keys[:-1]:
if key not in current:
current[key] = {}
current = current[key]
# Try to convert value to appropriate type
try:
if value.lower() in ['true', 'false']:
value = value.lower() == 'true'
elif value.isdigit():
value = int(value)
elif '.' in value and value.replace('.', '').isdigit():
value = float(value)
except:
pass # Keep as string
current[keys[-1]] = value
logger.debug(f"Override config {path} = {value}")
def get(self, path: str, default: Any = None) -> Any:
"""Get configuration value using dot notation"""
keys = path.split('.')
current = self.config
for key in keys:
if isinstance(current, dict) and key in current:
current = current[key]
else:
return default
return current
def get_available_models(self) -> List[Dict[str, Any]]:
"""Get list of available models"""
return self.get("models.available", [])
def get_model_ids(self) -> List[str]:
"""Get list of model IDs"""
return [m["model_id"] for m in self.get_available_models()]
def get_default_model(self) -> str:
"""Get default model ID"""
models = self.get_available_models()
for model in models:
if model.get("default", False):
return model["model_id"]
return models[0]["model_id"] if models else "gpt-4o-mini"
def get_available_products(self) -> List[Dict[str, Any]]:
"""Get list of available products"""
return self.get("products.available", [])
def get_product_ids(self) -> List[str]:
"""Get list of product IDs"""
return [p["id"] for p in self.get_available_products()]
def get_product_versions(self, product_id: str) -> List[str]:
"""Get versions for a specific product"""
products = self.get_available_products()
for product in products:
if product["id"] == product_id:
return product.get("versions", [])
return []
def get_default_product(self) -> str:
"""Get default product ID"""
return self.get("products.default_product", "harmony")
def get_default_version(self, product_id: str) -> str:
"""Get default version for a product"""
products = self.get_available_products()
for product in products:
if product["id"] == product_id:
return product.get("default_version", product["versions"][0])
return "1.0"
def get_rag_k(self) -> int:
"""Get default number of RAG results"""
return self.get("rag.default_k", 5)
def get_max_tool_calls(self) -> int:
"""Get maximum number of tool calls"""
return self.get("tools.max_tool_calls", 5)
def get_agent_config(self, agent_id: str) -> Dict[str, Any]:
"""Get configuration for a specific agent"""
return self.get(f"agents.{agent_id}", {})
def is_agent_enabled(self, agent_id: str) -> bool:
"""Check if an agent is enabled"""
return self.get(f"agents.{agent_id}.enabled", True)
def get_ui_config(self) -> Dict[str, Any]:
"""Get UI configuration"""
return self.get("ui", {})
def get_logging_config(self) -> Dict[str, Any]:
"""Get logging configuration"""
return self.get("logging", {})
def reload(self):
"""Reload configuration from file"""
self.config = self._load_config()
self._apply_env_overrides()
logger.info("Configuration reloaded")
def save(self, config_path: Optional[str] = None):
"""Save current configuration to file"""
save_path = config_path or self.config_path
try:
with open(save_path, 'w') as f:
yaml.dump(self.config, f, default_flow_style=False, sort_keys=False)
logger.info(f"Configuration saved to {save_path}")
except Exception as e:
logger.error(f"Error saving configuration: {e}")
raise
# Global config instance
_config = None
def get_config() -> ConfigLoader:
"""Get global configuration instance"""
global _config
if _config is None:
_config = ConfigLoader()
return _config