NorthernTribe-Research commited on
Commit
a86edac
·
verified ·
1 Parent(s): 4461ccc

Upgrade SOTA curriculum: 4-stage training, post-eval metrics, and quality-gated promotion.

Browse files
README.md CHANGED
@@ -17,19 +17,19 @@ datasets:
17
 
18
  # NorthernTribe-Research/math-conjecture-model
19
 
20
- This folder contains the fine-tuning pipeline for building a conjecture-solution
21
- model from the merged dataset in `data/releases/v1/`.
22
 
23
- ## What is included
24
 
25
- - `configs/deepseek_math.yaml`: preset for `DeepSeek-Math`
26
- - `configs/deepseek_math_v2.yaml`: preset for `DeepSeek-Math-V2`
27
- - `configs/deepseek_math_sota.yaml`: multi-stage SOTA advancement recipe
28
- - `scripts/train_sft.py`: LoRA/QLoRA supervised fine-tuning + optional Hub push
29
- - `scripts/train_sota.py`: weighted multi-stage curriculum fine-tuning
30
- - `scripts/merge_and_push.py`: optional adapter merge into full weights + Hub push
31
- - `scripts/eval_sota.py`: self-consistency `pass@1` / `pass@k` evaluation harness
32
- - `requirements.txt`: model-training dependencies
 
33
 
34
  ## Setup
35
 
@@ -37,64 +37,54 @@ model from the merged dataset in `data/releases/v1/`.
37
  .venv/bin/python -m pip install -r model_development/requirements.txt
38
  ```
39
 
40
- ## Fine-tune DeepSeek-Math
41
 
42
  ```bash
43
- .venv/bin/python model_development/scripts/train_sft.py \
44
- --config model_development/configs/deepseek_math.yaml
45
  ```
46
 
47
- ## Fine-tune DeepSeek-Math-V2
48
 
49
  ```bash
50
- .venv/bin/python model_development/scripts/train_sft.py \
51
- --config model_development/configs/deepseek_math_v2.yaml
52
- ```
53
-
54
- ## SOTA Advancement Recipe (Multi-stage)
55
 
56
- ```bash
57
  .venv/bin/python model_development/scripts/train_sota.py \
58
- --config model_development/configs/deepseek_math_sota.yaml
 
59
  ```
60
 
61
- This recipe runs:
62
- - Stage 1: broad math bootstrap
63
- - Stage 2: conjecture + formal proof specialization
64
- - Stage 3: conjecture-core alignment
65
-
66
- and saves a final adapter under:
67
- - `model_development/runs/math-conjecture-sota/final_adapter`
68
-
69
- ## Evaluate pass@k with self-consistency
70
 
71
  ```bash
72
  .venv/bin/python model_development/scripts/eval_sota.py \
73
  --config model_development/configs/deepseek_math_sota.yaml \
74
  --adapter-path model_development/runs/math-conjecture-sota/final_adapter \
75
  --eval-file data/releases/v1/test.parquet \
76
- --k 4 \
77
- --max-samples 300
78
  ```
79
 
80
- ## Important notes
81
 
82
- - Both presets point to `data/releases/v1/train.parquet` and
83
- `data/releases/v1/validation.parquet`.
84
- - `deepseek_math_sota.yaml` defaults to `DeepSeek-Math-V2` and pushes to
85
- `NorthernTribe-Research/math-conjecture-model`.
86
- - If your exact v2 checkpoint id differs, update `model.base_model` in
87
- `model_development/configs/deepseek_math_v2.yaml`.
88
- - Hub auth uses `HF_TOKEN` first, then `huggingface-api-key.json`.
89
- - If `hub.repo_id` is empty, repo name defaults to
90
- `<username>/<output_dir_name>`.
91
 
92
- ## Optional: merge LoRA adapter into full model
93
 
94
- ```bash
95
- .venv/bin/python model_development/scripts/merge_and_push.py \
96
- --adapter-path model_development/runs/deepseek-math-lora \
97
- --output-dir model_development/merged/math-conjecture-model \
98
- --push-to-hub \
99
- --repo-id NorthernTribe-Research/math-conjecture-model
100
- ```
 
 
 
 
 
17
 
18
  # NorthernTribe-Research/math-conjecture-model
19
 
20
+ Launch multi-stage DeepSeek-Math fine-tuning on Space GPU and push adapters to your model repo.
 
21
 
22
+ This folder contains the autonomous training/evaluation stack used by the Space and local runs.
23
 
24
+ ## Included
25
+
26
+ - `configs/deepseek_math.yaml`: DeepSeek-Math baseline preset
27
+ - `configs/deepseek_math_v2.yaml`: DeepSeek-Math-V2 baseline preset
28
+ - `configs/deepseek_math_sota.yaml`: 4-stage SOTA curriculum + post-eval + quality gate
29
+ - `scripts/train_sft.py`: single-stage LoRA/QLoRA SFT
30
+ - `scripts/train_sota.py`: staged weighted curriculum with autonomous post-eval and gated push
31
+ - `scripts/eval_sota.py`: pass@k + exact/boxed + family/difficulty metrics
32
+ - `scripts/merge_and_push.py`: optional adapter merge into full model weights
33
 
34
  ## Setup
35
 
 
37
  .venv/bin/python -m pip install -r model_development/requirements.txt
38
  ```
39
 
40
+ ## Run SOTA curriculum
41
 
42
  ```bash
43
+ .venv/bin/python model_development/scripts/train_sota.py \
44
+ --config model_development/configs/deepseek_math_sota.yaml
45
  ```
46
 
47
+ Optional controls:
48
 
49
  ```bash
50
+ # Validate stages only
51
+ .venv/bin/python model_development/scripts/train_sota.py \
52
+ --config model_development/configs/deepseek_math_sota.yaml \
53
+ --dry-run
 
54
 
55
+ # Force skip quality gate for one run
56
  .venv/bin/python model_development/scripts/train_sota.py \
57
+ --config model_development/configs/deepseek_math_sota.yaml \
58
+ --skip-quality-gate
59
  ```
60
 
61
+ ## Evaluate adapters
 
 
 
 
 
 
 
 
62
 
63
  ```bash
64
  .venv/bin/python model_development/scripts/eval_sota.py \
65
  --config model_development/configs/deepseek_math_sota.yaml \
66
  --adapter-path model_development/runs/math-conjecture-sota/final_adapter \
67
  --eval-file data/releases/v1/test.parquet \
68
+ --k 6 \
69
+ --max-samples 240
70
  ```
71
 
72
+ ## Outputs
73
 
74
+ - final adapter: `model_development/runs/math-conjecture-sota/final_adapter`
75
+ - training summary: `model_development/runs/math-conjecture-sota/training_summary.json`
76
+ - post-eval report: `model_development/runs/math-conjecture-sota/post_eval_report.json`
 
 
 
 
 
 
77
 
78
+ ## Quality gate behavior
79
 
80
+ When enabled in config/runtime:
81
+
82
+ - validates minimum evaluation coverage
83
+ - enforces `pass@1` / `pass@k` thresholds
84
+ - enforces required family-level `pass@k` thresholds
85
+ - can enforce max final stage `eval_loss`
86
+ - blocks hub push if gate fails
87
+
88
+ ## Auth
89
+
90
+ Hub auth resolves from environment first (`HF_TOKEN` / `HUGGINGFACE_HUB_TOKEN`) and can fall back to `huggingface-api-key.json`.
configs/deepseek_math_sota.yaml CHANGED
@@ -97,17 +97,55 @@ stages:
97
  - conjecture_core
98
  require_conjecture_id: true
99
  training:
100
- num_train_epochs: 3
101
  learning_rate: 5.0e-6
102
  save_steps: 100
103
  eval_steps: 100
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  hub:
106
  push_to_hub: true
107
  repo_id: NorthernTribe-Research/math-conjecture-model
108
  private: false
109
  upload_stage_checkpoints: true
110
- commit_message: Train multi-stage SOTA curriculum for conjecture reasoning.
111
 
112
  credentials:
113
  path: huggingface-api-key.json
 
97
  - conjecture_core
98
  require_conjecture_id: true
99
  training:
100
+ num_train_epochs: 2
101
  learning_rate: 5.0e-6
102
  save_steps: 100
103
  eval_steps: 100
104
 
