"""LoRA supervised fine-tuning over rejection-sampled code data. Wraps trl.SFTTrainer with PEFT for efficient adapter-based finetuning. Loads a YAML config, formats examples with the Qwen chat template (matching inference-time formatting), trains, and saves adapters. Single-GPU. Multi-GPU is a Week 4+ concern. """ from __future__ import annotations import os from dataclasses import dataclass from pathlib import Path from typing import Any, cast import torch # type: ignore[import-not-found] import yaml from datasets import load_dataset # type: ignore[import-untyped] from peft import LoraConfig, TaskType # type: ignore[import-untyped] from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore[import-untyped] from trl import SFTConfig, SFTTrainer # type: ignore[import-untyped] # Must match Proposer.DEFAULT_SYSTEM_PROMPT so training and inference see # the same chat-template layout. SYSTEM_PROMPT = ( "You are an expert Python programmer. Respond with a single Python code " "block containing the requested function and nothing else." ) @dataclass class LoraSpec: r: int alpha: int dropout: float target_modules: list[str] @dataclass class TrainerSpec: num_train_epochs: int per_device_train_batch_size: int gradient_accumulation_steps: int learning_rate: float lr_scheduler_type: str warmup_ratio: float weight_decay: float bf16: bool max_seq_length: int save_strategy: str save_total_limit: int logging_steps: int report_to: list[str] seed: int @dataclass class LoggingSpec: wandb_project: str run_name: str tags: list[str] @dataclass class SFTRunConfig: model_id: str dataset_path: str output_dir: str lora: LoraSpec trainer: TrainerSpec logging: LoggingSpec def load_config(path: str | Path) -> SFTRunConfig: """Parse a YAML config into typed dataclasses.""" raw = cast("dict[str, Any]", yaml.safe_load(Path(path).read_text())) return SFTRunConfig( model_id=str(raw["model_id"]), dataset_path=str(raw["dataset_path"]), output_dir=str(raw["output_dir"]), lora=LoraSpec(**raw["lora"]), trainer=TrainerSpec(**raw["trainer"]), logging=LoggingSpec(**raw["logging"]), ) def _format_example(sample: dict[str, Any]) -> dict[str, list[dict[str, str]]]: """Return a `{"messages": [...]}` record for trl's chat-format auto-handler. Three-turn: system + user (task prompt) + assistant (code block around the rejection-sampled solution). trl's SFTTrainer detects the `messages` column and applies the tokenizer's chat template internally — no need to pre-template ourselves or set `dataset_text_field`. """ prompt = str(sample["prompt"]) solution = str(sample["solution"]).rstrip() return { "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": prompt}, {"role": "assistant", "content": f"```python\n{solution}\n```"}, ] } def run_sft_training(config_path: str | Path) -> None: """Run LoRA SFT end-to-end from a YAML config.""" config = load_config(config_path) # Skip W&B gracefully when the key is absent — training should still work. report_to = list(config.trainer.report_to) if "wandb" in report_to and not os.environ.get("WANDB_API_KEY"): print("==> WANDB_API_KEY unset; disabling wandb reporting", flush=True) report_to = [r for r in report_to if r != "wandb"] if "wandb" in report_to: os.environ["WANDB_PROJECT"] = config.logging.wandb_project print(f"==> loading tokenizer + model {config.model_id}", flush=True) tokenizer = AutoTokenizer.from_pretrained(config.model_id, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # trl truncates per tokenizer.model_max_length; cap via the config value. tokenizer.model_max_length = config.trainer.max_seq_length model = AutoModelForCausalLM.from_pretrained( config.model_id, dtype=torch.bfloat16, trust_remote_code=True, ) print(f"==> loading dataset from {config.dataset_path}", flush=True) raw_ds = cast( "Any", load_dataset("json", data_files=config.dataset_path, split="train"), ) train_ds = raw_ds.map( lambda s: _format_example(cast("dict[str, Any]", s)), remove_columns=raw_ds.column_names, ) print(f" {len(train_ds)} examples", flush=True) lora_config = LoraConfig( r=config.lora.r, lora_alpha=config.lora.alpha, lora_dropout=config.lora.dropout, target_modules=list(config.lora.target_modules), task_type=TaskType.CAUSAL_LM, ) # Drop `dataset_text_field` / `max_seq_length` — trl >= 0.12 autodetects # chat-formatted datasets from the `messages` column and handles tokenizer # truncation via tokenizer.model_max_length by default. sft_config = SFTConfig( output_dir=config.output_dir, num_train_epochs=config.trainer.num_train_epochs, per_device_train_batch_size=config.trainer.per_device_train_batch_size, gradient_accumulation_steps=config.trainer.gradient_accumulation_steps, learning_rate=config.trainer.learning_rate, lr_scheduler_type=config.trainer.lr_scheduler_type, warmup_ratio=config.trainer.warmup_ratio, weight_decay=config.trainer.weight_decay, bf16=config.trainer.bf16, save_strategy=config.trainer.save_strategy, save_total_limit=config.trainer.save_total_limit, logging_steps=config.trainer.logging_steps, report_to=report_to, seed=config.trainer.seed, run_name=config.logging.run_name, ) trainer = SFTTrainer( model=model, args=sft_config, train_dataset=train_ds, processing_class=tokenizer, # trl 0.12+ renamed from `tokenizer=` peft_config=lora_config, ) print("==> starting training", flush=True) trainer.train() print(f"==> saving adapter + tokenizer to {config.output_dir}", flush=True) trainer.save_model(config.output_dir) tokenizer.save_pretrained(config.output_dir) print("==> done", flush=True)