Upgrade training pipeline with post-eval quality gates and tactical UI controls.
Browse files- README.md +35 -35
- app.py +171 -7
- configs/deepseek_math_sota.yaml +40 -2
- scripts/eval_sota.py +344 -61
- scripts/train_sota.py +387 -47
README.md
CHANGED
|
@@ -9,52 +9,52 @@ pinned: false
|
|
| 9 |
|
| 10 |
# Math Conjecture Trainer Space
|
| 11 |
|
| 12 |
-
|
| 13 |
-
training corpus splits, builds runtime config from the SOTA curriculum YAML,
|
| 14 |
-
executes multi-stage `DeepSeek-Math` fine-tuning, optionally evaluates
|
| 15 |
-
self-consistency, and can publish adapters/checkpoints/training summaries to:
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
-
`NorthernTribe-Research/math-conjecture-training-corpus`
|
| 23 |
-
2. Builds a runtime config from `configs/deepseek_math_sota.yaml`
|
| 24 |
-
3. Runs `scripts/train_sota.py` for staged curriculum training
|
| 25 |
-
4. Optionally runs `scripts/eval_sota.py`
|
| 26 |
-
5. Streams logs and run summary JSON in the UI
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
-
It resolves auth in this order:
|
| 32 |
|
| 33 |
-
|
| 34 |
-
2. `HUGGINGFACE_HUB_TOKEN` environment variable
|
| 35 |
-
3. `huggingface-api-key.json` (if present)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
automatically disabled for that run.
|
| 39 |
|
| 40 |
-
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
`train_sota.py --dry-run`.
|
| 44 |
-
- `Push Adapter to Hub`: controls whether `hub.push_to_hub` is enabled in the
|
| 45 |
-
runtime config.
|
| 46 |
-
- `Force Dataset Redownload`: bypasses cached local parquet files.
|
| 47 |
-
- `Stop Active Run`: requests cancellation and terminates active subprocesses.
|
| 48 |
-
- `Run Summary (JSON)`: structured output with config, status, and metrics.
|
| 49 |
|
| 50 |
-
##
|
| 51 |
|
| 52 |
-
- `
|
| 53 |
-
-
|
| 54 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
## Notes
|
| 57 |
|
| 58 |
-
- Full training
|
| 59 |
-
-
|
| 60 |
-
`workspace/runtime/deepseek_math_sota.runtime.yaml`.
|
|
|
|
| 9 |
|
| 10 |
# Math Conjecture Trainer Space
|
| 11 |
|
| 12 |
+
Launch multi-stage DeepSeek-Math fine-tuning on Space GPU and push adapters to your model repo.
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
This Space is the tactical operations console for `maths-conjuncture-solutions` and is wired to:
|
| 15 |
|
| 16 |
+
- dataset: `NorthernTribe-Research/math-conjecture-training-corpus`
|
| 17 |
+
- model repo: `NorthernTribe-Research/math-conjecture-model`
|
| 18 |
|
| 19 |
+
## End-to-end flow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
1. Download released parquet splits (`train/validation/test`).
|
| 22 |
+
2. Build runtime config from `configs/deepseek_math_sota.yaml`.
|
| 23 |
+
3. Run 4-stage curriculum LoRA fine-tuning with `scripts/train_sota.py`.
|
| 24 |
+
4. Run post-train evaluation (`pass@1`, `pass@k`, exact/boxed, family metrics).
|
| 25 |
+
5. Apply quality gate thresholds before hub push.
|
| 26 |
+
6. Emit `training_summary.json` + `post_eval_report.json` and stream live telemetry in UI.
|
| 27 |
|
| 28 |
+
## Autonomous authentication
|
|
|
|
| 29 |
|
| 30 |
+
No token input is required in the UI.
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
Resolution order:
|
|
|
|
| 33 |
|
| 34 |
+
1. `HF_TOKEN`
|
| 35 |
+
2. `HUGGINGFACE_HUB_TOKEN`
|
| 36 |
+
3. `huggingface-api-key.json`
|
| 37 |
|
| 38 |
+
If no token is available, public dataset training still works and push is automatically disabled.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
## Runtime controls
|
| 41 |
|
| 42 |
+
- `Run Evaluation After Training`: toggles post-train eval in runtime config.
|
| 43 |
+
- `Enforce Quality Gate`: enables/disables promotion gate checks.
|
| 44 |
+
- `Gate Min pass@1`, `Gate Min pass@k`, `Gate Min Rows`: runtime gate thresholds.
|
| 45 |
+
- `Preflight Only (No Training)`: validates pipeline with `--dry-run`.
|
| 46 |
+
- `Force Dataset Redownload`: bypasses cached parquet files.
|
| 47 |
+
- `Abort Active Run`: cancels active subprocess tree.
|
| 48 |
+
|
| 49 |
+
## Artifacts
|
| 50 |
+
|
| 51 |
+
- runtime config: `workspace/runtime/deepseek_math_sota.runtime.yaml`
|
| 52 |
+
- run output root: `workspace/runs/math-conjecture-sota`
|
| 53 |
+
- final adapter: `workspace/runs/math-conjecture-sota/final_adapter`
|
| 54 |
+
- training summary: `workspace/runs/math-conjecture-sota/training_summary.json`
|
| 55 |
+
- post-eval report: `workspace/runs/math-conjecture-sota/post_eval_report.json`
|
| 56 |
|
| 57 |
## Notes
|
| 58 |
|
| 59 |
+
- Full training requires GPU hardware.
|
| 60 |
+
- App handles Gradio copy-button compatibility across versions automatically.
|
|
|
app.py
CHANGED
|
@@ -198,15 +198,58 @@ PROJECT_DESCRIPTION = """
|
|
| 198 |
# Math Conjecture Trainer
|
| 199 |
This console runs the full training operations lane for the `maths-conjuncture-solutions` project:
|
| 200 |
|
|
|
|
|
|
|
| 201 |
1. Pull released parquet splits from `NorthernTribe-Research/math-conjecture-training-corpus`.
|
| 202 |
2. Build runtime training configuration from `configs/deepseek_math_sota.yaml`.
|
| 203 |
3. Execute multi-stage DeepSeek-Math curriculum fine-tuning via `scripts/train_sota.py`.
|
| 204 |
-
4.
|
| 205 |
-
5.
|
| 206 |
-
6.
|
| 207 |
"""
|
| 208 |
|
| 209 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
def now_ts() -> str:
|
| 211 |
return dt.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC")
|
| 212 |
|
|
@@ -390,7 +433,15 @@ def write_runtime_config(
|
|
| 390 |
model_repo_id: str,
|
| 391 |
train_file: str,
|
| 392 |
validation_file: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
push_to_hub: bool,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
) -> Path:
|
| 395 |
cfg = yaml.safe_load(CONFIG_TEMPLATE.read_text(encoding="utf-8"))
|
| 396 |
cfg["model"]["base_model"] = base_model_id
|
|
@@ -399,6 +450,21 @@ def write_runtime_config(
|
|
| 399 |
cfg["data"]["default_train_file"] = train_file
|
| 400 |
cfg["data"]["default_validation_file"] = validation_file
|
| 401 |
cfg["global"]["output_root"] = str(TRAIN_OUTPUT_DIR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
runtime_path = RUNTIME_DIR / "deepseek_math_sota.runtime.yaml"
|
| 403 |
runtime_path.write_text(yaml.safe_dump(cfg, sort_keys=False), encoding="utf-8")
|
| 404 |
return runtime_path
|
|
@@ -511,6 +577,10 @@ def run_pipeline(
|
|
| 511 |
run_eval: bool,
|
| 512 |
eval_k: int,
|
| 513 |
eval_samples: int,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
push_to_hub: bool,
|
| 515 |
force_redownload: bool,
|
| 516 |
preflight_only: bool,
|
|
@@ -542,14 +612,25 @@ def run_pipeline(
|
|
| 542 |
stage_count = int(max_stages)
|
| 543 |
eval_k = int(eval_k)
|
| 544 |
eval_samples = int(eval_samples)
|
|
|
|
|
|
|
|
|
|
| 545 |
if stage_start < 1:
|
| 546 |
raise ValueError("Start stage must be >= 1.")
|
|
|
|
|
|
|
| 547 |
if stage_count < 1:
|
| 548 |
raise ValueError("How many stages must be >= 1.")
|
| 549 |
if eval_k < 1:
|
| 550 |
raise ValueError("Eval K must be >= 1.")
|
| 551 |
if eval_samples < 1:
|
| 552 |
raise ValueError("Eval max samples must be >= 1.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
for required_path in (CONFIG_TEMPLATE, TRAIN_SCRIPT):
|
| 555 |
if not required_path.exists():
|
|
@@ -570,6 +651,10 @@ def run_pipeline(
|
|
| 570 |
"run_eval": bool(run_eval),
|
| 571 |
"eval_k": eval_k,
|
| 572 |
"eval_samples": eval_samples,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
"push_to_hub": bool(push_to_hub),
|
| 574 |
"force_redownload": bool(force_redownload),
|
| 575 |
"preflight_only": bool(preflight_only),
|
|
@@ -633,7 +718,15 @@ def run_pipeline(
|
|
| 633 |
model_repo_id=model_repo_id,
|
| 634 |
train_file=train_file,
|
| 635 |
validation_file=validation_file,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
push_to_hub=effective_push_to_hub,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
)
|
| 638 |
summary["runtime_config"] = str(runtime_cfg)
|
| 639 |
append_log(log_lines, f"Wrote runtime config: {runtime_cfg}")
|
|
@@ -701,15 +794,50 @@ def run_pipeline(
|
|
| 701 |
return
|
| 702 |
|
| 703 |
training_summary_path = TRAIN_OUTPUT_DIR / "training_summary.json"
|
|
|
|
| 704 |
if training_summary_path.exists():
|
| 705 |
try:
|
| 706 |
summary["training_summary_path"] = str(training_summary_path)
|
| 707 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
except json.JSONDecodeError:
|
| 709 |
summary["training_summary_path"] = str(training_summary_path)
|
| 710 |
summary["training_summary"] = {"warning": "Unable to parse training summary JSON."}
|
| 711 |
|
| 712 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
eval_report = WORKSPACE_DIR / "runs" / "latest_eval_report.json"
|
| 714 |
eval_cmd = [
|
| 715 |
sys.executable,
|
|
@@ -763,9 +891,12 @@ def run_pipeline(
|
|
| 763 |
if eval_report.exists():
|
| 764 |
report = json.loads(eval_report.read_text(encoding="utf-8"))
|
| 765 |
summary["evaluation"] = {
|
|
|
|
| 766 |
"evaluated_rows": report.get("evaluated_rows"),
|
| 767 |
"pass_at_1": report.get("pass_at_1"),
|
| 768 |
"pass_at_k": report.get("pass_at_k"),
|
|
|
|
|
|
|
| 769 |
"k": report.get("k"),
|
| 770 |
"report_path": str(eval_report),
|
| 771 |
}
|
|
@@ -807,12 +938,41 @@ with gr.Blocks(title="Math Conjecture Trainer Space") as demo:
|
|
| 807 |
value="deepseek-ai/deepseek-math-v2",
|
| 808 |
)
|
| 809 |
with gr.Row():
|
| 810 |
-
start_stage = gr.Slider(label="Stage Start", minimum=1, maximum=
|
| 811 |
-
max_stages = gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
run_eval = gr.Checkbox(label="Run Evaluation After Training", value=True)
|
| 813 |
with gr.Row():
|
| 814 |
eval_k = gr.Slider(label="Evaluation K", minimum=1, maximum=8, step=1, value=4)
|
| 815 |
eval_samples = gr.Slider(label="Evaluation Max Samples", minimum=50, maximum=1000, step=50, value=300)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
with gr.Row():
|
| 817 |
push_to_hub = gr.Checkbox(label="Push Adapter to Hub", value=True)
|
| 818 |
force_redownload = gr.Checkbox(label="Force Dataset Redownload", value=False)
|
|
@@ -843,6 +1003,10 @@ with gr.Blocks(title="Math Conjecture Trainer Space") as demo:
|
|
| 843 |
run_eval,
|
| 844 |
eval_k,
|
| 845 |
eval_samples,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 846 |
push_to_hub,
|
| 847 |
force_redownload,
|
| 848 |
preflight_only,
|
|
|
|
| 198 |
# Math Conjecture Trainer
|
| 199 |
This console runs the full training operations lane for the `maths-conjuncture-solutions` project:
|
| 200 |
|
| 201 |
+
Launch multi-stage DeepSeek-Math fine-tuning on Space GPU and push adapters to your model repo.
|
| 202 |
+
|
| 203 |
1. Pull released parquet splits from `NorthernTribe-Research/math-conjecture-training-corpus`.
|
| 204 |
2. Build runtime training configuration from `configs/deepseek_math_sota.yaml`.
|
| 205 |
3. Execute multi-stage DeepSeek-Math curriculum fine-tuning via `scripts/train_sota.py`.
|
| 206 |
+
4. Run post-training evaluation with pass@k-style sampling and family-level metrics.
|
| 207 |
+
5. Enforce autonomous quality gates before adapter promotion/push.
|
| 208 |
+
6. Auto-resolve Hugging Face credentials, stream live telemetry, and emit structured run summaries.
|
| 209 |
"""
|
| 210 |
|
| 211 |
|
| 212 |
+
def _safe_float(value: Any, default: float) -> float:
|
| 213 |
+
try:
|
| 214 |
+
return float(value)
|
| 215 |
+
except (TypeError, ValueError):
|
| 216 |
+
return default
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _safe_int(value: Any, default: int) -> int:
|
| 220 |
+
try:
|
| 221 |
+
return int(value)
|
| 222 |
+
except (TypeError, ValueError):
|
| 223 |
+
return default
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def load_template_defaults() -> Dict[str, Any]:
|
| 227 |
+
if not CONFIG_TEMPLATE.exists():
|
| 228 |
+
return {}
|
| 229 |
+
try:
|
| 230 |
+
cfg = yaml.safe_load(CONFIG_TEMPLATE.read_text(encoding="utf-8"))
|
| 231 |
+
except Exception:
|
| 232 |
+
return {}
|
| 233 |
+
if not isinstance(cfg, dict):
|
| 234 |
+
return {}
|
| 235 |
+
return cfg
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
TEMPLATE_CFG = load_template_defaults()
|
| 239 |
+
TEMPLATE_STAGE_COUNT = max(1, len(TEMPLATE_CFG.get("stages", []) or [None]))
|
| 240 |
+
TEMPLATE_QUALITY_GATE = TEMPLATE_CFG.get("quality_gate", {})
|
| 241 |
+
if not isinstance(TEMPLATE_QUALITY_GATE, dict):
|
| 242 |
+
TEMPLATE_QUALITY_GATE = {}
|
| 243 |
+
_raw_gate_enabled = TEMPLATE_QUALITY_GATE.get("enabled", True)
|
| 244 |
+
if isinstance(_raw_gate_enabled, bool):
|
| 245 |
+
DEFAULT_GATE_ENABLED = _raw_gate_enabled
|
| 246 |
+
else:
|
| 247 |
+
DEFAULT_GATE_ENABLED = str(_raw_gate_enabled).strip().lower() in {"1", "true", "yes", "y", "on"}
|
| 248 |
+
DEFAULT_GATE_MIN_ROWS = max(1, _safe_int(TEMPLATE_QUALITY_GATE.get("min_evaluated_rows"), 120))
|
| 249 |
+
DEFAULT_GATE_MIN_PASS_AT_1 = max(0.0, _safe_float(TEMPLATE_QUALITY_GATE.get("min_pass_at_1"), 0.01))
|
| 250 |
+
DEFAULT_GATE_MIN_PASS_AT_K = max(0.0, _safe_float(TEMPLATE_QUALITY_GATE.get("min_pass_at_k"), 0.06))
|
| 251 |
+
|
| 252 |
+
|
| 253 |
def now_ts() -> str:
|
| 254 |
return dt.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC")
|
| 255 |
|
|
|
|
| 433 |
model_repo_id: str,
|
| 434 |
train_file: str,
|
| 435 |
validation_file: str,
|
| 436 |
+
test_file: str,
|
| 437 |
+
run_eval: bool,
|
| 438 |
+
eval_k: int,
|
| 439 |
+
eval_samples: int,
|
| 440 |
push_to_hub: bool,
|
| 441 |
+
enforce_quality_gate: bool,
|
| 442 |
+
gate_min_pass_at_1: float,
|
| 443 |
+
gate_min_pass_at_k: float,
|
| 444 |
+
gate_min_rows: int,
|
| 445 |
) -> Path:
|
| 446 |
cfg = yaml.safe_load(CONFIG_TEMPLATE.read_text(encoding="utf-8"))
|
| 447 |
cfg["model"]["base_model"] = base_model_id
|
|
|
|
| 450 |
cfg["data"]["default_train_file"] = train_file
|
| 451 |
cfg["data"]["default_validation_file"] = validation_file
|
| 452 |
cfg["global"]["output_root"] = str(TRAIN_OUTPUT_DIR)
|
| 453 |
+
|
| 454 |
+
cfg.setdefault("post_eval", {})
|
| 455 |
+
cfg["post_eval"]["enabled"] = bool(run_eval)
|
| 456 |
+
cfg["post_eval"]["eval_file"] = test_file
|
| 457 |
+
cfg["post_eval"]["k"] = int(eval_k)
|
| 458 |
+
cfg["post_eval"]["max_samples"] = int(eval_samples)
|
| 459 |
+
cfg["post_eval"]["output_json"] = str(TRAIN_OUTPUT_DIR / "post_eval_report.json")
|
| 460 |
+
|
| 461 |
+
cfg.setdefault("quality_gate", {})
|
| 462 |
+
cfg["quality_gate"]["enabled"] = bool(enforce_quality_gate)
|
| 463 |
+
cfg["quality_gate"]["min_evaluated_rows"] = int(gate_min_rows)
|
| 464 |
+
cfg["quality_gate"]["min_pass_at_1"] = float(gate_min_pass_at_1)
|
| 465 |
+
cfg["quality_gate"]["min_pass_at_k"] = float(gate_min_pass_at_k)
|
| 466 |
+
cfg["quality_gate"]["require_post_eval"] = bool(enforce_quality_gate and run_eval)
|
| 467 |
+
|
| 468 |
runtime_path = RUNTIME_DIR / "deepseek_math_sota.runtime.yaml"
|
| 469 |
runtime_path.write_text(yaml.safe_dump(cfg, sort_keys=False), encoding="utf-8")
|
| 470 |
return runtime_path
|
|
|
|
| 577 |
run_eval: bool,
|
| 578 |
eval_k: int,
|
| 579 |
eval_samples: int,
|
| 580 |
+
enforce_quality_gate: bool,
|
| 581 |
+
gate_min_pass_at_1: float,
|
| 582 |
+
gate_min_pass_at_k: float,
|
| 583 |
+
gate_min_rows: int,
|
| 584 |
push_to_hub: bool,
|
| 585 |
force_redownload: bool,
|
| 586 |
preflight_only: bool,
|
|
|
|
| 612 |
stage_count = int(max_stages)
|
| 613 |
eval_k = int(eval_k)
|
| 614 |
eval_samples = int(eval_samples)
|
| 615 |
+
gate_min_rows = int(gate_min_rows)
|
| 616 |
+
gate_min_pass_at_1 = float(gate_min_pass_at_1)
|
| 617 |
+
gate_min_pass_at_k = float(gate_min_pass_at_k)
|
| 618 |
if stage_start < 1:
|
| 619 |
raise ValueError("Start stage must be >= 1.")
|
| 620 |
+
if stage_start > TEMPLATE_STAGE_COUNT:
|
| 621 |
+
raise ValueError(f"Start stage must be <= {TEMPLATE_STAGE_COUNT}.")
|
| 622 |
if stage_count < 1:
|
| 623 |
raise ValueError("How many stages must be >= 1.")
|
| 624 |
if eval_k < 1:
|
| 625 |
raise ValueError("Eval K must be >= 1.")
|
| 626 |
if eval_samples < 1:
|
| 627 |
raise ValueError("Eval max samples must be >= 1.")
|
| 628 |
+
if gate_min_rows < 1:
|
| 629 |
+
raise ValueError("Gate minimum rows must be >= 1.")
|
| 630 |
+
if not 0.0 <= gate_min_pass_at_1 <= 1.0:
|
| 631 |
+
raise ValueError("Gate min pass@1 must be between 0 and 1.")
|
| 632 |
+
if not 0.0 <= gate_min_pass_at_k <= 1.0:
|
| 633 |
+
raise ValueError("Gate min pass@k must be between 0 and 1.")
|
| 634 |
|
| 635 |
for required_path in (CONFIG_TEMPLATE, TRAIN_SCRIPT):
|
| 636 |
if not required_path.exists():
|
|
|
|
| 651 |
"run_eval": bool(run_eval),
|
| 652 |
"eval_k": eval_k,
|
| 653 |
"eval_samples": eval_samples,
|
| 654 |
+
"enforce_quality_gate": bool(enforce_quality_gate),
|
| 655 |
+
"gate_min_rows": gate_min_rows,
|
| 656 |
+
"gate_min_pass_at_1": gate_min_pass_at_1,
|
| 657 |
+
"gate_min_pass_at_k": gate_min_pass_at_k,
|
| 658 |
"push_to_hub": bool(push_to_hub),
|
| 659 |
"force_redownload": bool(force_redownload),
|
| 660 |
"preflight_only": bool(preflight_only),
|
|
|
|
| 718 |
model_repo_id=model_repo_id,
|
| 719 |
train_file=train_file,
|
| 720 |
validation_file=validation_file,
|
| 721 |
+
test_file=test_file,
|
| 722 |
+
run_eval=bool(run_eval),
|
| 723 |
+
eval_k=eval_k,
|
| 724 |
+
eval_samples=eval_samples,
|
| 725 |
push_to_hub=effective_push_to_hub,
|
| 726 |
+
enforce_quality_gate=bool(enforce_quality_gate),
|
| 727 |
+
gate_min_pass_at_1=gate_min_pass_at_1,
|
| 728 |
+
gate_min_pass_at_k=gate_min_pass_at_k,
|
| 729 |
+
gate_min_rows=gate_min_rows,
|
| 730 |
)
|
| 731 |
summary["runtime_config"] = str(runtime_cfg)
|
| 732 |
append_log(log_lines, f"Wrote runtime config: {runtime_cfg}")
|
|
|
|
| 794 |
return
|
| 795 |
|
| 796 |
training_summary_path = TRAIN_OUTPUT_DIR / "training_summary.json"
|
| 797 |
+
training_summary: Optional[Dict[str, Any]] = None
|
| 798 |
if training_summary_path.exists():
|
| 799 |
try:
|
| 800 |
summary["training_summary_path"] = str(training_summary_path)
|
| 801 |
+
loaded_summary = json.loads(training_summary_path.read_text(encoding="utf-8"))
|
| 802 |
+
if isinstance(loaded_summary, dict):
|
| 803 |
+
training_summary = loaded_summary
|
| 804 |
+
summary["training_summary"] = loaded_summary
|
| 805 |
+
else:
|
| 806 |
+
summary["training_summary"] = {"warning": "Training summary JSON is not an object."}
|
| 807 |
except json.JSONDecodeError:
|
| 808 |
summary["training_summary_path"] = str(training_summary_path)
|
| 809 |
summary["training_summary"] = {"warning": "Unable to parse training summary JSON."}
|
| 810 |
|
| 811 |
+
if isinstance(training_summary, dict):
|
| 812 |
+
quality_gate = training_summary.get("quality_gate")
|
| 813 |
+
if isinstance(quality_gate, dict):
|
| 814 |
+
summary["quality_gate"] = quality_gate
|
| 815 |
+
append_log(
|
| 816 |
+
log_lines,
|
| 817 |
+
f"Quality gate: passed={quality_gate.get('passed')} enabled={quality_gate.get('enabled')}",
|
| 818 |
+
)
|
| 819 |
+
push_report = training_summary.get("push")
|
| 820 |
+
if isinstance(push_report, dict):
|
| 821 |
+
summary["push"] = push_report
|
| 822 |
+
append_log(
|
| 823 |
+
log_lines,
|
| 824 |
+
f"Push decision: requested={push_report.get('requested')} performed={push_report.get('performed')}",
|
| 825 |
+
)
|
| 826 |
+
post_eval_report = training_summary.get("post_eval")
|
| 827 |
+
if run_eval and isinstance(post_eval_report, dict):
|
| 828 |
+
summary["evaluation"] = {
|
| 829 |
+
"source": "train_post_eval",
|
| 830 |
+
"evaluated_rows": post_eval_report.get("evaluated_rows"),
|
| 831 |
+
"pass_at_1": post_eval_report.get("pass_at_1"),
|
| 832 |
+
"pass_at_k": post_eval_report.get("pass_at_k"),
|
| 833 |
+
"exact_at_k": post_eval_report.get("exact_at_k"),
|
| 834 |
+
"composite_score": post_eval_report.get("composite_score"),
|
| 835 |
+
"k": post_eval_report.get("k"),
|
| 836 |
+
"report_path": post_eval_report.get("report_path"),
|
| 837 |
+
}
|
| 838 |
+
append_log(log_lines, "Using post-eval metrics emitted by training run.")
|
| 839 |
+
|
| 840 |
+
if run_eval and "evaluation" not in summary:
|
| 841 |
eval_report = WORKSPACE_DIR / "runs" / "latest_eval_report.json"
|
| 842 |
eval_cmd = [
|
| 843 |
sys.executable,
|
|
|
|
| 891 |
if eval_report.exists():
|
| 892 |
report = json.loads(eval_report.read_text(encoding="utf-8"))
|
| 893 |
summary["evaluation"] = {
|
| 894 |
+
"source": "fallback_eval",
|
| 895 |
"evaluated_rows": report.get("evaluated_rows"),
|
| 896 |
"pass_at_1": report.get("pass_at_1"),
|
| 897 |
"pass_at_k": report.get("pass_at_k"),
|
| 898 |
+
"exact_at_k": report.get("exact_at_k"),
|
| 899 |
+
"composite_score": report.get("composite_score"),
|
| 900 |
"k": report.get("k"),
|
| 901 |
"report_path": str(eval_report),
|
| 902 |
}
|
|
|
|
| 938 |
value="deepseek-ai/deepseek-math-v2",
|
| 939 |
)
|
| 940 |
with gr.Row():
|
| 941 |
+
start_stage = gr.Slider(label="Stage Start", minimum=1, maximum=TEMPLATE_STAGE_COUNT, step=1, value=1)
|
| 942 |
+
max_stages = gr.Slider(
|
| 943 |
+
label="Stage Count",
|
| 944 |
+
minimum=1,
|
| 945 |
+
maximum=TEMPLATE_STAGE_COUNT,
|
| 946 |
+
step=1,
|
| 947 |
+
value=TEMPLATE_STAGE_COUNT,
|
| 948 |
+
)
|
| 949 |
run_eval = gr.Checkbox(label="Run Evaluation After Training", value=True)
|
| 950 |
with gr.Row():
|
| 951 |
eval_k = gr.Slider(label="Evaluation K", minimum=1, maximum=8, step=1, value=4)
|
| 952 |
eval_samples = gr.Slider(label="Evaluation Max Samples", minimum=50, maximum=1000, step=50, value=300)
|
| 953 |
+
with gr.Row():
|
| 954 |
+
enforce_quality_gate = gr.Checkbox(label="Enforce Quality Gate", value=DEFAULT_GATE_ENABLED)
|
| 955 |
+
gate_min_pass_at_1 = gr.Slider(
|
| 956 |
+
label="Gate Min pass@1",
|
| 957 |
+
minimum=0.0,
|
| 958 |
+
maximum=0.5,
|
| 959 |
+
step=0.005,
|
| 960 |
+
value=min(max(DEFAULT_GATE_MIN_PASS_AT_1, 0.0), 0.5),
|
| 961 |
+
)
|
| 962 |
+
gate_min_pass_at_k = gr.Slider(
|
| 963 |
+
label="Gate Min pass@k",
|
| 964 |
+
minimum=0.0,
|
| 965 |
+
maximum=1.0,
|
| 966 |
+
step=0.01,
|
| 967 |
+
value=min(max(DEFAULT_GATE_MIN_PASS_AT_K, 0.0), 1.0),
|
| 968 |
+
)
|
| 969 |
+
gate_min_rows = gr.Slider(
|
| 970 |
+
label="Gate Min Rows",
|
| 971 |
+
minimum=10,
|
| 972 |
+
maximum=2000,
|
| 973 |
+
step=10,
|
| 974 |
+
value=min(max(DEFAULT_GATE_MIN_ROWS, 10), 2000),
|
| 975 |
+
)
|
| 976 |
with gr.Row():
|
| 977 |
push_to_hub = gr.Checkbox(label="Push Adapter to Hub", value=True)
|
| 978 |
force_redownload = gr.Checkbox(label="Force Dataset Redownload", value=False)
|
|
|
|
| 1003 |
run_eval,
|
| 1004 |
eval_k,
|
| 1005 |
eval_samples,
|
| 1006 |
+
enforce_quality_gate,
|
| 1007 |
+
gate_min_pass_at_1,
|
| 1008 |
+
gate_min_pass_at_k,
|
| 1009 |
+
gate_min_rows,
|
| 1010 |
push_to_hub,
|
| 1011 |
force_redownload,
|
| 1012 |
preflight_only,
|
configs/deepseek_math_sota.yaml
CHANGED
|
@@ -97,17 +97,55 @@ stages:
|
|
| 97 |
- conjecture_core
|
| 98 |
require_conjecture_id: true
|
| 99 |
training:
|
| 100 |
-
num_train_epochs:
|
| 101 |
learning_rate: 5.0e-6
|
| 102 |
save_steps: 100
|
| 103 |
eval_steps: 100
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
hub:
|
| 106 |
push_to_hub: true
|
| 107 |
repo_id: NorthernTribe-Research/math-conjecture-model
|
| 108 |
private: false
|
| 109 |
upload_stage_checkpoints: true
|
| 110 |
-
commit_message:
|
| 111 |
|
| 112 |
credentials:
|
| 113 |
path: huggingface-api-key.json
|
|
|
|
| 97 |
- conjecture_core
|
| 98 |
require_conjecture_id: true
|
| 99 |
training:
|
| 100 |
+
num_train_epochs: 2
|
| 101 |
learning_rate: 5.0e-6
|
| 102 |
save_steps: 100
|
| 103 |
eval_steps: 100
|
| 104 |
|
| 105 |
+
- name: hard_case_polish
|
| 106 |
+
max_train_samples: 60000
|
| 107 |
+
max_eval_samples: 2000
|
| 108 |
+
filters:
|
| 109 |
+
include_families:
|
| 110 |
+
- conjecture_core
|
| 111 |
+
- formal_proof
|
| 112 |
+
require_conjecture_id: true
|
| 113 |
+
min_sample_weight: 3.0
|
| 114 |
+
training:
|
| 115 |
+
num_train_epochs: 1
|
| 116 |
+
learning_rate: 3.0e-6
|
| 117 |
+
gradient_accumulation_steps: 24
|
| 118 |
+
save_steps: 80
|
| 119 |
+
eval_steps: 80
|
| 120 |
+
|
| 121 |
+
post_eval:
|
| 122 |
+
enabled: true
|
| 123 |
+
eval_file: workspace/data/releases/v1/test.parquet
|
| 124 |
+
max_samples: 240
|
| 125 |
+
k: 6
|
| 126 |
+
max_new_tokens: 320
|
| 127 |
+
temperature: 0.7
|
| 128 |
+
top_p: 0.95
|
| 129 |
+
seed: 17
|
| 130 |
+
output_json: workspace/runs/math-conjecture-sota/post_eval_report.json
|
| 131 |
+
|
| 132 |
+
quality_gate:
|
| 133 |
+
enabled: true
|
| 134 |
+
require_post_eval: true
|
| 135 |
+
min_evaluated_rows: 120
|
| 136 |
+
min_pass_at_1: 0.01
|
| 137 |
+
min_pass_at_k: 0.06
|
| 138 |
+
max_final_eval_loss: 2.6
|
| 139 |
+
required_family_pass_at_k:
|
| 140 |
+
conjecture_core: 0.06
|
| 141 |
+
formal_proof: 0.03
|
| 142 |
+
|
| 143 |
hub:
|
| 144 |
push_to_hub: true
|
| 145 |
repo_id: NorthernTribe-Research/math-conjecture-model
|
| 146 |
private: false
|
| 147 |
upload_stage_checkpoints: true
|
| 148 |
+
commit_message: Launch multi-stage DeepSeek-Math fine-tuning on Space GPU and push adapters to your model repo.
|
| 149 |
|
| 150 |
credentials:
|
| 151 |
path: huggingface-api-key.json
|
scripts/eval_sota.py
CHANGED
|
@@ -7,7 +7,7 @@ import argparse
|
|
| 7 |
import json
|
| 8 |
import re
|
| 9 |
from pathlib import Path
|
| 10 |
-
from typing import Any, Dict, List, Optional, Sequence
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import yaml
|
|
@@ -15,13 +15,20 @@ from datasets import load_dataset
|
|
| 15 |
from peft import PeftModel
|
| 16 |
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def parse_args() -> argparse.Namespace:
|
| 20 |
parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.")
|
| 21 |
parser.add_argument(
|
| 22 |
"--config",
|
| 23 |
type=Path,
|
| 24 |
-
default=
|
| 25 |
help="Training config used for prompt formatting defaults.",
|
| 26 |
)
|
| 27 |
parser.add_argument(
|
|
@@ -39,19 +46,32 @@ def parse_args() -> argparse.Namespace:
|
|
| 39 |
parser.add_argument(
|
| 40 |
"--eval-file",
|
| 41 |
type=Path,
|
| 42 |
-
default=
|
| 43 |
-
help="Parquet split used for evaluation.",
|
| 44 |
)
|
| 45 |
parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.")
|
| 46 |
parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.")
|
| 47 |
parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.")
|
|
|
|
| 48 |
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
|
| 49 |
parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.")
|
| 50 |
parser.add_argument("--seed", type=int, default=17, help="Random seed.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
parser.add_argument(
|
| 52 |
"--output-json",
|
| 53 |
type=Path,
|
| 54 |
-
default=
|
| 55 |
help="Where to write evaluation report.",
|
| 56 |
)
|
| 57 |
return parser.parse_args()
|
|
@@ -65,6 +85,24 @@ def as_text(value: Any) -> str:
|
|
| 65 |
return str(value).strip()
|
| 66 |
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
def load_config(path: Path) -> Dict[str, Any]:
|
| 69 |
cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
|
| 70 |
if not isinstance(cfg, dict):
|
|
@@ -74,9 +112,124 @@ def load_config(path: Path) -> Dict[str, Any]:
|
|
| 74 |
|
| 75 |
def normalize_answer(text: str) -> str:
|
| 76 |
text = text.strip().lower()
|
| 77 |
-
text = re.sub(r"\s+", " ", text)
|
| 78 |
text = text.replace("$", "")
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]:
|
|
@@ -168,27 +321,11 @@ def extract_candidate_text(full_generation: str, prompt_text: str) -> str:
|
|
| 168 |
return full_generation.strip()
|
| 169 |
|
| 170 |
|
| 171 |
-
def is_match(candidate: str, expected_values: Sequence[str]) -> bool:
|
| 172 |
-
cand_norm = normalize_answer(candidate)
|
| 173 |
-
if not cand_norm:
|
| 174 |
-
return False
|
| 175 |
-
for expected in expected_values:
|
| 176 |
-
exp_norm = normalize_answer(expected)
|
| 177 |
-
if not exp_norm:
|
| 178 |
-
continue
|
| 179 |
-
if exp_norm in cand_norm or cand_norm in exp_norm:
|
| 180 |
-
return True
|
| 181 |
-
boxed = re.findall(r"\\boxed\{([^{}]+)\}", cand_norm)
|
| 182 |
-
if boxed and any(exp_norm in item for item in boxed):
|
| 183 |
-
return True
|
| 184 |
-
return False
|
| 185 |
-
|
| 186 |
-
|
| 187 |
def load_model_and_tokenizer(
|
| 188 |
base_model: str,
|
| 189 |
adapter_path: Optional[Path],
|
| 190 |
trust_remote_code: bool,
|
| 191 |
-
) ->
|
| 192 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True)
|
| 193 |
if tokenizer.pad_token is None:
|
| 194 |
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
|
|
@@ -207,55 +344,126 @@ def load_model_and_tokenizer(
|
|
| 207 |
return model, tokenizer
|
| 208 |
|
| 209 |
|
| 210 |
-
def
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
data_cfg = cfg.get("data", {})
|
| 214 |
-
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
|
|
|
|
|
|
| 217 |
if args.k < 1:
|
| 218 |
raise ValueError("--k must be >= 1.")
|
| 219 |
if args.max_samples < 1:
|
| 220 |
raise ValueError("--max-samples must be >= 1.")
|
| 221 |
if args.max_new_tokens < 1:
|
| 222 |
raise ValueError("--max-new-tokens must be >= 1.")
|
|
|
|
|
|
|
| 223 |
if args.temperature <= 0:
|
| 224 |
raise ValueError("--temperature must be > 0.")
|
| 225 |
if not 0 < args.top_p <= 1:
|
| 226 |
raise ValueError("--top-p must be in (0, 1].")
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
base_model = args.base_model or as_text(model_cfg.get("base_model"))
|
| 229 |
if not base_model:
|
| 230 |
raise ValueError("Base model is required via --base-model or config.model.base_model.")
|
| 231 |
if args.adapter_path is not None and not args.adapter_path.exists():
|
| 232 |
raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}")
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
model, tokenizer = load_model_and_tokenizer(
|
| 235 |
base_model=base_model,
|
| 236 |
adapter_path=args.adapter_path,
|
| 237 |
trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
|
| 238 |
)
|
| 239 |
|
| 240 |
-
|
| 241 |
-
raise FileNotFoundError(f"Evaluation file not found: {args.eval_file}")
|
| 242 |
-
ds = load_dataset("parquet", data_files={"eval": str(args.eval_file)})["eval"]
|
| 243 |
-
|
| 244 |
if args.max_samples > 0 and args.max_samples < len(ds):
|
| 245 |
ds = ds.select(range(args.max_samples))
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
for row in ds:
|
| 253 |
expected_values = flatten_expected(row, data_cfg)
|
| 254 |
if not expected_values:
|
|
|
|
| 255 |
continue
|
|
|
|
| 256 |
prompt_text = build_prompt_text(row, tokenizer, data_cfg)
|
| 257 |
-
inputs = tokenizer(
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 260 |
|
| 261 |
with torch.no_grad():
|
|
@@ -269,44 +477,119 @@ def main() -> None:
|
|
| 269 |
pad_token_id=tokenizer.pad_token_id,
|
| 270 |
eos_token_id=tokenizer.eos_token_id,
|
| 271 |
)
|
|
|
|
| 272 |
generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
| 273 |
candidates = [extract_candidate_text(text, prompt_text) for text in generations]
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
)
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
"base_model": base_model,
|
| 295 |
"adapter_path": str(args.adapter_path) if args.adapter_path is not None else None,
|
| 296 |
-
"eval_file": str(
|
| 297 |
-
"
|
|
|
|
|
|
|
|
|
|
| 298 |
"k": args.k,
|
| 299 |
"pass_at_1": pass_at_1,
|
| 300 |
"pass_at_k": pass_at_k,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
"temperature": args.temperature,
|
| 302 |
"top_p": args.top_p,
|
| 303 |
"max_new_tokens": args.max_new_tokens,
|
| 304 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
}
|
|
|
|
| 306 |
args.output_json.parent.mkdir(parents=True, exist_ok=True)
|
| 307 |
args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8")
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
print(f"Saved report to {args.output_json}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
|
| 312 |
if __name__ == "__main__":
|
|
|
|
| 7 |
import json
|
| 8 |
import re
|
| 9 |
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import yaml
|
|
|
|
| 15 |
from peft import PeftModel
|
| 16 |
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
| 17 |
|
| 18 |
+
SCRIPT_ROOT = Path(__file__).resolve().parents[1]
|
| 19 |
+
DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml"
|
| 20 |
+
DEFAULT_OUTPUT_JSON = SCRIPT_ROOT / "runs" / "latest_eval_report.json"
|
| 21 |
+
|
| 22 |
+
BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}")
|
| 23 |
+
SPACE_RE = re.compile(r"\s+")
|
| 24 |
+
|
| 25 |
|
| 26 |
def parse_args() -> argparse.Namespace:
|
| 27 |
parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.")
|
| 28 |
parser.add_argument(
|
| 29 |
"--config",
|
| 30 |
type=Path,
|
| 31 |
+
default=DEFAULT_CONFIG_PATH,
|
| 32 |
help="Training config used for prompt formatting defaults.",
|
| 33 |
)
|
| 34 |
parser.add_argument(
|
|
|
|
| 46 |
parser.add_argument(
|
| 47 |
"--eval-file",
|
| 48 |
type=Path,
|
| 49 |
+
default=None,
|
| 50 |
+
help="Parquet split used for evaluation (defaults to post_eval.eval_file or data.default_validation_file).",
|
| 51 |
)
|
| 52 |
parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.")
|
| 53 |
parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.")
|
| 54 |
parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.")
|
| 55 |
+
parser.add_argument("--max-input-length", type=int, default=4096, help="Prompt tokenization length cap.")
|
| 56 |
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.")
|
| 57 |
parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.")
|
| 58 |
parser.add_argument("--seed", type=int, default=17, help="Random seed.")
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--progress-every",
|
| 61 |
+
type=int,
|
| 62 |
+
default=25,
|
| 63 |
+
help="Print progress every N evaluated rows (0 disables).",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--sample-records",
|
| 67 |
+
type=int,
|
| 68 |
+
default=30,
|
| 69 |
+
help="How many sample records to store in report.",
|
| 70 |
+
)
|
| 71 |
parser.add_argument(
|
| 72 |
"--output-json",
|
| 73 |
type=Path,
|
| 74 |
+
default=DEFAULT_OUTPUT_JSON,
|
| 75 |
help="Where to write evaluation report.",
|
| 76 |
)
|
| 77 |
return parser.parse_args()
|
|
|
|
| 85 |
return str(value).strip()
|
| 86 |
|
| 87 |
|
| 88 |
+
def as_float(value: Any, default: float) -> float:
|
| 89 |
+
if value is None:
|
| 90 |
+
return default
|
| 91 |
+
try:
|
| 92 |
+
return float(value)
|
| 93 |
+
except (TypeError, ValueError):
|
| 94 |
+
return default
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def as_int(value: Any, default: int) -> int:
|
| 98 |
+
if value is None:
|
| 99 |
+
return default
|
| 100 |
+
try:
|
| 101 |
+
return int(value)
|
| 102 |
+
except (TypeError, ValueError):
|
| 103 |
+
return default
|
| 104 |
+
|
| 105 |
+
|
| 106 |
def load_config(path: Path) -> Dict[str, Any]:
|
| 107 |
cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
|
| 108 |
if not isinstance(cfg, dict):
|
|
|
|
| 112 |
|
| 113 |
def normalize_answer(text: str) -> str:
|
| 114 |
text = text.strip().lower()
|
|
|
|
| 115 |
text = text.replace("$", "")
|
| 116 |
+
text = text.replace("\\left", "").replace("\\right", "")
|
| 117 |
+
text = text.replace("\\,", "").replace("\\!", "").replace("\\;", "")
|
| 118 |
+
text = SPACE_RE.sub(" ", text)
|
| 119 |
+
return text.strip(" .")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def extract_boxed_values(text: str) -> List[str]:
|
| 123 |
+
return [normalize_answer(match) for match in BOXED_RE.findall(text or "") if normalize_answer(match)]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def parse_numeric_value(text: str) -> Optional[float]:
|
| 127 |
+
normalized = normalize_answer(text)
|
| 128 |
+
if not normalized:
|
| 129 |
+
return None
|
| 130 |
+
candidate = normalized.replace(",", "")
|
| 131 |
+
if re.fullmatch(r"[-+]?\d+\s*/\s*[-+]?\d+", candidate):
|
| 132 |
+
left, right = candidate.split("/", maxsplit=1)
|
| 133 |
+
try:
|
| 134 |
+
numerator = float(left.strip())
|
| 135 |
+
denominator = float(right.strip())
|
| 136 |
+
except ValueError:
|
| 137 |
+
return None
|
| 138 |
+
if denominator == 0:
|
| 139 |
+
return None
|
| 140 |
+
return numerator / denominator
|
| 141 |
+
if re.fullmatch(r"[-+]?(?:\d+\.\d*|\d*\.\d+|\d+)(?:[eE][-+]?\d+)?", candidate):
|
| 142 |
+
try:
|
| 143 |
+
return float(candidate)
|
| 144 |
+
except ValueError:
|
| 145 |
+
return None
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def approximately_equal(left: float, right: float) -> bool:
|
| 150 |
+
tolerance = 1e-6 * max(1.0, abs(left), abs(right))
|
| 151 |
+
return abs(left - right) <= tolerance
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def match_candidate(candidate: str, expected_values: Sequence[str]) -> Dict[str, Any]:
|
| 155 |
+
cand_norm = normalize_answer(candidate)
|
| 156 |
+
if not cand_norm:
|
| 157 |
+
return {
|
| 158 |
+
"match": False,
|
| 159 |
+
"exact": False,
|
| 160 |
+
"boxed": False,
|
| 161 |
+
"numeric": False,
|
| 162 |
+
"reason": "empty_candidate",
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
cand_boxed = extract_boxed_values(candidate)
|
| 166 |
+
cand_num = parse_numeric_value(cand_norm)
|
| 167 |
+
|
| 168 |
+
substring_hit = False
|
| 169 |
+
boxed_hit = False
|
| 170 |
+
numeric_hit = False
|
| 171 |
+
|
| 172 |
+
for expected in expected_values:
|
| 173 |
+
exp_norm = normalize_answer(expected)
|
| 174 |
+
if not exp_norm:
|
| 175 |
+
continue
|
| 176 |
+
|
| 177 |
+
if cand_norm == exp_norm:
|
| 178 |
+
return {
|
| 179 |
+
"match": True,
|
| 180 |
+
"exact": True,
|
| 181 |
+
"boxed": exp_norm in cand_boxed,
|
| 182 |
+
"numeric": False,
|
| 183 |
+
"reason": "exact",
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
if exp_norm in cand_norm or cand_norm in exp_norm:
|
| 187 |
+
substring_hit = True
|
| 188 |
+
|
| 189 |
+
expected_boxed = extract_boxed_values(expected)
|
| 190 |
+
for cand_box in cand_boxed:
|
| 191 |
+
if cand_box == exp_norm or exp_norm in cand_box or cand_box in exp_norm:
|
| 192 |
+
boxed_hit = True
|
| 193 |
+
for exp_box in expected_boxed:
|
| 194 |
+
if cand_norm == exp_box or exp_box in cand_norm or cand_norm in exp_box:
|
| 195 |
+
boxed_hit = True
|
| 196 |
+
|
| 197 |
+
exp_num = parse_numeric_value(exp_norm)
|
| 198 |
+
if cand_num is not None and exp_num is not None and approximately_equal(cand_num, exp_num):
|
| 199 |
+
numeric_hit = True
|
| 200 |
+
|
| 201 |
+
if boxed_hit:
|
| 202 |
+
return {
|
| 203 |
+
"match": True,
|
| 204 |
+
"exact": False,
|
| 205 |
+
"boxed": True,
|
| 206 |
+
"numeric": numeric_hit,
|
| 207 |
+
"reason": "boxed",
|
| 208 |
+
}
|
| 209 |
+
if numeric_hit:
|
| 210 |
+
return {
|
| 211 |
+
"match": True,
|
| 212 |
+
"exact": False,
|
| 213 |
+
"boxed": False,
|
| 214 |
+
"numeric": True,
|
| 215 |
+
"reason": "numeric",
|
| 216 |
+
}
|
| 217 |
+
if substring_hit:
|
| 218 |
+
return {
|
| 219 |
+
"match": True,
|
| 220 |
+
"exact": False,
|
| 221 |
+
"boxed": False,
|
| 222 |
+
"numeric": False,
|
| 223 |
+
"reason": "substring",
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
return {
|
| 227 |
+
"match": False,
|
| 228 |
+
"exact": False,
|
| 229 |
+
"boxed": False,
|
| 230 |
+
"numeric": False,
|
| 231 |
+
"reason": "no_match",
|
| 232 |
+
}
|
| 233 |
|
| 234 |
|
| 235 |
def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]:
|
|
|
|
| 321 |
return full_generation.strip()
|
| 322 |
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
def load_model_and_tokenizer(
|
| 325 |
base_model: str,
|
| 326 |
adapter_path: Optional[Path],
|
| 327 |
trust_remote_code: bool,
|
| 328 |
+
) -> Tuple[Any, AutoTokenizer]:
|
| 329 |
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True)
|
| 330 |
if tokenizer.pad_token is None:
|
| 331 |
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
|
|
|
|
| 344 |
return model, tokenizer
|
| 345 |
|
| 346 |
|
| 347 |
+
def make_bucket() -> Dict[str, Any]:
|
| 348 |
+
return {
|
| 349 |
+
"evaluated_rows": 0,
|
| 350 |
+
"pass_at_1_hits": 0,
|
| 351 |
+
"pass_at_k_hits": 0,
|
| 352 |
+
"exact_at_1_hits": 0,
|
| 353 |
+
"exact_at_k_hits": 0,
|
| 354 |
+
"boxed_at_k_hits": 0,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def update_bucket(bucket: Dict[str, Any], hit1: bool, hitk: bool, exact1: bool, exactk: bool, boxedk: bool) -> None:
|
| 359 |
+
bucket["evaluated_rows"] += 1
|
| 360 |
+
if hit1:
|
| 361 |
+
bucket["pass_at_1_hits"] += 1
|
| 362 |
+
if hitk:
|
| 363 |
+
bucket["pass_at_k_hits"] += 1
|
| 364 |
+
if exact1:
|
| 365 |
+
bucket["exact_at_1_hits"] += 1
|
| 366 |
+
if exactk:
|
| 367 |
+
bucket["exact_at_k_hits"] += 1
|
| 368 |
+
if boxedk:
|
| 369 |
+
bucket["boxed_at_k_hits"] += 1
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def finalize_bucket(bucket: Dict[str, Any]) -> Dict[str, Any]:
|
| 373 |
+
total = max(int(bucket.get("evaluated_rows", 0)), 1)
|
| 374 |
+
rows = int(bucket.get("evaluated_rows", 0))
|
| 375 |
+
return {
|
| 376 |
+
"evaluated_rows": rows,
|
| 377 |
+
"pass_at_1": float(bucket.get("pass_at_1_hits", 0)) / total,
|
| 378 |
+
"pass_at_k": float(bucket.get("pass_at_k_hits", 0)) / total,
|
| 379 |
+
"exact_at_1": float(bucket.get("exact_at_1_hits", 0)) / total,
|
| 380 |
+
"exact_at_k": float(bucket.get("exact_at_k_hits", 0)) / total,
|
| 381 |
+
"boxed_at_k": float(bucket.get("boxed_at_k_hits", 0)) / total,
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def resolve_eval_file(arg_eval_file: Optional[Path], cfg: Dict[str, Any]) -> Path:
|
| 386 |
+
if arg_eval_file is not None:
|
| 387 |
+
return arg_eval_file
|
| 388 |
+
post_eval_cfg = cfg.get("post_eval", {})
|
| 389 |
data_cfg = cfg.get("data", {})
|
| 390 |
+
for candidate in (
|
| 391 |
+
as_text(post_eval_cfg.get("eval_file")),
|
| 392 |
+
as_text(data_cfg.get("default_validation_file")),
|
| 393 |
+
"data/releases/v1/test.parquet",
|
| 394 |
+
"workspace/data/releases/v1/test.parquet",
|
| 395 |
+
):
|
| 396 |
+
if not candidate:
|
| 397 |
+
continue
|
| 398 |
+
path = Path(candidate)
|
| 399 |
+
if path.exists():
|
| 400 |
+
return path
|
| 401 |
+
return Path("data/releases/v1/test.parquet")
|
| 402 |
|
| 403 |
+
|
| 404 |
+
def run_evaluation(args: argparse.Namespace) -> Dict[str, Any]:
|
| 405 |
if args.k < 1:
|
| 406 |
raise ValueError("--k must be >= 1.")
|
| 407 |
if args.max_samples < 1:
|
| 408 |
raise ValueError("--max-samples must be >= 1.")
|
| 409 |
if args.max_new_tokens < 1:
|
| 410 |
raise ValueError("--max-new-tokens must be >= 1.")
|
| 411 |
+
if args.max_input_length < 128:
|
| 412 |
+
raise ValueError("--max-input-length must be >= 128.")
|
| 413 |
if args.temperature <= 0:
|
| 414 |
raise ValueError("--temperature must be > 0.")
|
| 415 |
if not 0 < args.top_p <= 1:
|
| 416 |
raise ValueError("--top-p must be in (0, 1].")
|
| 417 |
|
| 418 |
+
cfg = load_config(args.config)
|
| 419 |
+
data_cfg = cfg.get("data", {})
|
| 420 |
+
model_cfg = cfg.get("model", {})
|
| 421 |
+
set_seed(args.seed)
|
| 422 |
+
|
| 423 |
base_model = args.base_model or as_text(model_cfg.get("base_model"))
|
| 424 |
if not base_model:
|
| 425 |
raise ValueError("Base model is required via --base-model or config.model.base_model.")
|
| 426 |
if args.adapter_path is not None and not args.adapter_path.exists():
|
| 427 |
raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}")
|
| 428 |
|
| 429 |
+
eval_file = resolve_eval_file(args.eval_file, cfg)
|
| 430 |
+
if not eval_file.exists():
|
| 431 |
+
raise FileNotFoundError(f"Evaluation file not found: {eval_file}")
|
| 432 |
+
|
| 433 |
model, tokenizer = load_model_and_tokenizer(
|
| 434 |
base_model=base_model,
|
| 435 |
adapter_path=args.adapter_path,
|
| 436 |
trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
|
| 437 |
)
|
| 438 |
|
| 439 |
+
ds = load_dataset("parquet", data_files={"eval": str(eval_file)})["eval"]
|
|
|
|
|
|
|
|
|
|
| 440 |
if args.max_samples > 0 and args.max_samples < len(ds):
|
| 441 |
ds = ds.select(range(args.max_samples))
|
| 442 |
|
| 443 |
+
totals = make_bucket()
|
| 444 |
+
family_buckets: Dict[str, Dict[str, Any]] = {}
|
| 445 |
+
difficulty_buckets: Dict[str, Dict[str, Any]] = {}
|
| 446 |
+
|
| 447 |
+
processed_rows = 0
|
| 448 |
+
skipped_no_expected = 0
|
| 449 |
+
samples: List[Dict[str, Any]] = []
|
| 450 |
+
|
| 451 |
+
model_device = next(model.parameters()).device
|
| 452 |
+
prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt"
|
| 453 |
|
| 454 |
for row in ds:
|
| 455 |
expected_values = flatten_expected(row, data_cfg)
|
| 456 |
if not expected_values:
|
| 457 |
+
skipped_no_expected += 1
|
| 458 |
continue
|
| 459 |
+
|
| 460 |
prompt_text = build_prompt_text(row, tokenizer, data_cfg)
|
| 461 |
+
inputs = tokenizer(
|
| 462 |
+
prompt_text,
|
| 463 |
+
return_tensors="pt",
|
| 464 |
+
truncation=True,
|
| 465 |
+
max_length=args.max_input_length,
|
| 466 |
+
)
|
| 467 |
inputs = {k: v.to(model_device) for k, v in inputs.items()}
|
| 468 |
|
| 469 |
with torch.no_grad():
|
|
|
|
| 477 |
pad_token_id=tokenizer.pad_token_id,
|
| 478 |
eos_token_id=tokenizer.eos_token_id,
|
| 479 |
)
|
| 480 |
+
|
| 481 |
generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
| 482 |
candidates = [extract_candidate_text(text, prompt_text) for text in generations]
|
| 483 |
+
details = [match_candidate(candidate, expected_values) for candidate in candidates]
|
| 484 |
+
|
| 485 |
+
matches = [bool(item["match"]) for item in details]
|
| 486 |
+
exacts = [bool(item["exact"]) for item in details]
|
| 487 |
+
boxed = [bool(item["boxed"]) for item in details]
|
| 488 |
+
|
| 489 |
+
hit1 = bool(matches and matches[0])
|
| 490 |
+
hitk = bool(any(matches))
|
| 491 |
+
exact1 = bool(exacts and exacts[0])
|
| 492 |
+
exactk = bool(any(exacts))
|
| 493 |
+
boxedk = bool(any(boxed))
|
| 494 |
+
|
| 495 |
+
update_bucket(totals, hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk)
|
| 496 |
+
|
| 497 |
+
family = as_text(row.get("family")) or "__unknown__"
|
| 498 |
+
if family not in family_buckets:
|
| 499 |
+
family_buckets[family] = make_bucket()
|
| 500 |
+
update_bucket(family_buckets[family], hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk)
|
| 501 |
+
|
| 502 |
+
difficulty = as_text(row.get("difficulty")) or "__unknown__"
|
| 503 |
+
if difficulty not in difficulty_buckets:
|
| 504 |
+
difficulty_buckets[difficulty] = make_bucket()
|
| 505 |
+
update_bucket(
|
| 506 |
+
difficulty_buckets[difficulty],
|
| 507 |
+
hit1=hit1,
|
| 508 |
+
hitk=hitk,
|
| 509 |
+
exact1=exact1,
|
| 510 |
+
exactk=exactk,
|
| 511 |
+
boxedk=boxedk,
|
| 512 |
)
|
| 513 |
|
| 514 |
+
processed_rows += 1
|
| 515 |
+
if args.progress_every > 0 and processed_rows % args.progress_every == 0:
|
| 516 |
+
print(f"Progress: evaluated_rows={processed_rows} latest_family={family}")
|
| 517 |
+
|
| 518 |
+
if len(samples) < args.sample_records:
|
| 519 |
+
samples.append(
|
| 520 |
+
{
|
| 521 |
+
"uid": as_text(row.get("uid")),
|
| 522 |
+
"family": family,
|
| 523 |
+
"difficulty": difficulty,
|
| 524 |
+
"prompt": as_text(row.get(prompt_field)),
|
| 525 |
+
"expected_values": expected_values[:5],
|
| 526 |
+
"candidates": candidates,
|
| 527 |
+
"match_details": details,
|
| 528 |
+
"matches": matches,
|
| 529 |
+
}
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
total_eval = int(totals.get("evaluated_rows", 0))
|
| 533 |
+
denominator = max(total_eval, 1)
|
| 534 |
+
|
| 535 |
+
pass_at_1 = float(totals.get("pass_at_1_hits", 0)) / denominator
|
| 536 |
+
pass_at_k = float(totals.get("pass_at_k_hits", 0)) / denominator
|
| 537 |
+
exact_at_1 = float(totals.get("exact_at_1_hits", 0)) / denominator
|
| 538 |
+
exact_at_k = float(totals.get("exact_at_k_hits", 0)) / denominator
|
| 539 |
+
boxed_at_k = float(totals.get("boxed_at_k_hits", 0)) / denominator
|
| 540 |
+
|
| 541 |
+
composite_score = 0.30 * pass_at_1 + 0.50 * pass_at_k + 0.20 * exact_at_k
|
| 542 |
+
|
| 543 |
+
report: Dict[str, Any] = {
|
| 544 |
"base_model": base_model,
|
| 545 |
"adapter_path": str(args.adapter_path) if args.adapter_path is not None else None,
|
| 546 |
+
"eval_file": str(eval_file),
|
| 547 |
+
"config": str(args.config),
|
| 548 |
+
"evaluated_rows": total_eval,
|
| 549 |
+
"skipped_rows_without_targets": skipped_no_expected,
|
| 550 |
+
"requested_rows": len(ds),
|
| 551 |
"k": args.k,
|
| 552 |
"pass_at_1": pass_at_1,
|
| 553 |
"pass_at_k": pass_at_k,
|
| 554 |
+
"exact_at_1": exact_at_1,
|
| 555 |
+
"exact_at_k": exact_at_k,
|
| 556 |
+
"boxed_at_k": boxed_at_k,
|
| 557 |
+
"composite_score": composite_score,
|
| 558 |
"temperature": args.temperature,
|
| 559 |
"top_p": args.top_p,
|
| 560 |
"max_new_tokens": args.max_new_tokens,
|
| 561 |
+
"max_input_length": args.max_input_length,
|
| 562 |
+
"seed": args.seed,
|
| 563 |
+
"family_metrics": {
|
| 564 |
+
key: finalize_bucket(family_buckets[key])
|
| 565 |
+
for key in sorted(family_buckets.keys())
|
| 566 |
+
},
|
| 567 |
+
"difficulty_metrics": {
|
| 568 |
+
key: finalize_bucket(difficulty_buckets[key])
|
| 569 |
+
for key in sorted(difficulty_buckets.keys())
|
| 570 |
+
},
|
| 571 |
+
"samples": samples,
|
| 572 |
}
|
| 573 |
+
|
| 574 |
args.output_json.parent.mkdir(parents=True, exist_ok=True)
|
| 575 |
args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8")
|
| 576 |
+
|
| 577 |
+
summary_view = {
|
| 578 |
+
"evaluated_rows": total_eval,
|
| 579 |
+
"pass_at_1": pass_at_1,
|
| 580 |
+
"pass_at_k": pass_at_k,
|
| 581 |
+
"exact_at_k": exact_at_k,
|
| 582 |
+
"composite_score": composite_score,
|
| 583 |
+
"k": args.k,
|
| 584 |
+
}
|
| 585 |
+
print(json.dumps(summary_view, indent=2))
|
| 586 |
print(f"Saved report to {args.output_json}")
|
| 587 |
+
return report
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def main() -> None:
|
| 591 |
+
args = parse_args()
|
| 592 |
+
run_evaluation(args)
|
| 593 |
|
| 594 |
|
| 595 |
if __name__ == "__main__":
|
scripts/train_sota.py
CHANGED
|
@@ -4,10 +4,13 @@
|
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import argparse
|
|
|
|
| 7 |
import json
|
| 8 |
import os
|
|
|
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
-
from typing import Any, Dict, Optional, Tuple
|
| 11 |
|
| 12 |
import torch
|
| 13 |
import yaml
|
|
@@ -25,7 +28,9 @@ from transformers import (
|
|
| 25 |
set_seed,
|
| 26 |
)
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def parse_args() -> argparse.Namespace:
|
|
@@ -41,6 +46,21 @@ def parse_args() -> argparse.Namespace:
|
|
| 41 |
parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.")
|
| 42 |
parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.")
|
| 43 |
parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
parser.add_argument(
|
| 45 |
"--start-stage",
|
| 46 |
type=int,
|
|
@@ -93,6 +113,19 @@ def as_int(value: Any, default: int) -> int:
|
|
| 93 |
return default
|
| 94 |
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
def load_config(path: Path) -> Dict[str, Any]:
|
| 97 |
if not path.exists():
|
| 98 |
raise FileNotFoundError(f"Config not found: {path}")
|
|
@@ -108,6 +141,8 @@ def load_config(path: Path) -> Dict[str, Any]:
|
|
| 108 |
cfg.setdefault("training_defaults", {})
|
| 109 |
cfg.setdefault("hub", {})
|
| 110 |
cfg.setdefault("credentials", {})
|
|
|
|
|
|
|
| 111 |
return cfg
|
| 112 |
|
| 113 |
|
|
@@ -123,6 +158,16 @@ def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None:
|
|
| 123 |
if args.no_push_to_hub:
|
| 124 |
cfg.setdefault("hub", {})["push_to_hub"] = False
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
|
| 128 |
token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
|
|
@@ -133,9 +178,17 @@ def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
|
|
| 133 |
if path.exists():
|
| 134 |
data = json.loads(path.read_text(encoding="utf-8"))
|
| 135 |
if token is None:
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
if username is None:
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
return token, username
|
| 140 |
|
| 141 |
|
|
@@ -556,6 +609,228 @@ def push_folder(
|
|
| 556 |
api.upload_folder(**kwargs)
|
| 557 |
|
| 558 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
def main() -> None:
|
| 560 |
args = parse_args()
|
| 561 |
cfg = load_config(args.config)
|
|
@@ -564,17 +839,17 @@ def main() -> None:
|
|
| 564 |
seed = as_int(cfg.get("global", {}).get("seed"), 17)
|
| 565 |
set_seed(seed)
|
| 566 |
|
| 567 |
-
output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "
|
| 568 |
output_root.mkdir(parents=True, exist_ok=True)
|
| 569 |
|
| 570 |
token, username = resolve_auth(cfg)
|
| 571 |
repo_id = resolve_repo_id(cfg, username=username, output_root=output_root)
|
| 572 |
-
|
| 573 |
-
if args.dry_run and
|
| 574 |
print("Dry-run enabled. Disabling push_to_hub for this run.")
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
if
|
| 578 |
if token is None:
|
| 579 |
raise ValueError("Hub push requested but token is missing.")
|
| 580 |
if repo_id is None:
|
|
@@ -585,8 +860,9 @@ def main() -> None:
|
|
| 585 |
model = None
|
| 586 |
else:
|
| 587 |
model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
|
|
|
|
| 588 |
data_cfg = cfg["data"]
|
| 589 |
-
stage_reports = []
|
| 590 |
|
| 591 |
start_stage = max(1, args.start_stage)
|
| 592 |
stages = cfg["stages"]
|
|
@@ -607,15 +883,18 @@ def main() -> None:
|
|
| 607 |
raw = load_dataset("parquet", data_files=split_files)
|
| 608 |
train_rows_before = len(raw["train"])
|
| 609 |
valid_rows_before = len(raw["validation"])
|
|
|
|
| 610 |
filters = stage.get("filters", {})
|
| 611 |
raw["train"] = apply_filters(raw["train"], filters)
|
| 612 |
raw["validation"] = apply_filters(raw["validation"], filters)
|
| 613 |
train_rows_after_filter = len(raw["train"])
|
| 614 |
valid_rows_after_filter = len(raw["validation"])
|
|
|
|
| 615 |
raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples"))
|
| 616 |
raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples"))
|
| 617 |
train_rows_selected = len(raw["train"])
|
| 618 |
valid_rows_selected = len(raw["validation"])
|
|
|
|
| 619 |
print(
|
| 620 |
f"[stage {index}] rows train: {train_rows_before} -> {train_rows_after_filter} -> {train_rows_selected}; "
|
| 621 |
f"validation: {valid_rows_before} -> {valid_rows_after_filter} -> {valid_rows_selected}"
|
|
@@ -627,19 +906,20 @@ def main() -> None:
|
|
| 627 |
sample_row = raw["train"][0]
|
| 628 |
_ = build_prompt_text(sample_row, tokenizer, data_cfg)
|
| 629 |
_ = build_answer_block(sample_row, data_cfg)
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
|
|
|
| 643 |
print(f"[stage {index}] Dry-run checks passed.")
|
| 644 |
continue
|
| 645 |
|
|
@@ -670,33 +950,36 @@ def main() -> None:
|
|
| 670 |
trainer.log_metrics("train", train_result.metrics)
|
| 671 |
trainer.save_metrics("train", train_result.metrics)
|
| 672 |
trainer.save_state()
|
|
|
|
| 673 |
eval_metrics = None
|
| 674 |
if eval_dataset is not None:
|
| 675 |
eval_metrics = trainer.evaluate()
|
| 676 |
trainer.log_metrics("eval", eval_metrics)
|
| 677 |
trainer.save_metrics("eval", eval_metrics)
|
|
|
|
| 678 |
trainer.save_model(str(stage_output_dir))
|
| 679 |
tokenizer.save_pretrained(str(stage_output_dir))
|
| 680 |
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
|
|
|
| 697 |
print(
|
| 698 |
-
f"[stage {index}] Completed: train_rows={
|
| 699 |
-
f"eval_rows={
|
| 700 |
)
|
| 701 |
|
| 702 |
if args.dry_run:
|
|
@@ -720,17 +1003,59 @@ def main() -> None:
|
|
| 720 |
model.save_pretrained(str(final_dir))
|
| 721 |
tokenizer.save_pretrained(str(final_dir))
|
| 722 |
|
| 723 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
"config_path": str(args.config),
|
| 725 |
"repo_id": repo_id,
|
| 726 |
"seed": seed,
|
| 727 |
"stages_ran": stage_reports,
|
| 728 |
"final_adapter_dir": str(final_dir),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
summary_path = output_root / "training_summary.json"
|
| 731 |
summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
|
| 732 |
|
| 733 |
-
if
|
| 734 |
api = HfApi(token=token)
|
| 735 |
api.create_repo(
|
| 736 |
repo_id=repo_id,
|
|
@@ -740,17 +1065,22 @@ def main() -> None:
|
|
| 740 |
)
|
| 741 |
commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter."
|
| 742 |
push_folder(api, repo_id, final_dir, commit_message=commit_message)
|
|
|
|
| 743 |
if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)):
|
| 744 |
for report in stage_reports:
|
| 745 |
-
|
| 746 |
-
|
|
|
|
|
|
|
|
|
|
| 747 |
push_folder(
|
| 748 |
api,
|
| 749 |
repo_id,
|
| 750 |
stage_dir,
|
| 751 |
-
commit_message=f"Upload stage checkpoint {report
|
| 752 |
path_in_repo=path_in_repo,
|
| 753 |
)
|
|
|
|
| 754 |
api.upload_file(
|
| 755 |
path_or_fileobj=str(summary_path),
|
| 756 |
path_in_repo="training_summary.json",
|
|
@@ -758,6 +1088,16 @@ def main() -> None:
|
|
| 758 |
repo_type="model",
|
| 759 |
commit_message="Upload training summary for SOTA curriculum run.",
|
| 760 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 761 |
print(f"Pushed training artifacts to https://huggingface.co/{repo_id}")
|
| 762 |
|
| 763 |
print(f"Training complete. Final adapter: {final_dir}")
|
|
|
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import argparse
|
| 7 |
+
import gc
|
| 8 |
import json
|
| 9 |
import os
|
| 10 |
+
import subprocess
|
| 11 |
+
import sys
|
| 12 |
from pathlib import Path
|
| 13 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 14 |
|
| 15 |
import torch
|
| 16 |
import yaml
|
|
|
|
| 28 |
set_seed,
|
| 29 |
)
|
| 30 |
|
| 31 |
+
SCRIPT_ROOT = Path(__file__).resolve().parents[1]
|
| 32 |
+
DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml"
|
| 33 |
+
DEFAULT_EVAL_SCRIPT = Path(__file__).resolve().with_name("eval_sota.py")
|
| 34 |
|
| 35 |
|
| 36 |
def parse_args() -> argparse.Namespace:
|
|
|
|
| 46 |
parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.")
|
| 47 |
parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.")
|
| 48 |
parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.")
|
| 49 |
+
parser.add_argument(
|
| 50 |
+
"--run-post-eval",
|
| 51 |
+
action="store_true",
|
| 52 |
+
help="Force post-training evaluation enabled.",
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--no-post-eval",
|
| 56 |
+
action="store_true",
|
| 57 |
+
help="Force post-training evaluation disabled.",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--skip-quality-gate",
|
| 61 |
+
action="store_true",
|
| 62 |
+
help="Disable quality gate checks for this run.",
|
| 63 |
+
)
|
| 64 |
parser.add_argument(
|
| 65 |
"--start-stage",
|
| 66 |
type=int,
|
|
|
|
| 113 |
return default
|
| 114 |
|
| 115 |
|
| 116 |
+
def as_bool(value: Any, default: bool = False) -> bool:
|
| 117 |
+
if value is None:
|
| 118 |
+
return default
|
| 119 |
+
if isinstance(value, bool):
|
| 120 |
+
return value
|
| 121 |
+
text = as_text(value).lower()
|
| 122 |
+
if text in {"1", "true", "yes", "y", "on"}:
|
| 123 |
+
return True
|
| 124 |
+
if text in {"0", "false", "no", "n", "off"}:
|
| 125 |
+
return False
|
| 126 |
+
return default
|
| 127 |
+
|
| 128 |
+
|
| 129 |
def load_config(path: Path) -> Dict[str, Any]:
|
| 130 |
if not path.exists():
|
| 131 |
raise FileNotFoundError(f"Config not found: {path}")
|
|
|
|
| 141 |
cfg.setdefault("training_defaults", {})
|
| 142 |
cfg.setdefault("hub", {})
|
| 143 |
cfg.setdefault("credentials", {})
|
| 144 |
+
cfg.setdefault("post_eval", {})
|
| 145 |
+
cfg.setdefault("quality_gate", {})
|
| 146 |
return cfg
|
| 147 |
|
| 148 |
|
|
|
|
| 158 |
if args.no_push_to_hub:
|
| 159 |
cfg.setdefault("hub", {})["push_to_hub"] = False
|
| 160 |
|
| 161 |
+
if args.run_post_eval and args.no_post_eval:
|
| 162 |
+
raise ValueError("Cannot set both --run-post-eval and --no-post-eval.")
|
| 163 |
+
if args.run_post_eval:
|
| 164 |
+
cfg.setdefault("post_eval", {})["enabled"] = True
|
| 165 |
+
if args.no_post_eval:
|
| 166 |
+
cfg.setdefault("post_eval", {})["enabled"] = False
|
| 167 |
+
|
| 168 |
+
if args.skip_quality_gate:
|
| 169 |
+
cfg.setdefault("quality_gate", {})["enabled"] = False
|
| 170 |
+
|
| 171 |
|
| 172 |
def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
|
| 173 |
token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
|
|
|
|
| 178 |
if path.exists():
|
| 179 |
data = json.loads(path.read_text(encoding="utf-8"))
|
| 180 |
if token is None:
|
| 181 |
+
for key in ("token", "key", "api_key", "hf_token"):
|
| 182 |
+
candidate = as_text(data.get(key))
|
| 183 |
+
if candidate:
|
| 184 |
+
token = candidate
|
| 185 |
+
break
|
| 186 |
if username is None:
|
| 187 |
+
for key in ("username", "user", "owner"):
|
| 188 |
+
candidate = as_text(data.get(key))
|
| 189 |
+
if candidate:
|
| 190 |
+
username = candidate
|
| 191 |
+
break
|
| 192 |
return token, username
|
| 193 |
|
| 194 |
|
|
|
|
| 609 |
api.upload_folder(**kwargs)
|
| 610 |
|
| 611 |
|
| 612 |
+
def extract_final_eval_loss(stage_reports: List[Dict[str, Any]]) -> Optional[float]:
|
| 613 |
+
for report in reversed(stage_reports):
|
| 614 |
+
eval_metrics = report.get("eval_metrics")
|
| 615 |
+
if not isinstance(eval_metrics, dict):
|
| 616 |
+
continue
|
| 617 |
+
value = eval_metrics.get("eval_loss")
|
| 618 |
+
if value is None:
|
| 619 |
+
continue
|
| 620 |
+
try:
|
| 621 |
+
return float(value)
|
| 622 |
+
except (TypeError, ValueError):
|
| 623 |
+
continue
|
| 624 |
+
return None
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def release_model_memory(model: Any) -> None:
|
| 628 |
+
try:
|
| 629 |
+
model.to("cpu")
|
| 630 |
+
except Exception:
|
| 631 |
+
pass
|
| 632 |
+
if torch.cuda.is_available():
|
| 633 |
+
torch.cuda.empty_cache()
|
| 634 |
+
gc.collect()
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
def run_post_eval(
|
| 638 |
+
cfg: Dict[str, Any],
|
| 639 |
+
config_path: Path,
|
| 640 |
+
output_root: Path,
|
| 641 |
+
final_adapter_dir: Path,
|
| 642 |
+
) -> Optional[Dict[str, Any]]:
|
| 643 |
+
post_cfg = cfg.get("post_eval", {})
|
| 644 |
+
if not as_bool(post_cfg.get("enabled"), False):
|
| 645 |
+
return None
|
| 646 |
+
|
| 647 |
+
eval_script = DEFAULT_EVAL_SCRIPT
|
| 648 |
+
if not eval_script.exists():
|
| 649 |
+
raise FileNotFoundError(f"Post-eval enabled but eval script is missing: {eval_script}")
|
| 650 |
+
|
| 651 |
+
data_cfg = cfg.get("data", {})
|
| 652 |
+
eval_file = Path(
|
| 653 |
+
as_text(post_cfg.get("eval_file"))
|
| 654 |
+
or as_text(data_cfg.get("default_validation_file"))
|
| 655 |
+
or "data/releases/v1/test.parquet"
|
| 656 |
+
)
|
| 657 |
+
if not eval_file.exists():
|
| 658 |
+
raise FileNotFoundError(f"Post-eval file not found: {eval_file}")
|
| 659 |
+
|
| 660 |
+
output_json = Path(as_text(post_cfg.get("output_json")) or str(output_root / "post_eval_report.json"))
|
| 661 |
+
base_model = as_text(cfg.get("model", {}).get("base_model"))
|
| 662 |
+
if not base_model:
|
| 663 |
+
raise ValueError("model.base_model is required for post-eval.")
|
| 664 |
+
|
| 665 |
+
cmd = [
|
| 666 |
+
sys.executable,
|
| 667 |
+
str(eval_script),
|
| 668 |
+
"--config",
|
| 669 |
+
str(config_path),
|
| 670 |
+
"--base-model",
|
| 671 |
+
base_model,
|
| 672 |
+
"--adapter-path",
|
| 673 |
+
str(final_adapter_dir),
|
| 674 |
+
"--eval-file",
|
| 675 |
+
str(eval_file),
|
| 676 |
+
"--max-samples",
|
| 677 |
+
str(as_int(post_cfg.get("max_samples"), 300)),
|
| 678 |
+
"--k",
|
| 679 |
+
str(as_int(post_cfg.get("k"), 4)),
|
| 680 |
+
"--max-new-tokens",
|
| 681 |
+
str(as_int(post_cfg.get("max_new_tokens"), 256)),
|
| 682 |
+
"--temperature",
|
| 683 |
+
str(as_float(post_cfg.get("temperature"), 0.7)),
|
| 684 |
+
"--top-p",
|
| 685 |
+
str(as_float(post_cfg.get("top_p"), 0.95)),
|
| 686 |
+
"--seed",
|
| 687 |
+
str(as_int(post_cfg.get("seed"), as_int(cfg.get("global", {}).get("seed"), 17))),
|
| 688 |
+
"--output-json",
|
| 689 |
+
str(output_json),
|
| 690 |
+
]
|
| 691 |
+
print(f"Running post-training eval: {' '.join(cmd)}")
|
| 692 |
+
completed = subprocess.run(cmd, check=False)
|
| 693 |
+
if completed.returncode != 0:
|
| 694 |
+
raise RuntimeError(f"Post-training evaluation failed with exit code {completed.returncode}.")
|
| 695 |
+
|
| 696 |
+
if not output_json.exists():
|
| 697 |
+
raise FileNotFoundError(f"Post-eval report was not created: {output_json}")
|
| 698 |
+
|
| 699 |
+
report = json.loads(output_json.read_text(encoding="utf-8"))
|
| 700 |
+
return {
|
| 701 |
+
"enabled": True,
|
| 702 |
+
"report_path": str(output_json),
|
| 703 |
+
"report": report,
|
| 704 |
+
"command": cmd,
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def evaluate_quality_gate(
|
| 709 |
+
stage_reports: List[Dict[str, Any]],
|
| 710 |
+
post_eval_result: Optional[Dict[str, Any]],
|
| 711 |
+
gate_cfg: Dict[str, Any],
|
| 712 |
+
) -> Dict[str, Any]:
|
| 713 |
+
enabled = as_bool(gate_cfg.get("enabled"), False)
|
| 714 |
+
result: Dict[str, Any] = {
|
| 715 |
+
"enabled": enabled,
|
| 716 |
+
"passed": True,
|
| 717 |
+
"violations": [],
|
| 718 |
+
"checks": [],
|
| 719 |
+
}
|
| 720 |
+
if not enabled:
|
| 721 |
+
return result
|
| 722 |
+
|
| 723 |
+
violations: List[str] = []
|
| 724 |
+
checks: List[Dict[str, Any]] = []
|
| 725 |
+
|
| 726 |
+
final_eval_loss = extract_final_eval_loss(stage_reports)
|
| 727 |
+
max_final_eval_loss = gate_cfg.get("max_final_eval_loss")
|
| 728 |
+
if max_final_eval_loss is not None:
|
| 729 |
+
threshold = as_float(max_final_eval_loss, 0.0)
|
| 730 |
+
checks.append(
|
| 731 |
+
{
|
| 732 |
+
"name": "max_final_eval_loss",
|
| 733 |
+
"actual": final_eval_loss,
|
| 734 |
+
"threshold": threshold,
|
| 735 |
+
}
|
| 736 |
+
)
|
| 737 |
+
if final_eval_loss is None:
|
| 738 |
+
violations.append("Final stage eval_loss is missing for max_final_eval_loss gate.")
|
| 739 |
+
elif final_eval_loss > threshold:
|
| 740 |
+
violations.append(
|
| 741 |
+
f"Final eval_loss {final_eval_loss:.4f} exceeds threshold {threshold:.4f}."
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
report: Optional[Dict[str, Any]] = None
|
| 745 |
+
if isinstance(post_eval_result, dict):
|
| 746 |
+
loaded = post_eval_result.get("report")
|
| 747 |
+
if isinstance(loaded, dict):
|
| 748 |
+
report = loaded
|
| 749 |
+
|
| 750 |
+
require_post_eval = as_bool(gate_cfg.get("require_post_eval"), False)
|
| 751 |
+
if report is None:
|
| 752 |
+
if require_post_eval:
|
| 753 |
+
violations.append("Quality gate requires post-eval metrics, but post-eval report is missing.")
|
| 754 |
+
else:
|
| 755 |
+
evaluated_rows = as_int(report.get("evaluated_rows"), 0)
|
| 756 |
+
min_rows = as_int(gate_cfg.get("min_evaluated_rows"), 0)
|
| 757 |
+
checks.append(
|
| 758 |
+
{
|
| 759 |
+
"name": "min_evaluated_rows",
|
| 760 |
+
"actual": evaluated_rows,
|
| 761 |
+
"threshold": min_rows,
|
| 762 |
+
}
|
| 763 |
+
)
|
| 764 |
+
if evaluated_rows < min_rows:
|
| 765 |
+
violations.append(
|
| 766 |
+
f"Post-eval rows {evaluated_rows} is below minimum required {min_rows}."
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
min_pass_at_1_raw = gate_cfg.get("min_pass_at_1")
|
| 770 |
+
if min_pass_at_1_raw is not None:
|
| 771 |
+
min_pass_at_1 = as_float(min_pass_at_1_raw, 0.0)
|
| 772 |
+
pass_at_1 = as_float(report.get("pass_at_1"), 0.0)
|
| 773 |
+
checks.append(
|
| 774 |
+
{
|
| 775 |
+
"name": "min_pass_at_1",
|
| 776 |
+
"actual": pass_at_1,
|
| 777 |
+
"threshold": min_pass_at_1,
|
| 778 |
+
}
|
| 779 |
+
)
|
| 780 |
+
if pass_at_1 < min_pass_at_1:
|
| 781 |
+
violations.append(
|
| 782 |
+
f"pass@1 {pass_at_1:.4f} is below threshold {min_pass_at_1:.4f}."
|
| 783 |
+
)
|
| 784 |
+
|
| 785 |
+
min_pass_at_k_raw = gate_cfg.get("min_pass_at_k")
|
| 786 |
+
if min_pass_at_k_raw is not None:
|
| 787 |
+
min_pass_at_k = as_float(min_pass_at_k_raw, 0.0)
|
| 788 |
+
pass_at_k = as_float(report.get("pass_at_k"), 0.0)
|
| 789 |
+
checks.append(
|
| 790 |
+
{
|
| 791 |
+
"name": "min_pass_at_k",
|
| 792 |
+
"actual": pass_at_k,
|
| 793 |
+
"threshold": min_pass_at_k,
|
| 794 |
+
}
|
| 795 |
+
)
|
| 796 |
+
if pass_at_k < min_pass_at_k:
|
| 797 |
+
violations.append(
|
| 798 |
+
f"pass@k {pass_at_k:.4f} is below threshold {min_pass_at_k:.4f}."
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
family_requirements = gate_cfg.get("required_family_pass_at_k", {})
|
| 802 |
+
family_metrics = report.get("family_metrics", {})
|
| 803 |
+
if isinstance(family_requirements, dict):
|
| 804 |
+
for family, threshold_raw in family_requirements.items():
|
| 805 |
+
threshold = as_float(threshold_raw, 0.0)
|
| 806 |
+
actual = None
|
| 807 |
+
if isinstance(family_metrics, dict):
|
| 808 |
+
family_row = family_metrics.get(family)
|
| 809 |
+
if isinstance(family_row, dict):
|
| 810 |
+
try:
|
| 811 |
+
actual = float(family_row.get("pass_at_k"))
|
| 812 |
+
except (TypeError, ValueError):
|
| 813 |
+
actual = None
|
| 814 |
+
checks.append(
|
| 815 |
+
{
|
| 816 |
+
"name": f"family_pass_at_k:{family}",
|
| 817 |
+
"actual": actual,
|
| 818 |
+
"threshold": threshold,
|
| 819 |
+
}
|
| 820 |
+
)
|
| 821 |
+
if actual is None:
|
| 822 |
+
violations.append(f"Missing pass@k metric for required family '{family}'.")
|
| 823 |
+
elif actual < threshold:
|
| 824 |
+
violations.append(
|
| 825 |
+
f"Family '{family}' pass@k {actual:.4f} is below threshold {threshold:.4f}."
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
result["violations"] = violations
|
| 829 |
+
result["checks"] = checks
|
| 830 |
+
result["passed"] = len(violations) == 0
|
| 831 |
+
return result
|
| 832 |
+
|
| 833 |
+
|
| 834 |
def main() -> None:
|
| 835 |
args = parse_args()
|
| 836 |
cfg = load_config(args.config)
|
|
|
|
| 839 |
seed = as_int(cfg.get("global", {}).get("seed"), 17)
|
| 840 |
set_seed(seed)
|
| 841 |
|
| 842 |
+
output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "runs/math-conjecture-sota")
|
| 843 |
output_root.mkdir(parents=True, exist_ok=True)
|
| 844 |
|
| 845 |
token, username = resolve_auth(cfg)
|
| 846 |
repo_id = resolve_repo_id(cfg, username=username, output_root=output_root)
|
| 847 |
+
push_to_hub_requested = bool(cfg.get("hub", {}).get("push_to_hub", False))
|
| 848 |
+
if args.dry_run and push_to_hub_requested:
|
| 849 |
print("Dry-run enabled. Disabling push_to_hub for this run.")
|
| 850 |
+
push_to_hub_requested = push_to_hub_requested and not args.dry_run
|
| 851 |
+
|
| 852 |
+
if push_to_hub_requested:
|
| 853 |
if token is None:
|
| 854 |
raise ValueError("Hub push requested but token is missing.")
|
| 855 |
if repo_id is None:
|
|
|
|
| 860 |
model = None
|
| 861 |
else:
|
| 862 |
model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
|
| 863 |
+
|
| 864 |
data_cfg = cfg["data"]
|
| 865 |
+
stage_reports: List[Dict[str, Any]] = []
|
| 866 |
|
| 867 |
start_stage = max(1, args.start_stage)
|
| 868 |
stages = cfg["stages"]
|
|
|
|
| 883 |
raw = load_dataset("parquet", data_files=split_files)
|
| 884 |
train_rows_before = len(raw["train"])
|
| 885 |
valid_rows_before = len(raw["validation"])
|
| 886 |
+
|
| 887 |
filters = stage.get("filters", {})
|
| 888 |
raw["train"] = apply_filters(raw["train"], filters)
|
| 889 |
raw["validation"] = apply_filters(raw["validation"], filters)
|
| 890 |
train_rows_after_filter = len(raw["train"])
|
| 891 |
valid_rows_after_filter = len(raw["validation"])
|
| 892 |
+
|
| 893 |
raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples"))
|
| 894 |
raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples"))
|
| 895 |
train_rows_selected = len(raw["train"])
|
| 896 |
valid_rows_selected = len(raw["validation"])
|
| 897 |
+
|
| 898 |
print(
|
| 899 |
f"[stage {index}] rows train: {train_rows_before} -> {train_rows_after_filter} -> {train_rows_selected}; "
|
| 900 |
f"validation: {valid_rows_before} -> {valid_rows_after_filter} -> {valid_rows_selected}"
|
|
|
|
| 906 |
sample_row = raw["train"][0]
|
| 907 |
_ = build_prompt_text(sample_row, tokenizer, data_cfg)
|
| 908 |
_ = build_answer_block(sample_row, data_cfg)
|
| 909 |
+
stage_reports.append(
|
| 910 |
+
{
|
| 911 |
+
"stage_index": index,
|
| 912 |
+
"stage_name": stage_name,
|
| 913 |
+
"stage_slug": stage_slug,
|
| 914 |
+
"mode": "dry_run",
|
| 915 |
+
"train_rows_before_filter": train_rows_before,
|
| 916 |
+
"validation_rows_before_filter": valid_rows_before,
|
| 917 |
+
"train_rows_after_filter": train_rows_after_filter,
|
| 918 |
+
"validation_rows_after_filter": valid_rows_after_filter,
|
| 919 |
+
"train_rows_selected": train_rows_selected,
|
| 920 |
+
"validation_rows_selected": valid_rows_selected,
|
| 921 |
+
}
|
| 922 |
+
)
|
| 923 |
print(f"[stage {index}] Dry-run checks passed.")
|
| 924 |
continue
|
| 925 |
|
|
|
|
| 950 |
trainer.log_metrics("train", train_result.metrics)
|
| 951 |
trainer.save_metrics("train", train_result.metrics)
|
| 952 |
trainer.save_state()
|
| 953 |
+
|
| 954 |
eval_metrics = None
|
| 955 |
if eval_dataset is not None:
|
| 956 |
eval_metrics = trainer.evaluate()
|
| 957 |
trainer.log_metrics("eval", eval_metrics)
|
| 958 |
trainer.save_metrics("eval", eval_metrics)
|
| 959 |
+
|
| 960 |
trainer.save_model(str(stage_output_dir))
|
| 961 |
tokenizer.save_pretrained(str(stage_output_dir))
|
| 962 |
|
| 963 |
+
stage_reports.append(
|
| 964 |
+
{
|
| 965 |
+
"stage_index": index,
|
| 966 |
+
"stage_name": stage_name,
|
| 967 |
+
"output_dir": str(stage_output_dir),
|
| 968 |
+
"train_rows_before_filter": train_rows_before,
|
| 969 |
+
"validation_rows_before_filter": valid_rows_before,
|
| 970 |
+
"train_rows_after_filter": train_rows_after_filter,
|
| 971 |
+
"validation_rows_after_filter": valid_rows_after_filter,
|
| 972 |
+
"train_rows_selected": train_rows_selected,
|
| 973 |
+
"validation_rows_selected": valid_rows_selected,
|
| 974 |
+
"train_rows": len(train_dataset),
|
| 975 |
+
"eval_rows": len(eval_dataset) if eval_dataset is not None else 0,
|
| 976 |
+
"train_metrics": train_result.metrics,
|
| 977 |
+
"eval_metrics": eval_metrics,
|
| 978 |
+
}
|
| 979 |
+
)
|
| 980 |
print(
|
| 981 |
+
f"[stage {index}] Completed: train_rows={len(train_dataset)} "
|
| 982 |
+
f"eval_rows={len(eval_dataset) if eval_dataset is not None else 0} output={stage_output_dir}"
|
| 983 |
)
|
| 984 |
|
| 985 |
if args.dry_run:
|
|
|
|
| 1003 |
model.save_pretrained(str(final_dir))
|
| 1004 |
tokenizer.save_pretrained(str(final_dir))
|
| 1005 |
|
| 1006 |
+
release_model_memory(model)
|
| 1007 |
+
del model
|
| 1008 |
+
|
| 1009 |
+
post_eval_result = run_post_eval(
|
| 1010 |
+
cfg=cfg,
|
| 1011 |
+
config_path=args.config,
|
| 1012 |
+
output_root=output_root,
|
| 1013 |
+
final_adapter_dir=final_dir,
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
quality_gate = evaluate_quality_gate(
|
| 1017 |
+
stage_reports=stage_reports,
|
| 1018 |
+
post_eval_result=post_eval_result,
|
| 1019 |
+
gate_cfg=cfg.get("quality_gate", {}),
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
push_to_hub_performed = push_to_hub_requested
|
| 1023 |
+
push_block_reason: Optional[str] = None
|
| 1024 |
+
if push_to_hub_requested and not quality_gate.get("passed", True):
|
| 1025 |
+
push_to_hub_performed = False
|
| 1026 |
+
push_block_reason = "quality_gate_failed"
|
| 1027 |
+
print("Quality gate failed; skipping hub push for this run.")
|
| 1028 |
+
|
| 1029 |
+
summary: Dict[str, Any] = {
|
| 1030 |
"config_path": str(args.config),
|
| 1031 |
"repo_id": repo_id,
|
| 1032 |
"seed": seed,
|
| 1033 |
"stages_ran": stage_reports,
|
| 1034 |
"final_adapter_dir": str(final_dir),
|
| 1035 |
+
"quality_gate": quality_gate,
|
| 1036 |
+
"push": {
|
| 1037 |
+
"requested": bool(push_to_hub_requested),
|
| 1038 |
+
"performed": bool(push_to_hub_performed),
|
| 1039 |
+
"block_reason": push_block_reason,
|
| 1040 |
+
},
|
| 1041 |
}
|
| 1042 |
+
|
| 1043 |
+
if post_eval_result is not None:
|
| 1044 |
+
report = post_eval_result.get("report", {})
|
| 1045 |
+
summary["post_eval"] = {
|
| 1046 |
+
"report_path": post_eval_result.get("report_path"),
|
| 1047 |
+
"evaluated_rows": report.get("evaluated_rows"),
|
| 1048 |
+
"k": report.get("k"),
|
| 1049 |
+
"pass_at_1": report.get("pass_at_1"),
|
| 1050 |
+
"pass_at_k": report.get("pass_at_k"),
|
| 1051 |
+
"exact_at_k": report.get("exact_at_k"),
|
| 1052 |
+
"composite_score": report.get("composite_score"),
|
| 1053 |
+
}
|
| 1054 |
+
|
| 1055 |
summary_path = output_root / "training_summary.json"
|
| 1056 |
summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
|
| 1057 |
|
| 1058 |
+
if push_to_hub_performed and repo_id is not None and token is not None:
|
| 1059 |
api = HfApi(token=token)
|
| 1060 |
api.create_repo(
|
| 1061 |
repo_id=repo_id,
|
|
|
|
| 1065 |
)
|
| 1066 |
commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter."
|
| 1067 |
push_folder(api, repo_id, final_dir, commit_message=commit_message)
|
| 1068 |
+
|
| 1069 |
if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)):
|
| 1070 |
for report in stage_reports:
|
| 1071 |
+
stage_dir_raw = report.get("output_dir")
|
| 1072 |
+
if not stage_dir_raw:
|
| 1073 |
+
continue
|
| 1074 |
+
stage_dir = Path(stage_dir_raw)
|
| 1075 |
+
path_in_repo = f"checkpoints/{stage_dir.name}"
|
| 1076 |
push_folder(
|
| 1077 |
api,
|
| 1078 |
repo_id,
|
| 1079 |
stage_dir,
|
| 1080 |
+
commit_message=f"Upload stage checkpoint {report.get('stage_name', stage_dir.name)}",
|
| 1081 |
path_in_repo=path_in_repo,
|
| 1082 |
)
|
| 1083 |
+
|
| 1084 |
api.upload_file(
|
| 1085 |
path_or_fileobj=str(summary_path),
|
| 1086 |
path_in_repo="training_summary.json",
|
|
|
|
| 1088 |
repo_type="model",
|
| 1089 |
commit_message="Upload training summary for SOTA curriculum run.",
|
| 1090 |
)
|
| 1091 |
+
|
| 1092 |
+
if post_eval_result is not None and post_eval_result.get("report_path"):
|
| 1093 |
+
api.upload_file(
|
| 1094 |
+
path_or_fileobj=str(post_eval_result["report_path"]),
|
| 1095 |
+
path_in_repo="post_eval_report.json",
|
| 1096 |
+
repo_id=repo_id,
|
| 1097 |
+
repo_type="model",
|
| 1098 |
+
commit_message="Upload post-training evaluation report.",
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
print(f"Pushed training artifacts to https://huggingface.co/{repo_id}")
|
| 1102 |
|
| 1103 |
print(f"Training complete. Final adapter: {final_dir}")
|