NorthernTribe-Research commited on
Commit
0be512a
·
verified ·
1 Parent(s): f505b3a

Upgrade training pipeline with post-eval quality gates and tactical UI controls.

Browse files
Files changed (5) hide show
  1. README.md +35 -35
  2. app.py +171 -7
  3. configs/deepseek_math_sota.yaml +40 -2
  4. scripts/eval_sota.py +344 -61
  5. scripts/train_sota.py +387 -47
README.md CHANGED
@@ -9,52 +9,52 @@ pinned: false
9
 
10
  # Math Conjecture Trainer Space
11
 
12
- This Space is the tactical training console for the project: it pulls released
13
- training corpus splits, builds runtime config from the SOTA curriculum YAML,
14
- executes multi-stage `DeepSeek-Math` fine-tuning, optionally evaluates
15
- self-consistency, and can publish adapters/checkpoints/training summaries to:
16
 
17
- - `NorthernTribe-Research/math-conjecture-model` (when push is enabled)
18
 
19
- ## What this Space does
 
20
 
21
- 1. Downloads dataset parquet splits from:
22
- `NorthernTribe-Research/math-conjecture-training-corpus`
23
- 2. Builds a runtime config from `configs/deepseek_math_sota.yaml`
24
- 3. Runs `scripts/train_sota.py` for staged curriculum training
25
- 4. Optionally runs `scripts/eval_sota.py`
26
- 5. Streams logs and run summary JSON in the UI
27
 
28
- ## Authentication mode
 
 
 
 
 
29
 
30
- The app is autonomous and does not require entering an HF token in the UI.
31
- It resolves auth in this order:
32
 
33
- 1. `HF_TOKEN` environment variable
34
- 2. `HUGGINGFACE_HUB_TOKEN` environment variable
35
- 3. `huggingface-api-key.json` (if present)
36
 
37
- If no token is found, training can still run on public datasets, and hub push is
38
- automatically disabled for that run.
39
 
40
- ## Operational controls in UI
 
 
41
 
42
- - `Preflight Only (No Training)`: validates data/config/stage pipeline using
43
- `train_sota.py --dry-run`.
44
- - `Push Adapter to Hub`: controls whether `hub.push_to_hub` is enabled in the
45
- runtime config.
46
- - `Force Dataset Redownload`: bypasses cached local parquet files.
47
- - `Stop Active Run`: requests cancellation and terminates active subprocesses.
48
- - `Run Summary (JSON)`: structured output with config, status, and metrics.
49
 
50
- ## Default training config
51
 
52
- - `configs/deepseek_math_sota.yaml`
53
- - base model default: `deepseek-ai/deepseek-math-v2`
54
- - output root: `workspace/runs/math-conjecture-sota`
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  ## Notes
57
 
58
- - Full training expects GPU hardware.
59
- - Runtime config generated by the app is stored at:
60
- `workspace/runtime/deepseek_math_sota.runtime.yaml`.
 
9
 
10
  # Math Conjecture Trainer Space
11
 
12
+ Launch multi-stage DeepSeek-Math fine-tuning on Space GPU and push adapters to your model repo.
 
 
 
13
 
14
+ This Space is the tactical operations console for `maths-conjuncture-solutions` and is wired to:
15
 
16
+ - dataset: `NorthernTribe-Research/math-conjecture-training-corpus`
17
+ - model repo: `NorthernTribe-Research/math-conjecture-model`
18
 
19
+ ## End-to-end flow
 
 
 
 
 
20
 
21
+ 1. Download released parquet splits (`train/validation/test`).
22
+ 2. Build runtime config from `configs/deepseek_math_sota.yaml`.
23
+ 3. Run 4-stage curriculum LoRA fine-tuning with `scripts/train_sota.py`.
24
+ 4. Run post-train evaluation (`pass@1`, `pass@k`, exact/boxed, family metrics).
25
+ 5. Apply quality gate thresholds before hub push.
26
+ 6. Emit `training_summary.json` + `post_eval_report.json` and stream live telemetry in UI.
27
 
28
+ ## Autonomous authentication
 
29
 
30
+ No token input is required in the UI.
 
 
31
 
32
+ Resolution order:
 
33
 
34
+ 1. `HF_TOKEN`
35
+ 2. `HUGGINGFACE_HUB_TOKEN`
36
+ 3. `huggingface-api-key.json`
37
 
38
+ If no token is available, public dataset training still works and push is automatically disabled.
 
 
 
 
 
 
39
 
40
+ ## Runtime controls
41
 
42
+ - `Run Evaluation After Training`: toggles post-train eval in runtime config.
43
+ - `Enforce Quality Gate`: enables/disables promotion gate checks.
44
+ - `Gate Min pass@1`, `Gate Min pass@k`, `Gate Min Rows`: runtime gate thresholds.
45
+ - `Preflight Only (No Training)`: validates pipeline with `--dry-run`.
46
+ - `Force Dataset Redownload`: bypasses cached parquet files.
47
+ - `Abort Active Run`: cancels active subprocess tree.
48
+
49
+ ## Artifacts
50
+
51
+ - runtime config: `workspace/runtime/deepseek_math_sota.runtime.yaml`
52
+ - run output root: `workspace/runs/math-conjecture-sota`
53
+ - final adapter: `workspace/runs/math-conjecture-sota/final_adapter`
54
+ - training summary: `workspace/runs/math-conjecture-sota/training_summary.json`
55
+ - post-eval report: `workspace/runs/math-conjecture-sota/post_eval_report.json`
56
 
57
  ## Notes
58
 
59
+ - Full training requires GPU hardware.
60
+ - App handles Gradio copy-button compatibility across versions automatically.
 
app.py CHANGED
@@ -198,15 +198,58 @@ PROJECT_DESCRIPTION = """
198
  # Math Conjecture Trainer
199
  This console runs the full training operations lane for the `maths-conjuncture-solutions` project:
200
 
 
 
201
  1. Pull released parquet splits from `NorthernTribe-Research/math-conjecture-training-corpus`.
202
  2. Build runtime training configuration from `configs/deepseek_math_sota.yaml`.
203
  3. Execute multi-stage DeepSeek-Math curriculum fine-tuning via `scripts/train_sota.py`.
204
- 4. Optionally evaluate adapters with pass@k-style sampling via `scripts/eval_sota.py`.
205
- 5. Auto-resolve Hugging Face credentials, push adapters/checkpoints/summary when allowed, and stream live logs.
206
- 6. Support preflight validation, abort control, cache strategy, and structured run-summary telemetry in one UI.
207
  """
208
 
209
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  def now_ts() -> str:
211
  return dt.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC")
212
 
@@ -390,7 +433,15 @@ def write_runtime_config(
390
  model_repo_id: str,
391
  train_file: str,
392
  validation_file: str,
 
 
 
 
393
  push_to_hub: bool,
 
 
 
 
394
  ) -> Path:
395
  cfg = yaml.safe_load(CONFIG_TEMPLATE.read_text(encoding="utf-8"))
396
  cfg["model"]["base_model"] = base_model_id
