workbench / training /planner.py
GitHub Actions
Initial ZeroGPU deployment with spaces shim
7f9dfed
Raw
History Blame Contribute Delete
5.54 kB
from __future__ import annotations
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
import yaml
@dataclass(frozen=True)
class LoraConfig:
"""LoRA settings for dry-run training plans."""
rank: int
alpha: int
dropout: float
@dataclass(frozen=True)
class TrainingConfig:
"""Training settings for dry-run validation."""
epochs: int
batch_size: int
grad_accum: int
lr: float
report_to: str
@dataclass(frozen=True)
class TrainingPlan:
"""Non-executing local training plan."""
dataset_path: str
output_dir: str
lora: LoraConfig
training: TrainingConfig
validation_errors: list[str]
hardware_notes: list[str]
executes_training: bool = False
def as_dict(self) -> dict[str, Any]:
return {
"dataset_path": self.dataset_path,
"output_dir": self.output_dir,
"lora": asdict(self.lora),
"training": asdict(self.training),
"validation_errors": self.validation_errors,
"hardware_notes": self.hardware_notes,
"executes_training": self.executes_training,
}
def as_text(self) -> str:
status = "valid" if not self.validation_errors else "blocked"
lines = [
"Training dry-run plan",
f"Status: {status}",
f"Dataset: {self.dataset_path or '(none selected)'}",
f"Output directory: {self.output_dir}",
f"LoRA rank: {self.lora.rank}",
f"LoRA alpha: {self.lora.alpha}",
f"LoRA dropout: {self.lora.dropout}",
f"Epochs: {self.training.epochs}",
f"Batch size: {self.training.batch_size}",
f"Gradient accumulation: {self.training.grad_accum}",
f"Learning rate: {self.training.lr}",
f"Report to: {self.training.report_to}",
f"Executes training: {self.executes_training}",
"",
"Validation:",
*[f"- {item}" for item in (self.validation_errors or ["passed"])],
"",
"Hardware notes:",
*[f"- {item}" for item in self.hardware_notes],
]
return "\n".join(lines)
def load_training_config(
path: str | Path = "config/training.yaml",
) -> tuple[LoraConfig, TrainingConfig]:
raw = yaml.safe_load(Path(path).read_text(encoding="utf-8")) or {}
lora = raw.get("lora", {})
training = raw.get("training", {})
return (
LoraConfig(
rank=int(lora.get("rank", 16)),
alpha=int(lora.get("alpha", 32)),
dropout=float(lora.get("dropout", 0.05)),
),
TrainingConfig(
epochs=int(training.get("epochs", 1)),
batch_size=int(training.get("batch_size", 2)),
grad_accum=int(training.get("grad_accum", 4)),
lr=float(training.get("lr", 0.0002)),
report_to=str(training.get("report_to", "none")),
),
)
def build_training_plan(
dataset_path: str,
rank: int | None = None,
epochs: int | None = None,
output_root: str | Path = "outputs/checkpoints",
config_path: str | Path = "config/training.yaml",
) -> TrainingPlan:
lora, training = load_training_config(config_path)
if rank is not None:
lora = LoraConfig(rank=rank, alpha=max(rank * 2, lora.alpha), dropout=lora.dropout)
if epochs is not None:
training = TrainingConfig(
epochs=epochs,
batch_size=training.batch_size,
grad_accum=training.grad_accum,
lr=training.lr,
report_to=training.report_to,
)
output_dir = str(Path(output_root) / _safe_output_name(dataset_path))
return TrainingPlan(
dataset_path=dataset_path,
output_dir=output_dir,
lora=lora,
training=training,
validation_errors=validate_training_plan(dataset_path, lora, training),
hardware_notes=training_hardware_notes(),
)
def validate_training_plan(
dataset_path: str,
lora: LoraConfig,
training: TrainingConfig,
) -> list[str]:
errors = []
if not dataset_path:
errors.append("Training dataset path is required.")
elif not Path(dataset_path).exists():
errors.append(f"Training dataset does not exist: {dataset_path}")
if lora.rank <= 0:
errors.append("LoRA rank must be positive.")
if not 0 <= lora.dropout < 1:
errors.append("LoRA dropout must be between 0 and 1.")
if training.epochs <= 0:
errors.append("Epochs must be positive.")
if training.batch_size <= 0:
errors.append("Batch size must be positive.")
if training.grad_accum <= 0:
errors.append("Gradient accumulation must be positive.")
if training.lr <= 0:
errors.append("Learning rate must be positive.")
return errors
def training_hardware_notes() -> list[str]:
return [
"MiniCPM 1B LoRA may run on modest CUDA hardware; CPU-only training will be slow.",
"MiniCPM 8B and vision fine-tuning need substantially more VRAM "
"or quantized/adapted tooling.",
"Install PEFT/TRL, SWIFT, or LLaMA-Factory only after choosing the final training path.",
"Keep checkpoints and model weights out of git.",
]
def _safe_output_name(dataset_path: str) -> str:
if not dataset_path:
return "unselected"
stem = Path(dataset_path).stem or "dataset"
return "".join(char if char.isalnum() or char in {"-", "_"} else "_" for char in stem)