105
+ - name: hard_case_polish
106
+ max_train_samples: 60000
107
+ max_eval_samples: 2000
108
+ filters:
109
+ include_families:
110
+ - conjecture_core
111
+ - formal_proof
112
+ require_conjecture_id: true
113
+ min_sample_weight: 3.0
114
+ training:
115
+ num_train_epochs: 1
116
+ learning_rate: 3.0e-6
117
+ gradient_accumulation_steps: 24
118
+ save_steps: 80
119
+ eval_steps: 80
120
+
121
+ post_eval:
122
+ enabled: true
123
+ eval_file: data/releases/v1/test.parquet
124
+ max_samples: 240
125
+ k: 6
126
+ max_new_tokens: 320
127
+ temperature: 0.7
128
+ top_p: 0.95
129
+ seed: 17
130
+ output_json: model_development/runs/math-conjecture-sota/post_eval_report.json
131
+
132
+ quality_gate:
133
+ enabled: true
134
+ require_post_eval: true
135
+ min_evaluated_rows: 120
136
+ min_pass_at_1: 0.01
137
+ min_pass_at_k: 0.06
138
+ max_final_eval_loss: 2.6
139
+ required_family_pass_at_k:
140
+ conjecture_core: 0.06
141
+ formal_proof: 0.03
142
+
143
  hub:
144
  push_to_hub: true
145
  repo_id: NorthernTribe-Research/math-conjecture-model
146
  private: false
147
  upload_stage_checkpoints: true
148
+ commit_message: Launch multi-stage DeepSeek-Math fine-tuning on Space GPU and push adapters to your model repo.
149
 
150
  credentials:
151
  path: huggingface-api-key.json
scripts/eval_sota.py CHANGED
@@ -7,7 +7,7 @@ import argparse
7
  import json
8
  import re
9
  from pathlib import Path
10
- from typing import Any, Dict, List, Optional, Sequence
11
 
12
  import torch
13
  import yaml
@@ -15,13 +15,20 @@ from datasets import load_dataset
15
  from peft import PeftModel
16
  from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
17
 
 
 
 
 
 
 
 
18
 
19
  def parse_args() -> argparse.Namespace:
20
  parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.")
21
  parser.add_argument(
22
  "--config",
23
  type=Path,
24
- default=Path("model_development/configs/deepseek_math_sota.yaml"),
25
  help="Training config used for prompt formatting defaults.",
26
  )
27
  parser.add_argument(
@@ -39,19 +46,32 @@ def parse_args() -> argparse.Namespace:
39
  parser.add_argument(
40
  "--eval-file",
41
  type=Path,
42
- default=Path("data/releases/v1/test.parquet"),
43
- help="Parquet split used for evaluation.",
44
  )
45
  parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.")
46
  parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.")
47
  parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.")
 
48
  parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
49
  parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.")
50
  parser.add_argument("--seed", type=int, default=17, help="Random seed.")
 
 
 
 
 
 
 
 
 
 
 
 
51
  parser.add_argument(
52
  "--output-json",
53
  type=Path,
54
- default=Path("model_development/runs/latest_eval_report.json"),
55
  help="Where to write evaluation report.",
56
  )
57
  return parser.parse_args()
@@ -65,6 +85,24 @@ def as_text(value: Any) -> str:
65
  return str(value).strip()
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def load_config(path: Path) -> Dict[str, Any]:
69
  cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
70
  if not isinstance(cfg, dict):
@@ -74,9 +112,124 @@ def load_config(path: Path) -> Dict[str, Any]:
74
 
75
  def normalize_answer(text: str) -> str:
76
  text = text.strip().lower()
77
- text = re.sub(r"\s+", " ", text)
78
  text = text.replace("$", "")
79
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]:
@@ -168,27 +321,11 @@ def extract_candidate_text(full_generation: str, prompt_text: str) -> str:
168
  return full_generation.strip()
169
 
170
 
171
- def is_match(candidate: str, expected_values: Sequence[str]) -> bool:
172
- cand_norm = normalize_answer(candidate)
173
- if not cand_norm:
174
- return False
175
- for expected in expected_values:
176
- exp_norm = normalize_answer(expected)
177
- if not exp_norm:
178
- continue
179
- if exp_norm in cand_norm or cand_norm in exp_norm:
180
- return True
181
- boxed = re.findall(r"\\boxed\{([^{}]+)\}", cand_norm)
182
- if boxed and any(exp_norm in item for item in boxed):
183
- return True
184
- return False
185
-
186
-
187
  def load_model_and_tokenizer(
188
  base_model: str,
189
  adapter_path: Optional[Path],
190
  trust_remote_code: bool,
191
- ) -> tuple[Any, AutoTokenizer]:
192
  tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True)
193
  if tokenizer.pad_token is None:
194
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
@@ -207,8 +344,77 @@ def load_model_and_tokenizer(
207
  return model, tokenizer
208
 
209
 
210
- def main() -> None:
211
- args = parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  cfg = load_config(args.config)
213
  data_cfg = cfg.get("data", {})
214
  model_cfg = cfg.get("model", {})
@@ -217,6 +423,12 @@ def main() -> None:
217
  base_model = args.base_model or as_text(model_cfg.get("base_model"))
218
  if not base_model:
219
  raise ValueError("Base model is required via --base-model or config.model.base_model.")
 
 
 
 
 
 
220
 
221
  model, tokenizer = load_model_and_tokenizer(
222
  base_model=base_model,
@@ -224,25 +436,35 @@ def main() -> None:
224
  trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
225
  )
226
 
227
- if not args.eval_file.exists():
228
- raise FileNotFoundError(f"Evaluation file not found: {args.eval_file}")
229
- ds = load_dataset("parquet", data_files={"eval": str(args.eval_file)})["eval"]
230
-
231
  if args.max_samples > 0 and args.max_samples < len(ds):
232
  ds = ds.select(range(args.max_samples))
233
 
234
- total = 0
235
- hit_at_1 = 0
236
- hit_at_k = 0
237
- records = []
 
 
 
 
 
 
238
 
239
  for row in ds:
240
  expected_values = flatten_expected(row, data_cfg)
241
  if not expected_values:
 
242
  continue
 
243
  prompt_text = build_prompt_text(row, tokenizer, data_cfg)
244
- inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=4096)
245
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
 
246
 
247
  with torch.no_grad():
248
  output_ids = model.generate(
@@ -255,44 +477,119 @@ def main() -> None:
255
  pad_token_id=tokenizer.pad_token_id,
256
  eos_token_id=tokenizer.eos_token_id,
257
  )
 
258
  generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
259
  candidates = [extract_candidate_text(text, prompt_text) for text in generations]
260
- matches = [is_match(candidate, expected_values) for candidate in candidates]
261
- total += 1
262
- if matches and matches[0]:
263
- hit_at_1 += 1
264
- if any(matches):
265
- hit_at_k += 1
266
-
267
- records.append(
268
- {
269
- "uid": as_text(row.get("uid")),
270
- "prompt": as_text(row.get(as_text(data_cfg.get("prompt_field")) or "prompt")),
271
- "expected_values": expected_values[:5],
272
- "candidates": candidates,
273
- "matches": matches,
274
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  )
276
 
277
- pass_at_1 = (hit_at_1 / total) if total else 0.0
278
- pass_at_k = (hit_at_k / total) if total else 0.0
279
- report = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  "base_model": base_model,
281
  "adapter_path": str(args.adapter_path) if args.adapter_path is not None else None,
282
- "eval_file": str(args.eval_file),
283
- "evaluated_rows": total,
 
 
 
284
  "k": args.k,
285
  "pass_at_1": pass_at_1,
286
  "pass_at_k": pass_at_k,
 
 
 
 
287
  "temperature": args.temperature,
288
  "top_p": args.top_p,
289
  "max_new_tokens": args.max_new_tokens,
290
- "samples": records[:30],
 
 
 
 
 
 
 
 
 
 
291
  }
 
292
  args.output_json.parent.mkdir(parents=True, exist_ok=True)
293
  args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8")
294
- print(json.dumps({k: report[k] for k in ("evaluated_rows", "pass_at_1", "pass_at_k", "k")}, indent=2))
 
 
 
 
 
 
 
 
 
295
  print(f"Saved report to {args.output_json}")
 
 
 
 
 
 
296
 
