kerdosai / config.py
Anonymous Hunter
feat: Add robust configuration management, Docker support, initial testing, and quickstart documentation.
f21249a
"""
Configuration management for KerdosAI.
"""
import os
import yaml
from pathlib import Path
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, field_validator
import logging
from exceptions import ConfigurationError
logger = logging.getLogger(__name__)
class LoRAConfig(BaseModel):
"""LoRA configuration."""
enabled: bool = True
r: int = Field(default=8, ge=1, le=256)
alpha: int = Field(default=32, ge=1)
dropout: float = Field(default=0.1, ge=0.0, le=1.0)
target_modules: Optional[List[str]] = None
@field_validator('alpha')
@classmethod
def validate_alpha(cls, v, info):
if 'r' in info.data and v < info.data['r']:
logger.warning(f"LoRA alpha ({v}) is less than r ({info.data['r']}), which may reduce effectiveness")
return v
class QuantizationConfig(BaseModel):
"""Quantization configuration."""
enabled: bool = False
bits: int = Field(default=4, ge=4, le=8)
use_double_quant: bool = True
quant_type: str = Field(default="nf4", pattern="^(nf4|fp4)$")
compute_dtype: str = Field(default="float16", pattern="^(float16|bfloat16|float32)$")
@field_validator('bits')
@classmethod
def validate_bits(cls, v):
if v not in [4, 8]:
raise ValueError("Only 4-bit and 8-bit quantization are supported")
return v
class TrainingConfig(BaseModel):
"""Training configuration."""
epochs: int = Field(default=3, ge=1)
batch_size: int = Field(default=4, ge=1)
learning_rate: float = Field(default=2e-5, gt=0.0)
warmup_steps: int = Field(default=100, ge=0)
gradient_accumulation_steps: int = Field(default=1, ge=1)
max_grad_norm: float = Field(default=1.0, gt=0.0)
weight_decay: float = Field(default=0.01, ge=0.0)
logging_steps: int = Field(default=10, ge=1)
save_steps: int = Field(default=100, ge=1)
eval_steps: int = Field(default=100, ge=1)
max_seq_length: int = Field(default=512, ge=1)
seed: int = Field(default=42, ge=0)
fp16: bool = False
bf16: bool = False
@field_validator('fp16', 'bf16')
@classmethod
def validate_precision(cls, v, info):
if v and info.field_name == 'bf16' and info.data.get('fp16'):
raise ValueError("Cannot use both fp16 and bf16")
if v and info.field_name == 'fp16' and info.data.get('bf16'):
raise ValueError("Cannot use both fp16 and bf16")
return v
class DataConfig(BaseModel):
"""Data configuration."""
train_file: Optional[str] = None
validation_file: Optional[str] = None
test_file: Optional[str] = None
dataset_name: Optional[str] = None
dataset_config: Optional[str] = None
text_column: str = "text"
max_samples: Optional[int] = None
preprocessing_num_workers: int = Field(default=4, ge=1)
@field_validator('train_file', 'validation_file', 'test_file')
@classmethod
def validate_file_path(cls, v):
if v and not Path(v).exists():
logger.warning(f"File {v} does not exist")
return v
class DeploymentConfig(BaseModel):
"""Deployment configuration."""
type: str = Field(default="rest", pattern="^(rest|docker|kubernetes)$")
host: str = "0.0.0.0"
port: int = Field(default=8000, ge=1, le=65535)
workers: int = Field(default=1, ge=1)
max_batch_size: int = Field(default=8, ge=1)
timeout: int = Field(default=60, ge=1)
class MonitoringConfig(BaseModel):
"""Monitoring configuration."""
enabled: bool = True
wandb_project: Optional[str] = None
wandb_entity: Optional[str] = None
tensorboard_dir: str = "./runs"
log_model: bool = False
class KerdosConfig(BaseModel):
"""Main KerdosAI configuration."""
# Model configuration
base_model: str
model_revision: Optional[str] = None
trust_remote_code: bool = False
device: Optional[str] = None
# Component configurations
lora: LoRAConfig = Field(default_factory=LoRAConfig)
quantization: QuantizationConfig = Field(default_factory=QuantizationConfig)
training: TrainingConfig = Field(default_factory=TrainingConfig)
data: DataConfig = Field(default_factory=DataConfig)
deployment: DeploymentConfig = Field(default_factory=DeploymentConfig)
monitoring: MonitoringConfig = Field(default_factory=MonitoringConfig)
# Output configuration
output_dir: str = "./output"
checkpoint_dir: str = "./checkpoints"
# Additional settings
cache_dir: Optional[str] = None
@classmethod
def from_yaml(cls, config_path: Union[str, Path]) -> "KerdosConfig":
"""
Load configuration from YAML file.
Args:
config_path: Path to YAML configuration file
Returns:
KerdosConfig instance
Raises:
ConfigurationError: If configuration is invalid
"""
try:
config_path = Path(config_path)
if not config_path.exists():
raise ConfigurationError(
f"Configuration file not found: {config_path}",
{"path": str(config_path)}
)
with open(config_path, 'r') as f:
config_dict = yaml.safe_load(f)
# Handle environment variable substitution
config_dict = cls._substitute_env_vars(config_dict)
return cls(**config_dict)
except yaml.YAMLError as e:
raise ConfigurationError(
f"Error parsing YAML configuration: {str(e)}",
{"path": str(config_path)}
)
except Exception as e:
raise ConfigurationError(
f"Error loading configuration: {str(e)}",
{"path": str(config_path)}
)
def to_yaml(self, output_path: Union[str, Path]) -> None:
"""
Save configuration to YAML file.
Args:
output_path: Path to save YAML configuration
"""
try:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
yaml.dump(
self.model_dump(),
f,
default_flow_style=False,
sort_keys=False
)
logger.info(f"Configuration saved to {output_path}")
except Exception as e:
raise ConfigurationError(
f"Error saving configuration: {str(e)}",
{"path": str(output_path)}
)
@staticmethod
def _substitute_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively substitute environment variables in configuration.
Environment variables should be specified as ${VAR_NAME}.
"""
if isinstance(config, dict):
return {k: KerdosConfig._substitute_env_vars(v) for k, v in config.items()}
elif isinstance(config, list):
return [KerdosConfig._substitute_env_vars(item) for item in config]
elif isinstance(config, str) and config.startswith("${") and config.endswith("}"):
var_name = config[2:-1]
return os.getenv(var_name, config)
return config
def validate_compatibility(self) -> None:
"""
Validate configuration compatibility.
Raises:
ConfigurationError: If configuration has incompatible settings
"""
# Check quantization and LoRA compatibility
if self.quantization.enabled and self.lora.enabled:
logger.info("Using LoRA with quantization - this is recommended for efficiency")
# Check precision settings
if self.training.fp16 and self.device == "cpu":
raise ConfigurationError(
"fp16 training is not supported on CPU",
{"device": self.device, "fp16": True}
)
# Check data configuration
if not self.data.train_file and not self.data.dataset_name:
raise ConfigurationError(
"Either train_file or dataset_name must be specified",
{"train_file": self.data.train_file, "dataset_name": self.data.dataset_name}
)
logger.info("Configuration validation passed")
def load_config(config_path: Optional[Union[str, Path]] = None) -> KerdosConfig:
"""
Load configuration from file or create default.
Args:
config_path: Optional path to configuration file
Returns:
KerdosConfig instance
"""
if config_path:
return KerdosConfig.from_yaml(config_path)
# Return default configuration
logger.info("No configuration file specified, using defaults")
return KerdosConfig(base_model="gpt2") # Default model for testing