""" Configuration management for KerdosAI. """ import os import yaml from pathlib import Path from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, field_validator import logging from exceptions import ConfigurationError logger = logging.getLogger(__name__) class LoRAConfig(BaseModel): """LoRA configuration.""" enabled: bool = True r: int = Field(default=8, ge=1, le=256) alpha: int = Field(default=32, ge=1) dropout: float = Field(default=0.1, ge=0.0, le=1.0) target_modules: Optional[List[str]] = None @field_validator('alpha') @classmethod def validate_alpha(cls, v, info): if 'r' in info.data and v < info.data['r']: logger.warning(f"LoRA alpha ({v}) is less than r ({info.data['r']}), which may reduce effectiveness") return v class QuantizationConfig(BaseModel): """Quantization configuration.""" enabled: bool = False bits: int = Field(default=4, ge=4, le=8) use_double_quant: bool = True quant_type: str = Field(default="nf4", pattern="^(nf4|fp4)$") compute_dtype: str = Field(default="float16", pattern="^(float16|bfloat16|float32)$") @field_validator('bits') @classmethod def validate_bits(cls, v): if v not in [4, 8]: raise ValueError("Only 4-bit and 8-bit quantization are supported") return v class TrainingConfig(BaseModel): """Training configuration.""" epochs: int = Field(default=3, ge=1) batch_size: int = Field(default=4, ge=1) learning_rate: float = Field(default=2e-5, gt=0.0) warmup_steps: int = Field(default=100, ge=0) gradient_accumulation_steps: int = Field(default=1, ge=1) max_grad_norm: float = Field(default=1.0, gt=0.0) weight_decay: float = Field(default=0.01, ge=0.0) logging_steps: int = Field(default=10, ge=1) save_steps: int = Field(default=100, ge=1) eval_steps: int = Field(default=100, ge=1) max_seq_length: int = Field(default=512, ge=1) seed: int = Field(default=42, ge=0) fp16: bool = False bf16: bool = False @field_validator('fp16', 'bf16') @classmethod def validate_precision(cls, v, info): if v and info.field_name == 'bf16' and info.data.get('fp16'): raise ValueError("Cannot use both fp16 and bf16") if v and info.field_name == 'fp16' and info.data.get('bf16'): raise ValueError("Cannot use both fp16 and bf16") return v class DataConfig(BaseModel): """Data configuration.""" train_file: Optional[str] = None validation_file: Optional[str] = None test_file: Optional[str] = None dataset_name: Optional[str] = None dataset_config: Optional[str] = None text_column: str = "text" max_samples: Optional[int] = None preprocessing_num_workers: int = Field(default=4, ge=1) @field_validator('train_file', 'validation_file', 'test_file') @classmethod def validate_file_path(cls, v): if v and not Path(v).exists(): logger.warning(f"File {v} does not exist") return v class DeploymentConfig(BaseModel): """Deployment configuration.""" type: str = Field(default="rest", pattern="^(rest|docker|kubernetes)$") host: str = "0.0.0.0" port: int = Field(default=8000, ge=1, le=65535) workers: int = Field(default=1, ge=1) max_batch_size: int = Field(default=8, ge=1) timeout: int = Field(default=60, ge=1) class MonitoringConfig(BaseModel): """Monitoring configuration.""" enabled: bool = True wandb_project: Optional[str] = None wandb_entity: Optional[str] = None tensorboard_dir: str = "./runs" log_model: bool = False class KerdosConfig(BaseModel): """Main KerdosAI configuration.""" # Model configuration base_model: str model_revision: Optional[str] = None trust_remote_code: bool = False device: Optional[str] = None # Component configurations lora: LoRAConfig = Field(default_factory=LoRAConfig) quantization: QuantizationConfig = Field(default_factory=QuantizationConfig) training: TrainingConfig = Field(default_factory=TrainingConfig) data: DataConfig = Field(default_factory=DataConfig) deployment: DeploymentConfig = Field(default_factory=DeploymentConfig) monitoring: MonitoringConfig = Field(default_factory=MonitoringConfig) # Output configuration output_dir: str = "./output" checkpoint_dir: str = "./checkpoints" # Additional settings cache_dir: Optional[str] = None @classmethod def from_yaml(cls, config_path: Union[str, Path]) -> "KerdosConfig": """ Load configuration from YAML file. Args: config_path: Path to YAML configuration file Returns: KerdosConfig instance Raises: ConfigurationError: If configuration is invalid """ try: config_path = Path(config_path) if not config_path.exists(): raise ConfigurationError( f"Configuration file not found: {config_path}", {"path": str(config_path)} ) with open(config_path, 'r') as f: config_dict = yaml.safe_load(f) # Handle environment variable substitution config_dict = cls._substitute_env_vars(config_dict) return cls(**config_dict) except yaml.YAMLError as e: raise ConfigurationError( f"Error parsing YAML configuration: {str(e)}", {"path": str(config_path)} ) except Exception as e: raise ConfigurationError( f"Error loading configuration: {str(e)}", {"path": str(config_path)} ) def to_yaml(self, output_path: Union[str, Path]) -> None: """ Save configuration to YAML file. Args: output_path: Path to save YAML configuration """ try: output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, 'w') as f: yaml.dump( self.model_dump(), f, default_flow_style=False, sort_keys=False ) logger.info(f"Configuration saved to {output_path}") except Exception as e: raise ConfigurationError( f"Error saving configuration: {str(e)}", {"path": str(output_path)} ) @staticmethod def _substitute_env_vars(config: Dict[str, Any]) -> Dict[str, Any]: """ Recursively substitute environment variables in configuration. Environment variables should be specified as ${VAR_NAME}. """ if isinstance(config, dict): return {k: KerdosConfig._substitute_env_vars(v) for k, v in config.items()} elif isinstance(config, list): return [KerdosConfig._substitute_env_vars(item) for item in config] elif isinstance(config, str) and config.startswith("${") and config.endswith("}"): var_name = config[2:-1] return os.getenv(var_name, config) return config def validate_compatibility(self) -> None: """ Validate configuration compatibility. Raises: ConfigurationError: If configuration has incompatible settings """ # Check quantization and LoRA compatibility if self.quantization.enabled and self.lora.enabled: logger.info("Using LoRA with quantization - this is recommended for efficiency") # Check precision settings if self.training.fp16 and self.device == "cpu": raise ConfigurationError( "fp16 training is not supported on CPU", {"device": self.device, "fp16": True} ) # Check data configuration if not self.data.train_file and not self.data.dataset_name: raise ConfigurationError( "Either train_file or dataset_name must be specified", {"train_file": self.data.train_file, "dataset_name": self.data.dataset_name} ) logger.info("Configuration validation passed") def load_config(config_path: Optional[Union[str, Path]] = None) -> KerdosConfig: """ Load configuration from file or create default. Args: config_path: Optional path to configuration file Returns: KerdosConfig instance """ if config_path: return KerdosConfig.from_yaml(config_path) # Return default configuration logger.info("No configuration file specified, using defaults") return KerdosConfig(base_model="gpt2") # Default model for testing