@@ -399,6 +450,21 @@ def write_runtime_config(
399
  cfg["data"]["default_train_file"] = train_file
400
  cfg["data"]["default_validation_file"] = validation_file
401
  cfg["global"]["output_root"] = str(TRAIN_OUTPUT_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  runtime_path = RUNTIME_DIR / "deepseek_math_sota.runtime.yaml"
403
  runtime_path.write_text(yaml.safe_dump(cfg, sort_keys=False), encoding="utf-8")
404
  return runtime_path
@@ -511,6 +577,10 @@ def run_pipeline(
511
  run_eval: bool,
512
  eval_k: int,
513
  eval_samples: int,
 
 
 
 
514
  push_to_hub: bool,
515
  force_redownload: bool,
516
  preflight_only: bool,
@@ -542,14 +612,25 @@ def run_pipeline(
542
  stage_count = int(max_stages)
543
  eval_k = int(eval_k)
544
  eval_samples = int(eval_samples)
 
 
 
545
  if stage_start < 1:
546
  raise ValueError("Start stage must be >= 1.")
 
 
547
  if stage_count < 1:
548
  raise ValueError("How many stages must be >= 1.")
549
  if eval_k < 1:
550
  raise ValueError("Eval K must be >= 1.")
551
  if eval_samples < 1:
552
  raise ValueError("Eval max samples must be >= 1.")
 
 
 
 
 
 
553
 
554
  for required_path in (CONFIG_TEMPLATE, TRAIN_SCRIPT):
555
  if not required_path.exists():
@@ -570,6 +651,10 @@ def run_pipeline(
570
  "run_eval": bool(run_eval),
571
  "eval_k": eval_k,
572
  "eval_samples": eval_samples,
 
 
 
 
573
  "push_to_hub": bool(push_to_hub),
574
  "force_redownload": bool(force_redownload),
575
  "preflight_only": bool(preflight_only),
@@ -633,7 +718,15 @@ def run_pipeline(
633
  model_repo_id=model_repo_id,
634
  train_file=train_file,
635
  validation_file=validation_file,
 
 
 
 
636
  push_to_hub=effective_push_to_hub,
 
 
 
 
637
  )
638
  summary["runtime_config"] = str(runtime_cfg)
639
  append_log(log_lines, f"Wrote runtime config: {runtime_cfg}")
@@ -701,15 +794,50 @@ def run_pipeline(
701
  return
702
 
703
  training_summary_path = TRAIN_OUTPUT_DIR / "training_summary.json"
 
704
  if training_summary_path.exists():
705
  try:
706
  summary["training_summary_path"] = str(training_summary_path)
707
- summary["training_summary"] = json.loads(training_summary_path.read_text(encoding="utf-8"))
 
 
 
 
 
708
  except json.JSONDecodeError:
709
  summary["training_summary_path"] = str(training_summary_path)
710
  summary["training_summary"] = {"warning": "Unable to parse training summary JSON."}
711
 
712
- if run_eval:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  eval_report = WORKSPACE_DIR / "runs" / "latest_eval_report.json"
714
  eval_cmd = [
715
  sys.executable,
@@ -763,9 +891,12 @@ def run_pipeline(
763
  if eval_report.exists():
764
  report = json.loads(eval_report.read_text(encoding="utf-8"))
765
  summary["evaluation"] = {
 
766
  "evaluated_rows": report.get("evaluated_rows"),
767
  "pass_at_1": report.get("pass_at_1"),
768
  "pass_at_k": report.get("pass_at_k"),
 
 
769
  "k": report.get("k"),
770
  "report_path": str(eval_report),
771
  }
@@ -807,12 +938,41 @@ with gr.Blocks(title="Math Conjecture Trainer Space") as demo:
807
  value="deepseek-ai/deepseek-math-v2",
808
  )
809
  with gr.Row():
810
- start_stage = gr.Slider(label="Stage Start", minimum=1, maximum=3, step=1, value=1)
811
- max_stages = gr.Slider(label="Stage Count", minimum=1, maximum=3, step=1, value=3)
 
 
 
 
 
 
812
  run_eval = gr.Checkbox(label="Run Evaluation After Training", value=True)
813
  with gr.Row():
814
  eval_k = gr.Slider(label="Evaluation K", minimum=1, maximum=8, step=1, value=4)
815
  eval_samples = gr.Slider(label="Evaluation Max Samples", minimum=50, maximum=1000, step=50, value=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
  with gr.Row():
817
  push_to_hub = gr.Checkbox(label="Push Adapter to Hub", value=True)
818
  force_redownload = gr.Checkbox(label="Force Dataset Redownload", value=False)
@@ -843,6 +1003,10 @@ with gr.Blocks(title="Math Conjecture Trainer Space") as demo:
843
  run_eval,
844
  eval_k,
845
  eval_samples,
 
 
 
 
846
  push_to_hub,
847
  force_redownload,
848
  preflight_only,
 
198
  # Math Conjecture Trainer
199
  This console runs the full training operations lane for the `maths-conjuncture-solutions` project:
200
 
201
+ Launch multi-stage DeepSeek-Math fine-tuning on Space GPU and push adapters to your model repo.
202
+
203
  1. Pull released parquet splits from `NorthernTribe-Research/math-conjecture-training-corpus`.
204
  2. Build runtime training configuration from `configs/deepseek_math_sota.yaml`.
205
  3. Execute multi-stage DeepSeek-Math curriculum fine-tuning via `scripts/train_sota.py`.
206
+ 4. Run post-training evaluation with pass@k-style sampling and family-level metrics.
207
+ 5. Enforce autonomous quality gates before adapter promotion/push.
208
+ 6. Auto-resolve Hugging Face credentials, stream live telemetry, and emit structured run summaries.
209
  """
210
 
211
 
212
+ def _safe_float(value: Any, default: float) -> float:
213
+ try:
214
+ return float(value)
215
+ except (TypeError, ValueError):
216
+ return default
217
+
218
+
219
+ def _safe_int(value: Any, default: int) -> int:
220
+ try:
221
+ return int(value)
222
+ except (TypeError, ValueError):
223
+ return default
224
+
225
+
226
+ def load_template_defaults() -> Dict[str, Any]:
227
+ if not CONFIG_TEMPLATE.exists():
228
+ return {}
229
+ try:
230
+ cfg = yaml.safe_load(CONFIG_TEMPLATE.read_text(encoding="utf-8"))
231
+ except Exception:
232
+ return {}
233
+ if not isinstance(cfg, dict):
234
+ return {}
235
+ return cfg
236
+
237
+
238
+ TEMPLATE_CFG = load_template_defaults()
239
+ TEMPLATE_STAGE_COUNT = max(1, len(TEMPLATE_CFG.get("stages", []) or [None]))
240
+ TEMPLATE_QUALITY_GATE = TEMPLATE_CFG.get("quality_gate", {})
241
+ if not isinstance(TEMPLATE_QUALITY_GATE, dict):
242
+ TEMPLATE_QUALITY_GATE = {}
243
+ _raw_gate_enabled = TEMPLATE_QUALITY_GATE.get("enabled", True)
244
+ if isinstance(_raw_gate_enabled, bool):
245
+ DEFAULT_GATE_ENABLED = _raw_gate_enabled
246
+ else:
247
+ DEFAULT_GATE_ENABLED = str(_raw_gate_enabled).strip().lower() in {"1", "true", "yes", "y", "on"}
248
+ DEFAULT_GATE_MIN_ROWS = max(1, _safe_int(TEMPLATE_QUALITY_GATE.get("min_evaluated_rows"), 120))
249
+ DEFAULT_GATE_MIN_PASS_AT_1 = max(0.0, _safe_float(TEMPLATE_QUALITY_GATE.get("min_pass_at_1"), 0.01))
250
+ DEFAULT_GATE_MIN_PASS_AT_K = max(0.0, _safe_float(TEMPLATE_QUALITY_GATE.get("min_pass_at_k"), 0.06))
251
+
252
+
253
  def now_ts() -> str:
254
  return dt.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC")
255
 
 
433
  model_repo_id: str,
434
  train_file: str,
435
  validation_file: str,
436
+ test_file: str,
437
+ run_eval: bool,
438
+ eval_k: int,
439
+ eval_samples: int,
440
  push_to_hub: bool,
441
+ enforce_quality_gate: bool,
442
+ gate_min_pass_at_1: float,
443
+ gate_min_pass_at_k: float,
444
+ gate_min_rows: int,
445
  ) -> Path:
446
  cfg = yaml.safe_load(CONFIG_TEMPLATE.read_text(encoding="utf-8"))
447
  cfg["model"]["base_model"] = base_model_id
 
450
  cfg["data"]["default_train_file"] = train_file
451
  cfg["data"]["default_validation_file"] = validation_file
452
  cfg["global"]["output_root"] = str(TRAIN_OUTPUT_DIR)
453
+
454
+ cfg.setdefault("post_eval", {})
455
+ cfg["post_eval"]["enabled"] = bool(run_eval)
456
+ cfg["post_eval"]["eval_file"] = test_file
457
+ cfg["post_eval"]["k"] = int(eval_k)
458
+ cfg["post_eval"]["max_samples"] = int(eval_samples)
459
+ cfg["post_eval"]["output_json"] = str(TRAIN_OUTPUT_DIR / "post_eval_report.json")
460
+
461
+ cfg.setdefault("quality_gate", {})
462
+ cfg["quality_gate"]["enabled"] = bool(enforce_quality_gate)
463
+ cfg["quality_gate"]["min_evaluated_rows"] = int(gate_min_rows)
464
+ cfg["quality_gate"]["min_pass_at_1"] = float(gate_min_pass_at_1)
465
+ cfg["quality_gate"]["min_pass_at_k"] = float(gate_min_pass_at_k)
466
+ cfg["quality_gate"]["require_post_eval"] = bool(enforce_quality_gate and run_eval)
467
+
468
  runtime_path = RUNTIME_DIR / "deepseek_math_sota.runtime.yaml"
469
  runtime_path.write_text(yaml.safe_dump(cfg, sort_keys=False), encoding="utf-8")
470
  return runtime_path
 
577
  run_eval: bool,
578
  eval_k: int,
579
  eval_samples: int,
580
+ enforce_quality_gate: bool,
581
+ gate_min_pass_at_1: float,
582
+ gate_min_pass_at_k: float,
583
+ gate_min_rows: int,
584
  push_to_hub: bool,
585
  force_redownload: bool,
586
  preflight_only: bool,
 
612
  stage_count = int(max_stages)
613
  eval_k = int(eval_k)
614
  eval_samples = int(eval_samples)
615
+ gate_min_rows = int(gate_min_rows)
616
+ gate_min_pass_at_1 = float(gate_min_pass_at_1)
617
+ gate_min_pass_at_k = float(gate_min_pass_at_k)
618
  if stage_start < 1:
619
  raise ValueError("Start stage must be >= 1.")
620
+ if stage_start > TEMPLATE_STAGE_COUNT:
621
+ raise ValueError(f"Start stage must be <= {TEMPLATE_STAGE_COUNT}.")
622
  if stage_count < 1:
623
  raise ValueError("How many stages must be >= 1.")
624
  if eval_k < 1:
625
  raise ValueError("Eval K must be >= 1.")
626
  if eval_samples < 1:
627
  raise ValueError("Eval max samples must be >= 1.")
628
+ if gate_min_rows < 1:
629
+ raise ValueError("Gate minimum rows must be >= 1.")
630
+ if not 0.0 <= gate_min_pass_at_1 <= 1.0:
631
+ raise ValueError("Gate min pass@1 must be between 0 and 1.")
632
+ if not 0.0 <= gate_min_pass_at_k <= 1.0:
633
+ raise ValueError("Gate min pass@k must be between 0 and 1.")
634
 
635
  for required_path in (CONFIG_TEMPLATE, TRAIN_SCRIPT):
636
  if not required_path.exists():
 
651
  "run_eval": bool(run_eval),
652
  "eval_k": eval_k,
653
  "eval_samples": eval_samples,
654
+ "enforce_quality_gate": bool(enforce_quality_gate),
655
+ "gate_min_rows": gate_min_rows,
656
+ "gate_min_pass_at_1": gate_min_pass_at_1,
657
+ "gate_min_pass_at_k": gate_min_pass_at_k,
658
  "push_to_hub": bool(push_to_hub),
659
  "force_redownload": bool(force_redownload),
660
  "preflight_only": bool(preflight_only),
 
718
  model_repo_id=model_repo_id,
719
  train_file=train_file,
720
  validation_file=validation_file,
721
+ test_file=test_file,
722
+ run_eval=bool(run_eval),
723
+ eval_k=eval_k,
724
+ eval_samples=eval_samples,
725
  push_to_hub=effective_push_to_hub,
726
+ enforce_quality_gate=bool(enforce_quality_gate),
727
+ gate_min_pass_at_1=gate_min_pass_at_1,
728
+ gate_min_pass_at_k=gate_min_pass_at_k,
729
+ gate_min_rows=gate_min_rows,
730
  )
731
  summary["runtime_config"] = str(runtime_cfg)
732
  append_log(log_lines, f"Wrote runtime config: {runtime_cfg}")
 
794
  return
795
 
796
  training_summary_path = TRAIN_OUTPUT_DIR / "training_summary.json"
797
+ training_summary: Optional[Dict[str, Any]] = None
798
  if training_summary_path.exists():
799
  try:
800
  summary["training_summary_path"] = str(training_summary_path)
801
+ loaded_summary = json.loads(training_summary_path.read_text(encoding="utf-8"))
802
+ if isinstance(loaded_summary, dict):
803
+ training_summary = loaded_summary
804
+ summary["training_summary"] = loaded_summary
805
+ else:
806
+ summary["training_summary"] = {"warning": "Training summary JSON is not an object."}
807
  except json.JSONDecodeError:
808
  summary["training_summary_path"] = str(training_summary_path)
809
  summary["training_summary"] = {"warning": "Unable to parse training summary JSON."}
810
 
811
+ if isinstance(training_summary, dict):
812
+ quality_gate = training_summary.get("quality_gate")
813
+ if isinstance(quality_gate, dict):
814
+ summary["quality_gate"] = quality_gate
815
+ append_log(
816
+ log_lines,
817
+ f"Quality gate: passed={quality_gate.get('passed')} enabled={quality_gate.get('enabled')}",
818
+ )
819
+ push_report = training_summary.get("push")
820
+ if isinstance(push_report, dict):
821
+ summary["push"] = push_report
822
+ append_log(
823
+ log_lines,
824
+ f"Push decision: requested={push_report.get('requested')} performed={push_report.get('performed')}",
825
+ )
826
+ post_eval_report = training_summary.get("post_eval")
827
+ if run_eval and isinstance(post_eval_report, dict):
828
+ summary["evaluation"] = {
829
+ "source": "train_post_eval",
830
+ "evaluated_rows": post_eval_report.get("evaluated_rows"),
831
+ "pass_at_1": post_eval_report.get("pass_at_1"),
832
+ "pass_at_k": post_eval_report.get("pass_at_k"),
833
+ "exact_at_k": post_eval_report.get("exact_at_k"),
834
+ "composite_score": post_eval_report.get("composite_score"),
835
+ "k": post_eval_report.get("k"),
836
+ "report_path": post_eval_report.get("report_path"),
837
+ }
838
+ append_log(log_lines, "Using post-eval metrics emitted by training run.")
839
+
840
+ if run_eval and "evaluation" not in summary:
841
  eval_report = WORKSPACE_DIR / "runs" / "latest_eval_report.json"
842
  eval_cmd = [
843
  sys.executable,
 
891
  if eval_report.exists():
892
  report = json.loads(eval_report.read_text(encoding="utf-8"))
893
  summary["evaluation"] = {
894
+ "source": "fallback_eval",
895
  "evaluated_rows": report.get("evaluated_rows"),
896
  "pass_at_1": report.get("pass_at_1"),
897
  "pass_at_k": report.get("pass_at_k"),
898
+ "exact_at_k": report.get("exact_at_k"),
899
+ "composite_score": report.get("composite_score"),
900
  "k": report.get("k"),
901
  "report_path": str(eval_report),
902
  }
 
938
  value="deepseek-ai/deepseek-math-v2",
939
  )
940
  with gr.Row():
941
+ start_stage = gr.Slider(label="Stage Start", minimum=1, maximum=TEMPLATE_STAGE_COUNT, step=1, value=1)
942
+ max_stages = gr.Slider(
943
+ label="Stage Count",
944
+ minimum=1,
945
+ maximum=TEMPLATE_STAGE_COUNT,
946
+ step=1,
947
+ value=TEMPLATE_STAGE_COUNT,
948
+ )
949
  run_eval = gr.Checkbox(label="Run Evaluation After Training", value=True)
950
  with gr.Row():
951
  eval_k = gr.Slider(label="Evaluation K", minimum=1, maximum=8, step=1, value=4)
952
  eval_samples = gr.Slider(label="Evaluation Max Samples", minimum=50, maximum=1000, step=50, value=300)
953
+ with gr.Row():
954
+ enforce_quality_gate = gr.Checkbox(label="Enforce Quality Gate", value=DEFAULT_GATE_ENABLED)
955
+ gate_min_pass_at_1 = gr.Slider(
956
+ label="Gate Min pass@1",
957
+ minimum=0.0,
958
+ maximum=0.5,
959
+ step=0.005,
960
+ value=min(max(DEFAULT_GATE_MIN_PASS_AT_1, 0.0), 0.5),
961
+ )
962
+ gate_min_pass_at_k = gr.Slider(
963
+ label="Gate Min pass@k",
964
+ minimum=0.0,
965
+ maximum=1.0,
966
+ step=0.01,
967
+ value=min(max(DEFAULT_GATE_MIN_PASS_AT_K, 0.0), 1.0),
968
+ )
969
+ gate_min_rows = gr.Slider(
970
+ label="Gate Min Rows",
971
+ minimum=10,
972
+ maximum=2000,
973
+ step=10,
974
+ value=min(max(DEFAULT_GATE_MIN_ROWS, 10), 2000),
975
+ )
976
  with gr.Row():
977
  push_to_hub = gr.Checkbox(label="Push Adapter to Hub", value=True)
978
  force_redownload = gr.Checkbox(label="Force Dataset Redownload", value=False)
 
1003
  run_eval,
1004
  eval_k,
1005
  eval_samples,
1006
+ enforce_quality_gate,
1007
+ gate_min_pass_at_1,
1008
+ gate_min_pass_at_k,
1009
+ gate_min_rows,
1010
  push_to_hub,
1011
  force_redownload,
1012
  preflight_only,
configs/deepseek_math_sota.yaml CHANGED
@@ -97,17 +97,55 @@ stages:
97
  - conjecture_core
98
  require_conjecture_id: true
99
  training:
100
- num_train_epochs: 3
101
  learning_rate: 5.0e-6
102
  save_steps: 100
103
  eval_steps: 100
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  hub:
106
  push_to_hub: true
107
  repo_id: NorthernTribe-Research/math-conjecture-model
108
  private: false
109
  upload_stage_checkpoints: true
110
- commit_message: Train multi-stage SOTA curriculum for conjecture reasoning.
111
 
112
  credentials:
113
  path: huggingface-api-key.json
 
97
  - conjecture_core
98
  require_conjecture_id: true
99
  training:
100
+ num_train_epochs: 2
101
  learning_rate: 5.0e-6
102
  save_steps: 100
103
  eval_steps: 100
104
 
105
+ - name: hard_case_polish
106
+ max_train_samples: 60000
107
+ max_eval_samples: 2000
108
+ filters:
109
+ include_families:
110
+ - conjecture_core
111
+ - formal_proof
112
+ require_conjecture_id: true
113
+ min_sample_weight: 3.0
114
+ training:
115
+ num_train_epochs: 1
116
+ learning_rate: 3.0e-6
117
+ gradient_accumulation_steps: 24
118
+ save_steps: 80
119
+ eval_steps: 80
120
+
121
+ post_eval:
122
+ enabled: true
123
+ eval_file: workspace/data/releases/v1/test.parquet
124
+ max_samples: 240
125
+ k: 6
126
+ max_new_tokens: 320
127
+ temperature: 0.7
128
+ top_p: 0.95
129
+ seed: 17
130
+ output_json: workspace/runs/math-conjecture-sota/post_eval_report.json
131
+
132
+ quality_gate:
133
+ enabled: true
134
+ require_post_eval: true
135
+ min_evaluated_rows: 120
136
+ min_pass_at_1: 0.01
137
+ min_pass_at_k: 0.06
138
+ max_final_eval_loss: 2.6
139
+ required_family_pass_at_k:
140
+ conjecture_core: 0.06
141
+ formal_proof: 0.03
142
+
143
  hub:
144
  push_to_hub: true
145
  repo_id: NorthernTribe-Research/math-conjecture-model
146
  private: false
147
  upload_stage_checkpoints: true
148
+ commit_message: Launch multi-stage DeepSeek-Math fine-tuning on Space GPU and push adapters to your model repo.
149
 
150
  credentials:
151
  path: huggingface-api-key.json
scripts/eval_sota.py CHANGED
@@ -7,7 +7,7 @@ import argparse
7
  import json
8
  import re
9
  from pathlib import Path
10
- from typing import Any, Dict, List, Optional, Sequence
11
 
12
  import torch
13
  import yaml
@@ -15,13 +15,20 @@ from datasets import load_dataset
15
  from peft import PeftModel
16
  from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
17
 
 
 
 
 
 
 
 
18
 
19
  def parse_args() -> argparse.Namespace:
20
  parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.")
21
  parser.add_argument(
22
  "--config",
23
  type=Path,
24
- default=Path("configs/deepseek_math_sota.yaml"),
25
  help="Training config used for prompt formatting defaults.",
26
  )
27
  parser.add_argument(
@@ -39,19 +46,32 @@ def parse_args() -> argparse.Namespace:
39
  parser.add_argument(
40
  "--eval-file",
41
  type=Path,
42
- default=Path("data/releases/v1/test.parquet"),
43
- help="Parquet split used for evaluation.",
44
  )
45
  parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.")
46
  parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.")
47
  parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.")
 
48
  parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
49
  parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.")
50
  parser.add_argument("--seed", type=int, default=17, help="Random seed.")
 
 
 
 
 
 
 
 
 
 
 
 
51
  parser.add_argument(
52
  "--output-json",
53
  type=Path,
54
- default=Path("workspace/runs/latest_eval_report.json"),
55
  help="Where to write evaluation report.",
56
  )
57
  return parser.parse_args()
@@ -65,6 +85,24 @@ def as_text(value: Any) -> str:
65
  return str(value).strip()
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def load_config(path: Path) -> Dict[str, Any]:
69
  cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
70
  if not isinstance(cfg, dict):
@@ -74,9 +112,124 @@ def load_config(path: Path) -> Dict[str, Any]:
74
 
75
  def normalize_answer(text: str) -> str:
76
  text = text.strip().lower()
77
- text = re.sub(r"\s+", " ", text)
78
  text = text.replace("$", "")
79
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]:
@@ -168,27 +321,11 @@ def extract_candidate_text(full_generation: str, prompt_text: str) -> str:
168
  return full_generation.strip()
169
 
170
 
171
- def is_match(candidate: str, expected_values: Sequence[str]) -> bool:
172
- cand_norm = normalize_answer(candidate)
173
- if not cand_norm:
174
- return False
175
- for expected in expected_values:
176
- exp_norm = normalize_answer(expected)
177
- if not exp_norm:
178
- continue
179
- if exp_norm in cand_norm or cand_norm in exp_norm:
180
- return True
181
- boxed = re.findall(r"\\boxed\{([^{}]+)\}", cand_norm)
182
- if boxed and any(exp_norm in item for item in boxed):
183
- return True
184
- return False
185
-
186
-
187
  def load_model_and_tokenizer(
188
  base_model: str,
189
  adapter_path: Optional[Path],
190
  trust_remote_code: bool,
191
- ) -> tuple[Any, AutoTokenizer]:
192
  tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True)
193
  if tokenizer.pad_token is None:
194
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
@@ -207,55 +344,126 @@ def load_model_and_tokenizer(
207
  return model, tokenizer
208
 
209
 
210
- def main() -> None:
211
- args = parse_args()
212
- cfg = load_config(args.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  data_cfg = cfg.get("data", {})
214
- model_cfg = cfg.get("model", {})
215
- set_seed(args.seed)
 
 
 
 
 
 
 
 
 
 
216
 
 
 
217
  if args.k < 1:
218
  raise ValueError("--k must be >= 1.")
219
  if args.max_samples < 1:
220
  raise ValueError("--max-samples must be >= 1.")
221
  if args.max_new_tokens < 1:
222
  raise ValueError("--max-new-tokens must be >= 1.")
 
 
223
  if args.temperature <= 0:
224
  raise ValueError("--temperature must be > 0.")
225
  if not 0 < args.top_p <= 1:
226
  raise ValueError("--top-p must be in (0, 1].")
227
 
 
 
 
 
 
228
  base_model = args.base_model or as_text(model_cfg.get("base_model"))
229
  if not base_model:
230
  raise ValueError("Base model is required via --base-model or config.model.base_model.")
231
  if args.adapter_path is not None and not args.adapter_path.exists():
232
  raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}")
233
 
 
 
 
 
234
  model, tokenizer = load_model_and_tokenizer(
235
  base_model=base_model,
236
  adapter_path=args.adapter_path,
237
  trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
238
  )
239
 
240
- if not args.eval_file.exists():
241
- raise FileNotFoundError(f"Evaluation file not found: {args.eval_file}")
242
- ds = load_dataset("parquet", data_files={"eval": str(args.eval_file)})["eval"]
243
-
244
  if args.max_samples > 0 and args.max_samples < len(ds):
245
  ds = ds.select(range(args.max_samples))
246
 
247
- total = 0
248
- hit_at_1 = 0
249
- hit_at_k = 0
250
- records = []
 
 
 
 
 
 
251
 
252
  for row in ds:
253
  expected_values = flatten_expected(row, data_cfg)
254
  if not expected_values:
 
255
  continue
 
256
  prompt_text = build_prompt_text(row, tokenizer, data_cfg)
257
- inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=4096)
258
- model_device = next(model.parameters()).device
 
 
 
 
259
  inputs = {k: v.to(model_device) for k, v in inputs.items()}
260
 
261
  with torch.no_grad():
@@ -269,44 +477,119 @@ def main() -> None:
269
  pad_token_id=tokenizer.pad_token_id,
270
  eos_token_id=tokenizer.eos_token_id,
271
  )
 
272
  generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
273
  candidates = [extract_candidate_text(text, prompt_text) for text in generations]
274
- matches = [is_match(candidate, expected_values) for candidate in candidates]
275
- total += 1
276
- if matches and matches[0]:
277
- hit_at_1 += 1
278
- if any(matches):
279
- hit_at_k += 1
280
-
281
- records.append(
282
- {
283
- "uid": as_text(row.get("uid")),
284
- "prompt": as_text(row.get(as_text(data_cfg.get("prompt_field")) or "prompt")),
285
- "expected_values": expected_values[:5],
286
- "candidates": candidates,
287
- "matches": matches,
288
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  )
290
 
291
- pass_at_1 = (hit_at_1 / total) if total else 0.0
292
- pass_at_k = (hit_at_k / total) if total else 0.0
293
- report = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  "base_model": base_model,
295
  "adapter_path": str(args.adapter_path) if args.adapter_path is not None else None,
296
- "eval_file": str(args.eval_file),
297
- "evaluated_rows": total,
 
 
 
298
  "k": args.k,
299
  "pass_at_1": pass_at_1,
300
  "pass_at_k": pass_at_k,
 
 
 
 
301
  "temperature": args.temperature,
302
  "top_p": args.top_p,
303
  "max_new_tokens": args.max_new_tokens,
304
- "samples": records[:30],
 
 
 
 
 
 
 
 
 
 
305
  }
 
306
  args.output_json.parent.mkdir(parents=True, exist_ok=True)
307
  args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8")
308
- print(json.dumps({k: report[k] for k in ("evaluated_rows", "pass_at_1", "pass_at_k", "k")}, indent=2))
 
 
 
 
 
 
 
 
 
309
  print(f"Saved report to {args.output_json}")
 
 
 
 
 
 
310
 
311
 
312
  if __name__ == "__main__":
 
7
  import json
8
  import re
9
  from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
11
 
12
  import torch
13
  import yaml
 
15
  from peft import PeftModel
16
  from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
17
 
18
+ SCRIPT_ROOT = Path(__file__).resolve().parents[1]
19
+ DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml"
20
+ DEFAULT_OUTPUT_JSON = SCRIPT_ROOT / "runs" / "latest_eval_report.json"
21
+
22
+ BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}")
23
+ SPACE_RE = re.compile(r"\s+")
24
+
25
 
26
  def parse_args() -> argparse.Namespace:
27
  parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.")
28
  parser.add_argument(
29
  "--config",
30
  type=Path,
31
+ default=DEFAULT_CONFIG_PATH,
32
  help="Training config used for prompt formatting defaults.",
33
  )
34
  parser.add_argument(
 
46
  parser.add_argument(
47
  "--eval-file",
48
  type=Path,
49
+ default=None,
50
+ help="Parquet split used for evaluation (defaults to post_eval.eval_file or data.default_validation_file).",
51
  )
52
  parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.")
53
  parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.")
54
  parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.")
55
+ parser.add_argument("--max-input-length", type=int, default=4096, help="Prompt tokenization length cap.")
56
  parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
57
  parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.")
58
  parser.add_argument("--seed", type=int, default=17, help="Random seed.")
59
+ parser.add_argument(
60
+ "--progress-every",
61
+ type=int,
62
+ default=25,
63
+ help="Print progress every N evaluated rows (0 disables).",
64
+ )
65
+ parser.add_argument(
66
+ "--sample-records",
67
+ type=int,
68
+ default=30,
69
+ help="How many sample records to store in report.",
70
+ )
71
  parser.add_argument(
72
  "--output-json",
73
  type=Path,
74
+ default=DEFAULT_OUTPUT_JSON,
75
  help="Where to write evaluation report.",
76
  )
77
  return parser.parse_args()
 
85
  return str(value).strip()
86
 
87
 
88
+ def as_float(value: Any, default: float) -> float:
89
+ if value is None:
90
+ return default
91
+ try:
92
+ return float(value)
93
+ except (TypeError, ValueError):
94
+ return default
95
+
96
+
97
+ def as_int(value: Any, default: int) -> int:
98
+ if value is None:
99
+ return default
100
+ try:
101
+ return int(value)
102
+ except (TypeError, ValueError):
103
+ return default
104
+
105
+
106
  def load_config(path: Path) -> Dict[str, Any]:
107
  cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
108
  if not isinstance(cfg, dict):
 
112
 
113
  def normalize_answer(text: str) -> str:
114
  text = text.strip().lower()
 
115
  text = text.replace("$", "")
116
+ text = text.replace("\\left", "").replace("\\right", "")
117
+ text = text.replace("\\,", "").replace("\\!", "").replace("\\;", "")
118
+ text = SPACE_RE.sub(" ", text)
119
+ return text.strip(" .")
120
+
121
+
122
+ def extract_boxed_values(text: str) -> List[str]:
123
+ return [normalize_answer(match) for match in BOXED_RE.findall(text or "") if normalize_answer(match)]
124
+
125
+
126
+ def parse_numeric_value(text: str) -> Optional[float]:
127
+ normalized = normalize_answer(text)
128
+ if not normalized:
129
+ return None
130
+ candidate = normalized.replace(",", "")
131
+ if re.fullmatch(r"[-+]?\d+\s*/\s*[-+]?\d+", candidate):
132
+ left, right = candidate.split("/", maxsplit=1)
133
+ try:
134
+ numerator = float(left.strip())
135
+ denominator = float(right.strip())
136
+ except ValueError:
137
+ return None
138
+ if denominator == 0:
139
+ return None
140
+ return numerator / denominator
141
+ if re.fullmatch(r"[-+]?(?:\d+\.\d*|\d*\.\d+|\d+)(?:[eE][-+]?\d+)?", candidate):
142
+ try:
143
+ return float(candidate)
144
+ except ValueError:
145
+ return None
146
+ return None
147
+
148
+
149
+ def approximately_equal(left: float, right: float) -> bool:
150
+ tolerance = 1e-6 * max(1.0, abs(left), abs(right))
151
+ return abs(left - right) <= tolerance
152
+
153
+
154
+ def match_candidate(candidate: str, expected_values: Sequence[str]) -> Dict[str, Any]:
155
+ cand_norm = normalize_answer(candidate)
156
+ if not cand_norm:
157
+ return {
158
+ "match": False,
159
+ "exact": False,
160
+ "boxed": False,
161
+ "numeric": False,
162
+ "reason": "empty_candidate",
163
+ }
164
+
165
+ cand_boxed = extract_boxed_values(candidate)
166
+ cand_num = parse_numeric_value(cand_norm)
167
+
168
+ substring_hit = False
169
+ boxed_hit = False
170
+ numeric_hit = False
171
+
172
+ for expected in expected_values:
173
+ exp_norm = normalize_answer(expected)
174
+ if not exp_norm:
175
+ continue
176
+
177
+ if cand_norm == exp_norm:
178
+ return {
179
+ "match": True,
180
+ "exact": True,
181
+ "boxed": exp_norm in cand_boxed,
182
+ "numeric": False,
183
+ "reason": "exact",
184
+ }
185
+
186
+ if exp_norm in cand_norm or cand_norm in exp_norm:
187
+ substring_hit = True
188
+
189
+ expected_boxed = extract_boxed_values(expected)
190
+ for cand_box in cand_boxed:
191
+ if cand_box == exp_norm or exp_norm in cand_box or cand_box in exp_norm:
192
+ boxed_hit = True
193
+ for exp_box in expected_boxed:
194
+ if cand_norm == exp_box or exp_box in cand_norm or cand_norm in exp_box:
195
+ boxed_hit = True
196
+
197
+ exp_num = parse_numeric_value(exp_norm)
198
+ if cand_num is not None and exp_num is not None and approximately_equal(cand_num, exp_num):
199
+ numeric_hit = True
200
+
201
+ if boxed_hit:
202
+ return {
203
+ "match": True,
204
+ "exact": False,
205
+ "boxed": True,
206
+ "numeric": numeric_hit,
207
+ "reason": "boxed",
208
+ }
209
+ if numeric_hit:
210
+ return {
211
+ "match": True,
212
+ "exact": False,
213
+ "boxed": False,
214
+ "numeric": True,
215
+ "reason": "numeric",
216
+ }
217
+ if substring_hit:
218
+ return {
219
+ "match": True,
220
+ "exact": False,
221
+ "boxed": False,
222
+ "numeric": False,
223
+ "reason": "substring",
224
+ }
225
+
226
+ return {
227
+ "match": False,
228
+ "exact": False,
229
+ "boxed": False,
230
+ "numeric": False,
231
+ "reason": "no_match",
232
+ }
233
 
234
 
235
  def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]:
 
321
  return full_generation.strip()
322
 
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  def load_model_and_tokenizer(
325
  base_model: str,
326
  adapter_path: Optional[Path],
327
  trust_remote_code: bool,
328
+ ) -> Tuple[Any, AutoTokenizer]:
329
  tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True)
330
  if tokenizer.pad_token is None:
331
  tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
 
344
  return model, tokenizer
345
 
346
 
347
+ def make_bucket() -> Dict[str, Any]:
348
+ return {
349
+ "evaluated_rows": 0,
350
+ "pass_at_1_hits": 0,
351
+ "pass_at_k_hits": 0,
352
+ "exact_at_1_hits": 0,
353
+ "exact_at_k_hits": 0,
354
+ "boxed_at_k_hits": 0,
355
+ }
356
+
357
+
358
+ def update_bucket(bucket: Dict[str, Any], hit1: bool, hitk: bool, exact1: bool, exactk: bool, boxedk: bool) -> None:
359
+ bucket["evaluated_rows"] += 1
360
+ if hit1:
361
+ bucket["pass_at_1_hits"] += 1
362
+ if hitk:
363
+ bucket["pass_at_k_hits"] += 1
364
+ if exact1:
365
+ bucket["exact_at_1_hits"] += 1
366
+ if exactk:
367
+ bucket["exact_at_k_hits"] += 1
368
+ if boxedk:
369
+ bucket["boxed_at_k_hits"] += 1
370
+
371
+
372
+ def finalize_bucket(bucket: Dict[str, Any]) -> Dict[str, Any]:
373
+ total = max(int(bucket.get("evaluated_rows", 0)), 1)
374
+ rows = int(bucket.get("evaluated_rows", 0))
375
+ return {
376
+ "evaluated_rows": rows,
377
+ "pass_at_1": float(bucket.get("pass_at_1_hits", 0)) / total,
378
+ "pass_at_k": float(bucket.get("pass_at_k_hits", 0)) / total,
379
+ "exact_at_1": float(bucket.get("exact_at_1_hits", 0)) / total,
380
+ "exact_at_k": float(bucket.get("exact_at_k_hits", 0)) / total,
381
+ "boxed_at_k": float(bucket.get("boxed_at_k_hits", 0)) / total,
382
+ }
383
+
384
+
385
+ def resolve_eval_file(arg_eval_file: Optional[Path], cfg: Dict[str, Any]) -> Path:
386
+ if arg_eval_file is not None:
387
+ return arg_eval_file
388
+ post_eval_cfg = cfg.get("post_eval", {})
389
  data_cfg = cfg.get("data", {})
390
+ for candidate in (
391
+ as_text(post_eval_cfg.get("eval_file")),
392
+ as_text(data_cfg.get("default_validation_file")),
393
+ "data/releases/v1/test.parquet",
394
+ "workspace/data/releases/v1/test.parquet",
395
+ ):
396
+ if not candidate:
397
+ continue
398
+ path = Path(candidate)
399
+ if path.exists():
400
+ return path
401
+ return Path("data/releases/v1/test.parquet")
402
 
403
+
404
+ def run_evaluation(args: argparse.Namespace) -> Dict[str, Any]:
405
  if args.k < 1:
406
  raise ValueError("--k must be >= 1.")
407
  if args.max_samples < 1:
408
  raise ValueError("--max-samples must be >= 1.")
409
  if args.max_new_tokens < 1:
410
  raise ValueError("--max-new-tokens must be >= 1.")
411
+ if args.max_input_length < 128:
412
+ raise ValueError("--max-input-length must be >= 128.")
413
  if args.temperature <= 0:
414
  raise ValueError("--temperature must be > 0.")
415
  if not 0 < args.top_p <= 1:
416
  raise ValueError("--top-p must be in (0, 1].")
417
 
418
+ cfg = load_config(args.config)
419
+ data_cfg = cfg.get("data", {})
420
+ model_cfg = cfg.get("model", {})
421
+ set_seed(args.seed)
422
+
423
  base_model = args.base_model or as_text(model_cfg.get("base_model"))
424
  if not base_model:
425
  raise ValueError("Base model is required via --base-model or config.model.base_model.")
426
  if args.adapter_path is not None and not args.adapter_path.exists():
427
  raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}")
428
 
429
+ eval_file = resolve_eval_file(args.eval_file, cfg)
430
+ if not eval_file.exists():
431
+ raise FileNotFoundError(f"Evaluation file not found: {eval_file}")
432
+
433
  model, tokenizer = load_model_and_tokenizer(
434
  base_model=base_model,
435
  adapter_path=args.adapter_path,
436
  trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
437
  )
438
 
439
+ ds = load_dataset("parquet", data_files={"eval": str(eval_file)})["eval"]
 
 
 
440
  if args.max_samples > 0 and args.max_samples < len(ds):
441
  ds = ds.select(range(args.max_samples))
442
 
443
+ totals = make_bucket()
444
+ family_buckets: Dict[str, Dict[str, Any]] = {}
445
+ difficulty_buckets: Dict[str, Dict[str, Any]] = {}
446
+
447
+ processed_rows = 0
448
+ skipped_no_expected = 0
449
+ samples: List[Dict[str, Any]] = []
450
+
451
+ model_device = next(model.parameters()).device
452
+ prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt"
453
 
454
  for row in ds:
455
  expected_values = flatten_expected(row, data_cfg)
456
  if not expected_values:
457
+ skipped_no_expected += 1
458
  continue
459
+
460
  prompt_text = build_prompt_text(row, tokenizer, data_cfg)
461
+ inputs = tokenizer(
462
+ prompt_text,
463
+ return_tensors="pt",
464
+ truncation=True,
465
+ max_length=args.max_input_length,
466
+ )
467
  inputs = {k: v.to(model_device) for k, v in inputs.items()}
468
 
469
  with torch.no_grad():
 
477
  pad_token_id=tokenizer.pad_token_id,
478
  eos_token_id=tokenizer.eos_token_id,
479
  )
480
+
481
  generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
482
  candidates = [extract_candidate_text(text, prompt_text) for text in generations]
483
+ details = [match_candidate(candidate, expected_values) for candidate in candidates]
484
+
485
+ matches = [bool(item["match"]) for item in details]
486
+ exacts = [bool(item["exact"]) for item in details]
487
+ boxed = [bool(item["boxed"]) for item in details]
488
+
489
+ hit1 = bool(matches and matches[0])
490
+ hitk = bool(any(matches))
491
+ exact1 = bool(exacts and exacts[0])
492
+ exactk = bool(any(exacts))
493
+ boxedk = bool(any(boxed))
494
+
495
+ update_bucket(totals, hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk)
496
+
497
+ family = as_text(row.get("family")) or "__unknown__"
498
+ if family not in family_buckets:
499
+ family_buckets[family] = make_bucket()
500
+ update_bucket(family_buckets[family], hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk)
501
+
502
+ difficulty = as_text(row.get("difficulty")) or "__unknown__"
503
+ if difficulty not in difficulty_buckets:
504
+ difficulty_buckets[difficulty] = make_bucket()
505
+ update_bucket(
506
+ difficulty_buckets[difficulty],
507
+ hit1=hit1,
508
+ hitk=hitk,
509
+ exact1=exact1,
510
+ exactk=exactk,
511
+ boxedk=boxedk,
512
  )
513
 
514
+ processed_rows += 1
515
+ if args.progress_every > 0 and processed_rows % args.progress_every == 0:
516
+ print(f"Progress: evaluated_rows={processed_rows} latest_family={family}")
517
+
518
+ if len(samples) < args.sample_records:
519
+ samples.append(
520
+ {
521
+ "uid": as_text(row.get("uid")),
522
+ "family": family,
523
+ "difficulty": difficulty,
524
+ "prompt": as_text(row.get(prompt_field)),
525
+ "expected_values": expected_values[:5],
526
+ "candidates": candidates,
527
+ "match_details": details,
528
+ "matches": matches,
529
+ }
530
+ )
531
+
532
+ total_eval = int(totals.get("evaluated_rows", 0))
533
+ denominator = max(total_eval, 1)
534
+
535
+ pass_at_1 = float(totals.get("pass_at_1_hits", 0)) / denominator
536
+ pass_at_k = float(totals.get("pass_at_k_hits", 0)) / denominator
537
+ exact_at_1 = float(totals.get("exact_at_1_hits", 0)) / denominator
538
+ exact_at_k = float(totals.get("exact_at_k_hits", 0)) / denominator
539
+ boxed_at_k = float(totals.get("boxed_at_k_hits", 0)) / denominator
540
+
541
+ composite_score = 0.30 * pass_at_1 + 0.50 * pass_at_k + 0.20 * exact_at_k
542
+
543
+ report: Dict[str, Any] = {
544
  "base_model": base_model,
545
  "adapter_path": str(args.adapter_path) if args.adapter_path is not None else None,
546
+ "eval_file": str(eval_file),
547
+ "config": str(args.config),
548
+ "evaluated_rows": total_eval,
549
+ "skipped_rows_without_targets": skipped_no_expected,
550
+ "requested_rows": len(ds),
551
  "k": args.k,
552
  "pass_at_1": pass_at_1,
553
  "pass_at_k": pass_at_k,
554
+ "exact_at_1": exact_at_1,
555
+ "exact_at_k": exact_at_k,
556
+ "boxed_at_k": boxed_at_k,
557
+ "composite_score": composite_score,
558
  "temperature": args.temperature,
559
  "top_p": args.top_p,
560
  "max_new_tokens": args.max_new_tokens,
561
+ "max_input_length": args.max_input_length,
562
+ "seed": args.seed,
563
+ "family_metrics": {
564
+ key: finalize_bucket(family_buckets[key])
565
+ for key in sorted(family_buckets.keys())
566
+ },
567
+ "difficulty_metrics": {
568
+ key: finalize_bucket(difficulty_buckets[key])
569
+ for key in sorted(difficulty_buckets.keys())
570
+ },
571
+ "samples": samples,
572
  }
573
+
574
  args.output_json.parent.mkdir(parents=True, exist_ok=True)
575
  args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8")
576
+
577
+ summary_view = {
578
+ "evaluated_rows": total_eval,
579
+ "pass_at_1": pass_at_1,
580
+ "pass_at_k": pass_at_k,
581
+ "exact_at_k": exact_at_k,
582
+ "composite_score": composite_score,
583
+ "k": args.k,
584
+ }
585
+ print(json.dumps(summary_view, indent=2))
586
  print(f"Saved report to {args.output_json}")
587
+ return report
588
+
589
+
590
+ def main() -> None:
591
+ args = parse_args()
592
+ run_evaluation(args)
593
 
594
 
595
  if __name__ == "__main__":
scripts/train_sota.py CHANGED
@@ -4,10 +4,13 @@
4
  from __future__ import annotations
5
 
6
  import argparse
 
7
  import json
8
  import os
 
 
9
  from pathlib import Path
10
- from typing import Any, Dict, Optional, Tuple
11
 
12
  import torch
13
  import yaml
@@ -25,7 +28,9 @@ from transformers import (
25
  set_seed,
26
  )
27
 
28
- DEFAULT_CONFIG_PATH = Path("configs/deepseek_math_sota.yaml")
 
 
29
 
30
 
31
  def parse_args() -> argparse.Namespace:
@@ -41,6 +46,21 @@ def parse_args() -> argparse.Namespace:
41
  parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.")
42
  parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.")
43
  parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  parser.add_argument(
45
  "--start-stage",
46
  type=int,
@@ -93,6 +113,19 @@ def as_int(value: Any, default: int) -> int:
93
  return default
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def load_config(path: Path) -> Dict[str, Any]:
97
  if not path.exists():
98
  raise FileNotFoundError(f"Config not found: {path}")
@@ -108,6 +141,8 @@ def load_config(path: Path) -> Dict[str, Any]:
108
  cfg.setdefault("training_defaults", {})
109
  cfg.setdefault("hub", {})
110
  cfg.setdefault("credentials", {})
 
 
111
  return cfg
112
 
113
 
@@ -123,6 +158,16 @@ def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None:
123
  if args.no_push_to_hub:
124
  cfg.setdefault("hub", {})["push_to_hub"] = False
125
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
128
  token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
@@ -133,9 +178,17 @@ def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
133
  if path.exists():
134
  data = json.loads(path.read_text(encoding="utf-8"))
135
  if token is None:
136
- token = as_text(data.get("key")) or None
 
 
 
 
137
  if username is None:
138
- username = as_text(data.get("username")) or None
 
 
 
 
139
  return token, username
140
 
141
 
@@ -556,6 +609,228 @@ def push_folder(
556
  api.upload_folder(**kwargs)
557
 
558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  def main() -> None:
560
  args = parse_args()
561
  cfg = load_config(args.config)
@@ -564,17 +839,17 @@ def main() -> None:
564
  seed = as_int(cfg.get("global", {}).get("seed"), 17)
565
  set_seed(seed)
566
 
567
- output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "model_development/runs/math-conjecture-sota")
568
  output_root.mkdir(parents=True, exist_ok=True)
569
 
570
  token, username = resolve_auth(cfg)
571
  repo_id = resolve_repo_id(cfg, username=username, output_root=output_root)
572
- push_to_hub = bool(cfg.get("hub", {}).get("push_to_hub", False))
573
- if args.dry_run and push_to_hub:
574
  print("Dry-run enabled. Disabling push_to_hub for this run.")
575
- if args.dry_run:
576
- push_to_hub = False
577
- if push_to_hub:
578
  if token is None:
579
  raise ValueError("Hub push requested but token is missing.")
580
  if repo_id is None:
@@ -585,8 +860,9 @@ def main() -> None:
585
  model = None
586
  else:
587
  model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
 
588
  data_cfg = cfg["data"]
589
- stage_reports = []
590
 
591
  start_stage = max(1, args.start_stage)
592
  stages = cfg["stages"]
@@ -607,15 +883,18 @@ def main() -> None:
607
  raw = load_dataset("parquet", data_files=split_files)
608
  train_rows_before = len(raw["train"])
609
  valid_rows_before = len(raw["validation"])
 
610
  filters = stage.get("filters", {})
611
  raw["train"] = apply_filters(raw["train"], filters)
612
  raw["validation"] = apply_filters(raw["validation"], filters)
613
  train_rows_after_filter = len(raw["train"])
614
  valid_rows_after_filter = len(raw["validation"])
 
615
  raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples"))
616
  raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples"))
617
  train_rows_selected = len(raw["train"])
618
  valid_rows_selected = len(raw["validation"])
 
619
  print(
620
  f"[stage {index}] rows train: {train_rows_before} -> {train_rows_after_filter} -> {train_rows_selected}; "
621
  f"validation: {valid_rows_before} -> {valid_rows_after_filter} -> {valid_rows_selected}"
@@ -627,19 +906,20 @@ def main() -> None:
627
  sample_row = raw["train"][0]
628
  _ = build_prompt_text(sample_row, tokenizer, data_cfg)
629
  _ = build_answer_block(sample_row, data_cfg)
630
- report = {
631
- "stage_index": index,
632
- "stage_name": stage_name,
633
- "stage_slug": stage_slug,
634
- "mode": "dry_run",
635
- "train_rows_before_filter": train_rows_before,
636
- "validation_rows_before_filter": valid_rows_before,
637
- "train_rows_after_filter": train_rows_after_filter,
638
- "validation_rows_after_filter": valid_rows_after_filter,
639
- "train_rows_selected": train_rows_selected,
640
- "validation_rows_selected": valid_rows_selected,
641
- }
642
- stage_reports.append(report)
 
643
  print(f"[stage {index}] Dry-run checks passed.")
644
  continue
645
 
@@ -670,33 +950,36 @@ def main() -> None:
670
  trainer.log_metrics("train", train_result.metrics)
671
  trainer.save_metrics("train", train_result.metrics)
672
  trainer.save_state()
 
673
  eval_metrics = None
674
  if eval_dataset is not None:
675
  eval_metrics = trainer.evaluate()
676
  trainer.log_metrics("eval", eval_metrics)
677
  trainer.save_metrics("eval", eval_metrics)
 
678
  trainer.save_model(str(stage_output_dir))
679
  tokenizer.save_pretrained(str(stage_output_dir))
680
 
681
- report = {
682
- "stage_index": index,
683
- "stage_name": stage_name,
684
- "output_dir": str(stage_output_dir),
685
- "train_rows_before_filter": train_rows_before,
686
- "validation_rows_before_filter": valid_rows_before,
687
- "train_rows_after_filter": train_rows_after_filter,
688
- "validation_rows_after_filter": valid_rows_after_filter,
689
- "train_rows_selected": train_rows_selected,
690
- "validation_rows_selected": valid_rows_selected,
691
- "train_rows": len(train_dataset),
692
- "eval_rows": len(eval_dataset) if eval_dataset is not None else 0,
693
- "train_metrics": train_result.metrics,
694
- "eval_metrics": eval_metrics,
695
- }
696
- stage_reports.append(report)
 
697
  print(
698
- f"[stage {index}] Completed: train_rows={report['train_rows']} "
699
- f"eval_rows={report['eval_rows']} output={stage_output_dir}"
700
  )
701
 
702
  if args.dry_run:
@@ -720,17 +1003,59 @@ def main() -> None:
720
  model.save_pretrained(str(final_dir))
721
  tokenizer.save_pretrained(str(final_dir))
722
 
723
- summary = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
  "config_path": str(args.config),
725
  "repo_id": repo_id,
726
  "seed": seed,
727
  "stages_ran": stage_reports,
728
  "final_adapter_dir": str(final_dir),
 
 
 
 
 
 
729
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
730
  summary_path = output_root / "training_summary.json"
731
  summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
732
 
733
- if push_to_hub and repo_id is not None and token is not None:
734
  api = HfApi(token=token)
735
  api.create_repo(
736
  repo_id=repo_id,
@@ -740,17 +1065,22 @@ def main() -> None:
740
  )
741
  commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter."
742
  push_folder(api, repo_id, final_dir, commit_message=commit_message)
 
743
  if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)):
744
  for report in stage_reports:
745
- stage_dir = Path(report["output_dir"])
746
- path_in_repo = f"checkpoints/{Path(report['output_dir']).name}"
 
 
 
747
  push_folder(
748
  api,
749
  repo_id,
750
  stage_dir,
751
- commit_message=f"Upload stage checkpoint {report['stage_name']}",
752
  path_in_repo=path_in_repo,
753
  )
 
754
  api.upload_file(
755
  path_or_fileobj=str(summary_path),
756
  path_in_repo="training_summary.json",
@@ -758,6 +1088,16 @@ def main() -> None:
758
  repo_type="model",
759
  commit_message="Upload training summary for SOTA curriculum run.",
760
  )
 
 
 
 
 
 
 
 
 
 
761
  print(f"Pushed training artifacts to https://huggingface.co/{repo_id}")
762
 
763
  print(f"Training complete. Final adapter: {final_dir}")
 
4
  from __future__ import annotations
5
 
6
  import argparse
7
+ import gc
8
  import json
9
  import os
10
+ import subprocess
11
+ import sys
12
  from pathlib import Path
13
+ from typing import Any, Dict, List, Optional, Tuple
14
 
15
  import torch
16
  import yaml
 
28
  set_seed,
29
  )
30
 
31
+ SCRIPT_ROOT = Path(__file__).resolve().parents[1]
32
+ DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml"
33
+ DEFAULT_EVAL_SCRIPT = Path(__file__).resolve().with_name("eval_sota.py")
34
 
35
 
36
  def parse_args() -> argparse.Namespace:
 
46
  parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.")
47
  parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.")
48
  parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.")
49
+ parser.add_argument(
50
+ "--run-post-eval",
51
+ action="store_true",
52
+ help="Force post-training evaluation enabled.",
53
+ )
54
+ parser.add_argument(
55
+ "--no-post-eval",
56
+ action="store_true",
57
+ help="Force post-training evaluation disabled.",
58
+ )
59
+ parser.add_argument(
60
+ "--skip-quality-gate",
61
+ action="store_true",
62
+ help="Disable quality gate checks for this run.",
63
+ )
64
  parser.add_argument(
65
  "--start-stage",
66
  type=int,
 
113
  return default
114
 
115
 
116
+ def as_bool(value: Any, default: bool = False) -> bool:
117
+ if value is None:
118
+ return default
119
+ if isinstance(value, bool):
120
+ return value
121
+ text = as_text(value).lower()
122
+ if text in {"1", "true", "yes", "y", "on"}:
123
+ return True
124
+ if text in {"0", "false", "no", "n", "off"}:
125
+ return False
126
+ return default
127
+
128
+
129
  def load_config(path: Path) -> Dict[str, Any]:
130
  if not path.exists():
131
  raise FileNotFoundError(f"Config not found: {path}")
 
141
  cfg.setdefault("training_defaults", {})
142
  cfg.setdefault("hub", {})
143
  cfg.setdefault("credentials", {})
144
+ cfg.setdefault("post_eval", {})
145
+ cfg.setdefault("quality_gate", {})
146
  return cfg
147
 
148
 
 
158
  if args.no_push_to_hub:
159
  cfg.setdefault("hub", {})["push_to_hub"] = False
160
 
161
+ if args.run_post_eval and args.no_post_eval:
162
+ raise ValueError("Cannot set both --run-post-eval and --no-post-eval.")
163
+ if args.run_post_eval:
164
+ cfg.setdefault("post_eval", {})["enabled"] = True
165
+ if args.no_post_eval:
166
+ cfg.setdefault("post_eval", {})["enabled"] = False
167
+
168
+ if args.skip_quality_gate:
169
+ cfg.setdefault("quality_gate", {})["enabled"] = False
170
+
171
 
172
  def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
173
  token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
 
178
  if path.exists():
179
  data = json.loads(path.read_text(encoding="utf-8"))
180
  if token is None:
181
+ for key in ("token", "key", "api_key", "hf_token"):
182
+ candidate = as_text(data.get(key))
183
+ if candidate:
184
+ token = candidate
185
+ break
186
  if username is None:
187
+ for key in ("username", "user", "owner"):
188
+ candidate = as_text(data.get(key))
189
+ if candidate:
190
+ username = candidate
191
+ break
192
  return token, username
193
 
194
 
 
609
  api.upload_folder(**kwargs)
610
 
611
 
612
+ def extract_final_eval_loss(stage_reports: List[Dict[str, Any]]) -> Optional[float]:
613
+ for report in reversed(stage_reports):
614
+ eval_metrics = report.get("eval_metrics")
615
+ if not isinstance(eval_metrics, dict):
616
+ continue
617
+ value = eval_metrics.get("eval_loss")
618
+ if value is None:
619
+ continue
620
+ try:
621
+ return float(value)
622
+ except (TypeError, ValueError):
623
+ continue
624
+ return None
625
+
626
+
627
+ def release_model_memory(model: Any) -> None:
628
+ try:
629
+ model.to("cpu")
630
+ except Exception:
631
+ pass
632
+ if torch.cuda.is_available():
633
+ torch.cuda.empty_cache()
634
+ gc.collect()
635
+
636
+
637
+ def run_post_eval(
638
+ cfg: Dict[str, Any],
639
+ config_path: Path,
640
+ output_root: Path,
641
+ final_adapter_dir: Path,
642
+ ) -> Optional[Dict[str, Any]]:
643
+ post_cfg = cfg.get("post_eval", {})
644
+ if not as_bool(post_cfg.get("enabled"), False):
645
+ return None
646
+
647
+ eval_script = DEFAULT_EVAL_SCRIPT
648
+ if not eval_script.exists():
649
+ raise FileNotFoundError(f"Post-eval enabled but eval script is missing: {eval_script}")
650
+
651
+ data_cfg = cfg.get("data", {})
652
+ eval_file = Path(
653
+ as_text(post_cfg.get("eval_file"))
654
+ or as_text(data_cfg.get("default_validation_file"))
655
+ or "data/releases/v1/test.parquet"
656
+ )
657
+ if not eval_file.exists():
658
+ raise FileNotFoundError(f"Post-eval file not found: {eval_file}")
659
+
660
+ output_json = Path(as_text(post_cfg.get("output_json")) or str(output_root / "post_eval_report.json"))
661
+ base_model = as_text(cfg.get("model", {}).get("base_model"))
662
+ if not base_model:
663
+ raise ValueError("model.base_model is required for post-eval.")
664
+
665
+ cmd = [
666
+ sys.executable,
667
+ str(eval_script),
668
+ "--config",
669
+ str(config_path),
670
+ "--base-model",
671
+ base_model,
672
+ "--adapter-path",
673
+ str(final_adapter_dir),
674
+ "--eval-file",
675
+ str(eval_file),
676
+ "--max-samples",
677
+ str(as_int(post_cfg.get("max_samples"), 300)),
678
+ "--k",
679
+ str(as_int(post_cfg.get("k"), 4)),
680
+ "--max-new-tokens",
681
+ str(as_int(post_cfg.get("max_new_tokens"), 256)),
682
+ "--temperature",
683
+ str(as_float(post_cfg.get("temperature"), 0.7)),
684
+ "--top-p",
685
+ str(as_float(post_cfg.get("top_p"), 0.95)),
686
+ "--seed",
687
+ str(as_int(post_cfg.get("seed"), as_int(cfg.get("global", {}).get("seed"), 17))),
688
+ "--output-json",
689
+ str(output_json),
690
+ ]
691
+ print(f"Running post-training eval: {' '.join(cmd)}")
692
+ completed = subprocess.run(cmd, check=False)
693
+ if completed.returncode != 0:
694
+ raise RuntimeError(f"Post-training evaluation failed with exit code {completed.returncode}.")
695
+
696
+ if not output_json.exists():
697
+ raise FileNotFoundError(f"Post-eval report was not created: {output_json}")
698
+
699
+ report = json.loads(output_json.read_text(encoding="utf-8"))
700
+ return {
701
+ "enabled": True,
702
+ "report_path": str(output_json),
703
+ "report": report,
704
+ "command": cmd,
705
+ }
706
+
707
+
708
+ def evaluate_quality_gate(
709
+ stage_reports: List[Dict[str, Any]],
710
+ post_eval_result: Optional[Dict[str, Any]],
711
+ gate_cfg: Dict[str, Any],
712
+ ) -> Dict[str, Any]:
713
+ enabled = as_bool(gate_cfg.get("enabled"), False)
714
+ result: Dict[str, Any] = {
715
+ "enabled": enabled,
716
+ "passed": True,
717
+ "violations": [],
718
+ "checks": [],
719
+ }
720
+ if not enabled:
721
+ return result
722
+
723
+ violations: List[str] = []
724
+ checks: List[Dict[str, Any]] = []
725
+
726
+ final_eval_loss = extract_final_eval_loss(stage_reports)
727
+ max_final_eval_loss = gate_cfg.get("max_final_eval_loss")
728
+ if max_final_eval_loss is not None:
729
+ threshold = as_float(max_final_eval_loss, 0.0)
730
+ checks.append(
731
+ {
732
+ "name": "max_final_eval_loss",
733
+ "actual": final_eval_loss,
734
+ "threshold": threshold,
735
+ }
736
+ )
737
+ if final_eval_loss is None:
738
+ violations.append("Final stage eval_loss is missing for max_final_eval_loss gate.")
739
+ elif final_eval_loss > threshold:
740
+ violations.append(
741
+ f"Final eval_loss {final_eval_loss:.4f} exceeds threshold {threshold:.4f}."
742
+ )
743
+
744
+ report: Optional[Dict[str, Any]] = None
745
+ if isinstance(post_eval_result, dict):
746
+ loaded = post_eval_result.get("report")
747
+ if isinstance(loaded, dict):
748
+ report = loaded
749
+
750
+ require_post_eval = as_bool(gate_cfg.get("require_post_eval"), False)
751
+ if report is None:
752
+ if require_post_eval:
753
+ violations.append("Quality gate requires post-eval metrics, but post-eval report is missing.")
754
+ else:
755
+ evaluated_rows = as_int(report.get("evaluated_rows"), 0)
756
+ min_rows = as_int(gate_cfg.get("min_evaluated_rows"), 0)
757
+ checks.append(
758
+ {
759
+ "name": "min_evaluated_rows",
760
+ "actual": evaluated_rows,
761
+ "threshold": min_rows,
762
+ }
763
+ )
764
+ if evaluated_rows < min_rows:
765
+ violations.append(
766
+ f"Post-eval rows {evaluated_rows} is below minimum required {min_rows}."
767
+ )
768
+
769
+ min_pass_at_1_raw = gate_cfg.get("min_pass_at_1")
770
+ if min_pass_at_1_raw is not None:
771
+ min_pass_at_1 = as_float(min_pass_at_1_raw, 0.0)
772
+ pass_at_1 = as_float(report.get("pass_at_1"), 0.0)
773
+ checks.append(
774
+ {
775
+ "name": "min_pass_at_1",
776
+ "actual": pass_at_1,
777
+ "threshold": min_pass_at_1,
778
+ }
779
+ )
780
+ if pass_at_1 < min_pass_at_1:
781
+ violations.append(
782
+ f"pass@1 {pass_at_1:.4f} is below threshold {min_pass_at_1:.4f}."
783
+ )
784
+
785
+ min_pass_at_k_raw = gate_cfg.get("min_pass_at_k")
786
+ if min_pass_at_k_raw is not None:
787
+ min_pass_at_k = as_float(min_pass_at_k_raw, 0.0)
788
+ pass_at_k = as_float(report.get("pass_at_k"), 0.0)
789
+ checks.append(
790
+ {
791
+ "name": "min_pass_at_k",
792
+ "actual": pass_at_k,
793
+ "threshold": min_pass_at_k,
794
+ }
795
+ )
796
+ if pass_at_k < min_pass_at_k:
797
+ violations.append(
798
+ f"pass@k {pass_at_k:.4f} is below threshold {min_pass_at_k:.4f}."
799
+ )
800
+
801
+ family_requirements = gate_cfg.get("required_family_pass_at_k", {})
802
+ family_metrics = report.get("family_metrics", {})
803
+ if isinstance(family_requirements, dict):
804
+ for family, threshold_raw in family_requirements.items():
805
+ threshold = as_float(threshold_raw, 0.0)
806
+ actual = None
807
+ if isinstance(family_metrics, dict):
808
+ family_row = family_metrics.get(family)
809
+ if isinstance(family_row, dict):
810
+ try:
811
+ actual = float(family_row.get("pass_at_k"))
812
+ except (TypeError, ValueError):
813
+ actual = None
814
+ checks.append(
815
+ {
816
+ "name": f"family_pass_at_k:{family}",
817
+ "actual": actual,
818
+ "threshold": threshold,
819
+ }
820
+ )
821
+ if actual is None:
822
+ violations.append(f"Missing pass@k metric for required family '{family}'.")
823
+ elif actual < threshold:
824
+ violations.append(
825
+ f"Family '{family}' pass@k {actual:.4f} is below threshold {threshold:.4f}."
826
+ )
827
+
828
+ result["violations"] = violations
829
+ result["checks"] = checks
830
+ result["passed"] = len(violations) == 0
831
+ return result
832
+
833
+
834
  def main() -> None:
835
  args = parse_args()
836
  cfg = load_config(args.config)
 
839
  seed = as_int(cfg.get("global", {}).get("seed"), 17)
840
  set_seed(seed)
841
 
842
+ output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "runs/math-conjecture-sota")
843
  output_root.mkdir(parents=True, exist_ok=True)
844
 
845
  token, username = resolve_auth(cfg)
846
  repo_id = resolve_repo_id(cfg, username=username, output_root=output_root)
847
+ push_to_hub_requested = bool(cfg.get("hub", {}).get("push_to_hub", False))
848
+ if args.dry_run and push_to_hub_requested:
849
  print("Dry-run enabled. Disabling push_to_hub for this run.")
850
+ push_to_hub_requested = push_to_hub_requested and not args.dry_run
851
+
852
+ if push_to_hub_requested:
853
  if token is None:
854
  raise ValueError("Hub push requested but token is missing.")
855
  if repo_id is None:
 
860
  model = None
861
  else:
862
  model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
863
+
864
  data_cfg = cfg["data"]
865
+ stage_reports: List[Dict[str, Any]] = []
866
 
867
  start_stage = max(1, args.start_stage)
868
  stages = cfg["stages"]
 
883
  raw = load_dataset("parquet", data_files=split_files)
884
  train_rows_before = len(raw["train"])
885
  valid_rows_before = len(raw["validation"])
886
+
887
  filters = stage.get("filters", {})
888
  raw["train"] = apply_filters(raw["train"], filters)
889
  raw["validation"] = apply_filters(raw["validation"], filters)
890
  train_rows_after_filter = len(raw["train"])
891
  valid_rows_after_filter = len(raw["validation"])
892
+
893
  raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples"))
894
  raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples"))
895
  train_rows_selected = len(raw["train"])
896
  valid_rows_selected = len(raw["validation"])
897
+
898
  print(
899
  f"[stage {index}] rows train: {train_rows_before} -> {train_rows_after_filter} -> {train_rows_selected}; "
900
  f"validation: {valid_rows_before} -> {valid_rows_after_filter} -> {valid_rows_selected}"
 
906
  sample_row = raw["train"][0]
907
  _ = build_prompt_text(sample_row, tokenizer, data_cfg)
908
  _ = build_answer_block(sample_row, data_cfg)
909
+ stage_reports.append(
910
+ {
911
+ "stage_index": index,
912
+ "stage_name": stage_name,
913
+ "stage_slug": stage_slug,
914
+ "mode": "dry_run",
915
+ "train_rows_before_filter": train_rows_before,
916
+ "validation_rows_before_filter": valid_rows_before,
917
+ "train_rows_after_filter": train_rows_after_filter,
918
+ "validation_rows_after_filter": valid_rows_after_filter,
919
+ "train_rows_selected": train_rows_selected,
920
+ "validation_rows_selected": valid_rows_selected,
921
+ }
922
+ )
923
  print(f"[stage {index}] Dry-run checks passed.")
924
  continue
925
 
 
950
  trainer.log_metrics("train", train_result.metrics)
951
  trainer.save_metrics("train", train_result.metrics)
952
  trainer.save_state()
953
+
954
  eval_metrics = None
955
  if eval_dataset is not None:
956
  eval_metrics = trainer.evaluate()
957
  trainer.log_metrics("eval", eval_metrics)
958
  trainer.save_metrics("eval", eval_metrics)
959
+
960
  trainer.save_model(str(stage_output_dir))
961
  tokenizer.save_pretrained(str(stage_output_dir))
962
 
963
+ stage_reports.append(
964
+ {
965
+ "stage_index": index,
966
+ "stage_name": stage_name,
967
+ "output_dir": str(stage_output_dir),
968
+ "train_rows_before_filter": train_rows_before,
969
+ "validation_rows_before_filter": valid_rows_before,
970
+ "train_rows_after_filter": train_rows_after_filter,
971
+ "validation_rows_after_filter": valid_rows_after_filter,
972
+ "train_rows_selected": train_rows_selected,
973
+ "validation_rows_selected": valid_rows_selected,
974
+ "train_rows": len(train_dataset),
975
+ "eval_rows": len(eval_dataset) if eval_dataset is not None else 0,
976
+ "train_metrics": train_result.metrics,
977
+ "eval_metrics": eval_metrics,
978
+ }
979
+ )
980
  print(
981
+ f"[stage {index}] Completed: train_rows={len(train_dataset)} "
982
+ f"eval_rows={len(eval_dataset) if eval_dataset is not None else 0} output={stage_output_dir}"
983
  )
984
 
985
  if args.dry_run:
 
1003
  model.save_pretrained(str(final_dir))
1004
  tokenizer.save_pretrained(str(final_dir))
1005
 
1006
+ release_model_memory(model)
1007
+ del model
1008
+
1009
+ post_eval_result = run_post_eval(
1010
+ cfg=cfg,
1011
+ config_path=args.config,
1012
+ output_root=output_root,
1013
+ final_adapter_dir=final_dir,
1014
+ )
1015
+
1016
+ quality_gate = evaluate_quality_gate(
1017
+ stage_reports=stage_reports,
1018
+ post_eval_result=post_eval_result,
1019
+ gate_cfg=cfg.get("quality_gate", {}),
1020
+ )
1021
+
1022
+ push_to_hub_performed = push_to_hub_requested
1023
+ push_block_reason: Optional[str] = None
1024
+ if push_to_hub_requested and not quality_gate.get("passed", True):
1025
+ push_to_hub_performed = False
1026
+ push_block_reason = "quality_gate_failed"
1027
+ print("Quality gate failed; skipping hub push for this run.")
1028
+
1029
+ summary: Dict[str, Any] = {
1030
  "config_path": str(args.config),
1031
  "repo_id": repo_id,
1032
  "seed": seed,
1033
  "stages_ran": stage_reports,
1034
  "final_adapter_dir": str(final_dir),
1035
+ "quality_gate": quality_gate,
1036
+ "push": {
1037
+ "requested": bool(push_to_hub_requested),
1038
+ "performed": bool(push_to_hub_performed),
1039
+ "block_reason": push_block_reason,
1040
+ },
1041
  }
1042
+
1043
+ if post_eval_result is not None:
1044
+ report = post_eval_result.get("report", {})
1045
+ summary["post_eval"] = {
1046
+ "report_path": post_eval_result.get("report_path"),
1047
+ "evaluated_rows": report.get("evaluated_rows"),
1048
+ "k": report.get("k"),
1049
+ "pass_at_1": report.get("pass_at_1"),
1050
+ "pass_at_k": report.get("pass_at_k"),
1051
+ "exact_at_k": report.get("exact_at_k"),
1052
+ "composite_score": report.get("composite_score"),
1053
+ }
1054
+
1055
  summary_path = output_root / "training_summary.json"
1056
  summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
1057
 
1058
+ if push_to_hub_performed and repo_id is not None and token is not None:
1059
  api = HfApi(token=token)
1060
  api.create_repo(
1061
  repo_id=repo_id,
 
1065
  )
1066
  commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter."
1067
  push_folder(api, repo_id, final_dir, commit_message=commit_message)
1068
+
1069
  if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)):
1070
  for report in stage_reports:
1071
+ stage_dir_raw = report.get("output_dir")
1072
+ if not stage_dir_raw:
1073
+ continue
1074
+ stage_dir = Path(stage_dir_raw)
1075
+ path_in_repo = f"checkpoints/{stage_dir.name}"
1076
  push_folder(
1077
  api,
1078
  repo_id,
1079
  stage_dir,
1080
+ commit_message=f"Upload stage checkpoint {report.get('stage_name', stage_dir.name)}",
1081
  path_in_repo=path_in_repo,
1082
  )
1083
+
1084
  api.upload_file(
1085
  path_or_fileobj=str(summary_path),
1086
  path_in_repo="training_summary.json",
 
1088
  repo_type="model",
1089
  commit_message="Upload training summary for SOTA curriculum run.",
1090
  )
1091
+
1092
+ if post_eval_result is not None and post_eval_result.get("report_path"):
1093
+ api.upload_file(
1094
+ path_or_fileobj=str(post_eval_result["report_path"]),
1095
+ path_in_repo="post_eval_report.json",
1096
+ repo_id=repo_id,
1097
+ repo_type="model",
1098
+ commit_message="Upload post-training evaluation report.",
1099
+ )
1100
+
1101
  print(f"Pushed training artifacts to https://huggingface.co/{repo_id}")
1102
 
1103
  print(f"Training complete. Final adapter: {final_dir}")