297
 
298
  if __name__ == "__main__":
 
7
  import json
8
  import re
9
  from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
11
 
12
  import torch
13
  import yaml
 
15
  from peft import PeftModel
16
  from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
17
 
18
+ SCRIPT_ROOT = Path(__file__).resolve().parents[1]
19
+ DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml"
20
+ DEFAULT_OUTPUT_JSON = SCRIPT_ROOT / "runs" / "latest_eval_report.json"
21
+
22
+ BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}")
23
+ SPACE_RE = re.compile(r"\s+")
24
+
25
 
26
  def parse_args() -> argparse.Namespace:
27
  parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.")
28
  parser.add_argument(
29
  "--config",
30
  type=Path,
31
+ default=DEFAULT_CONFIG_PATH,
32
  help="Training config used for prompt formatting defaults.",
33
  )
34
  parser.add_argument(
 
46
  parser.add_argument(
47
  "--eval-file",
48
  type=Path,
49
+ default=None,
50
+ help="Parquet split used for evaluation (defaults to post_eval.eval_file or data.default_validation_file).",
51
  )
52
  parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.")
53
  parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.")
54
  parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.")
55
+ parser.add_argument("--max-input-length", type=int, default=4096, help="Prompt tokenization length cap.")
56
  parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
57
  parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.")
58
  parser.add_argument("--seed", type=int, default=17, help="Random seed.")
59
+ parser.add_argument(
60
+ "--progress-every",
61
+ type=int,
62
+ default=25,
63
+ help="Print progress every N evaluated rows (0 disables).",
64
+ )
65
+ parser.add_argument(
66
+ "--sample-records",
67
+ type=int,
68
+ default=30,
69
+ help="How many sample records to store in report.",
70
+ )
71
  parser.add_argument(
72
  "--output-json",
73
  type=Path,
74
+ default=DEFAULT_OUTPUT_JSON,
75
  help="Where to write evaluation report.",
76
  )
77
  return parser.parse_args()
 
85
  return str(value).strip()
86
 
87
 
88
+ def as_float(value: Any, default: float) -> float:
89
+ if value is None:
90
+ return default
91
+ try:
92
+ return float(value)
93
+ except (TypeError, ValueError):
94
+ return default
95
+
96
+
97
+ def as_int(value: Any, default: int) -> int:
98
+ if value is None:
99
+ return default
100
+ try:
101
+ return int(value)
102
+ except (TypeError, ValueError):
103
+ return default
104
+
105
+
106
  def load_config(path: Path) -> Dict[str, Any]:
107
  cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
108
  if not isinstance(cfg, dict):
 
112
 
113
  def normalize_answer(text: str) -> str:
114
  text = text.strip().lower()
 
115
  text = text.replace("$", "")
116
+ text = text.replace("\\left", "").replace("\\right", "")
117
+ text = text.replace("\\,", "").replace("\\!", "").replace("\\;", "")
118
+ text = SPACE_RE.sub(" ", text)
119
+ return text.strip(" .")
120
+
121
+
122
+ def extract_boxed_values(text: str) -> List[str]:
123
+ return [normalize_answer(match) for match in BOXED_RE.findall(text or "") if normalize_answer(match)]
124
+
125
+
126
+ def parse_numeric_value(text: str) -> Optional[float]:
127
+ normalized = normalize_answer(text)
128
+ if not normalized:
129
+ return None
130
+ candidate = normalized.replace(",", "")
131
+ if re.fullmatch(r"[-+]?\d+\s*/\s*[-+]?\d+", candidate):
132
+ left, right = candidate.split("/", maxsplit=1)
133
+ try:
134
+ numerator = float(left.strip())
135
+ denominator = float(right.strip())
136
+ except ValueError:
137
+ return None
138
+ if denominator == 0:
139
+ return None
140
+ return numerator / denominator
141
+ if re.fullmatch(r"[-+]?(?:\d+\.\d*|\d*\.\d+|\d+)(?:[eE][-+]?\d+)?", candidate):
142
+ try:
143
+ return float(candidate)
144
+ except ValueError:
145
+ return None
146
+ return None
147
+
148
+
149
+ def approximately_equal(left: float, right: float) -> bool:
150
+ tolerance = 1e-6 * max(1.0, abs(left), abs(right))
151
+ return abs(left - right) <= tolerance
152
+
153
+
154
+ def match_candidate(candidate: str, expected_values: Sequence[str]) -> Dict[str, Any]:
155
+ cand_norm = normalize_answer(candidate)
156
+ if not cand_norm:
157
+ return {
158
+ "match": False,
159
+ "exact": False,
160
+ "boxed": False,
161
+ "numeric": False,
162
+ "reason": "empty_candidate",
163
+ }
164
+
165
+ cand_boxed = extract_boxed_values(candidate)
166
+ cand_num = parse_numeric_value(cand_norm)
167
+
168
+ substring_hit = False
169
+ boxed_hit = False
170
+ numeric_hit = False
171
+
172
+ for expected in expected_values:
173
+ exp_norm = normalize_answer(expected)
174
+ if not exp_norm:
175
+ continue
176
+
177
+ if cand_norm == exp_norm:
178
+ return {
179
+ "match": True,
180
+ "exact": True,
181
+ "boxed": exp_norm in cand_boxed,
182
+ "numeric": False,
183
+ "reason": "exact",
184
+ }
185
+
186
+ if exp_norm in cand_norm or cand_norm in exp_norm:
187
+ substring_hit = True
188
+
189
+ expected_boxed = extract_boxed_values(expected)
190
+ for cand_box in cand_boxed:
191
+ if cand_box == exp_norm or exp_norm in cand_box or cand_box in exp_norm:
192
+ boxed_hit = True
193
+ for exp_box in expected_boxed:
194
+ if cand_norm == exp_box or exp_box in cand_norm or cand_norm in exp_box:
195
+ boxed_hit = True
196
+
197
+ exp_num = parse_numeric_value(exp_norm)
198
+ if cand_num is not None and exp_num is not None and approximately_equal(cand_num, exp_num):
199
+ numeric_hit = True
200
+
201
+ if boxed_hit:
202
+ return {
203
+ "match": True,
204
+ "exact": False,
205
+ "boxed": True,
206
+ "numeric": numeric_hit,
207
+ "reason": "boxed",
208
+ }
209
+ if numeric_hit:
210
+ return {
211
+ "match": True,
212
+ "exact": False,
213
+ "boxed": False,
214
+ "numeric": True,
215
+ "reason": "numeric",
216
+ }
217
+ if substring_hit:
218
+ return {
219
+ "match": True,
220
+ "exact": False,
221
+ "boxed": False,
222
+ "numeric": False,
223
+ "reason": "substring",
224
+ }
225
+
226
+ return {
227
+ "match": False,
228
+ "exact": False,
229
+ "boxed": False,
230
+ "numeric": False,
231
+ "reason": "no_match",
232
+ }
233
 
234
 
235
  def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]:
 
321
  return full_generation.strip()
322
 
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  def load_model_and_tokenizer(
325
  base_model: str,
326
  adapter_path: Optional[Path],
327
  trust_remote_code: bool,
328
+ ) -> Tuple[Any, AutoTokenizer]:
329
  tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True)
330
  if tokenizer.pad_token is None:
331
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
 
344
  return model, tokenizer
345
 
346
 
