| """LoRA supervised fine-tuning over rejection-sampled code data. |
| |
| Wraps trl.SFTTrainer with PEFT for efficient adapter-based finetuning. |
| Loads a YAML config, formats examples with the Qwen chat template (matching |
| inference-time formatting), trains, and saves adapters. |
| |
| Single-GPU. Multi-GPU is a Week 4+ concern. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, cast |
|
|
| import torch |
| import yaml |
| from datasets import load_dataset |
| from peft import LoraConfig, TaskType |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from trl import SFTConfig, SFTTrainer |
|
|
| |
| |
| SYSTEM_PROMPT = ( |
| "You are an expert Python programmer. Respond with a single Python code " |
| "block containing the requested function and nothing else." |
| ) |
|
|
|
|
| @dataclass |
| class LoraSpec: |
| r: int |
| alpha: int |
| dropout: float |
| target_modules: list[str] |
|
|
|
|
| @dataclass |
| class TrainerSpec: |
| num_train_epochs: int |
| per_device_train_batch_size: int |
| gradient_accumulation_steps: int |
| learning_rate: float |
| lr_scheduler_type: str |
| warmup_ratio: float |
| weight_decay: float |
| bf16: bool |
| max_seq_length: int |
| save_strategy: str |
| save_total_limit: int |
| logging_steps: int |
| report_to: list[str] |
| seed: int |
|
|
|
|
| @dataclass |
| class LoggingSpec: |
| wandb_project: str |
| run_name: str |
| tags: list[str] |
|
|
|
|
| @dataclass |
| class SFTRunConfig: |
| model_id: str |
| dataset_path: str |
| output_dir: str |
| lora: LoraSpec |
| trainer: TrainerSpec |
| logging: LoggingSpec |
|
|
|
|
| def load_config(path: str | Path) -> SFTRunConfig: |
| """Parse a YAML config into typed dataclasses.""" |
| raw = cast("dict[str, Any]", yaml.safe_load(Path(path).read_text())) |
| return SFTRunConfig( |
| model_id=str(raw["model_id"]), |
| dataset_path=str(raw["dataset_path"]), |
| output_dir=str(raw["output_dir"]), |
| lora=LoraSpec(**raw["lora"]), |
| trainer=TrainerSpec(**raw["trainer"]), |
| logging=LoggingSpec(**raw["logging"]), |
| ) |
|
|
|
|
| def _format_example(sample: dict[str, Any]) -> dict[str, list[dict[str, str]]]: |
| """Return a `{"messages": [...]}` record for trl's chat-format auto-handler. |
| |
| Three-turn: system + user (task prompt) + assistant (code block around |
| the rejection-sampled solution). trl's SFTTrainer detects the `messages` |
| column and applies the tokenizer's chat template internally — no need |
| to pre-template ourselves or set `dataset_text_field`. |
| """ |
| prompt = str(sample["prompt"]) |
| solution = str(sample["solution"]).rstrip() |
| return { |
| "messages": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt}, |
| {"role": "assistant", "content": f"```python\n{solution}\n```"}, |
| ] |
| } |
|
|
|
|
| def run_sft_training(config_path: str | Path) -> None: |
| """Run LoRA SFT end-to-end from a YAML config.""" |
| config = load_config(config_path) |
|
|
| |
| report_to = list(config.trainer.report_to) |
| if "wandb" in report_to and not os.environ.get("WANDB_API_KEY"): |
| print("==> WANDB_API_KEY unset; disabling wandb reporting", flush=True) |
| report_to = [r for r in report_to if r != "wandb"] |
| if "wandb" in report_to: |
| os.environ["WANDB_PROJECT"] = config.logging.wandb_project |
|
|
| print(f"==> loading tokenizer + model {config.model_id}", flush=True) |
| tokenizer = AutoTokenizer.from_pretrained(config.model_id, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| tokenizer.model_max_length = config.trainer.max_seq_length |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| config.model_id, |
| dtype=torch.bfloat16, |
| trust_remote_code=True, |
| ) |
|
|
| print(f"==> loading dataset from {config.dataset_path}", flush=True) |
| raw_ds = cast( |
| "Any", |
| load_dataset("json", data_files=config.dataset_path, split="train"), |
| ) |
| train_ds = raw_ds.map( |
| lambda s: _format_example(cast("dict[str, Any]", s)), |
| remove_columns=raw_ds.column_names, |
| ) |
| print(f" {len(train_ds)} examples", flush=True) |
|
|
| lora_config = LoraConfig( |
| r=config.lora.r, |
| lora_alpha=config.lora.alpha, |
| lora_dropout=config.lora.dropout, |
| target_modules=list(config.lora.target_modules), |
| task_type=TaskType.CAUSAL_LM, |
| ) |
|
|
| |
| |
| |
| sft_config = SFTConfig( |
| output_dir=config.output_dir, |
| num_train_epochs=config.trainer.num_train_epochs, |
| per_device_train_batch_size=config.trainer.per_device_train_batch_size, |
| gradient_accumulation_steps=config.trainer.gradient_accumulation_steps, |
| learning_rate=config.trainer.learning_rate, |
| lr_scheduler_type=config.trainer.lr_scheduler_type, |
| warmup_ratio=config.trainer.warmup_ratio, |
| weight_decay=config.trainer.weight_decay, |
| bf16=config.trainer.bf16, |
| save_strategy=config.trainer.save_strategy, |
| save_total_limit=config.trainer.save_total_limit, |
| logging_steps=config.trainer.logging_steps, |
| report_to=report_to, |
| seed=config.trainer.seed, |
| run_name=config.logging.run_name, |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, |
| args=sft_config, |
| train_dataset=train_ds, |
| processing_class=tokenizer, |
| peft_config=lora_config, |
| ) |
|
|
| print("==> starting training", flush=True) |
| trainer.train() |
|
|
| print(f"==> saving adapter + tokenizer to {config.output_dir}", flush=True) |
| trainer.save_model(config.output_dir) |
| tokenizer.save_pretrained(config.output_dir) |
| print("==> done", flush=True) |
|
|