NorthernTribe-Research commited on
Commit
4461ccc
·
verified ·
1 Parent(s): 2754a79

Add SOTA advancement pipeline: multi-stage weighted curriculum + pass@k eval harness.

Browse files
README.md CHANGED
@@ -24,8 +24,11 @@ model from the merged dataset in `data/releases/v1/`.
24
 
25
  - `configs/deepseek_math.yaml`: preset for `DeepSeek-Math`
26
  - `configs/deepseek_math_v2.yaml`: preset for `DeepSeek-Math-V2`
 
27
  - `scripts/train_sft.py`: LoRA/QLoRA supervised fine-tuning + optional Hub push
 
28
  - `scripts/merge_and_push.py`: optional adapter merge into full weights + Hub push
 
29
  - `requirements.txt`: model-training dependencies
30
 
31
  ## Setup
@@ -48,10 +51,38 @@ model from the merged dataset in `data/releases/v1/`.
48
  --config model_development/configs/deepseek_math_v2.yaml
49
  ```
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ## Important notes
52
 
53
  - Both presets point to `data/releases/v1/train.parquet` and
54
  `data/releases/v1/validation.parquet`.
 
 
55
  - If your exact v2 checkpoint id differs, update `model.base_model` in
56
  `model_development/configs/deepseek_math_v2.yaml`.
57
  - Hub auth uses `HF_TOKEN` first, then `huggingface-api-key.json`.
 
24
 
25
  - `configs/deepseek_math.yaml`: preset for `DeepSeek-Math`
26
  - `configs/deepseek_math_v2.yaml`: preset for `DeepSeek-Math-V2`
27
+ - `configs/deepseek_math_sota.yaml`: multi-stage SOTA advancement recipe
28
  - `scripts/train_sft.py`: LoRA/QLoRA supervised fine-tuning + optional Hub push
29
+ - `scripts/train_sota.py`: weighted multi-stage curriculum fine-tuning
30
  - `scripts/merge_and_push.py`: optional adapter merge into full weights + Hub push
31
+ - `scripts/eval_sota.py`: self-consistency `pass@1` / `pass@k` evaluation harness
32
  - `requirements.txt`: model-training dependencies
33
 
34
  ## Setup
 
51
  --config model_development/configs/deepseek_math_v2.yaml
52
  ```
53
 
54
+ ## SOTA Advancement Recipe (Multi-stage)
55
+
56
+ ```bash
57
+ .venv/bin/python model_development/scripts/train_sota.py \
58
+ --config model_development/configs/deepseek_math_sota.yaml
59
+ ```
60
+
61
+ This recipe runs:
62
+ - Stage 1: broad math bootstrap
63
+ - Stage 2: conjecture + formal proof specialization
64
+ - Stage 3: conjecture-core alignment
65
+
66
+ and saves a final adapter under:
67
+ - `model_development/runs/math-conjecture-sota/final_adapter`
68
+
69
+ ## Evaluate pass@k with self-consistency
70
+
71
+ ```bash
72
+ .venv/bin/python model_development/scripts/eval_sota.py \
73
+ --config model_development/configs/deepseek_math_sota.yaml \
74
+ --adapter-path model_development/runs/math-conjecture-sota/final_adapter \
75
+ --eval-file data/releases/v1/test.parquet \
76
+ --k 4 \
77
+ --max-samples 300
78
+ ```
79
+
80
  ## Important notes
81
 
82
  - Both presets point to `data/releases/v1/train.parquet` and
83
  `data/releases/v1/validation.parquet`.
84
+ - `deepseek_math_sota.yaml` defaults to `DeepSeek-Math-V2` and pushes to
85
+ `NorthernTribe-Research/math-conjecture-model`.
86
  - If your exact v2 checkpoint id differs, update `model.base_model` in
87
  `model_development/configs/deepseek_math_v2.yaml`.
88
  - Hub auth uses `HF_TOKEN` first, then `huggingface-api-key.json`.