347
+ def make_bucket() -> Dict[str, Any]:
348
+ return {
349
+ "evaluated_rows": 0,
350
+ "pass_at_1_hits": 0,
351
+ "pass_at_k_hits": 0,
352
+ "exact_at_1_hits": 0,
353
+ "exact_at_k_hits": 0,
354
+ "boxed_at_k_hits": 0,
355
+ }
356
+
357
+
358
+ def update_bucket(bucket: Dict[str, Any], hit1: bool, hitk: bool, exact1: bool, exactk: bool, boxedk: bool) -> None:
359
+ bucket["evaluated_rows"] += 1
360
+ if hit1:
361
+ bucket["pass_at_1_hits"] += 1
362
+ if hitk:
363
+ bucket["pass_at_k_hits"] += 1
364
+ if exact1:
365
+ bucket["exact_at_1_hits"] += 1
366
+ if exactk:
367
+ bucket["exact_at_k_hits"] += 1
368
+ if boxedk:
369
+ bucket["boxed_at_k_hits"] += 1
370
+
371
+
372
+ def finalize_bucket(bucket: Dict[str, Any]) -> Dict[str, Any]:
373
+ total = max(int(bucket.get("evaluated_rows", 0)), 1)
374
+ rows = int(bucket.get("evaluated_rows", 0))
375
+ return {
376
+ "evaluated_rows": rows,
377
+ "pass_at_1": float(bucket.get("pass_at_1_hits", 0)) / total,
378
+ "pass_at_k": float(bucket.get("pass_at_k_hits", 0)) / total,
379
+ "exact_at_1": float(bucket.get("exact_at_1_hits", 0)) / total,
380
+ "exact_at_k": float(bucket.get("exact_at_k_hits", 0)) / total,
381
+ "boxed_at_k": float(bucket.get("boxed_at_k_hits", 0)) / total,
382
+ }
383
+
384
+
385
+ def resolve_eval_file(arg_eval_file: Optional[Path], cfg: Dict[str, Any]) -> Path:
386
+ if arg_eval_file is not None:
387
+ return arg_eval_file
388
+ post_eval_cfg = cfg.get("post_eval", {})
389
+ data_cfg = cfg.get("data", {})
390
+ for candidate in (
391
+ as_text(post_eval_cfg.get("eval_file")),
392
+ as_text(data_cfg.get("default_validation_file")),
393
+ "data/releases/v1/test.parquet",
394
+ "workspace/data/releases/v1/test.parquet",
395
+ ):
396
+ if not candidate:
397
+ continue
398
+ path = Path(candidate)
399
+ if path.exists():
400
+ return path
401
+ return Path("data/releases/v1/test.parquet")
402
+
403
+
404
+ def run_evaluation(args: argparse.Namespace) -> Dict[str, Any]:
405
+ if args.k < 1:
406
+ raise ValueError("--k must be >= 1.")
407
+ if args.max_samples < 1:
408
+ raise ValueError("--max-samples must be >= 1.")
409
+ if args.max_new_tokens < 1:
410
+ raise ValueError("--max-new-tokens must be >= 1.")
411
+ if args.max_input_length < 128:
412
+ raise ValueError("--max-input-length must be >= 128.")
413
+ if args.temperature <= 0:
414
+ raise ValueError("--temperature must be > 0.")
415
+ if not 0 < args.top_p <= 1:
416
+ raise ValueError("--top-p must be in (0, 1].")
417
+
418
  cfg = load_config(args.config)
419
  data_cfg = cfg.get("data", {})
420
  model_cfg = cfg.get("model", {})
 
423
  base_model = args.base_model or as_text(model_cfg.get("base_model"))
424
  if not base_model:
425
  raise ValueError("Base model is required via --base-model or config.model.base_model.")
426
+ if args.adapter_path is not None and not args.adapter_path.exists():
427
+ raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}")
428
+
429
+ eval_file = resolve_eval_file(args.eval_file, cfg)
430
+ if not eval_file.exists():
431
+ raise FileNotFoundError(f"Evaluation file not found: {eval_file}")
432
 
433
  model, tokenizer = load_model_and_tokenizer(
434
  base_model=base_model,
 
436
  trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
437
  )
438
 
439
+ ds = load_dataset("parquet", data_files={"eval": str(eval_file)})["eval"]
 
 
 
440
  if args.max_samples > 0 and args.max_samples < len(ds):
441
  ds = ds.select(range(args.max_samples))
442
 
443
+ totals = make_bucket()
444
+ family_buckets: Dict[str, Dict[str, Any]] = {}
445
+ difficulty_buckets: Dict[str, Dict[str, Any]] = {}
446
+
447
+ processed_rows = 0
448
+ skipped_no_expected = 0
449
+ samples: List[Dict[str, Any]] = []
450
+
451
+ model_device = next(model.parameters()).device
452
+ prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt"
453
 
454
  for row in ds:
455
  expected_values = flatten_expected(row, data_cfg)
456
  if not expected_values:
457
+ skipped_no_expected += 1
458
  continue
459
+
460
  prompt_text = build_prompt_text(row, tokenizer, data_cfg)
461
+ inputs = tokenizer(
462
+ prompt_text,
463
+ return_tensors="pt",
464
+ truncation=True,
465
+ max_length=args.max_input_length,
466
+ )
467
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
468
 
469
  with torch.no_grad():
470
  output_ids = model.generate(
 
477
  pad_token_id=tokenizer.pad_token_id,
478
  eos_token_id=tokenizer.eos_token_id,
479
  )
480
+
481
  generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
482
  candidates = [extract_candidate_text(text, prompt_text) for text in generations]
483
+ details = [match_candidate(candidate, expected_values) for candidate in candidates]
484
+
485
+ matches = [bool(item["match"]) for item in details]
486
+ exacts = [bool(item["exact"]) for item in details]
487
+ boxed = [bool(item["boxed"]) for item in details]
488
+
489
+ hit1 = bool(matches and matches[0])
490
+ hitk = bool(any(matches))
491
+ exact1 = bool(exacts and exacts[0])
492
+ exactk = bool(any(exacts))
493
+ boxedk = bool(any(boxed))
494
+
495
+ update_bucket(totals, hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk)
496
+
497
+ family = as_text(row.get("family")) or "__unknown__"
498
+ if family not in family_buckets:
499
+ family_buckets[family] = make_bucket()
500
+ update_bucket(family_buckets[family], hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk)
501
+
502
+ difficulty = as_text(row.get("difficulty")) or "__unknown__"
503
+ if difficulty not in difficulty_buckets:
504
+ difficulty_buckets[difficulty] = make_bucket()
505
+ update_bucket(
506
+ difficulty_buckets[difficulty],
507
+ hit1=hit1,
508
+ hitk=hitk,
509
+ exact1=exact1,
510
+ exactk=exactk,
511
+ boxedk=boxedk,
512
  )
513
 
514
+ processed_rows += 1
515
+ if args.progress_every > 0 and processed_rows % args.progress_every == 0:
516
+ print(f"Progress: evaluated_rows={processed_rows} latest_family={family}")
517
+
518
+ if len(samples) < args.sample_records:
519
+ samples.append(
520
+ {
521
+ "uid": as_text(row.get("uid")),
522
+ "family": family,
523
+ "difficulty": difficulty,
524
+ "prompt": as_text(row.get(prompt_field)),
525
+ "expected_values": expected_values[:5],
526
+ "candidates": candidates,
527
+ "match_details": details,
528
+ "matches": matches,
529
+ }
530
+ )
531
+
532
+ total_eval = int(totals.get("evaluated_rows", 0))
533
+ denominator = max(total_eval, 1)
534
+
535
+ pass_at_1 = float(totals.get("pass_at_1_hits", 0)) / denominator
536
+ pass_at_k = float(totals.get("pass_at_k_hits", 0)) / denominator
537
+ exact_at_1 = float(totals.get("exact_at_1_hits", 0)) / denominator
538
+ exact_at_k = float(totals.get("exact_at_k_hits", 0)) / denominator
539
+ boxed_at_k = float(totals.get("boxed_at_k_hits", 0)) / denominator
540
+
541
+ composite_score = 0.30 * pass_at_1 + 0.50 * pass_at_k + 0.20 * exact_at_k
542
+
543
+ report: Dict[str, Any] = {
544
  "base_model": base_model,
545
  "adapter_path": str(args.adapter_path) if args.adapter_path is not None else None,
546
+ "eval_file": str(eval_file),
547
+ "config": str(args.config),
548
+ "evaluated_rows": total_eval,
549
+ "skipped_rows_without_targets": skipped_no_expected,
550
+ "requested_rows": len(ds),
551
  "k": args.k,
552
  "pass_at_1": pass_at_1,
553
  "pass_at_k": pass_at_k,
554
+ "exact_at_1": exact_at_1,
555
+ "exact_at_k": exact_at_k,
556
+ "boxed_at_k": boxed_at_k,
557
+ "composite_score": composite_score,
558
  "temperature": args.temperature,
559
  "top_p": args.top_p,
560
  "max_new_tokens": args.max_new_tokens,
561
+ "max_input_length": args.max_input_length,
562
+ "seed": args.seed,
563
+ "family_metrics": {
564
+ key: finalize_bucket(family_buckets[key])
565
+ for key in sorted(family_buckets.keys())
566
+ },
567
+ "difficulty_metrics": {
568
+ key: finalize_bucket(difficulty_buckets[key])
569
+ for key in sorted(difficulty_buckets.keys())
570
+ },
571
+ "samples": samples,
572
  }
