Spaces:
Sleeping
Sleeping
| """Configuration management for datasets and models""" | |
| from __future__ import annotations | |
| import os | |
| import yaml | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| from dataclasses import dataclass | |
| class DatasetConfig: | |
| """Dataset configuration""" | |
| name: str | |
| description: str | |
| path: str | |
| huggingface_id: str | |
| split: Dict[str, str] | |
| cache_dir: str | |
| def from_dict(cls, data: Dict[str, Any]) -> DatasetConfig: | |
| """Create DatasetConfig from dictionary""" | |
| return cls( | |
| name=data["name"], | |
| description=data["description"], | |
| path=data["path"], | |
| huggingface_id=data["huggingface_id"], | |
| split=data["split"], | |
| cache_dir=data["cache_dir"], | |
| ) | |
| class ModelConfig: | |
| """Model configuration""" | |
| name: str | |
| description: str | |
| model_type: str | |
| path: Optional[str] = None | |
| huggingface_id: Optional[str] = None | |
| model_id: Optional[str] = None | |
| cache_dir: Optional[str] = None | |
| tokenizer_path: Optional[str] = None | |
| config: Dict[str, Any] = None | |
| def from_dict(cls, data: Dict[str, Any]) -> ModelConfig: | |
| """Create ModelConfig from dictionary""" | |
| return cls( | |
| name=data["name"], | |
| description=data["description"], | |
| model_type=data["model_type"], | |
| path=data.get("path"), | |
| huggingface_id=data.get("huggingface_id"), | |
| model_id=data.get("model_id"), | |
| cache_dir=data.get("cache_dir"), | |
| tokenizer_path=data.get("tokenizer_path"), | |
| config=data.get("config", {}), | |
| ) | |
| class Config: | |
| """Main configuration class""" | |
| def __init__(self, config_path: Optional[str] = None): | |
| """Initialize configuration | |
| Args: | |
| config_path: Path to configuration YAML file. If None, uses default path. | |
| """ | |
| if config_path is None: | |
| # Default to config/paths.yaml in the project root | |
| project_root = Path(__file__).parent.parent | |
| config_path = project_root / "config" / "paths.yaml" | |
| self.config_path = Path(config_path) | |
| self._config = self._load_config() | |
| def _load_config(self) -> Dict[str, Any]: | |
| """Load configuration from YAML file""" | |
| if not self.config_path.exists(): | |
| raise FileNotFoundError(f"Configuration file not found: {self.config_path}") | |
| with open(self.config_path, "r") as f: | |
| return yaml.safe_load(f) | |
| def get_dataset_config(self, dataset_name: str) -> DatasetConfig: | |
| """Get dataset configuration by name | |
| Args: | |
| dataset_name: Name of the dataset (e.g., 'gsm8k') | |
| Returns: | |
| DatasetConfig object | |
| Raises: | |
| KeyError: If dataset not found in configuration | |
| """ | |
| if dataset_name not in self._config["datasets"]: | |
| available = list(self._config["datasets"].keys()) | |
| raise KeyError( | |
| f"Dataset '{dataset_name}' not found in configuration. " | |
| f"Available datasets: {available}" | |
| ) | |
| return DatasetConfig.from_dict(self._config["datasets"][dataset_name]) | |
| def get_model_config(self, model_name: str) -> ModelConfig: | |
| """Get model configuration by name | |
| Args: | |
| model_name: Name of the model (e.g., 'llama-3.1-8b') | |
| Returns: | |
| ModelConfig object | |
| Raises: | |
| KeyError: If model not found in configuration | |
| """ | |
| if model_name not in self._config["models"]: | |
| available = list(self._config["models"].keys()) | |
| raise KeyError( | |
| f"Model '{model_name}' not found in configuration. " | |
| f"Available models: {available}" | |
| ) | |
| return ModelConfig.from_dict(self._config["models"][model_name]) | |
| def get_output_dir(self, output_type: str) -> Path: | |
| """Get output directory path | |
| Args: | |
| output_type: Type of output (e.g., 'results', 'trajectories', 'logs') | |
| Returns: | |
| Path to output directory | |
| """ | |
| project_root = Path(__file__).parent.parent | |
| output_path = project_root / self._config["output"][output_type] | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| return output_path | |
| def get_settings(self) -> Dict[str, Any]: | |
| """Get general settings""" | |
| return self._config["settings"] | |
| def list_datasets(self) -> list[str]: | |
| """List all available datasets""" | |
| return list(self._config["datasets"].keys()) | |
| def list_models(self) -> list[str]: | |
| """List all available models""" | |
| return list(self._config["models"].keys()) | |
| # Global configuration instance | |
| _global_config: Optional[Config] = None | |
| def load_config(config_path: Optional[str] = None) -> Config: | |
| """Load configuration (singleton pattern) | |
| Args: | |
| config_path: Path to configuration file. If None, uses default. | |
| Returns: | |
| Config instance | |
| """ | |
| global _global_config | |
| if _global_config is None or config_path is not None: | |
| _global_config = Config(config_path) | |
| return _global_config | |
| def get_dataset_config(dataset_name: str) -> DatasetConfig: | |
| """Convenience function to get dataset configuration | |
| Args: | |
| dataset_name: Name of the dataset | |
| Returns: | |
| DatasetConfig object | |
| """ | |
| config = load_config() | |
| return config.get_dataset_config(dataset_name) | |
| def get_model_config(model_name: str) -> ModelConfig: | |
| """Convenience function to get model configuration | |
| Args: | |
| model_name: Name of the model | |
| Returns: | |
| ModelConfig object | |
| """ | |
| config = load_config() | |
| return config.get_model_config(model_name) | |