configs/deepseek_math_sota.yaml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ global:
2
+ output_root: model_development/runs/math-conjecture-sota
3
+ seed: 17
4
+
5
+ model:
6
+ base_model: deepseek-ai/deepseek-math-v2
7
+ trust_remote_code: true
8
+ use_bf16: true
9
+ load_in_4bit: true
10
+ bnb_4bit_quant_type: nf4
11
+ bnb_4bit_use_double_quant: true
12
+ attn_implementation: null
13
+ lora:
14
+ r: 96
15
+ alpha: 192
16
+ dropout: 0.05
17
+ bias: none
18
+ target_modules:
19
+ - q_proj
20
+ - k_proj
21
+ - v_proj
22
+ - o_proj
23
+ - gate_proj
24
+ - up_proj
25
+ - down_proj
26
+
27
+ data:
28
+ default_train_file: data/releases/v1/train.parquet
29
+ default_validation_file: data/releases/v1/validation.parquet
30
+ prompt_field: prompt
31
+ target_field: target
32
+ final_answer_field: final_answer
33
+ proof_field: proof_formal
34
+ sample_weight_field: sample_weight
35
+ max_seq_length: 3072
36
+ min_loss_weight: 0.25
37
+ max_loss_weight: 6.0
38
+ family_boost:
39
+ conjecture_core: 2.5
40
+ formal_proof: 1.6
41
+ competition: 1.2
42
+ structured_reasoning: 1.0
43
+ system_prompt: |
44
+ You are a frontier mathematical reasoning model focused on unsolved
45
+ conjectures. Your outputs must be precise, technically coherent, and explicit
46
+ about uncertainty. Never claim a full proof unless it is derivable from given
47
+ assumptions or already established in cited prior results.
48
+
49
+ training_defaults:
50
+ per_device_train_batch_size: 1
51
+ per_device_eval_batch_size: 1
52
+ gradient_accumulation_steps: 16
53
+ weight_decay: 0.01
54
+ warmup_ratio: 0.03
55
+ lr_scheduler_type: cosine
56
+ max_grad_norm: 1.0
57
+ gradient_checkpointing: true
58
+ logging_steps: 10
59
+ save_steps: 400
60
+ eval_steps: 400
61
+ save_total_limit: 3
62
+ dataloader_num_workers: 2
63
+
64
+ stages:
65
+ - name: broad_math_bootstrap
66
+ max_train_samples: null
67
+ max_eval_samples: 3000
68
+ filters:
69
+ include_families:
70
+ - competition
71
+ - structured_reasoning
72
+ - formal_proof
73
+ - conjecture_core
74
+ training:
75
+ num_train_epochs: 1
76
+ learning_rate: 2.0e-5
77
+
78
+ - name: conjecture_specialization
79
+ max_train_samples: null
80
+ max_eval_samples: 2000
81
+ filters:
82
+ include_families:
83
+ - conjecture_core
84
+ - formal_proof
85
+ min_sample_weight: 2.0
86
+ training:
87
+ num_train_epochs: 2
88
+ learning_rate: 8.0e-6
89
+ save_steps: 250
90
+ eval_steps: 250
91
+
92
+ - name: conjecture_alignment
93
+ max_train_samples: null
94
+ max_eval_samples: null
95
+ filters:
96
+ include_families:
97
+ - conjecture_core
98
+ require_conjecture_id: true
99
+ training:
100
+ num_train_epochs: 3
101
+ learning_rate: 5.0e-6
102
+ save_steps: 100
103
+ eval_steps: 100
104
+
105
+ hub:
106
+ push_to_hub: true
107
+ repo_id: NorthernTribe-Research/math-conjecture-model
108
+ private: false
109
+ upload_stage_checkpoints: true
110
+ commit_message: Train multi-stage SOTA curriculum for conjecture reasoning.
111
+
112
+ credentials:
113
+ path: huggingface-api-key.json
requirements.txt CHANGED
@@ -6,3 +6,4 @@ peft>=0.14.0
6
  bitsandbytes>=0.45.0
7
  huggingface_hub>=0.26.0
8
  pyyaml>=6.0.2
 
 
6
  bitsandbytes>=0.45.0
7
  huggingface_hub>=0.26.0
8
  pyyaml>=6.0.2