573
+
574
  args.output_json.parent.mkdir(parents=True, exist_ok=True)
575
  args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8")
576
+
577
+ summary_view = {
578
+ "evaluated_rows": total_eval,
579
+ "pass_at_1": pass_at_1,
580
+ "pass_at_k": pass_at_k,
581
+ "exact_at_k": exact_at_k,
582
+ "composite_score": composite_score,
583
+ "k": args.k,
584
+ }
585
+ print(json.dumps(summary_view, indent=2))
586
  print(f"Saved report to {args.output_json}")
587
+ return report
588
+
589
+
590
+ def main() -> None:
591
+ args = parse_args()
592
+ run_evaluation(args)
593
 
594
 
595
  if __name__ == "__main__":
scripts/train_sota.py CHANGED
@@ -4,10 +4,13 @@
4
  from __future__ import annotations
5
 
6
  import argparse
 
7
  import json
8
  import os
 
 
9
  from pathlib import Path
10
- from typing import Any, Dict, Optional, Tuple
11
 
12
  import torch
13
  import yaml
@@ -25,7 +28,9 @@ from transformers import (
25
  set_seed,
26
  )
27
 
28
- DEFAULT_CONFIG_PATH = Path("model_development/configs/deepseek_math_sota.yaml")
 
 
29
 
30
 
31
  def parse_args() -> argparse.Namespace:
@@ -41,6 +46,21 @@ def parse_args() -> argparse.Namespace:
41
  parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.")
42
  parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.")
43
  parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  parser.add_argument(
45
  "--start-stage",
46
  type=int,
@@ -59,6 +79,11 @@ def parse_args() -> argparse.Namespace:
59
  default=None,
60
  help="Override credentials.path.",
61
  )
 
 
 
 
 
62
  return parser.parse_args()
63
 
64
 
@@ -88,6 +113,19 @@ def as_int(value: Any, default: int) -> int:
88
  return default
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def load_config(path: Path) -> Dict[str, Any]:
92
  if not path.exists():
93
  raise FileNotFoundError(f"Config not found: {path}")
@@ -103,6 +141,8 @@ def load_config(path: Path) -> Dict[str, Any]:
103
  cfg.setdefault("training_defaults", {})
104
  cfg.setdefault("hub", {})
105
  cfg.setdefault("credentials", {})
 
 
106
  return cfg
107
 
108
 
@@ -118,6 +158,16 @@ def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None:
118
  if args.no_push_to_hub:
119
  cfg.setdefault("hub", {})["push_to_hub"] = False
120
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
123
  token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
@@ -128,9 +178,17 @@ def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
128
  if path.exists():
129
  data = json.loads(path.read_text(encoding="utf-8"))
130
  if token is None:
131
- token = as_text(data.get("key")) or None
 
 
 
 
132
  if username is None:
133
- username = as_text(data.get("username")) or None
 
 
 
 
134
  return token, username
135
 
136
 
@@ -353,14 +411,10 @@ def tokenize_datasets(raw: DatasetDict, tokenizer: AutoTokenizer, data_cfg: Dict
353
  return tokenized
354
 
355
 
356
- def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict[str, Any]) -> Tuple[Any, AutoTokenizer]:
357
  base_model = as_text(model_cfg.get("base_model"))
358
  if not base_model:
359
  raise ValueError("model.base_model is required.")
360
-
361
- use_bf16 = bool(model_cfg.get("use_bf16", True))
362
- dtype = torch.bfloat16 if use_bf16 else torch.float16
363
-
364
  tokenizer = AutoTokenizer.from_pretrained(
365
  base_model,
366
  trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
@@ -370,6 +424,18 @@ def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict
370
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
371
  if tokenizer.pad_token is None:
372
  tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  model_kwargs: Dict[str, Any] = {
375
  "trust_remote_code": bool(model_cfg.get("trust_remote_code", False)),
@@ -543,6 +609,228 @@ def push_folder(
543
  api.upload_folder(**kwargs)
544
 
545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  def main() -> None:
547
  args = parse_args()
548
  cfg = load_config(args.config)
@@ -551,21 +839,30 @@ def main() -> None:
551
  seed = as_int(cfg.get("global", {}).get("seed"), 17)
552
  set_seed(seed)
553
 
554
- output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "model_development/runs/math-conjecture-sota")
555
  output_root.mkdir(parents=True, exist_ok=True)
556
 
557
  token, username = resolve_auth(cfg)
558
  repo_id = resolve_repo_id(cfg, username=username, output_root=output_root)
559
- push_to_hub = bool(cfg.get("hub", {}).get("push_to_hub", False))
560
- if push_to_hub:
 
 
 
 
561
  if token is None:
562
  raise ValueError("Hub push requested but token is missing.")
563
  if repo_id is None:
564
  raise ValueError("Hub push requested but repo_id is missing.")
565
 
566
- model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
 
 
 
 
 
567
  data_cfg = cfg["data"]
568
- stage_reports = []
569
 
570
  start_stage = max(1, args.start_stage)
571
  stages = cfg["stages"]
@@ -580,17 +877,52 @@ def main() -> None:
580
  stage_name = as_text(stage.get("name")) or f"stage_{index:02d}"
581
  stage_slug = f"{index:02d}_{stage_name.replace(' ', '_')}"
582
  stage_output_dir = output_root / stage_slug
 
583
 
584
  split_files = stage_split_files(stage, data_cfg)
585
  raw = load_dataset("parquet", data_files=split_files)
 
 
 
586
  filters = stage.get("filters", {})
587
  raw["train"] = apply_filters(raw["train"], filters)
588
  raw["validation"] = apply_filters(raw["validation"], filters)
 
 
 
589
  raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples"))
590
  raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples"))
 
 
 
 
 
 
 
591
  if len(raw["train"]) == 0:
592
  raise ValueError(f"Stage {stage_slug} has zero train rows after filtering.")
593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  tokenized = tokenize_datasets(raw, tokenizer, data_cfg)
595
  train_dataset = tokenized["train"]
596
  eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None
@@ -618,39 +950,112 @@ def main() -> None:
618
  trainer.log_metrics("train", train_result.metrics)
619
  trainer.save_metrics("train", train_result.metrics)
620
  trainer.save_state()
 
 
621
  if eval_dataset is not None:
622
  eval_metrics = trainer.evaluate()
623
  trainer.log_metrics("eval", eval_metrics)
624
  trainer.save_metrics("eval", eval_metrics)
 
625
  trainer.save_model(str(stage_output_dir))
626
  tokenizer.save_pretrained(str(stage_output_dir))
627
 
628
- report = {
629
- "stage_index": index,
630
- "stage_name": stage_name,
631
- "output_dir": str(stage_output_dir),
632
- "train_rows": len(train_dataset),
633
- "eval_rows": len(eval_dataset) if eval_dataset is not None else 0,
634
- "train_metrics": train_result.metrics,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  }
636
- stage_reports.append(report)
 
 
 
 
637
 
638
  final_dir = output_root / "final_adapter"
639
  final_dir.mkdir(parents=True, exist_ok=True)
 
640
  model.save_pretrained(str(final_dir))
641
  tokenizer.save_pretrained(str(final_dir))
642
 
643
- summary = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
644
  "config_path": str(args.config),
645
  "repo_id": repo_id,
646
  "seed": seed,
647
  "stages_ran": stage_reports,
648
  "final_adapter_dir": str(final_dir),
 
 
 
 
 
 
649
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  summary_path = output_root / "training_summary.json"
651
  summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
652
 
653
- if push_to_hub and repo_id is not None and token is not None:
654
  api = HfApi(token=token)
655
  api.create_repo(
656
  repo_id=repo_id,
@@ -660,17 +1065,22 @@ def main() -> None:
660
  )
661
  commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter."
662
  push_folder(api, repo_id, final_dir, commit_message=commit_message)
 
663
  if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)):
664
  for report in stage_reports:
665
- stage_dir = Path(report["output_dir"])
666
- path_in_repo = f"checkpoints/{Path(report['output_dir']).name}"
 
 
 
