""" Training configuration schemas — Pydantic v2. All training jobs are validated against these models before execution. No raw dicts escape into the pipeline; everything is typed and constrained. """ from __future__ import annotations from enum import StrEnum from typing import Annotated from pydantic import BaseModel, Field, HttpUrl, model_validator from pydantic import PositiveFloat, PositiveInt # --------------------------------------------------------------------------- # Enums # --------------------------------------------------------------------------- class EvalStrategy(StrEnum): NO = "no" STEPS = "steps" EPOCH = "epoch" class Precision(StrEnum): FP32 = "fp32" FP16 = "fp16" BF16 = "bf16" INT8 = "int8" class OptimizerType(StrEnum): ADAMW = "adamw_torch" ADAMW_8BIT = "adamw_8bit" PAGED_ADAMW_8BIT = "paged_adamw_8bit" SGD = "sgd" class EvalMetric(StrEnum): PASS_AT_1 = "pass_at_1" PASS_AT_10 = "pass_at_10" BLEU = "bleu" EXECUTION_ACCURACY = "execution_accuracy" EXACT_MATCH = "exact_match" # --------------------------------------------------------------------------- # Sub-configs # --------------------------------------------------------------------------- class LoRAConfig(BaseModel): """LoRA adapter configuration. Omit to disable LoRA (full fine-tune).""" enabled: bool = True r: Annotated[int, Field(ge=1, le=256)] = 16 alpha: Annotated[int, Field(ge=1)] = 32 dropout: Annotated[float, Field(ge=0.0, lt=1.0)] = 0.05 target_modules: list[str] = Field( default_factory=lambda: ["q_proj", "v_proj"], min_length=1, ) bias: str = "none" @model_validator(mode="after") def alpha_geq_r(self) -> "LoRAConfig": if self.alpha < self.r: raise ValueError(f"lora.alpha ({self.alpha}) should be >= lora.r ({self.r})") return self class TrainingHyperparams(BaseModel): num_epochs: Annotated[int, Field(ge=1, le=100)] = 3 batch_size: Annotated[int, Field(ge=1, le=256)] = 8 gradient_accumulation_steps: Annotated[int, Field(ge=1, le=128)] = 4 learning_rate: Annotated[float, Field(gt=0.0, lt=1.0)] = 2e-5 weight_decay: Annotated[float, Field(ge=0.0, lt=1.0)] = 0.01 warmup_ratio: Annotated[float, Field(ge=0.0, lt=1.0)] = 0.1 max_seq_length: Annotated[int, Field(ge=64, le=32768)] = 1024 max_grad_norm: Annotated[float, Field(gt=0.0)] = 1.0 optimizer: OptimizerType = OptimizerType.ADAMW precision: Precision = Precision.BF16 lr_scheduler: str = "cosine" seed: int = 42 dataloader_num_workers: Annotated[int, Field(ge=0, le=32)] = 4 @property def effective_batch_size(self) -> int: return self.batch_size * self.gradient_accumulation_steps class EvaluationConfig(BaseModel): enabled: bool = True strategy: EvalStrategy = EvalStrategy.EPOCH eval_steps: PositiveInt | None = None # required when strategy=STEPS metrics: list[EvalMetric] = Field( default_factory=lambda: [EvalMetric.PASS_AT_1, EvalMetric.BLEU] ) num_samples_per_problem: Annotated[int, Field(ge=1, le=200)] = 10 timeout_seconds: Annotated[int, Field(ge=1, le=60)] = 10 load_best_model_at_end: bool = True metric_for_best_model: EvalMetric = EvalMetric.PASS_AT_1 greater_is_better: bool = True @model_validator(mode="after") def eval_steps_required_for_steps_strategy(self) -> "EvaluationConfig": if self.strategy == EvalStrategy.STEPS and self.eval_steps is None: raise ValueError("evaluation.eval_steps is required when strategy='steps'") return self class CheckpointConfig(BaseModel): save_strategy: EvalStrategy = EvalStrategy.EPOCH save_steps: PositiveInt | None = None save_total_limit: Annotated[int, Field(ge=1, le=20)] = 3 output_dir: str = "./checkpoints" resume_from_checkpoint: str | None = None @model_validator(mode="after") def save_steps_required_for_steps_strategy(self) -> "CheckpointConfig": if self.save_strategy == EvalStrategy.STEPS and self.save_steps is None: raise ValueError("checkpoint.save_steps required when save_strategy='steps'") return self class HubConfig(BaseModel): push_to_hub: bool = False repo_id: str | None = None private: bool = True commit_message: str = "Training checkpoint" @model_validator(mode="after") def repo_id_required_if_pushing(self) -> "HubConfig": if self.push_to_hub and not self.repo_id: raise ValueError("hub.repo_id is required when hub.push_to_hub=true") return self class DatasetConfig(BaseModel): dataset_id: str # internal UUID or HF Hub dataset path split_ratio: Annotated[float, Field(gt=0.0, lt=1.0)] = 0.9 # train split max_samples: PositiveInt | None = None # None = use all text_column: str = "content" shuffle: bool = True shuffle_seed: int = 42 # --------------------------------------------------------------------------- # Root job config # --------------------------------------------------------------------------- class TrainingJobConfig(BaseModel): """ Complete training job specification. Validated at job submission time. If validation passes, the job is guaranteed to reach the pipeline with a coherent configuration. """ job_name: Annotated[str, Field(min_length=1, max_length=128, pattern=r"^[\w\-]+$")] base_model: str = Field( description="HuggingFace model ID or local path", examples=["Salesforce/codegen-350M-mono", "deepseek-ai/deepseek-coder-1.3b-base"], ) dataset: DatasetConfig training: TrainingHyperparams = Field(default_factory=TrainingHyperparams) lora: LoRAConfig | None = Field(default_factory=LoRAConfig) evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig) checkpoint: CheckpointConfig = Field(default_factory=CheckpointConfig) hub: HubConfig = Field(default_factory=HubConfig) tags: list[str] = Field(default_factory=list, max_length=20) notes: str | None = None model_config = { "json_schema_extra": { "examples": [ { "job_name": "codegen-finetune-v1", "base_model": "Salesforce/codegen-350M-mono", "dataset": {"dataset_id": "ds_abc123"}, "training": { "num_epochs": 3, "batch_size": 8, "learning_rate": 2e-5, }, "hub": { "push_to_hub": True, "repo_id": "your-org/codegen-finetune-v1", }, } ] } } # --------------------------------------------------------------------------- # Inference config (served separately but validated here for consistency) # --------------------------------------------------------------------------- class InferenceConfig(BaseModel): model_id: str max_new_tokens: Annotated[int, Field(ge=1, le=4096)] = 256 temperature: Annotated[float, Field(ge=0.0, le=2.0)] = 0.2 top_p: Annotated[float, Field(ge=0.0, le=1.0)] = 0.95 top_k: Annotated[int, Field(ge=0, le=1000)] = 50 do_sample: bool = True num_return_sequences: Annotated[int, Field(ge=1, le=200)] = 1 stop_sequences: list[str] = Field(default_factory=list) precision: Precision = Precision.BF16