dmaheshwar22's picture
deploy: replace template with real demo
0dd7c80 verified
"""LoRA supervised fine-tuning over rejection-sampled code data.
Wraps trl.SFTTrainer with PEFT for efficient adapter-based finetuning.
Loads a YAML config, formats examples with the Qwen chat template (matching
inference-time formatting), trains, and saves adapters.
Single-GPU. Multi-GPU is a Week 4+ concern.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, cast
import torch # type: ignore[import-not-found]
import yaml
from datasets import load_dataset # type: ignore[import-untyped]
from peft import LoraConfig, TaskType # type: ignore[import-untyped]
from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore[import-untyped]
from trl import SFTConfig, SFTTrainer # type: ignore[import-untyped]
# Must match Proposer.DEFAULT_SYSTEM_PROMPT so training and inference see
# the same chat-template layout.
SYSTEM_PROMPT = (
"You are an expert Python programmer. Respond with a single Python code "
"block containing the requested function and nothing else."
)
@dataclass
class LoraSpec:
r: int
alpha: int
dropout: float
target_modules: list[str]
@dataclass
class TrainerSpec:
num_train_epochs: int
per_device_train_batch_size: int
gradient_accumulation_steps: int
learning_rate: float
lr_scheduler_type: str
warmup_ratio: float
weight_decay: float
bf16: bool
max_seq_length: int
save_strategy: str
save_total_limit: int
logging_steps: int
report_to: list[str]
seed: int
@dataclass
class LoggingSpec:
wandb_project: str
run_name: str
tags: list[str]
@dataclass
class SFTRunConfig:
model_id: str
dataset_path: str
output_dir: str
lora: LoraSpec
trainer: TrainerSpec
logging: LoggingSpec
def load_config(path: str | Path) -> SFTRunConfig:
"""Parse a YAML config into typed dataclasses."""
raw = cast("dict[str, Any]", yaml.safe_load(Path(path).read_text()))
return SFTRunConfig(
model_id=str(raw["model_id"]),
dataset_path=str(raw["dataset_path"]),
output_dir=str(raw["output_dir"]),
lora=LoraSpec(**raw["lora"]),
trainer=TrainerSpec(**raw["trainer"]),
logging=LoggingSpec(**raw["logging"]),
)
def _format_example(sample: dict[str, Any]) -> dict[str, list[dict[str, str]]]:
"""Return a `{"messages": [...]}` record for trl's chat-format auto-handler.
Three-turn: system + user (task prompt) + assistant (code block around
the rejection-sampled solution). trl's SFTTrainer detects the `messages`
column and applies the tokenizer's chat template internally — no need
to pre-template ourselves or set `dataset_text_field`.
"""
prompt = str(sample["prompt"])
solution = str(sample["solution"]).rstrip()
return {
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
{"role": "assistant", "content": f"```python\n{solution}\n```"},
]
}
def run_sft_training(config_path: str | Path) -> None:
"""Run LoRA SFT end-to-end from a YAML config."""
config = load_config(config_path)
# Skip W&B gracefully when the key is absent — training should still work.
report_to = list(config.trainer.report_to)
if "wandb" in report_to and not os.environ.get("WANDB_API_KEY"):
print("==> WANDB_API_KEY unset; disabling wandb reporting", flush=True)
report_to = [r for r in report_to if r != "wandb"]
if "wandb" in report_to:
os.environ["WANDB_PROJECT"] = config.logging.wandb_project
print(f"==> loading tokenizer + model {config.model_id}", flush=True)
tokenizer = AutoTokenizer.from_pretrained(config.model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# trl truncates per tokenizer.model_max_length; cap via the config value.
tokenizer.model_max_length = config.trainer.max_seq_length
model = AutoModelForCausalLM.from_pretrained(
config.model_id,
dtype=torch.bfloat16,
trust_remote_code=True,
)
print(f"==> loading dataset from {config.dataset_path}", flush=True)
raw_ds = cast(
"Any",
load_dataset("json", data_files=config.dataset_path, split="train"),
)
train_ds = raw_ds.map(
lambda s: _format_example(cast("dict[str, Any]", s)),
remove_columns=raw_ds.column_names,
)
print(f" {len(train_ds)} examples", flush=True)
lora_config = LoraConfig(
r=config.lora.r,
lora_alpha=config.lora.alpha,
lora_dropout=config.lora.dropout,
target_modules=list(config.lora.target_modules),
task_type=TaskType.CAUSAL_LM,
)
# Drop `dataset_text_field` / `max_seq_length` — trl >= 0.12 autodetects
# chat-formatted datasets from the `messages` column and handles tokenizer
# truncation via tokenizer.model_max_length by default.
sft_config = SFTConfig(
output_dir=config.output_dir,
num_train_epochs=config.trainer.num_train_epochs,
per_device_train_batch_size=config.trainer.per_device_train_batch_size,
gradient_accumulation_steps=config.trainer.gradient_accumulation_steps,
learning_rate=config.trainer.learning_rate,
lr_scheduler_type=config.trainer.lr_scheduler_type,
warmup_ratio=config.trainer.warmup_ratio,
weight_decay=config.trainer.weight_decay,
bf16=config.trainer.bf16,
save_strategy=config.trainer.save_strategy,
save_total_limit=config.trainer.save_total_limit,
logging_steps=config.trainer.logging_steps,
report_to=report_to,
seed=config.trainer.seed,
run_name=config.logging.run_name,
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=train_ds,
processing_class=tokenizer, # trl 0.12+ renamed from `tokenizer=`
peft_config=lora_config,
)
print("==> starting training", flush=True)
trainer.train()
print(f"==> saving adapter + tokenizer to {config.output_dir}", flush=True)
trainer.save_model(config.output_dir)
tokenizer.save_pretrained(config.output_dir)
print("==> done", flush=True)