Anonymous Hunter
feat: Add robust configuration management, Docker support, initial testing, and quickstart documentation.
f21249a
| """ | |
| Configuration management for KerdosAI. | |
| """ | |
| import os | |
| import yaml | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Any, Union | |
| from pydantic import BaseModel, Field, field_validator | |
| import logging | |
| from exceptions import ConfigurationError | |
| logger = logging.getLogger(__name__) | |
| class LoRAConfig(BaseModel): | |
| """LoRA configuration.""" | |
| enabled: bool = True | |
| r: int = Field(default=8, ge=1, le=256) | |
| alpha: int = Field(default=32, ge=1) | |
| dropout: float = Field(default=0.1, ge=0.0, le=1.0) | |
| target_modules: Optional[List[str]] = None | |
| def validate_alpha(cls, v, info): | |
| if 'r' in info.data and v < info.data['r']: | |
| logger.warning(f"LoRA alpha ({v}) is less than r ({info.data['r']}), which may reduce effectiveness") | |
| return v | |
| class QuantizationConfig(BaseModel): | |
| """Quantization configuration.""" | |
| enabled: bool = False | |
| bits: int = Field(default=4, ge=4, le=8) | |
| use_double_quant: bool = True | |
| quant_type: str = Field(default="nf4", pattern="^(nf4|fp4)$") | |
| compute_dtype: str = Field(default="float16", pattern="^(float16|bfloat16|float32)$") | |
| def validate_bits(cls, v): | |
| if v not in [4, 8]: | |
| raise ValueError("Only 4-bit and 8-bit quantization are supported") | |
| return v | |
| class TrainingConfig(BaseModel): | |
| """Training configuration.""" | |
| epochs: int = Field(default=3, ge=1) | |
| batch_size: int = Field(default=4, ge=1) | |
| learning_rate: float = Field(default=2e-5, gt=0.0) | |
| warmup_steps: int = Field(default=100, ge=0) | |
| gradient_accumulation_steps: int = Field(default=1, ge=1) | |
| max_grad_norm: float = Field(default=1.0, gt=0.0) | |
| weight_decay: float = Field(default=0.01, ge=0.0) | |
| logging_steps: int = Field(default=10, ge=1) | |
| save_steps: int = Field(default=100, ge=1) | |
| eval_steps: int = Field(default=100, ge=1) | |
| max_seq_length: int = Field(default=512, ge=1) | |
| seed: int = Field(default=42, ge=0) | |
| fp16: bool = False | |
| bf16: bool = False | |
| def validate_precision(cls, v, info): | |
| if v and info.field_name == 'bf16' and info.data.get('fp16'): | |
| raise ValueError("Cannot use both fp16 and bf16") | |
| if v and info.field_name == 'fp16' and info.data.get('bf16'): | |
| raise ValueError("Cannot use both fp16 and bf16") | |
| return v | |
| class DataConfig(BaseModel): | |
| """Data configuration.""" | |
| train_file: Optional[str] = None | |
| validation_file: Optional[str] = None | |
| test_file: Optional[str] = None | |
| dataset_name: Optional[str] = None | |
| dataset_config: Optional[str] = None | |
| text_column: str = "text" | |
| max_samples: Optional[int] = None | |
| preprocessing_num_workers: int = Field(default=4, ge=1) | |
| def validate_file_path(cls, v): | |
| if v and not Path(v).exists(): | |
| logger.warning(f"File {v} does not exist") | |
| return v | |
| class DeploymentConfig(BaseModel): | |
| """Deployment configuration.""" | |
| type: str = Field(default="rest", pattern="^(rest|docker|kubernetes)$") | |
| host: str = "0.0.0.0" | |
| port: int = Field(default=8000, ge=1, le=65535) | |
| workers: int = Field(default=1, ge=1) | |
| max_batch_size: int = Field(default=8, ge=1) | |
| timeout: int = Field(default=60, ge=1) | |
| class MonitoringConfig(BaseModel): | |
| """Monitoring configuration.""" | |
| enabled: bool = True | |
| wandb_project: Optional[str] = None | |
| wandb_entity: Optional[str] = None | |
| tensorboard_dir: str = "./runs" | |
| log_model: bool = False | |
| class KerdosConfig(BaseModel): | |
| """Main KerdosAI configuration.""" | |
| # Model configuration | |
| base_model: str | |
| model_revision: Optional[str] = None | |
| trust_remote_code: bool = False | |
| device: Optional[str] = None | |
| # Component configurations | |
| lora: LoRAConfig = Field(default_factory=LoRAConfig) | |
| quantization: QuantizationConfig = Field(default_factory=QuantizationConfig) | |
| training: TrainingConfig = Field(default_factory=TrainingConfig) | |
| data: DataConfig = Field(default_factory=DataConfig) | |
| deployment: DeploymentConfig = Field(default_factory=DeploymentConfig) | |
| monitoring: MonitoringConfig = Field(default_factory=MonitoringConfig) | |
| # Output configuration | |
| output_dir: str = "./output" | |
| checkpoint_dir: str = "./checkpoints" | |
| # Additional settings | |
| cache_dir: Optional[str] = None | |
| def from_yaml(cls, config_path: Union[str, Path]) -> "KerdosConfig": | |
| """ | |
| Load configuration from YAML file. | |
| Args: | |
| config_path: Path to YAML configuration file | |
| Returns: | |
| KerdosConfig instance | |
| Raises: | |
| ConfigurationError: If configuration is invalid | |
| """ | |
| try: | |
| config_path = Path(config_path) | |
| if not config_path.exists(): | |
| raise ConfigurationError( | |
| f"Configuration file not found: {config_path}", | |
| {"path": str(config_path)} | |
| ) | |
| with open(config_path, 'r') as f: | |
| config_dict = yaml.safe_load(f) | |
| # Handle environment variable substitution | |
| config_dict = cls._substitute_env_vars(config_dict) | |
| return cls(**config_dict) | |
| except yaml.YAMLError as e: | |
| raise ConfigurationError( | |
| f"Error parsing YAML configuration: {str(e)}", | |
| {"path": str(config_path)} | |
| ) | |
| except Exception as e: | |
| raise ConfigurationError( | |
| f"Error loading configuration: {str(e)}", | |
| {"path": str(config_path)} | |
| ) | |
| def to_yaml(self, output_path: Union[str, Path]) -> None: | |
| """ | |
| Save configuration to YAML file. | |
| Args: | |
| output_path: Path to save YAML configuration | |
| """ | |
| try: | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(output_path, 'w') as f: | |
| yaml.dump( | |
| self.model_dump(), | |
| f, | |
| default_flow_style=False, | |
| sort_keys=False | |
| ) | |
| logger.info(f"Configuration saved to {output_path}") | |
| except Exception as e: | |
| raise ConfigurationError( | |
| f"Error saving configuration: {str(e)}", | |
| {"path": str(output_path)} | |
| ) | |
| def _substitute_env_vars(config: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Recursively substitute environment variables in configuration. | |
| Environment variables should be specified as ${VAR_NAME}. | |
| """ | |
| if isinstance(config, dict): | |
| return {k: KerdosConfig._substitute_env_vars(v) for k, v in config.items()} | |
| elif isinstance(config, list): | |
| return [KerdosConfig._substitute_env_vars(item) for item in config] | |
| elif isinstance(config, str) and config.startswith("${") and config.endswith("}"): | |
| var_name = config[2:-1] | |
| return os.getenv(var_name, config) | |
| return config | |
| def validate_compatibility(self) -> None: | |
| """ | |
| Validate configuration compatibility. | |
| Raises: | |
| ConfigurationError: If configuration has incompatible settings | |
| """ | |
| # Check quantization and LoRA compatibility | |
| if self.quantization.enabled and self.lora.enabled: | |
| logger.info("Using LoRA with quantization - this is recommended for efficiency") | |
| # Check precision settings | |
| if self.training.fp16 and self.device == "cpu": | |
| raise ConfigurationError( | |
| "fp16 training is not supported on CPU", | |
| {"device": self.device, "fp16": True} | |
| ) | |
| # Check data configuration | |
| if not self.data.train_file and not self.data.dataset_name: | |
| raise ConfigurationError( | |
| "Either train_file or dataset_name must be specified", | |
| {"train_file": self.data.train_file, "dataset_name": self.data.dataset_name} | |
| ) | |
| logger.info("Configuration validation passed") | |
| def load_config(config_path: Optional[Union[str, Path]] = None) -> KerdosConfig: | |
| """ | |
| Load configuration from file or create default. | |
| Args: | |
| config_path: Optional path to configuration file | |
| Returns: | |
| KerdosConfig instance | |
| """ | |
| if config_path: | |
| return KerdosConfig.from_yaml(config_path) | |
| # Return default configuration | |
| logger.info("No configuration file specified, using defaults") | |
| return KerdosConfig(base_model="gpt2") # Default model for testing | |