|
|
""" |
|
|
Training Configuration Classes |
|
|
|
|
|
Contains dataclasses for LoRA and training configurations. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LoRAConfig: |
|
|
"""Configuration for LoRA (Low-Rank Adaptation) training. |
|
|
|
|
|
Attributes: |
|
|
r: LoRA rank (dimension of low-rank matrices) |
|
|
alpha: LoRA scaling factor (alpha/r determines the scaling) |
|
|
dropout: Dropout probability for LoRA layers |
|
|
target_modules: List of module names to apply LoRA to |
|
|
bias: Whether to train bias parameters ("none", "all", or "lora_only") |
|
|
""" |
|
|
r: int = 8 |
|
|
alpha: int = 16 |
|
|
dropout: float = 0.1 |
|
|
target_modules: List[str] = field(default_factory=lambda: [ |
|
|
"q_proj", "k_proj", "v_proj", "o_proj" |
|
|
]) |
|
|
bias: str = "none" |
|
|
|
|
|
def to_dict(self): |
|
|
"""Convert to dictionary for PEFT config.""" |
|
|
return { |
|
|
"r": self.r, |
|
|
"lora_alpha": self.alpha, |
|
|
"lora_dropout": self.dropout, |
|
|
"target_modules": self.target_modules, |
|
|
"bias": self.bias, |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
"""Configuration for LoRA training process. |
|
|
|
|
|
Training uses: |
|
|
- BFloat16 precision (only supported precision) |
|
|
- Discrete timesteps from turbo shift=3.0 schedule (8 steps) |
|
|
- Randomly samples one of 8 timesteps per training step: |
|
|
[1.0, 0.9545, 0.9, 0.8333, 0.75, 0.6429, 0.5, 0.3] |
|
|
|
|
|
Attributes: |
|
|
shift: Timestep shift factor (fixed at 3.0 for turbo model) |
|
|
num_inference_steps: Number of inference steps (fixed at 8 for turbo) |
|
|
learning_rate: Initial learning rate |
|
|
batch_size: Training batch size |
|
|
gradient_accumulation_steps: Number of gradient accumulation steps |
|
|
max_epochs: Maximum number of training epochs |
|
|
save_every_n_epochs: Save checkpoint every N epochs |
|
|
warmup_steps: Number of warmup steps for learning rate scheduler |
|
|
weight_decay: Weight decay for optimizer |
|
|
max_grad_norm: Maximum gradient norm for clipping |
|
|
mixed_precision: Always "bf16" (only supported precision) |
|
|
seed: Random seed for reproducibility |
|
|
output_dir: Directory to save checkpoints and logs |
|
|
""" |
|
|
|
|
|
shift: float = 3.0 |
|
|
num_inference_steps: int = 8 |
|
|
learning_rate: float = 1e-4 |
|
|
batch_size: int = 1 |
|
|
gradient_accumulation_steps: int = 4 |
|
|
max_epochs: int = 100 |
|
|
save_every_n_epochs: int = 10 |
|
|
warmup_steps: int = 100 |
|
|
weight_decay: float = 0.01 |
|
|
max_grad_norm: float = 1.0 |
|
|
mixed_precision: str = "bf16" |
|
|
seed: int = 42 |
|
|
output_dir: str = "./lora_output" |
|
|
|
|
|
|
|
|
num_workers: int = 4 |
|
|
pin_memory: bool = True |
|
|
|
|
|
|
|
|
log_every_n_steps: int = 10 |
|
|
|
|
|
def to_dict(self): |
|
|
"""Convert to dictionary.""" |
|
|
return { |
|
|
"shift": self.shift, |
|
|
"num_inference_steps": self.num_inference_steps, |
|
|
"learning_rate": self.learning_rate, |
|
|
"batch_size": self.batch_size, |
|
|
"gradient_accumulation_steps": self.gradient_accumulation_steps, |
|
|
"max_epochs": self.max_epochs, |
|
|
"save_every_n_epochs": self.save_every_n_epochs, |
|
|
"warmup_steps": self.warmup_steps, |
|
|
"weight_decay": self.weight_decay, |
|
|
"max_grad_norm": self.max_grad_norm, |
|
|
"mixed_precision": self.mixed_precision, |
|
|
"seed": self.seed, |
|
|
"output_dir": self.output_dir, |
|
|
"num_workers": self.num_workers, |
|
|
"pin_memory": self.pin_memory, |
|
|
"log_every_n_steps": self.log_every_n_steps, |
|
|
} |
|
|
|