9
+ sentencepiece>=0.2.0
scripts/eval_sota.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Self-consistency evaluation for math-conjecture model checkpoints."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import re
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Sequence
11
+
12
+ import torch
13
+ import yaml
14
+ from datasets import load_dataset
15
+ from peft import PeftModel
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
17
+
18
+
19
+ def parse_args() -> argparse.Namespace:
20
+ parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.")
21
+ parser.add_argument(
22
+ "--config",
23
+ type=Path,
24
+ default=Path("model_development/configs/deepseek_math_sota.yaml"),
25
+ help="Training config used for prompt formatting defaults.",
26
+ )
27
+ parser.add_argument(
28
+ "--base-model",
29
+ type=str,
30
+ default=None,
31
+ help="Override base model id from config.",
32
+ )
33
+ parser.add_argument(
34
+ "--adapter-path",
35
+ type=Path,
36
+ default=None,
37
+ help="Optional LoRA adapter path to load on top of base model.",
38
+ )
39
+ parser.add_argument(
40
+ "--eval-file",
41
+ type=Path,
42
+ default=Path("data/releases/v1/test.parquet"),
43
+ help="Parquet split used for evaluation.",
44
+ )
45
+ parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.")
46
+ parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.")
47
+ parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.")
48
+ parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
49
+ parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.")
50
+ parser.add_argument("--seed", type=int, default=17, help="Random seed.")
51
+ parser.add_argument(
52
+ "--output-json",
53
+ type=Path,
54
+ default=Path("model_development/runs/latest_eval_report.json"),
55
+ help="Where to write evaluation report.",
56
+ )
57
+ return parser.parse_args()
58
+
59
+
60
+ def as_text(value: Any) -> str:
61
+ if value is None:
62
+ return ""
63
+ if isinstance(value, str):
64
+ return value.strip()
65
+ return str(value).strip()
66
+
67
+
68
+ def load_config(path: Path) -> Dict[str, Any]:
69
+ cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
70
+ if not isinstance(cfg, dict):
71
+ raise ValueError("Invalid YAML config.")
72
+ return cfg
73
+
74
+
75
+ def normalize_answer(text: str) -> str:
76
+ text = text.strip().lower()
77
+ text = re.sub(r"\s+", " ", text)
78
+ text = text.replace("$", "")
79
+ return text
80
+
81
+
82
+ def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]:
83
+ out: List[str] = []
84
+ final_field = as_text(data_cfg.get("final_answer_field")) or "final_answer"
85
+ target_field = as_text(data_cfg.get("target_field")) or "target"
86
+
87
+ final_answer = row.get(final_field)
88
+ if final_answer is not None:
89
+ txt = as_text(final_answer)
90
+ if txt:
91
+ out.append(txt)
92
+
93
+ target = row.get(target_field)
94
+ if target is None:
95
+ return out
96
+ if isinstance(target, str):
97
+ stripped = target.strip()
98
+ if not stripped:
99
+ return out
100
+ try:
101
+ target = json.loads(stripped)
102
+ except json.JSONDecodeError:
103
+ out.append(stripped)
104
+ return out
105
+
106
+ if isinstance(target, dict):
107
+ for value in target.values():
108
+ if isinstance(value, list):
109
+ for item in value:
110
+ txt = as_text(item)
111
+ if txt:
112
+ out.append(txt)
113
+ else:
114
+ txt = as_text(value)
115
+ if txt:
116
+ out.append(txt)
117
+ elif isinstance(target, list):
118
+ for item in target:
119
+ txt = as_text(item)
120
+ if txt:
121
+ out.append(txt)
122
+ else:
123
+ txt = as_text(target)
124
+ if txt:
125
+ out.append(txt)
126
+ return out
127
+
128
+
129
+ def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str:
130
+ prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt"
131
+ prompt = as_text(row.get(prompt_field))
132
+ if not prompt:
133
+ prompt = "Solve the math task."
134
+ meta_fields = [
135
+ ("task_type", "Task type"),
136
+ ("family", "Family"),
137
+ ("difficulty", "Difficulty"),
138
+ ("source_dataset", "Source"),
139
+ ("status_as_of", "Status as of"),
140
+ ]
141
+ lines = []
142
+ for key, label in meta_fields:
143
+ value = as_text(row.get(key))
144
+ if value:
145
+ lines.append(f"{label}: {value}")
146
+ if lines:
147
+ return f"{prompt}\n\nMetadata:\n" + "\n".join(lines)
148
+ return prompt
149
+
150
+
151
+ def build_prompt_text(row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> str:
152
+ system_prompt = as_text(data_cfg.get("system_prompt"))
153
+ if not system_prompt:
154
+ system_prompt = "You are a rigorous mathematical reasoning assistant."
155
+ user_block = build_user_block(row, data_cfg)
156
+ if getattr(tokenizer, "chat_template", None):
157
+ messages = [
158
+ {"role": "system", "content": system_prompt},
159
+ {"role": "user", "content": user_block},
160
+ ]
161
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
162
+ return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n"
163
+
164
+
165
+ def extract_candidate_text(full_generation: str, prompt_text: str) -> str:
166
+ if full_generation.startswith(prompt_text):
167
+ return full_generation[len(prompt_text) :].strip()
168
+ return full_generation.strip()
169
+
170
+
171
+ def is_match(candidate: str, expected_values: Sequence[str]) -> bool:
172
+ cand_norm = normalize_answer(candidate)
173
+ if not cand_norm:
174
+ return False
175
+ for expected in expected_values:
176
+ exp_norm = normalize_answer(expected)
177
+ if not exp_norm:
178
+ continue
179
+ if exp_norm in cand_norm or cand_norm in exp_norm:
180
+ return True
181
+ boxed = re.findall(r"\\boxed\{([^{}]+)\}", cand_norm)
182
+ if boxed and any(exp_norm in item for item in boxed):
183
+ return True
184
+ return False
185
+
186
+
187
+ def load_model_and_tokenizer(
188
+ base_model: str,
189
+ adapter_path: Optional[Path],
190
+ trust_remote_code: bool,
191
+ ) -> tuple[Any, AutoTokenizer]:
192
+ tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True)
193
+ if tokenizer.pad_token is None:
194
+ tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
195
+ if tokenizer.pad_token is None:
196
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
197
+
198
+ model = AutoModelForCausalLM.from_pretrained(
199
+ base_model,
200
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
201
+ device_map="auto" if torch.cuda.is_available() else None,
202
+ trust_remote_code=trust_remote_code,
203
+ )
204
+ if adapter_path is not None:
205
+ model = PeftModel.from_pretrained(model, str(adapter_path))
206
+ model.eval()
207
+ return model, tokenizer
208
+
209
+
210
+ def main() -> None:
211
+ args = parse_args()
212
+ cfg = load_config(args.config)
213
+ data_cfg = cfg.get("data", {})
214
+ model_cfg = cfg.get("model", {})
215
+ set_seed(args.seed)
216
+
217
+ base_model = args.base_model or as_text(model_cfg.get("base_model"))
218
+ if not base_model:
219
+ raise ValueError("Base model is required via --base-model or config.model.base_model.")
220
+
221
+ model, tokenizer = load_model_and_tokenizer(
222
+ base_model=base_model,
223
+ adapter_path=args.adapter_path,
224
+ trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
225
+ )
226
+
227
+ if not args.eval_file.exists():
228
+ raise FileNotFoundError(f"Evaluation file not found: {args.eval_file}")
229
+ ds = load_dataset("parquet", data_files={"eval": str(args.eval_file)})["eval"]
230
+
231
+ if args.max_samples > 0 and args.max_samples < len(ds):
232
+ ds = ds.select(range(args.max_samples))
233
+
234
+ total = 0
235
+ hit_at_1 = 0
236
+ hit_at_k = 0
237
+ records = []
238
+
239
+ for row in ds:
240
+ expected_values = flatten_expected(row, data_cfg)
241
+ if not expected_values:
242
+ continue
243
+ prompt_text = build_prompt_text(row, tokenizer, data_cfg)
244
+ inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=4096)
245
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
246
+
247
+ with torch.no_grad():
248
+ output_ids = model.generate(
249
+ **inputs,
250
+ do_sample=True,
251
+ temperature=args.temperature,
252
+ top_p=args.top_p,
253
+ num_return_sequences=args.k,
254
+ max_new_tokens=args.max_new_tokens,
255
+ pad_token_id=tokenizer.pad_token_id,
256
+ eos_token_id=tokenizer.eos_token_id,
257
+ )
258
+ generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
259
+ candidates = [extract_candidate_text(text, prompt_text) for text in generations]
260
+ matches = [is_match(candidate, expected_values) for candidate in candidates]
261
+ total += 1
262
+ if matches and matches[0]:
263
+ hit_at_1 += 1
264
+ if any(matches):
265
+ hit_at_k += 1
266
+
267
+ records.append(
268
+ {
269
+ "uid": as_text(row.get("uid")),
270
+ "prompt": as_text(row.get(as_text(data_cfg.get("prompt_field")) or "prompt")),
271
+ "expected_values": expected_values[:5],
272
+ "candidates": candidates,
273
+ "matches": matches,
274
+ }
275
+ )
276
+
277
+ pass_at_1 = (hit_at_1 / total) if total else 0.0
278
+ pass_at_k = (hit_at_k / total) if total else 0.0
279
+ report = {
280
+ "base_model": base_model,
281
+ "adapter_path": str(args.adapter_path) if args.adapter_path is not None else None,
282
+ "eval_file": str(args.eval_file),
283
+ "evaluated_rows": total,
284
+ "k": args.k,
285
+ "pass_at_1": pass_at_1,
286
+ "pass_at_k": pass_at_k,
287
+ "temperature": args.temperature,
288
+ "top_p": args.top_p,
289
+ "max_new_tokens": args.max_new_tokens,
290
+ "samples": records[:30],
291
+ }
292
+ args.output_json.parent.mkdir(parents=True, exist_ok=True)
293
+ args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8")
294
+ print(json.dumps({k: report[k] for k in ("evaluated_rows", "pass_at_1", "pass_at_k", "k")}, indent=2))
295
+ print(f"Saved report to {args.output_json}")
296
+
297
+
298
+ if __name__ == "__main__":
299
+ main()
scripts/train_sota.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Multi-stage curriculum SFT for advancing the conjecture math model."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Optional, Tuple
11
+
12
+ import torch
13
+ import yaml
14
+ from datasets import Dataset, DatasetDict, load_dataset
15
+ from huggingface_hub import HfApi
16
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
17
+ from torch.utils.data import WeightedRandomSampler
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ BitsAndBytesConfig,
22
+ DataCollatorForSeq2Seq,
23
+ Trainer,
24
+ TrainingArguments,
25
+ set_seed,
26
+ )
27
+
28
+ DEFAULT_CONFIG_PATH = Path("model_development/configs/deepseek_math_sota.yaml")
29
+
30
+
31
+ def parse_args() -> argparse.Namespace:
32
+ parser = argparse.ArgumentParser(
33
+ description="Train DeepSeek-Math with a multi-stage SOTA curriculum recipe."
34
+ )
35
+ parser.add_argument(
36
+ "--config",
37
+ type=Path,
38
+ default=DEFAULT_CONFIG_PATH,
39
+ help="Path to multi-stage YAML config.",
40
+ )
41
+ parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.")
42
+ parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.")
43
+ parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.")
44
+ parser.add_argument(
45
+ "--start-stage",
46
+ type=int,
47
+ default=1,
48
+ help="1-based stage index to start from.",
49
+ )
50
+ parser.add_argument(
51
+ "--max-stages",
52
+ type=int,
53
+ default=None,
54
+ help="Optional number of stages to run from --start-stage.",
55
+ )
56
+ parser.add_argument(
57
+ "--credentials-path",
58
+ type=Path,
59
+ default=None,
60
+ help="Override credentials.path.",
61
+ )
62
+ return parser.parse_args()
63
+
64
+
65
+ def as_text(value: Any) -> str:
66
+ if value is None:
67
+ return ""
68
+ if isinstance(value, str):
69
+ return value.strip()
70
+ return str(value).strip()
71
+
72
+
73
+ def as_float(value: Any, default: float) -> float:
74
+ if value is None:
75
+ return default
76
+ try:
77
+ return float(value)
78
+ except (TypeError, ValueError):
79
+ return default
80
+
81
+
82
+ def as_int(value: Any, default: int) -> int:
83
+ if value is None:
84
+ return default
85
+ try:
86
+ return int(value)
87
+ except (TypeError, ValueError):
88
+ return default
89
+
90
+
91
+ def load_config(path: Path) -> Dict[str, Any]:
92
+ if not path.exists():
93
+ raise FileNotFoundError(f"Config not found: {path}")
94
+ cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
95
+ if not isinstance(cfg, dict):
96
+ raise ValueError(f"Invalid config format: {path}")
97
+ for key in ("model", "data", "stages"):
98
+ if key not in cfg:
99
+ raise ValueError(f"Missing config section: {key}")
100
+ if not isinstance(cfg["stages"], list) or not cfg["stages"]:
101
+ raise ValueError("Config must contain at least one stage in stages[].")
102
+ cfg.setdefault("global", {})
103
+ cfg.setdefault("training_defaults", {})
104
+ cfg.setdefault("hub", {})
105
+ cfg.setdefault("credentials", {})
106
+ return cfg
107
+
108
+
109
+ def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None:
110
+ if args.repo_id:
111
+ cfg.setdefault("hub", {})["repo_id"] = args.repo_id
112
+ if args.credentials_path is not None:
113
+ cfg.setdefault("credentials", {})["path"] = str(args.credentials_path)
114
+ if args.push_to_hub and args.no_push_to_hub:
115
+ raise ValueError("Cannot set both --push-to-hub and --no-push-to-hub.")
116
+ if args.push_to_hub:
117
+ cfg.setdefault("hub", {})["push_to_hub"] = True
118
+ if args.no_push_to_hub:
119
+ cfg.setdefault("hub", {})["push_to_hub"] = False
120
+
121
+
122
+ def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
123
+ token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
124
+ username = as_text(os.environ.get("HF_USERNAME")) or None
125
+ cred_path = as_text(cfg.get("credentials", {}).get("path"))
126
+ if cred_path:
127
+ path = Path(cred_path)
128
+ if path.exists():
129
+ data = json.loads(path.read_text(encoding="utf-8"))
130
+ if token is None:
131
+ token = as_text(data.get("key")) or None
132
+ if username is None:
133
+ username = as_text(data.get("username")) or None
134
+ return token, username
135
+
136
+
137
+ def resolve_repo_id(cfg: Dict[str, Any], username: Optional[str], output_root: Path) -> Optional[str]:
138
+ repo_id = as_text(cfg.get("hub", {}).get("repo_id"))
139
+ if repo_id:
140
+ return repo_id
141
+ if not username:
142
+ return None
143
+ return f"{username}/{output_root.name}"
144
+
145
+
146
+ def stringify_structured(value: Any) -> str:
147
+ if value is None:
148
+ return ""
149
+ if isinstance(value, str):
150
+ text = value.strip()
151
+ if not text:
152
+ return ""
153
+ try:
154
+ parsed = json.loads(text)
155
+ except json.JSONDecodeError:
156
+ return text
157
+ return json.dumps(parsed, ensure_ascii=False, sort_keys=True)
158
+ return json.dumps(value, ensure_ascii=False, sort_keys=True)
159
+
160
+
161
+ def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str:
162
+ prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt"
163
+ prompt = as_text(row.get(prompt_field))
164
+ if not prompt:
165
+ prompt = "Solve the math task."
166
+ meta_fields = [
167
+ ("task_type", "Task type"),
168
+ ("family", "Family"),
169
+ ("difficulty", "Difficulty"),
170
+ ("source_dataset", "Source"),
171
+ ("status_as_of", "Status as of"),
172
+ ]
173
+ meta_lines = []
174
+ for key, label in meta_fields:
175
+ value = as_text(row.get(key))
176
+ if value:
177
+ meta_lines.append(f"{label}: {value}")
178
+ tags = row.get("topic_tags")
179
+ if isinstance(tags, list) and tags:
180
+ tag_text = ", ".join(as_text(tag) for tag in tags if as_text(tag))
181
+ if tag_text:
182
+ meta_lines.append(f"Tags: {tag_text}")
183
+ if not meta_lines:
184
+ return prompt
185
+ return f"{prompt}\n\nMetadata:\n" + "\n".join(meta_lines)
186
+
187
+
188
+ def build_answer_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str:
189
+ target_field = as_text(data_cfg.get("target_field")) or "target"
190
+ final_answer_field = as_text(data_cfg.get("final_answer_field")) or "final_answer"
191
+ proof_field = as_text(data_cfg.get("proof_field")) or "proof_formal"
192
+
193
+ sections = []
194
+ target_text = stringify_structured(row.get(target_field))
195
+ if target_text:
196
+ sections.append(f"Structured target:\n{target_text}")
197
+
198
+ final_answer = stringify_structured(row.get(final_answer_field))
199
+ if final_answer:
200
+ sections.append(f"Final answer:\n{final_answer}")
201
+
202
+ proof_text = stringify_structured(row.get(proof_field))
203
+ if proof_text:
204
+ sections.append(f"Formal proof snippet:\n{proof_text}")
205
+
206
+ if not sections:
207
+ sections.append("No structured target provided.")
208
+ return "\n\n".join(sections).strip()
209
+
210
+
211
+ def build_prompt_text(row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> str:
212
+ system_prompt = as_text(data_cfg.get("system_prompt"))
213
+ if not system_prompt:
214
+ system_prompt = (
215
+ "You are a rigorous mathematical reasoning assistant focused on unsolved "
216
+ "conjectures. Produce checkable reasoning."
217
+ )
218
+ user_block = build_user_block(row, data_cfg)
219
+ if getattr(tokenizer, "chat_template", None):
220
+ messages = [
221
+ {"role": "system", "content": system_prompt},
222
+ {"role": "user", "content": user_block},
223
+ ]
224
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
225
+ return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n"
226
+
227
+
228
+ def compute_loss_weight(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> float:
229
+ sample_weight_field = as_text(data_cfg.get("sample_weight_field")) or "sample_weight"
230
+ base = as_float(row.get(sample_weight_field), 1.0)
231
+ family = as_text(row.get("family"))
232
+ family_boost = data_cfg.get("family_boost", {})
233
+ if isinstance(family_boost, dict):
234
+ base *= as_float(family_boost.get(family), 1.0)
235
+ min_w = as_float(data_cfg.get("min_loss_weight"), 0.1)
236
+ max_w = as_float(data_cfg.get("max_loss_weight"), 8.0)
237
+ if min_w > max_w:
238
+ min_w, max_w = max_w, min_w
239
+ return max(min_w, min(max_w, base))
240
+
241
+
242
+ def stage_split_files(stage_cfg: Dict[str, Any], data_cfg: Dict[str, Any]) -> Dict[str, str]:
243
+ train_file = as_text(stage_cfg.get("train_file")) or as_text(data_cfg.get("default_train_file"))
244
+ valid_file = as_text(stage_cfg.get("validation_file")) or as_text(data_cfg.get("default_validation_file"))
245
+ train_path = Path(train_file)
246
+ valid_path = Path(valid_file)
247
+ if not train_path.exists():
248
+ raise FileNotFoundError(f"Missing train split for stage: {train_path}")
249
+ if not valid_path.exists():
250
+ raise FileNotFoundError(f"Missing validation split for stage: {valid_path}")
251
+ return {"train": str(train_path), "validation": str(valid_path)}
252
+
253
+
254
+ def apply_filters(dataset: Dataset, filter_cfg: Dict[str, Any]) -> Dataset:
255
+ if not filter_cfg:
256
+ return dataset
257
+ include_families = set(filter_cfg.get("include_families", []) or [])
258
+ exclude_families = set(filter_cfg.get("exclude_families", []) or [])
259
+ include_task_types = set(filter_cfg.get("include_task_types", []) or [])
260
+ source_datasets = set(filter_cfg.get("source_datasets", []) or [])
261
+ require_conjecture_id = bool(filter_cfg.get("require_conjecture_id", False))
262
+ min_sample_weight = filter_cfg.get("min_sample_weight")
263
+ min_sample_weight = as_float(min_sample_weight, 0.0) if min_sample_weight is not None else None
264
+
265
+ def _keep(row: Dict[str, Any]) -> bool:
266
+ family = as_text(row.get("family"))
267
+ if include_families and family not in include_families:
268
+ return False
269
+ if exclude_families and family in exclude_families:
270
+ return False
271
+ if include_task_types:
272
+ task_type = as_text(row.get("task_type"))
273
+ if task_type not in include_task_types:
274
+ return False
275
+ if source_datasets:
276
+ source = as_text(row.get("source_dataset"))
277
+ if source not in source_datasets:
278
+ return False
279
+ if require_conjecture_id:
280
+ conjecture_id = as_text(row.get("conjecture_id"))
281
+ if not conjecture_id or conjecture_id.lower() == "null":
282
+ return False
283
+ if min_sample_weight is not None:
284
+ sample_weight = as_float(row.get("sample_weight"), 0.0)
285
+ if sample_weight < min_sample_weight:
286
+ return False
287
+ return True
288
+
289
+ return dataset.filter(_keep, desc="Applying stage filters")
290
+
291
+
292
+ def maybe_select(dataset: Dataset, max_samples: Optional[int]) -> Dataset:
293
+ if max_samples is None:
294
+ return dataset
295
+ if max_samples <= 0:
296
+ raise ValueError("max_samples must be positive.")
297
+ if max_samples >= len(dataset):
298
+ return dataset
299
+ return dataset.select(range(max_samples))
300
+
301
+
302
+ def tokenize_datasets(raw: DatasetDict, tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> DatasetDict:
303
+ max_len = as_int(data_cfg.get("max_seq_length"), 2048)
304
+ if max_len < 64:
305
+ raise ValueError("data.max_seq_length must be >= 64")
306
+ eos = tokenizer.eos_token or ""
307
+ remove_columns = raw["train"].column_names
308
+
309
+ def _tokenize(row: Dict[str, Any]) -> Dict[str, Any]:
310
+ prompt_text = build_prompt_text(row, tokenizer, data_cfg)
311
+ answer_text = build_answer_block(row, data_cfg)
312
+ full_text = f"{prompt_text}{answer_text}{eos}"
313
+ prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
314
+ full_enc = tokenizer(
315
+ full_text,
316
+ add_special_tokens=False,
317
+ truncation=True,
318
+ max_length=max_len,
319
+ )
320
+ input_ids = full_enc["input_ids"]
321
+ attention_mask = full_enc["attention_mask"]
322
+ if not input_ids:
323
+ fallback = tokenizer.eos_token_id
324
+ if fallback is None:
325
+ fallback = tokenizer.pad_token_id
326
+ if fallback is None:
327
+ fallback = 0
328
+ input_ids = [fallback]
329
+ attention_mask = [1]
330
+ labels = [fallback]
331
+ else:
332
+ prompt_len = min(len(prompt_ids), len(input_ids))
333
+ labels = [-100] * prompt_len + input_ids[prompt_len:]
334
+ if prompt_len >= len(input_ids):
335
+ labels[-1] = input_ids[-1]
336
+ loss_weight = compute_loss_weight(row, data_cfg)
337
+ return {
338
+ "input_ids": input_ids,
339
+ "attention_mask": attention_mask,
340
+ "labels": labels,
341
+ "loss_weight": float(loss_weight),
342
+ }
343
+
344
+ tokenized = raw.map(
345
+ _tokenize,
346
+ remove_columns=remove_columns,
347
+ desc="Tokenizing prompt/answer pairs",
348
+ )
349
+ tokenized = tokenized.filter(
350
+ lambda row: any(token != -100 for token in row["labels"]),
351
+ desc="Dropping prompt-only rows",
352
+ )
353
+ return tokenized
354
+
355
+
356
+ def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict[str, Any]) -> Tuple[Any, AutoTokenizer]:
357
+ base_model = as_text(model_cfg.get("base_model"))
358
+ if not base_model:
359
+ raise ValueError("model.base_model is required.")
360
+
361
+ use_bf16 = bool(model_cfg.get("use_bf16", True))
362
+ dtype = torch.bfloat16 if use_bf16 else torch.float16
363
+
364
+ tokenizer = AutoTokenizer.from_pretrained(
365
+ base_model,
366
+ trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
367
+ use_fast=True,
368
+ )
369
+ if tokenizer.pad_token is None:
370
+ tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
371
+ if tokenizer.pad_token is None:
372
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
373
+
374
+ model_kwargs: Dict[str, Any] = {
375
+ "trust_remote_code": bool(model_cfg.get("trust_remote_code", False)),
376
+ "torch_dtype": dtype,
377
+ }
378
+ attn_impl = as_text(model_cfg.get("attn_implementation"))
379
+ if attn_impl:
380
+ model_kwargs["attn_implementation"] = attn_impl
381
+
382
+ load_in_4bit = bool(model_cfg.get("load_in_4bit", True))
383
+ if load_in_4bit:
384
+ if not torch.cuda.is_available():
385
+ raise RuntimeError("4-bit loading requested but CUDA is not available.")
386
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
387
+ load_in_4bit=True,
388
+ bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4",
389
+ bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)),
390
+ bnb_4bit_compute_dtype=dtype,
391
+ )
392
+ model_kwargs["device_map"] = "auto"
393
+
394
+ model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
395
+ if tokenizer.pad_token_id is not None:
396
+ model.config.pad_token_id = tokenizer.pad_token_id
397
+ model.config.use_cache = False
398
+
399
+ if load_in_4bit:
400
+ model = prepare_model_for_kbit_training(
401
+ model,
402
+ use_gradient_checkpointing=bool(training_defaults.get("gradient_checkpointing", True)),
403
+ )
404
+
405
+ lora_cfg = model_cfg.get("lora", {})
406
+ peft_cfg = LoraConfig(
407
+ r=as_int(lora_cfg.get("r"), 64),
408
+ lora_alpha=as_int(lora_cfg.get("alpha"), 128),
409
+ lora_dropout=as_float(lora_cfg.get("dropout"), 0.05),
410
+ bias=as_text(lora_cfg.get("bias")) or "none",
411
+ task_type="CAUSAL_LM",
412
+ target_modules=lora_cfg.get("target_modules"),
413
+ )
414
+ model = get_peft_model(model, peft_cfg)
415
+ model.print_trainable_parameters()
416
+ return model, tokenizer
417
+
418
+
419
+ class WeightedLossCollator:
420
+ def __init__(self, tokenizer: AutoTokenizer, model: Any) -> None:
421
+ self.base = DataCollatorForSeq2Seq(
422
+ tokenizer=tokenizer,
423
+ model=model,
424
+ label_pad_token_id=-100,
425
+ pad_to_multiple_of=8,
426
+ )
427
+
428
+ def __call__(self, features: list[Dict[str, Any]]) -> Dict[str, Any]:
429
+ weights = [float(feature.pop("loss_weight", 1.0)) for feature in features]
430
+ batch = self.base(features)
431
+ batch["loss_weight"] = torch.tensor(weights, dtype=torch.float32)
432
+ return batch
433
+
434
+
435
+ class WeightedLossTrainer(Trainer):
436
+ def _get_train_sampler(self):
437
+ if self.train_dataset is None:
438
+ return None
439
+ if "loss_weight" not in self.train_dataset.column_names:
440
+ return super()._get_train_sampler()
441
+ weights = self.train_dataset["loss_weight"]
442
+ if not weights:
443
+ return super()._get_train_sampler()
444
+ weight_tensor = torch.tensor(weights, dtype=torch.double)
445
+ return WeightedRandomSampler(
446
+ weights=weight_tensor,
447
+ num_samples=len(weight_tensor),
448
+ replacement=True,
449
+ )
450
+
451
+ def compute_loss(
452
+ self,
453
+ model: Any,
454
+ inputs: Dict[str, Any],
455
+ return_outputs: bool = False,
456
+ num_items_in_batch: Optional[torch.Tensor] = None,
457
+ ):
458
+ loss_weight = inputs.pop("loss_weight", None)
459
+ labels = inputs.get("labels")
460
+ if labels is None:
461
+ return super().compute_loss(
462
+ model=model,
463
+ inputs=inputs,
464
+ return_outputs=return_outputs,
465
+ num_items_in_batch=num_items_in_batch,
466
+ )
467
+
468
+ model_inputs = {k: v for k, v in inputs.items() if k != "labels"}
469
+ outputs = model(**model_inputs)
470
+ logits = outputs.logits
471
+
472
+ shift_logits = logits[..., :-1, :].contiguous()
473
+ shift_labels = labels[..., 1:].contiguous()
474
+ token_losses = torch.nn.functional.cross_entropy(
475
+ shift_logits.view(-1, shift_logits.size(-1)),
476
+ shift_labels.view(-1),
477
+ ignore_index=-100,
478
+ reduction="none",
479
+ ).view(shift_labels.size())
480
+ token_mask = shift_labels.ne(-100).float()
481
+ seq_den = token_mask.sum(dim=1).clamp(min=1.0)
482
+ seq_loss = (token_losses * token_mask).sum(dim=1) / seq_den
483
+
484
+ if loss_weight is not None:
485
+ normalized = loss_weight.to(seq_loss.device).float().clamp(min=0.05)
486
+ loss = (seq_loss * normalized).sum() / normalized.sum()
487
+ else:
488
+ loss = seq_loss.mean()
489
+
490
+ if return_outputs:
491
+ return loss, outputs
492
+ return loss
493
+
494
+
495
+ def build_training_args(
496
+ output_dir: Path,
497
+ training_cfg: Dict[str, Any],
498
+ use_bf16: bool,
499
+ has_eval_split: bool,
500
+ ) -> TrainingArguments:
501
+ output_dir.mkdir(parents=True, exist_ok=True)
502
+ return TrainingArguments(
503
+ output_dir=str(output_dir),
504
+ num_train_epochs=as_float(training_cfg.get("num_train_epochs"), 1.0),
505
+ per_device_train_batch_size=as_int(training_cfg.get("per_device_train_batch_size"), 1),
506
+ per_device_eval_batch_size=as_int(training_cfg.get("per_device_eval_batch_size"), 1),
507
+ gradient_accumulation_steps=as_int(training_cfg.get("gradient_accumulation_steps"), 1),
508
+ learning_rate=as_float(training_cfg.get("learning_rate"), 2e-5),
509
+ weight_decay=as_float(training_cfg.get("weight_decay"), 0.0),
510
+ warmup_ratio=as_float(training_cfg.get("warmup_ratio"), 0.0),
511
+ lr_scheduler_type=as_text(training_cfg.get("lr_scheduler_type")) or "cosine",
512
+ max_grad_norm=as_float(training_cfg.get("max_grad_norm"), 1.0),
513
+ gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)),
514
+ logging_steps=as_int(training_cfg.get("logging_steps"), 10),
515
+ save_steps=as_int(training_cfg.get("save_steps"), 500),
516
+ save_total_limit=as_int(training_cfg.get("save_total_limit"), 3),
517
+ dataloader_num_workers=as_int(training_cfg.get("dataloader_num_workers"), 0),
518
+ seed=as_int(training_cfg.get("seed"), 17),
519
+ bf16=use_bf16,
520
+ fp16=not use_bf16,
521
+ remove_unused_columns=False,
522
+ report_to="none",
523
+ evaluation_strategy="steps" if has_eval_split else "no",
524
+ eval_steps=as_int(training_cfg.get("eval_steps"), 500) if has_eval_split else None,
525
+ )
526
+
527
+
528
+ def push_folder(
529
+ api: HfApi,
530
+ repo_id: str,
531
+ folder_path: Path,
532
+ commit_message: str,
533
+ path_in_repo: Optional[str] = None,
534
+ ) -> None:
535
+ kwargs: Dict[str, Any] = {
536
+ "repo_id": repo_id,
537
+ "repo_type": "model",
538
+ "folder_path": str(folder_path),
539
+ "commit_message": commit_message,
540
+ }
541
+ if path_in_repo:
542
+ kwargs["path_in_repo"] = path_in_repo
543
+ api.upload_folder(**kwargs)
544
+
545
+
546
+ def main() -> None:
547
+ args = parse_args()
548
+ cfg = load_config(args.config)
549
+ apply_overrides(cfg, args)
550
+
551
+ seed = as_int(cfg.get("global", {}).get("seed"), 17)
552
+ set_seed(seed)
553
+
554
+ output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "model_development/runs/math-conjecture-sota")
555
+ output_root.mkdir(parents=True, exist_ok=True)
556
+
557
+ token, username = resolve_auth(cfg)
558
+ repo_id = resolve_repo_id(cfg, username=username, output_root=output_root)
559
+ push_to_hub = bool(cfg.get("hub", {}).get("push_to_hub", False))
560
+ if push_to_hub:
561
+ if token is None:
562
+ raise ValueError("Hub push requested but token is missing.")
563
+ if repo_id is None:
564
+ raise ValueError("Hub push requested but repo_id is missing.")
565
+
566
+ model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
567
+ data_cfg = cfg["data"]
568
+ stage_reports = []
569
+
570
+ start_stage = max(1, args.start_stage)
571
+ stages = cfg["stages"]
572
+ end_stage = len(stages)
573
+ if args.max_stages is not None:
574
+ if args.max_stages <= 0:
575
+ raise ValueError("--max-stages must be positive.")
576
+ end_stage = min(end_stage, start_stage + args.max_stages - 1)
577
+
578
+ for index in range(start_stage, end_stage + 1):
579
+ stage = stages[index - 1]
580
+ stage_name = as_text(stage.get("name")) or f"stage_{index:02d}"
581
+ stage_slug = f"{index:02d}_{stage_name.replace(' ', '_')}"
582
+ stage_output_dir = output_root / stage_slug
583
+
584
+ split_files = stage_split_files(stage, data_cfg)
585
+ raw = load_dataset("parquet", data_files=split_files)
586
+ filters = stage.get("filters", {})
587
+ raw["train"] = apply_filters(raw["train"], filters)
588
+ raw["validation"] = apply_filters(raw["validation"], filters)
589
+ raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples"))
590
+ raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples"))
591
+ if len(raw["train"]) == 0:
592
+ raise ValueError(f"Stage {stage_slug} has zero train rows after filtering.")
593
+
594
+ tokenized = tokenize_datasets(raw, tokenizer, data_cfg)
595
+ train_dataset = tokenized["train"]
596
+ eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None
597
+
598
+ merged_training = dict(cfg.get("training_defaults", {}))
599
+ merged_training.update(stage.get("training", {}))
600
+ merged_training["seed"] = seed
601
+ training_args = build_training_args(
602
+ output_dir=stage_output_dir,
603
+ training_cfg=merged_training,
604
+ use_bf16=bool(cfg["model"].get("use_bf16", True)),
605
+ has_eval_split=eval_dataset is not None,
606
+ )
607
+ collator = WeightedLossCollator(tokenizer=tokenizer, model=model)
608
+ trainer = WeightedLossTrainer(
609
+ model=model,
610
+ args=training_args,
611
+ train_dataset=train_dataset,
612
+ eval_dataset=eval_dataset,
613
+ tokenizer=tokenizer,
614
+ data_collator=collator,
615
+ )
616
+
617
+ train_result = trainer.train()
618
+ trainer.log_metrics("train", train_result.metrics)
619
+ trainer.save_metrics("train", train_result.metrics)
620
+ trainer.save_state()
621
+ if eval_dataset is not None:
622
+ eval_metrics = trainer.evaluate()
623
+ trainer.log_metrics("eval", eval_metrics)
624
+ trainer.save_metrics("eval", eval_metrics)
625
+ trainer.save_model(str(stage_output_dir))
626
+ tokenizer.save_pretrained(str(stage_output_dir))
627
+
628
+ report = {
629
+ "stage_index": index,
630
+ "stage_name": stage_name,
631
+ "output_dir": str(stage_output_dir),
632
+ "train_rows": len(train_dataset),
633
+ "eval_rows": len(eval_dataset) if eval_dataset is not None else 0,
634
+ "train_metrics": train_result.metrics,
635
+ }
636
+ stage_reports.append(report)
637
+
638
+ final_dir = output_root / "final_adapter"
639
+ final_dir.mkdir(parents=True, exist_ok=True)
640
+ model.save_pretrained(str(final_dir))
641
+ tokenizer.save_pretrained(str(final_dir))
642
+
643
+ summary = {
644
+ "config_path": str(args.config),
645
+ "repo_id": repo_id,
646
+ "seed": seed,
647
+ "stages_ran": stage_reports,
648
+ "final_adapter_dir": str(final_dir),
649
+ }
650
+ summary_path = output_root / "training_summary.json"
651
+ summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
652
+
653
+ if push_to_hub and repo_id is not None and token is not None:
654
+ api = HfApi(token=token)
655
+ api.create_repo(
656
+ repo_id=repo_id,
657
+ repo_type="model",
658
+ private=bool(cfg.get("hub", {}).get("private", False)),
659
+ exist_ok=True,
660
+ )
661
+ commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter."
662
+ push_folder(api, repo_id, final_dir, commit_message=commit_message)
663
+ if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)):
664
+ for report in stage_reports:
665
+ stage_dir = Path(report["output_dir"])
666
+ path_in_repo = f"checkpoints/{Path(report['output_dir']).name}"
667
+ push_folder(
668
+ api,
669
+ repo_id,
670
+ stage_dir,
671
+ commit_message=f"Upload stage checkpoint {report['stage_name']}",
672
+ path_in_repo=path_in_repo,
673
+ )
674
+ api.upload_file(
675
+ path_or_fileobj=str(summary_path),
676
+ path_in_repo="training_summary.json",
677
+ repo_id=repo_id,
678
+ repo_type="model",
679
+ commit_message="Upload training summary for SOTA curriculum run.",
680
+ )
681
+ print(f"Pushed training artifacts to https://huggingface.co/{repo_id}")
682
+
683
+ print(f"Training complete. Final adapter: {final_dir}")
684
+ print(f"Training summary: {summary_path}")
685
+
686
+
687
+ if __name__ == "__main__":
688
+ main()