File size: 5,306 Bytes
e90bc49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# /// script
# dependencies = [
#   "trl>=0.20,<0.24",
#   "peft>=0.17,<0.18",
#   "transformers>=4.55,<4.60",
#   "accelerate>=1.7,<2",
#   "datasets>=2.20,<4",
#   "trackio",
#   "kernels>=0.9,<0.10",
# ]
# ///
# Deps are pinned on purpose: gpt-oss is a Mixture-of-Experts model whose
# `kernels` lib must match `transformers`, and "latest of everything" makes them
# clash at import. These caps are the validated pair, and are harmless for dense
# models like Llama. See docs/FINETUNE_MODAL.md for the full story.
"""ProofKit — fine-tune a small model (LoRA SFT) on Hugging Face Jobs.



This script runs ON HUGGING FACE JOBS, not locally. It loads the ProofKit SFT

dataset from the Hub, trains an attention-only LoRA adapter, and pushes it back

to the Hub. It works for any base model; the intended HF Jobs target is a small

dense model like meta-llama/Llama-3.2-3B-Instruct — fast and cheap on a T4, and

the model that feeds ProofKit's GGUF / llama.cpp backend (the Llama Champion +

Off the Grid badges). gpt-oss-20b is trained on Modal instead, where its MoE

experts can be adapted on a bigger GPU — see scripts/modal_train_gpt_oss.py and

docs/FINETUNE_MODAL.md.



⚠️  The Jobs container is ephemeral — everything is deleted when the job ends.

    `push_to_hub=True` (+ the HF_TOKEN secret) is what makes the result survive.



Submit it from your terminal (after uploading this file to a Hub repo):



    hf jobs uv run \\

      --flavor a100-large \\

      --timeout 3h \\

      --secrets HF_TOKEN \\

      "https://huggingface.co/visproj/proofkit-train-scripts/resolve/main/train_gpt_oss.py"



Configuration is via environment variables (pass with `--env KEY=VALUE`):



    BASE_MODEL    base model to tune        (default: openai/gpt-oss-20b)

    DATASET_REPO  Hub dataset to train on   (default: visproj/proofkit-sft)

    MODEL_REPO    Hub repo to push to       (default: visproj/proofkit-gpt-oss-20b-lora)

    EPOCHS        training epochs           (default: 3)

    LR            learning rate             (default: 2e-4)

    MAX_LEN       max sequence length       (default: 1024)



See docs/FINETUNE_HF_JOBS.md for the full runbook.

"""
import os

from datasets import load_dataset
from peft import LoraConfig, TaskType
from trl import SFTConfig, SFTTrainer

BASE_MODEL = os.environ.get("BASE_MODEL", "openai/gpt-oss-20b")
DATASET_REPO = os.environ.get("DATASET_REPO", "visproj/proofkit-sft")
MODEL_REPO = os.environ.get("MODEL_REPO", "visproj/proofkit-gpt-oss-20b-lora")
EPOCHS = float(os.environ.get("EPOCHS", "3"))
LR = float(os.environ.get("LR", "2e-4"))
MAX_LEN = int(os.environ.get("MAX_LEN", "1024"))
is_gpt_oss = "gpt-oss" in BASE_MODEL.lower()

print(f"Base model : {BASE_MODEL}", flush=True)
print(f"Dataset    : {DATASET_REPO}", flush=True)
print(f"Push to    : {MODEL_REPO}", flush=True)

dataset = load_dataset(DATASET_REPO, split="train")
print(f"Examples   : {len(dataset)}", flush=True)

model_init_kwargs = {
    "attn_implementation": "eager",
    "torch_dtype": "auto",
    "use_cache": False,
}
# Only gpt-oss ships MXFP4-quantized MoE weights that need dequantizing to train.
# Dense models (Llama, Qwen, …) must NOT get a quantization_config — applying one
# to a non-quantized model is meaningless and can error.
if is_gpt_oss:
    try:
        from transformers import Mxfp4Config

        model_init_kwargs["quantization_config"] = Mxfp4Config(dequantize=True)
        print("MXFP4 dequantize: on", flush=True)
    except Exception:
        print("MXFP4 dequantize: unavailable (training in native dtype)", flush=True)

# Attention-only LoRA over all linear layers — the standard, reliable recipe that
# works for any architecture (attention projections + MLP/router linears). We do
# NOT adapt gpt-oss's fused MoE experts here: `target_parameters` would fail to
# match on a dense model like Llama, and on gpt-oss it needs a 141 GB GPU. HF Jobs
# is ProofKit's small-model path, so attention-only is exactly the right recipe.
# (Expert adaptation lives in scripts/modal_train_gpt_oss.py with TUNE_EXPERTS=1.)
lora = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules="all-linear",
)

args = SFTConfig(
    output_dir="proofkit-gpt-oss-20b",
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,        # effective batch size = 8
    learning_rate=LR,
    max_length=MAX_LEN,
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=10,
    save_strategy="no",                   # small run — push the final model once at the end
    push_to_hub=True,                     # ← results survive the ephemeral container
    hub_model_id=MODEL_REPO,
    report_to="trackio",                  # live metrics at https://huggingface.co/<you>/trackio
    run_name="gpt-oss-20b-lora-sft",
    model_init_kwargs=model_init_kwargs,
)

trainer = SFTTrainer(
    model=BASE_MODEL,
    train_dataset=dataset,
    peft_config=lora,
    args=args,
)

print("Training...", flush=True)
trainer.train()
trainer.push_to_hub()
print(f"Done. Adapter pushed to https://huggingface.co/{MODEL_REPO}", flush=True)