File size: 6,324 Bytes
0dd7c80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""LoRA supervised fine-tuning over rejection-sampled code data.

Wraps trl.SFTTrainer with PEFT for efficient adapter-based finetuning.
Loads a YAML config, formats examples with the Qwen chat template (matching
inference-time formatting), trains, and saves adapters.

Single-GPU. Multi-GPU is a Week 4+ concern.
"""

from __future__ import annotations

import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, cast

import torch  # type: ignore[import-not-found]
import yaml
from datasets import load_dataset  # type: ignore[import-untyped]
from peft import LoraConfig, TaskType  # type: ignore[import-untyped]
from transformers import AutoModelForCausalLM, AutoTokenizer  # type: ignore[import-untyped]
from trl import SFTConfig, SFTTrainer  # type: ignore[import-untyped]

# Must match Proposer.DEFAULT_SYSTEM_PROMPT so training and inference see
# the same chat-template layout.
SYSTEM_PROMPT = (
    "You are an expert Python programmer. Respond with a single Python code "
    "block containing the requested function and nothing else."
)


@dataclass
class LoraSpec:
    r: int
    alpha: int
    dropout: float
    target_modules: list[str]


@dataclass
class TrainerSpec:
    num_train_epochs: int
    per_device_train_batch_size: int
    gradient_accumulation_steps: int
    learning_rate: float
    lr_scheduler_type: str
    warmup_ratio: float
    weight_decay: float
    bf16: bool
    max_seq_length: int
    save_strategy: str
    save_total_limit: int
    logging_steps: int
    report_to: list[str]
    seed: int


@dataclass
class LoggingSpec:
    wandb_project: str
    run_name: str
    tags: list[str]


@dataclass
class SFTRunConfig:
    model_id: str
    dataset_path: str
    output_dir: str
    lora: LoraSpec
    trainer: TrainerSpec
    logging: LoggingSpec


def load_config(path: str | Path) -> SFTRunConfig:
    """Parse a YAML config into typed dataclasses."""
    raw = cast("dict[str, Any]", yaml.safe_load(Path(path).read_text()))
    return SFTRunConfig(
        model_id=str(raw["model_id"]),
        dataset_path=str(raw["dataset_path"]),
        output_dir=str(raw["output_dir"]),
        lora=LoraSpec(**raw["lora"]),
        trainer=TrainerSpec(**raw["trainer"]),
        logging=LoggingSpec(**raw["logging"]),
    )


def _format_example(sample: dict[str, Any]) -> dict[str, list[dict[str, str]]]:
    """Return a `{"messages": [...]}` record for trl's chat-format auto-handler.

    Three-turn: system + user (task prompt) + assistant (code block around
    the rejection-sampled solution). trl's SFTTrainer detects the `messages`
    column and applies the tokenizer's chat template internally — no need
    to pre-template ourselves or set `dataset_text_field`.
    """
    prompt = str(sample["prompt"])
    solution = str(sample["solution"]).rstrip()
    return {
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": f"```python\n{solution}\n```"},
        ]
    }


def run_sft_training(config_path: str | Path) -> None:
    """Run LoRA SFT end-to-end from a YAML config."""
    config = load_config(config_path)

    # Skip W&B gracefully when the key is absent — training should still work.
    report_to = list(config.trainer.report_to)
    if "wandb" in report_to and not os.environ.get("WANDB_API_KEY"):
        print("==> WANDB_API_KEY unset; disabling wandb reporting", flush=True)
        report_to = [r for r in report_to if r != "wandb"]
    if "wandb" in report_to:
        os.environ["WANDB_PROJECT"] = config.logging.wandb_project

    print(f"==> loading tokenizer + model {config.model_id}", flush=True)
    tokenizer = AutoTokenizer.from_pretrained(config.model_id, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    # trl truncates per tokenizer.model_max_length; cap via the config value.
    tokenizer.model_max_length = config.trainer.max_seq_length

    model = AutoModelForCausalLM.from_pretrained(
        config.model_id,
        dtype=torch.bfloat16,
        trust_remote_code=True,
    )

    print(f"==> loading dataset from {config.dataset_path}", flush=True)
    raw_ds = cast(
        "Any",
        load_dataset("json", data_files=config.dataset_path, split="train"),
    )
    train_ds = raw_ds.map(
        lambda s: _format_example(cast("dict[str, Any]", s)),
        remove_columns=raw_ds.column_names,
    )
    print(f"    {len(train_ds)} examples", flush=True)

    lora_config = LoraConfig(
        r=config.lora.r,
        lora_alpha=config.lora.alpha,
        lora_dropout=config.lora.dropout,
        target_modules=list(config.lora.target_modules),
        task_type=TaskType.CAUSAL_LM,
    )

    # Drop `dataset_text_field` / `max_seq_length` — trl >= 0.12 autodetects
    # chat-formatted datasets from the `messages` column and handles tokenizer
    # truncation via tokenizer.model_max_length by default.
    sft_config = SFTConfig(
        output_dir=config.output_dir,
        num_train_epochs=config.trainer.num_train_epochs,
        per_device_train_batch_size=config.trainer.per_device_train_batch_size,
        gradient_accumulation_steps=config.trainer.gradient_accumulation_steps,
        learning_rate=config.trainer.learning_rate,
        lr_scheduler_type=config.trainer.lr_scheduler_type,
        warmup_ratio=config.trainer.warmup_ratio,
        weight_decay=config.trainer.weight_decay,
        bf16=config.trainer.bf16,
        save_strategy=config.trainer.save_strategy,
        save_total_limit=config.trainer.save_total_limit,
        logging_steps=config.trainer.logging_steps,
        report_to=report_to,
        seed=config.trainer.seed,
        run_name=config.logging.run_name,
    )

    trainer = SFTTrainer(
        model=model,
        args=sft_config,
        train_dataset=train_ds,
        processing_class=tokenizer,  # trl 0.12+ renamed from `tokenizer=`
        peft_config=lora_config,
    )

    print("==> starting training", flush=True)
    trainer.train()

    print(f"==> saving adapter + tokenizer to {config.output_dir}", flush=True)
    trainer.save_model(config.output_dir)
    tokenizer.save_pretrained(config.output_dir)
    print("==> done", flush=True)