667
  push_folder(
668
  api,
669
  repo_id,
670
  stage_dir,
671
- commit_message=f"Upload stage checkpoint {report['stage_name']}",
672
  path_in_repo=path_in_repo,
673
  )
 
674
  api.upload_file(
675
  path_or_fileobj=str(summary_path),
676
  path_in_repo="training_summary.json",
@@ -678,6 +1088,16 @@ def main() -> None:
678
  repo_type="model",
679
  commit_message="Upload training summary for SOTA curriculum run.",
680
  )
 
 
 
 
 
 
 
 
 
 
681
  print(f"Pushed training artifacts to https://huggingface.co/{repo_id}")
682
 
683
  print(f"Training complete. Final adapter: {final_dir}")
 
4
  from __future__ import annotations
5
 
6
  import argparse
7
+ import gc
8
  import json
9
  import os
10
+ import subprocess
11
+ import sys
12
  from pathlib import Path
13
+ from typing import Any, Dict, List, Optional, Tuple
14
 
15
  import torch
16
  import yaml
 
28
  set_seed,
29
  )
30
 
31
+ SCRIPT_ROOT = Path(__file__).resolve().parents[1]
32
+ DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml"
33
+ DEFAULT_EVAL_SCRIPT = Path(__file__).resolve().with_name("eval_sota.py")
34
 
35
 
36
  def parse_args() -> argparse.Namespace:
 
46
  parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.")
47
  parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.")
48
  parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.")
49
+ parser.add_argument(
50
+ "--run-post-eval",
51
+ action="store_true",
52
+ help="Force post-training evaluation enabled.",
53
+ )
54
+ parser.add_argument(
55
+ "--no-post-eval",
56
+ action="store_true",
57
+ help="Force post-training evaluation disabled.",
58
+ )
59
+ parser.add_argument(
60
+ "--skip-quality-gate",
61
+ action="store_true",
62
+ help="Disable quality gate checks for this run.",
63
+ )
64
  parser.add_argument(
65
  "--start-stage",
66
  type=int,
 
79
  default=None,
80
  help="Override credentials.path.",
81
  )
82
+ parser.add_argument(
83
+ "--dry-run",
84
+ action="store_true",
85
+ help="Validate data/filter/tokenization stages without running training or pushing.",
86
+ )
87
  return parser.parse_args()
88
 
89
 
 
113
  return default
114
 
115
 
116
+ def as_bool(value: Any, default: bool = False) -> bool:
117
+ if value is None:
118
+ return default
119
+ if isinstance(value, bool):
120
+ return value
121
+ text = as_text(value).lower()
122
+ if text in {"1", "true", "yes", "y", "on"}:
123
+ return True
124
+ if text in {"0", "false", "no", "n", "off"}:
125
+ return False
126
+ return default
127
+
128
+
129
  def load_config(path: Path) -> Dict[str, Any]:
130
  if not path.exists():
131
  raise FileNotFoundError(f"Config not found: {path}")
 
141
  cfg.setdefault("training_defaults", {})
142
  cfg.setdefault("hub", {})
143
  cfg.setdefault("credentials", {})
144
+ cfg.setdefault("post_eval", {})
145
+ cfg.setdefault("quality_gate", {})
146
  return cfg
147
 
148
 
 
158
  if args.no_push_to_hub:
159
  cfg.setdefault("hub", {})["push_to_hub"] = False
160
 
161
+ if args.run_post_eval and args.no_post_eval:
162
+ raise ValueError("Cannot set both --run-post-eval and --no-post-eval.")
163
+ if args.run_post_eval:
164
+ cfg.setdefault("post_eval", {})["enabled"] = True
165
+ if args.no_post_eval:
166
+ cfg.setdefault("post_eval", {})["enabled"] = False
167
+
168
+ if args.skip_quality_gate:
169
+ cfg.setdefault("quality_gate", {})["enabled"] = False
170
+
171
 
172
  def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
173
  token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
 
178
  if path.exists():
179
  data = json.loads(path.read_text(encoding="utf-8"))
180
  if token is None:
181
+ for key in ("token", "key", "api_key", "hf_token"):
182
+ candidate = as_text(data.get(key))
183
+ if candidate:
184
+ token = candidate
185
+ break
186
  if username is None:
187
+ for key in ("username", "user", "owner"):
188
+ candidate = as_text(data.get(key))
189
+ if candidate:
190
+ username = candidate
191
+ break
192
  return token, username
193
 
194
 
 
411
  return tokenized
412
 
413
 
414
+ def build_tokenizer(model_cfg: Dict[str, Any]) -> AutoTokenizer:
415
  base_model = as_text(model_cfg.get("base_model"))
416
  if not base_model:
417
  raise ValueError("model.base_model is required.")
 
 
 
 
418
  tokenizer = AutoTokenizer.from_pretrained(
419
  base_model,
420
  trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
 
424
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
425
  if tokenizer.pad_token is None:
426
  tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
427
+ return tokenizer
428
+
429
+
430
+ def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict[str, Any]) -> Tuple[Any, AutoTokenizer]:
431
+ base_model = as_text(model_cfg.get("base_model"))
432
+ if not base_model:
433
+ raise ValueError("model.base_model is required.")
434
+
435
+ use_bf16 = bool(model_cfg.get("use_bf16", True))
436
+ dtype = torch.bfloat16 if use_bf16 else torch.float16
437
+
438
+ tokenizer = build_tokenizer(model_cfg)
439
 
