#!/usr/bin/env python3 """Configuration helpers for centralized LoRA finetuning.""" from __future__ import annotations from dataclasses import asdict, dataclass, field from types import SimpleNamespace from typing import Any, Dict, List, Optional @dataclass class LoRATuneConfig: """Structured config matching the current loratune.py CLI surface.""" base_model: str = "" output_dir: str = "" device: str = "cuda" dtype: str = "bfloat16" trust_remote_code: bool = False seed: int = 42 instruction_dataset: str = "tatsu-lab/alpaca" instruction_config: Optional[str] = None instruction_split: str = "train" instruction_field_instruction: str = "instruction" instruction_field_input: str = "input" instruction_field_output: str = "output" max_samples: int = 0 seq_len: int = 1024 batch_size: int = 64 micro_batch_size: int = 4 epochs: float = 1.0 learning_rate: float = 1e-4 weight_decay: float = 0.0 max_grad_norm: float = 1.0 log_steps: int = 100 wikitext2_ppl_on_log: bool = True wikitext2_ppl_seq_len: int = 128 wikitext2_ppl_batch_size: int = 8 wikitext2_ppl_max_batches: Optional[int] = None lora_rank: int = 8 lora_alpha: float = 16.0 lora_dropout: float = 0.0 lora_target_modules: List[str] = field( default_factory=lambda: [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj", ] ) @property def grad_accum_steps(self) -> int: if self.batch_size < 1: raise ValueError("batch_size must be >= 1") if self.micro_batch_size < 1: raise ValueError("micro_batch_size must be >= 1") if self.batch_size < self.micro_batch_size: raise ValueError("batch_size must be >= micro_batch_size") return self.batch_size // self.micro_batch_size def validate(self) -> "LoRATuneConfig": _ = self.grad_accum_steps if not self.base_model: raise ValueError("base_model must be set") if not self.output_dir: raise ValueError("output_dir must be set") return self def to_dict(self) -> Dict[str, Any]: data = asdict(self) data["grad_accum_steps"] = self.grad_accum_steps return data def to_namespace(self) -> SimpleNamespace: return SimpleNamespace(**self.to_dict()) @classmethod def from_dict(cls, values: Dict[str, Any]) -> "LoRATuneConfig": return cls(**values)