visproj commited on
Commit
e90bc49
·
verified ·
1 Parent(s): beace15

Upload train_gpt_oss.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_gpt_oss.py +130 -0
train_gpt_oss.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.20,<0.24",
4
+ # "peft>=0.17,<0.18",
5
+ # "transformers>=4.55,<4.60",
6
+ # "accelerate>=1.7,<2",
7
+ # "datasets>=2.20,<4",
8
+ # "trackio",
9
+ # "kernels>=0.9,<0.10",
10
+ # ]
11
+ # ///
12
+ # Deps are pinned on purpose: gpt-oss is a Mixture-of-Experts model whose
13
+ # `kernels` lib must match `transformers`, and "latest of everything" makes them
14
+ # clash at import. These caps are the validated pair, and are harmless for dense
15
+ # models like Llama. See docs/FINETUNE_MODAL.md for the full story.
16
+ """ProofKit — fine-tune a small model (LoRA SFT) on Hugging Face Jobs.
17
+
18
+ This script runs ON HUGGING FACE JOBS, not locally. It loads the ProofKit SFT
19
+ dataset from the Hub, trains an attention-only LoRA adapter, and pushes it back
20
+ to the Hub. It works for any base model; the intended HF Jobs target is a small
21
+ dense model like meta-llama/Llama-3.2-3B-Instruct — fast and cheap on a T4, and
22
+ the model that feeds ProofKit's GGUF / llama.cpp backend (the Llama Champion +
23
+ Off the Grid badges). gpt-oss-20b is trained on Modal instead, where its MoE
24
+ experts can be adapted on a bigger GPU — see scripts/modal_train_gpt_oss.py and
25
+ docs/FINETUNE_MODAL.md.
26
+
27
+ ⚠️ The Jobs container is ephemeral — everything is deleted when the job ends.
28
+ `push_to_hub=True` (+ the HF_TOKEN secret) is what makes the result survive.
29
+
30
+ Submit it from your terminal (after uploading this file to a Hub repo):
31
+
32
+ hf jobs uv run \\
33
+ --flavor a100-large \\
34
+ --timeout 3h \\
35
+ --secrets HF_TOKEN \\
36
+ "https://huggingface.co/visproj/proofkit-train-scripts/resolve/main/train_gpt_oss.py"
37
+
38
+ Configuration is via environment variables (pass with `--env KEY=VALUE`):
39
+
40
+ BASE_MODEL base model to tune (default: openai/gpt-oss-20b)
41
+ DATASET_REPO Hub dataset to train on (default: visproj/proofkit-sft)
42
+ MODEL_REPO Hub repo to push to (default: visproj/proofkit-gpt-oss-20b-lora)
43
+ EPOCHS training epochs (default: 3)
44
+ LR learning rate (default: 2e-4)
45
+ MAX_LEN max sequence length (default: 1024)
46
+
47
+ See docs/FINETUNE_HF_JOBS.md for the full runbook.
48
+ """
49
+ import os
50
+
51
+ from datasets import load_dataset
52
+ from peft import LoraConfig, TaskType
53
+ from trl import SFTConfig, SFTTrainer
54
+
55
+ BASE_MODEL = os.environ.get("BASE_MODEL", "openai/gpt-oss-20b")
56
+ DATASET_REPO = os.environ.get("DATASET_REPO", "visproj/proofkit-sft")
57
+ MODEL_REPO = os.environ.get("MODEL_REPO", "visproj/proofkit-gpt-oss-20b-lora")
58
+ EPOCHS = float(os.environ.get("EPOCHS", "3"))
59
+ LR = float(os.environ.get("LR", "2e-4"))
60
+ MAX_LEN = int(os.environ.get("MAX_LEN", "1024"))
61
+ is_gpt_oss = "gpt-oss" in BASE_MODEL.lower()
62
+
63
+ print(f"Base model : {BASE_MODEL}", flush=True)
64
+ print(f"Dataset : {DATASET_REPO}", flush=True)
65
+ print(f"Push to : {MODEL_REPO}", flush=True)
66
+
67
+ dataset = load_dataset(DATASET_REPO, split="train")
68
+ print(f"Examples : {len(dataset)}", flush=True)
69
+
70
+ model_init_kwargs = {
71
+ "attn_implementation": "eager",
72
+ "torch_dtype": "auto",
73
+ "use_cache": False,
74
+ }
75
+ # Only gpt-oss ships MXFP4-quantized MoE weights that need dequantizing to train.
76
+ # Dense models (Llama, Qwen, …) must NOT get a quantization_config — applying one
77
+ # to a non-quantized model is meaningless and can error.
78
+ if is_gpt_oss:
79
+ try:
80
+ from transformers import Mxfp4Config
81
+
82
+ model_init_kwargs["quantization_config"] = Mxfp4Config(dequantize=True)
83
+ print("MXFP4 dequantize: on", flush=True)
84
+ except Exception:
85
+ print("MXFP4 dequantize: unavailable (training in native dtype)", flush=True)
86
+
87
+ # Attention-only LoRA over all linear layers — the standard, reliable recipe that
88
+ # works for any architecture (attention projections + MLP/router linears). We do
89
+ # NOT adapt gpt-oss's fused MoE experts here: `target_parameters` would fail to
90
+ # match on a dense model like Llama, and on gpt-oss it needs a 141 GB GPU. HF Jobs
91
+ # is ProofKit's small-model path, so attention-only is exactly the right recipe.
92
+ # (Expert adaptation lives in scripts/modal_train_gpt_oss.py with TUNE_EXPERTS=1.)
93
+ lora = LoraConfig(
94
+ r=8,
95
+ lora_alpha=16,
96
+ lora_dropout=0.05,
97
+ bias="none",
98
+ task_type=TaskType.CAUSAL_LM,
99
+ target_modules="all-linear",
100
+ )
101
+
102
+ args = SFTConfig(
103
+ output_dir="proofkit-gpt-oss-20b",
104
+ num_train_epochs=EPOCHS,
105
+ per_device_train_batch_size=1,
106
+ gradient_accumulation_steps=8, # effective batch size = 8
107
+ learning_rate=LR,
108
+ max_length=MAX_LEN,
109
+ bf16=True,
110
+ gradient_checkpointing=True,
111
+ logging_steps=10,
112
+ save_strategy="no", # small run — push the final model once at the end
113
+ push_to_hub=True, # ← results survive the ephemeral container
114
+ hub_model_id=MODEL_REPO,
115
+ report_to="trackio", # live metrics at https://huggingface.co/<you>/trackio
116
+ run_name="gpt-oss-20b-lora-sft",
117
+ model_init_kwargs=model_init_kwargs,
118
+ )
119
+
120
+ trainer = SFTTrainer(
121
+ model=BASE_MODEL,
122
+ train_dataset=dataset,
123
+ peft_config=lora,
124
+ args=args,
125
+ )
126
+
127
+ print("Training...", flush=True)
128
+ trainer.train()
129
+ trainer.push_to_hub()
130
+ print(f"Done. Adapter pushed to https://huggingface.co/{MODEL_REPO}", flush=True)