440
  model_kwargs: Dict[str, Any] = {
441
  "trust_remote_code": bool(model_cfg.get("trust_remote_code", False)),
 
609
  api.upload_folder(**kwargs)
610
 
611
 
612
+ def extract_final_eval_loss(stage_reports: List[Dict[str, Any]]) -> Optional[float]:
613
+ for report in reversed(stage_reports):
614
+ eval_metrics = report.get("eval_metrics")
615
+ if not isinstance(eval_metrics, dict):
616
+ continue
617
+ value = eval_metrics.get("eval_loss")
618
+ if value is None:
619
+ continue
620
+ try:
621
+ return float(value)
622
+ except (TypeError, ValueError):
623
+ continue
624
+ return None
625
+
626
+
627
+ def release_model_memory(model: Any) -> None:
628
+ try:
629
+ model.to("cpu")
630
+ except Exception:
631
+ pass
632
+ if torch.cuda.is_available():
633
+ torch.cuda.empty_cache()
634
+ gc.collect()
635
+
636
+
637
+ def run_post_eval(
638
+ cfg: Dict[str, Any],
639
+ config_path: Path,
640
+ output_root: Path,
641
+ final_adapter_dir: Path,
642
+ ) -> Optional[Dict[str, Any]]:
643
+ post_cfg = cfg.get("post_eval", {})
644
+ if not as_bool(post_cfg.get("enabled"), False):
645
+ return None
646
+
647
+ eval_script = DEFAULT_EVAL_SCRIPT
648
+ if not eval_script.exists():
649
+ raise FileNotFoundError(f"Post-eval enabled but eval script is missing: {eval_script}")
650
+
651
+ data_cfg = cfg.get("data", {})
652
+ eval_file = Path(
653
+ as_text(post_cfg.get("eval_file"))
654
+ or as_text(data_cfg.get("default_validation_file"))
655
+ or "data/releases/v1/test.parquet"
656
+ )
657
+ if not eval_file.exists():
658
+ raise FileNotFoundError(f"Post-eval file not found: {eval_file}")
659
+
660
+ output_json = Path(as_text(post_cfg.get("output_json")) or str(output_root / "post_eval_report.json"))
661
+ base_model = as_text(cfg.get("model", {}).get("base_model"))
662
+ if not base_model:
663
+ raise ValueError("model.base_model is required for post-eval.")
664
+
665
+ cmd = [
666
+ sys.executable,
667
+ str(eval_script),
668
+ "--config",
669
+ str(config_path),
670
+ "--base-model",
671
+ base_model,
672
+ "--adapter-path",
673
+ str(final_adapter_dir),
674
+ "--eval-file",
675
+ str(eval_file),
676
+ "--max-samples",
677
+ str(as_int(post_cfg.get("max_samples"), 300)),
678
+ "--k",
679
+ str(as_int(post_cfg.get("k"), 4)),
680
+ "--max-new-tokens",
681
+ str(as_int(post_cfg.get("max_new_tokens"), 256)),
682
+ "--temperature",
683
+ str(as_float(post_cfg.get("temperature"), 0.7)),
684
+ "--top-p",
685
+ str(as_float(post_cfg.get("top_p"), 0.95)),
686
+ "--seed",
687
+ str(as_int(post_cfg.get("seed"), as_int(cfg.get("global", {}).get("seed"), 17))),
688
+ "--output-json",
689
+ str(output_json),
690
+ ]
691
+ print(f"Running post-training eval: {' '.join(cmd)}")
692
+ completed = subprocess.run(cmd, check=False)
693
+ if completed.returncode != 0:
694
+ raise RuntimeError(f"Post-training evaluation failed with exit code {completed.returncode}.")
695
+
696
+ if not output_json.exists():
697
+ raise FileNotFoundError(f"Post-eval report was not created: {output_json}")
698
+
699
+ report = json.loads(output_json.read_text(encoding="utf-8"))
700
+ return {
701
+ "enabled": True,
702
+ "report_path": str(output_json),
703
+ "report": report,
704
+ "command": cmd,
705
+ }
706
+
707
+
708
+ def evaluate_quality_gate(
709
+ stage_reports: List[Dict[str, Any]],
710
+ post_eval_result: Optional[Dict[str, Any]],
711
+ gate_cfg: Dict[str, Any],
712
+ ) -> Dict[str, Any]:
713
+ enabled = as_bool(gate_cfg.get("enabled"), False)
714
+ result: Dict[str, Any] = {
715
+ "enabled": enabled,
716
+ "passed": True,
717
+ "violations": [],
718
+ "checks": [],
719
+ }
720
+ if not enabled:
721
+ return result
722
+
723
+ violations: List[str] = []
724
+ checks: List[Dict[str, Any]] = []
725
+
726
+ final_eval_loss = extract_final_eval_loss(stage_reports)
727
+ max_final_eval_loss = gate_cfg.get("max_final_eval_loss")
728
+ if max_final_eval_loss is not None:
729
+ threshold = as_float(max_final_eval_loss, 0.0)
730
+ checks.append(
731
+ {
732
+ "name": "max_final_eval_loss",
733
+ "actual": final_eval_loss,
734
+ "threshold": threshold,
735
+ }
736
+ )
737
+ if final_eval_loss is None:
738
+ violations.append("Final stage eval_loss is missing for max_final_eval_loss gate.")
739
+ elif final_eval_loss > threshold:
740
+ violations.append(
741
+ f"Final eval_loss {final_eval_loss:.4f} exceeds threshold {threshold:.4f}."
742
+ )
743
+
744
+ report: Optional[Dict[str, Any]] = None
745
+ if isinstance(post_eval_result, dict):
746
+ loaded = post_eval_result.get("report")
747
+ if isinstance(loaded, dict):
748
+ report = loaded
749
+
750
+ require_post_eval = as_bool(gate_cfg.get("require_post_eval"), False)
751
+ if report is None:
752
+ if require_post_eval:
753
+ violations.append("Quality gate requires post-eval metrics, but post-eval report is missing.")
754
+ else:
755
+ evaluated_rows = as_int(report.get("evaluated_rows"), 0)
756
+ min_rows = as_int(gate_cfg.get("min_evaluated_rows"), 0)
757
+ checks.append(
758
+ {
759
+ "name": "min_evaluated_rows",
760
+ "actual": evaluated_rows,
761
+ "threshold": min_rows,
762
+ }
763
+ )
764
+ if evaluated_rows < min_rows:
765
+ violations.append(
766
+ f"Post-eval rows {evaluated_rows} is below minimum required {min_rows}."
767
+ )
768
+
769
+ min_pass_at_1_raw = gate_cfg.get("min_pass_at_1")
770
+ if min_pass_at_1_raw is not None:
771
+ min_pass_at_1 = as_float(min_pass_at_1_raw, 0.0)
772
+ pass_at_1 = as_float(report.get("pass_at_1"), 0.0)
773
+ checks.append(
774
+ {
775
+ "name": "min_pass_at_1",
776
+ "actual": pass_at_1,
777
+ "threshold": min_pass_at_1,
778
+ }
779
+ )
780
+ if pass_at_1 < min_pass_at_1:
781
+ violations.append(
782
+ f"pass@1 {pass_at_1:.4f} is below threshold {min_pass_at_1:.4f}."
783
+ )
784
+
785
+ min_pass_at_k_raw = gate_cfg.get("min_pass_at_k")
786
+ if min_pass_at_k_raw is not None:
787
+ min_pass_at_k = as_float(min_pass_at_k_raw, 0.0)
788
+ pass_at_k = as_float(report.get("pass_at_k"), 0.0)
789
+ checks.append(
790
+ {
791
+ "name": "min_pass_at_k",
792
+ "actual": pass_at_k,
793
+ "threshold": min_pass_at_k,
794
+ }
795
+ )
796
+ if pass_at_k < min_pass_at_k:
797
+ violations.append(
798
+ f"pass@k {pass_at_k:.4f} is below threshold {min_pass_at_k:.4f}."
799
+ )
800
+
801
+ family_requirements = gate_cfg.get("required_family_pass_at_k", {})
802
+ family_metrics = report.get("family_metrics", {})
803
+ if isinstance(family_requirements, dict):
804
+ for family, threshold_raw in family_requirements.items():
805
+ threshold = as_float(threshold_raw, 0.0)
806
+ actual = None
807
+ if isinstance(family_metrics, dict):
808
+ family_row = family_metrics.get(family)
809
+ if isinstance(family_row, dict):
810
+ try:
811
+ actual = float(family_row.get("pass_at_k"))
812
+ except (TypeError, ValueError):
813
+ actual = None
814
+ checks.append(
815
+ {
816
+ "name": f"family_pass_at_k:{family}",
817
+ "actual": actual,
818
+ "threshold": threshold,
819
+ }
820
+ )
821
+ if actual is None:
822
+ violations.append(f"Missing pass@k metric for required family '{family}'.")
823
+ elif actual < threshold:
824
+ violations.append(
825
+ f"Family '{family}' pass@k {actual:.4f} is below threshold {threshold:.4f}."
826
+ )
827
+
828
+ result["violations"] = violations
829
+ result["checks"] = checks
830
+ result["passed"] = len(violations) == 0
831
+ return result
832
+
833
+
834
  def main() -> None:
835
  args = parse_args()
836
  cfg = load_config(args.config)
 
839
  seed = as_int(cfg.get("global", {}).get("seed"), 17)
840
  set_seed(seed)
841
 
842
+ output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "runs/math-conjecture-sota")
843
  output_root.mkdir(parents=True, exist_ok=True)
844
 
845
  token, username = resolve_auth(cfg)
846
  repo_id = resolve_repo_id(cfg, username=username, output_root=output_root)
847
+ push_to_hub_requested = bool(cfg.get("hub", {}).get("push_to_hub", False))
848
+ if args.dry_run and push_to_hub_requested:
849
+ print("Dry-run enabled. Disabling push_to_hub for this run.")
850
+ push_to_hub_requested = push_to_hub_requested and not args.dry_run
851
+
852
+ if push_to_hub_requested:
853
  if token is None:
854
  raise ValueError("Hub push requested but token is missing.")
855
  if repo_id is None:
856
  raise ValueError("Hub push requested but repo_id is missing.")
857
 
858
+ if args.dry_run:
859
+ tokenizer = build_tokenizer(cfg["model"])
860
+ model = None
861
+ else:
862
+ model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
863
+
864
  data_cfg = cfg["data"]
865
+ stage_reports: List[Dict[str, Any]] = []
866
 
867
  start_stage = max(1, args.start_stage)
868
  stages = cfg["stages"]
 
877
  stage_name = as_text(stage.get("name")) or f"stage_{index:02d}"
878
  stage_slug = f"{index:02d}_{stage_name.replace(' ', '_')}"
879
  stage_output_dir = output_root / stage_slug
880
+ print(f"[stage {index}] Starting: {stage_name}")
881
 
882
  split_files = stage_split_files(stage, data_cfg)
883
  raw = load_dataset("parquet", data_files=split_files)
