File size: 3,765 Bytes
9f5c8f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
"""
Training Configuration Classes
Contains dataclasses for LoRA and training configurations.
"""
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class LoRAConfig:
"""Configuration for LoRA (Low-Rank Adaptation) training.
Attributes:
r: LoRA rank (dimension of low-rank matrices)
alpha: LoRA scaling factor (alpha/r determines the scaling)
dropout: Dropout probability for LoRA layers
target_modules: List of module names to apply LoRA to
bias: Whether to train bias parameters ("none", "all", or "lora_only")
"""
r: int = 8
alpha: int = 16
dropout: float = 0.1
target_modules: List[str] = field(default_factory=lambda: [
"q_proj", "k_proj", "v_proj", "o_proj"
])
bias: str = "none"
def to_dict(self):
"""Convert to dictionary for PEFT config."""
return {
"r": self.r,
"lora_alpha": self.alpha,
"lora_dropout": self.dropout,
"target_modules": self.target_modules,
"bias": self.bias,
}
@dataclass
class TrainingConfig:
"""Configuration for LoRA training process.
Training uses:
- BFloat16 precision (only supported precision)
- Discrete timesteps from turbo shift=3.0 schedule (8 steps)
- Randomly samples one of 8 timesteps per training step:
[1.0, 0.9545, 0.9, 0.8333, 0.75, 0.6429, 0.5, 0.3]
Attributes:
shift: Timestep shift factor (fixed at 3.0 for turbo model)
num_inference_steps: Number of inference steps (fixed at 8 for turbo)
learning_rate: Initial learning rate
batch_size: Training batch size
gradient_accumulation_steps: Number of gradient accumulation steps
max_epochs: Maximum number of training epochs
save_every_n_epochs: Save checkpoint every N epochs
warmup_steps: Number of warmup steps for learning rate scheduler
weight_decay: Weight decay for optimizer
max_grad_norm: Maximum gradient norm for clipping
mixed_precision: Always "bf16" (only supported precision)
seed: Random seed for reproducibility
output_dir: Directory to save checkpoints and logs
"""
# Fixed for turbo model
shift: float = 3.0 # Fixed: turbo uses shift=3.0
num_inference_steps: int = 8 # Fixed: turbo uses 8 steps
learning_rate: float = 1e-4
batch_size: int = 1
gradient_accumulation_steps: int = 4
max_epochs: int = 100
save_every_n_epochs: int = 10
warmup_steps: int = 100
weight_decay: float = 0.01
max_grad_norm: float = 1.0
mixed_precision: str = "bf16" # Fixed: only bf16 supported
seed: int = 42
output_dir: str = "./lora_output"
# Data loading
num_workers: int = 4
pin_memory: bool = True
# Logging
log_every_n_steps: int = 10
def to_dict(self):
"""Convert to dictionary."""
return {
"shift": self.shift,
"num_inference_steps": self.num_inference_steps,
"learning_rate": self.learning_rate,
"batch_size": self.batch_size,
"gradient_accumulation_steps": self.gradient_accumulation_steps,
"max_epochs": self.max_epochs,
"save_every_n_epochs": self.save_every_n_epochs,
"warmup_steps": self.warmup_steps,
"weight_decay": self.weight_decay,
"max_grad_norm": self.max_grad_norm,
"mixed_precision": self.mixed_precision,
"seed": self.seed,
"output_dir": self.output_dir,
"num_workers": self.num_workers,
"pin_memory": self.pin_memory,
"log_every_n_steps": self.log_every_n_steps,
}
|