Mustafa Tag Eldeen
HF Space: Reasoning Trajectory Demo
8675765
"""Configuration management for datasets and models"""
from __future__ import annotations
import os
import yaml
from pathlib import Path
from typing import Any, Dict, Optional
from dataclasses import dataclass
@dataclass
class DatasetConfig:
"""Dataset configuration"""
name: str
description: str
path: str
huggingface_id: str
split: Dict[str, str]
cache_dir: str
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> DatasetConfig:
"""Create DatasetConfig from dictionary"""
return cls(
name=data["name"],
description=data["description"],
path=data["path"],
huggingface_id=data["huggingface_id"],
split=data["split"],
cache_dir=data["cache_dir"],
)
@dataclass
class ModelConfig:
"""Model configuration"""
name: str
description: str
model_type: str
path: Optional[str] = None
huggingface_id: Optional[str] = None
model_id: Optional[str] = None
cache_dir: Optional[str] = None
tokenizer_path: Optional[str] = None
config: Dict[str, Any] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> ModelConfig:
"""Create ModelConfig from dictionary"""
return cls(
name=data["name"],
description=data["description"],
model_type=data["model_type"],
path=data.get("path"),
huggingface_id=data.get("huggingface_id"),
model_id=data.get("model_id"),
cache_dir=data.get("cache_dir"),
tokenizer_path=data.get("tokenizer_path"),
config=data.get("config", {}),
)
class Config:
"""Main configuration class"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize configuration
Args:
config_path: Path to configuration YAML file. If None, uses default path.
"""
if config_path is None:
# Default to config/paths.yaml in the project root
project_root = Path(__file__).parent.parent
config_path = project_root / "config" / "paths.yaml"
self.config_path = Path(config_path)
self._config = self._load_config()
def _load_config(self) -> Dict[str, Any]:
"""Load configuration from YAML file"""
if not self.config_path.exists():
raise FileNotFoundError(f"Configuration file not found: {self.config_path}")
with open(self.config_path, "r") as f:
return yaml.safe_load(f)
def get_dataset_config(self, dataset_name: str) -> DatasetConfig:
"""Get dataset configuration by name
Args:
dataset_name: Name of the dataset (e.g., 'gsm8k')
Returns:
DatasetConfig object
Raises:
KeyError: If dataset not found in configuration
"""
if dataset_name not in self._config["datasets"]:
available = list(self._config["datasets"].keys())
raise KeyError(
f"Dataset '{dataset_name}' not found in configuration. "
f"Available datasets: {available}"
)
return DatasetConfig.from_dict(self._config["datasets"][dataset_name])
def get_model_config(self, model_name: str) -> ModelConfig:
"""Get model configuration by name
Args:
model_name: Name of the model (e.g., 'llama-3.1-8b')
Returns:
ModelConfig object
Raises:
KeyError: If model not found in configuration
"""
if model_name not in self._config["models"]:
available = list(self._config["models"].keys())
raise KeyError(
f"Model '{model_name}' not found in configuration. "
f"Available models: {available}"
)
return ModelConfig.from_dict(self._config["models"][model_name])
def get_output_dir(self, output_type: str) -> Path:
"""Get output directory path
Args:
output_type: Type of output (e.g., 'results', 'trajectories', 'logs')
Returns:
Path to output directory
"""
project_root = Path(__file__).parent.parent
output_path = project_root / self._config["output"][output_type]
output_path.mkdir(parents=True, exist_ok=True)
return output_path
def get_settings(self) -> Dict[str, Any]:
"""Get general settings"""
return self._config["settings"]
def list_datasets(self) -> list[str]:
"""List all available datasets"""
return list(self._config["datasets"].keys())
def list_models(self) -> list[str]:
"""List all available models"""
return list(self._config["models"].keys())
# Global configuration instance
_global_config: Optional[Config] = None
def load_config(config_path: Optional[str] = None) -> Config:
"""Load configuration (singleton pattern)
Args:
config_path: Path to configuration file. If None, uses default.
Returns:
Config instance
"""
global _global_config
if _global_config is None or config_path is not None:
_global_config = Config(config_path)
return _global_config
def get_dataset_config(dataset_name: str) -> DatasetConfig:
"""Convenience function to get dataset configuration
Args:
dataset_name: Name of the dataset
Returns:
DatasetConfig object
"""
config = load_config()
return config.get_dataset_config(dataset_name)
def get_model_config(model_name: str) -> ModelConfig:
"""Convenience function to get model configuration
Args:
model_name: Name of the model
Returns:
ModelConfig object
"""
config = load_config()
return config.get_model_config(model_name)