884
+ train_rows_before = len(raw["train"])
885
+ valid_rows_before = len(raw["validation"])
886
+
887
  filters = stage.get("filters", {})
888
  raw["train"] = apply_filters(raw["train"], filters)
889
  raw["validation"] = apply_filters(raw["validation"], filters)
890
+ train_rows_after_filter = len(raw["train"])
891
+ valid_rows_after_filter = len(raw["validation"])
892
+
893
  raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples"))
894
  raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples"))
895
+ train_rows_selected = len(raw["train"])
896
+ valid_rows_selected = len(raw["validation"])
897
+
898
+ print(
899
+ f"[stage {index}] rows train: {train_rows_before} -> {train_rows_after_filter} -> {train_rows_selected}; "
900
+ f"validation: {valid_rows_before} -> {valid_rows_after_filter} -> {valid_rows_selected}"
901
+ )
902
  if len(raw["train"]) == 0:
903
  raise ValueError(f"Stage {stage_slug} has zero train rows after filtering.")
904
 
905
+ if args.dry_run:
906
+ sample_row = raw["train"][0]
907
+ _ = build_prompt_text(sample_row, tokenizer, data_cfg)
908
+ _ = build_answer_block(sample_row, data_cfg)
909
+ stage_reports.append(
910
+ {
911
+ "stage_index": index,
912
+ "stage_name": stage_name,
913
+ "stage_slug": stage_slug,
914
+ "mode": "dry_run",
915
+ "train_rows_before_filter": train_rows_before,
916
+ "validation_rows_before_filter": valid_rows_before,
917
+ "train_rows_after_filter": train_rows_after_filter,
918
+ "validation_rows_after_filter": valid_rows_after_filter,
919
+ "train_rows_selected": train_rows_selected,
920
+ "validation_rows_selected": valid_rows_selected,
921
+ }
922
+ )
923
+ print(f"[stage {index}] Dry-run checks passed.")
924
+ continue
925
+
926
  tokenized = tokenize_datasets(raw, tokenizer, data_cfg)
927
  train_dataset = tokenized["train"]
928
  eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None
 
950
  trainer.log_metrics("train", train_result.metrics)
951
  trainer.save_metrics("train", train_result.metrics)
952
  trainer.save_state()
953
+
954
+ eval_metrics = None
955
  if eval_dataset is not None:
956
  eval_metrics = trainer.evaluate()
957
  trainer.log_metrics("eval", eval_metrics)
958
  trainer.save_metrics("eval", eval_metrics)
959
+
960
  trainer.save_model(str(stage_output_dir))
961
  tokenizer.save_pretrained(str(stage_output_dir))
962
 
963
+ stage_reports.append(
964
+ {
965
+ "stage_index": index,
966
+ "stage_name": stage_name,
967
+ "output_dir": str(stage_output_dir),
968
+ "train_rows_before_filter": train_rows_before,
969
+ "validation_rows_before_filter": valid_rows_before,
970
+ "train_rows_after_filter": train_rows_after_filter,
971
+ "validation_rows_after_filter": valid_rows_after_filter,
972
+ "train_rows_selected": train_rows_selected,
973
+ "validation_rows_selected": valid_rows_selected,
974
+ "train_rows": len(train_dataset),
975
+ "eval_rows": len(eval_dataset) if eval_dataset is not None else 0,
976
+ "train_metrics": train_result.metrics,
977
+ "eval_metrics": eval_metrics,
978
+ }
979
+ )
980
+ print(
981
+ f"[stage {index}] Completed: train_rows={len(train_dataset)} "
982
+ f"eval_rows={len(eval_dataset) if eval_dataset is not None else 0} output={stage_output_dir}"
983
+ )
984
+
985
+ if args.dry_run:
986
+ summary = {
987
+ "mode": "dry_run",
988
+ "config_path": str(args.config),
989
+ "seed": seed,
990
+ "start_stage": start_stage,
991
+ "end_stage": end_stage,
992
+ "stages_ran": stage_reports,
993
  }
994
+ summary_path = output_root / "dry_run_summary.json"
995
+ summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
996
+ print("Dry-run complete. No training or model push was performed.")
997
+ print(f"Dry-run summary: {summary_path}")
998
+ return
999
 
1000
  final_dir = output_root / "final_adapter"
1001
  final_dir.mkdir(parents=True, exist_ok=True)
1002
+ assert model is not None
1003
  model.save_pretrained(str(final_dir))
1004
  tokenizer.save_pretrained(str(final_dir))
1005
 
1006
+ release_model_memory(model)
1007
+ del model
1008
+
1009
+ post_eval_result = run_post_eval(
1010
+ cfg=cfg,
1011
+ config_path=args.config,
1012
+ output_root=output_root,
1013
+ final_adapter_dir=final_dir,
1014
+ )
1015
+
1016
+ quality_gate = evaluate_quality_gate(
1017
+ stage_reports=stage_reports,
1018
+ post_eval_result=post_eval_result,
1019
+ gate_cfg=cfg.get("quality_gate", {}),
1020
+ )
1021
+
1022
+ push_to_hub_performed = push_to_hub_requested
1023
+ push_block_reason: Optional[str] = None
1024
+ if push_to_hub_requested and not quality_gate.get("passed", True):
1025
+ push_to_hub_performed = False
1026
+ push_block_reason = "quality_gate_failed"
1027
+ print("Quality gate failed; skipping hub push for this run.")
1028
+
1029
+ summary: Dict[str, Any] = {
1030
  "config_path": str(args.config),
1031
  "repo_id": repo_id,
1032
  "seed": seed,
1033
  "stages_ran": stage_reports,
1034
  "final_adapter_dir": str(final_dir),
1035
+ "quality_gate": quality_gate,
1036
+ "push": {
1037
+ "requested": bool(push_to_hub_requested),
1038
+ "performed": bool(push_to_hub_performed),
1039
+ "block_reason": push_block_reason,
1040
+ },
1041
  }
1042
+
1043
+ if post_eval_result is not None:
1044
+ report = post_eval_result.get("report", {})
1045
+ summary["post_eval"] = {
1046
+ "report_path": post_eval_result.get("report_path"),
1047
+ "evaluated_rows": report.get("evaluated_rows"),
1048
+ "k": report.get("k"),
1049
+ "pass_at_1": report.get("pass_at_1"),
1050
+ "pass_at_k": report.get("pass_at_k"),
1051
+ "exact_at_k": report.get("exact_at_k"),
1052
+ "composite_score": report.get("composite_score"),
1053
+ }
1054
+
1055
  summary_path = output_root / "training_summary.json"
1056
  summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
1057
 
1058
+ if push_to_hub_performed and repo_id is not None and token is not None:
1059
  api = HfApi(token=token)
1060
  api.create_repo(
1061
  repo_id=repo_id,
 
1065
  )
1066
  commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter."
1067
  push_folder(api, repo_id, final_dir, commit_message=commit_message)
1068
+
1069
  if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)):
1070
  for report in stage_reports:
1071
+ stage_dir_raw = report.get("output_dir")
1072
+ if not stage_dir_raw:
1073
+ continue
1074
+ stage_dir = Path(stage_dir_raw)
1075
+ path_in_repo = f"checkpoints/{stage_dir.name}"
1076
  push_folder(
1077
  api,
1078
  repo_id,
1079
  stage_dir,
1080
+ commit_message=f"Upload stage checkpoint {report.get('stage_name', stage_dir.name)}",
1081
  path_in_repo=path_in_repo,
1082
  )
1083
+
1084
  api.upload_file(
1085
  path_or_fileobj=str(summary_path),
1086
  path_in_repo="training_summary.json",
 
1088
  repo_type="model",
1089
  commit_message="Upload training summary for SOTA curriculum run.",
1090
  )
1091
+
1092
+ if post_eval_result is not None and post_eval_result.get("report_path"):
1093
+ api.upload_file(
1094
+ path_or_fileobj=str(post_eval_result["report_path"]),
1095
+ path_in_repo="post_eval_report.json",
1096
+ repo_id=repo_id,
1097
+ repo_type="model",
1098
+ commit_message="Upload post-training evaluation report.",
1099
+ )
1100
+
1101
  print(f"Pushed training artifacts to https://huggingface.co/{repo_id}")
1102
 
1103
  print(f"Training complete. Final adapter: {final_dir}")