math_trainer / scripts /train_sota.py
NorthernTribe-Research's picture
Enable CPU fallback training path and preserve live training-loss graph telemetry.
a68d3ef verified
#!/usr/bin/env python3
"""Multi-stage curriculum SFT for advancing the conjecture math model."""
from __future__ import annotations
import argparse
import gc
import json
import os
import subprocess
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
import yaml
from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import HfApi
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.utils.data import WeightedRandomSampler
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForSeq2Seq,
Trainer,
TrainingArguments,
set_seed,
)
SCRIPT_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml"
DEFAULT_EVAL_SCRIPT = Path(__file__).resolve().with_name("eval_sota.py")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Train DeepSeek-Math with a multi-stage SOTA curriculum recipe."
)
parser.add_argument(
"--config",
type=Path,
default=DEFAULT_CONFIG_PATH,
help="Path to multi-stage YAML config.",
)
parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.")
parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.")
parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.")
parser.add_argument(
"--run-post-eval",
action="store_true",
help="Force post-training evaluation enabled.",
)
parser.add_argument(
"--no-post-eval",
action="store_true",
help="Force post-training evaluation disabled.",
)
parser.add_argument(
"--skip-quality-gate",
action="store_true",
help="Disable quality gate checks for this run.",
)
parser.add_argument(
"--start-stage",
type=int,
default=1,
help="1-based stage index to start from.",
)
parser.add_argument(
"--max-stages",
type=int,
default=None,
help="Optional number of stages to run from --start-stage.",
)
parser.add_argument(
"--credentials-path",
type=Path,
default=None,
help="Override credentials.path.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Validate data/filter/tokenization stages without running training or pushing.",
)
return parser.parse_args()
def as_text(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value.strip()
return str(value).strip()
def as_float(value: Any, default: float) -> float:
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def as_int(value: Any, default: int) -> int:
if value is None:
return default
try:
return int(value)
except (TypeError, ValueError):
return default
def as_bool(value: Any, default: bool = False) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
text = as_text(value).lower()
if text in {"1", "true", "yes", "y", "on"}:
return True
if text in {"0", "false", "no", "n", "off"}:
return False
return default
def load_config(path: Path) -> Dict[str, Any]:
if not path.exists():
raise FileNotFoundError(f"Config not found: {path}")
cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
if not isinstance(cfg, dict):
raise ValueError(f"Invalid config format: {path}")
for key in ("model", "data", "stages"):
if key not in cfg:
raise ValueError(f"Missing config section: {key}")
if not isinstance(cfg["stages"], list) or not cfg["stages"]:
raise ValueError("Config must contain at least one stage in stages[].")
cfg.setdefault("global", {})
cfg.setdefault("training_defaults", {})
cfg.setdefault("hub", {})
cfg.setdefault("credentials", {})
cfg.setdefault("post_eval", {})
cfg.setdefault("quality_gate", {})
return cfg
def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None:
if args.repo_id:
cfg.setdefault("hub", {})["repo_id"] = args.repo_id
if args.credentials_path is not None:
cfg.setdefault("credentials", {})["path"] = str(args.credentials_path)
if args.push_to_hub and args.no_push_to_hub:
raise ValueError("Cannot set both --push-to-hub and --no-push-to-hub.")
if args.push_to_hub:
cfg.setdefault("hub", {})["push_to_hub"] = True
if args.no_push_to_hub:
cfg.setdefault("hub", {})["push_to_hub"] = False
if args.run_post_eval and args.no_post_eval:
raise ValueError("Cannot set both --run-post-eval and --no-post-eval.")
if args.run_post_eval:
cfg.setdefault("post_eval", {})["enabled"] = True
if args.no_post_eval:
cfg.setdefault("post_eval", {})["enabled"] = False
if args.skip_quality_gate:
cfg.setdefault("quality_gate", {})["enabled"] = False
def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
username = as_text(os.environ.get("HF_USERNAME")) or None
cred_path = as_text(cfg.get("credentials", {}).get("path"))
if cred_path:
path = Path(cred_path)
if path.exists():
data = json.loads(path.read_text(encoding="utf-8"))
if token is None:
for key in ("token", "key", "api_key", "hf_token"):
candidate = as_text(data.get(key))
if candidate:
token = candidate
break
if username is None:
for key in ("username", "user", "owner"):
candidate = as_text(data.get(key))
if candidate:
username = candidate
break
return token, username
def resolve_repo_id(cfg: Dict[str, Any], username: Optional[str], output_root: Path) -> Optional[str]:
repo_id = as_text(cfg.get("hub", {}).get("repo_id"))
if repo_id:
return repo_id
if not username:
return None
return f"{username}/{output_root.name}"
def stringify_structured(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
text = value.strip()
if not text:
return ""
try:
parsed = json.loads(text)
except json.JSONDecodeError:
return text
return json.dumps(parsed, ensure_ascii=False, sort_keys=True)
return json.dumps(value, ensure_ascii=False, sort_keys=True)
def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str:
prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt"
prompt = as_text(row.get(prompt_field))
if not prompt:
prompt = "Solve the math task."
meta_fields = [
("task_type", "Task type"),
("family", "Family"),
("difficulty", "Difficulty"),
("source_dataset", "Source"),
("status_as_of", "Status as of"),
]
meta_lines = []
for key, label in meta_fields:
value = as_text(row.get(key))
if value:
meta_lines.append(f"{label}: {value}")
tags = row.get("topic_tags")
if isinstance(tags, list) and tags:
tag_text = ", ".join(as_text(tag) for tag in tags if as_text(tag))
if tag_text:
meta_lines.append(f"Tags: {tag_text}")
if not meta_lines:
return prompt
return f"{prompt}\n\nMetadata:\n" + "\n".join(meta_lines)
def build_answer_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str:
target_field = as_text(data_cfg.get("target_field")) or "target"
final_answer_field = as_text(data_cfg.get("final_answer_field")) or "final_answer"
proof_field = as_text(data_cfg.get("proof_field")) or "proof_formal"
sections = []
target_text = stringify_structured(row.get(target_field))
if target_text:
sections.append(f"Structured target:\n{target_text}")
final_answer = stringify_structured(row.get(final_answer_field))
if final_answer:
sections.append(f"Final answer:\n{final_answer}")
proof_text = stringify_structured(row.get(proof_field))
if proof_text:
sections.append(f"Formal proof snippet:\n{proof_text}")
if not sections:
sections.append("No structured target provided.")
return "\n\n".join(sections).strip()
def build_prompt_text(row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> str:
system_prompt = as_text(data_cfg.get("system_prompt"))
if not system_prompt:
system_prompt = (
"You are a rigorous mathematical reasoning assistant focused on unsolved "
"conjectures. Produce checkable reasoning."
)
user_block = build_user_block(row, data_cfg)
if getattr(tokenizer, "chat_template", None):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_block},
]
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n"
def compute_loss_weight(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> float:
sample_weight_field = as_text(data_cfg.get("sample_weight_field")) or "sample_weight"
base = as_float(row.get(sample_weight_field), 1.0)
family = as_text(row.get("family"))
family_boost = data_cfg.get("family_boost", {})
if isinstance(family_boost, dict):
base *= as_float(family_boost.get(family), 1.0)
min_w = as_float(data_cfg.get("min_loss_weight"), 0.1)
max_w = as_float(data_cfg.get("max_loss_weight"), 8.0)
if min_w > max_w:
min_w, max_w = max_w, min_w
return max(min_w, min(max_w, base))
def stage_split_files(stage_cfg: Dict[str, Any], data_cfg: Dict[str, Any]) -> Dict[str, str]:
train_file = as_text(stage_cfg.get("train_file")) or as_text(data_cfg.get("default_train_file"))
valid_file = as_text(stage_cfg.get("validation_file")) or as_text(data_cfg.get("default_validation_file"))
train_path = Path(train_file)
valid_path = Path(valid_file)
if not train_path.exists():
raise FileNotFoundError(f"Missing train split for stage: {train_path}")
if not valid_path.exists():
raise FileNotFoundError(f"Missing validation split for stage: {valid_path}")
return {"train": str(train_path), "validation": str(valid_path)}
def apply_filters(dataset: Dataset, filter_cfg: Dict[str, Any]) -> Dataset:
if not filter_cfg:
return dataset
include_families = set(filter_cfg.get("include_families", []) or [])
exclude_families = set(filter_cfg.get("exclude_families", []) or [])
include_task_types = set(filter_cfg.get("include_task_types", []) or [])
source_datasets = set(filter_cfg.get("source_datasets", []) or [])
require_conjecture_id = bool(filter_cfg.get("require_conjecture_id", False))
min_sample_weight = filter_cfg.get("min_sample_weight")
min_sample_weight = as_float(min_sample_weight, 0.0) if min_sample_weight is not None else None
def _keep(row: Dict[str, Any]) -> bool:
family = as_text(row.get("family"))
if include_families and family not in include_families:
return False
if exclude_families and family in exclude_families:
return False
if include_task_types:
task_type = as_text(row.get("task_type"))
if task_type not in include_task_types:
return False
if source_datasets:
source = as_text(row.get("source_dataset"))
if source not in source_datasets:
return False
if require_conjecture_id:
conjecture_id = as_text(row.get("conjecture_id"))
if not conjecture_id or conjecture_id.lower() == "null":
return False
if min_sample_weight is not None:
sample_weight = as_float(row.get("sample_weight"), 0.0)
if sample_weight < min_sample_weight:
return False
return True
return dataset.filter(_keep, desc="Applying stage filters")
def maybe_select(dataset: Dataset, max_samples: Optional[int]) -> Dataset:
if max_samples is None:
return dataset
if max_samples <= 0:
raise ValueError("max_samples must be positive.")
if max_samples >= len(dataset):
return dataset
return dataset.select(range(max_samples))
def tokenize_datasets(raw: DatasetDict, tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> DatasetDict:
max_len = as_int(data_cfg.get("max_seq_length"), 2048)
if max_len < 64:
raise ValueError("data.max_seq_length must be >= 64")
eos = tokenizer.eos_token or ""
remove_columns = raw["train"].column_names
def _tokenize(row: Dict[str, Any]) -> Dict[str, Any]:
prompt_text = build_prompt_text(row, tokenizer, data_cfg)
answer_text = build_answer_block(row, data_cfg)
full_text = f"{prompt_text}{answer_text}{eos}"
prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
full_enc = tokenizer(
full_text,
add_special_tokens=False,
truncation=True,
max_length=max_len,
)
input_ids = full_enc["input_ids"]
attention_mask = full_enc["attention_mask"]
if not input_ids:
fallback = tokenizer.eos_token_id
if fallback is None:
fallback = tokenizer.pad_token_id
if fallback is None:
fallback = 0
input_ids = [fallback]
attention_mask = [1]
labels = [fallback]
else:
prompt_len = min(len(prompt_ids), len(input_ids))
labels = [-100] * prompt_len + input_ids[prompt_len:]
if prompt_len >= len(input_ids):
labels[-1] = input_ids[-1]
loss_weight = compute_loss_weight(row, data_cfg)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"loss_weight": float(loss_weight),
}
tokenized = raw.map(
_tokenize,
remove_columns=remove_columns,
desc="Tokenizing prompt/answer pairs",
)
tokenized = tokenized.filter(
lambda row: any(token != -100 for token in row["labels"]),
desc="Dropping prompt-only rows",
)
return tokenized
def build_tokenizer(model_cfg: Dict[str, Any]) -> AutoTokenizer:
base_model = as_text(model_cfg.get("base_model"))
if not base_model:
raise ValueError("model.base_model is required.")
tokenizer = AutoTokenizer.from_pretrained(
base_model,
trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
use_fast=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
return tokenizer
def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict[str, Any]) -> Tuple[Any, AutoTokenizer]:
base_model = as_text(model_cfg.get("base_model"))
if not base_model:
raise ValueError("model.base_model is required.")
use_cuda = torch.cuda.is_available()
requested_bf16 = bool(model_cfg.get("use_bf16", True))
if use_cuda:
dtype = torch.bfloat16 if requested_bf16 else torch.float16
else:
dtype = torch.float32
tokenizer = build_tokenizer(model_cfg)
model_kwargs: Dict[str, Any] = {
"trust_remote_code": bool(model_cfg.get("trust_remote_code", False)),
"torch_dtype": dtype,
}
attn_impl = as_text(model_cfg.get("attn_implementation"))
if attn_impl:
model_kwargs["attn_implementation"] = attn_impl
requested_load_in_4bit = bool(model_cfg.get("load_in_4bit", True))
load_in_4bit = requested_load_in_4bit and use_cuda
if requested_load_in_4bit and not load_in_4bit:
print("CUDA unavailable. Disabling 4-bit loading and using full-precision CPU fallback.")
if load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4",
bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)),
bnb_4bit_compute_dtype=dtype,
)
model_kwargs["device_map"] = "auto"
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
if tokenizer.pad_token_id is not None:
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
if load_in_4bit:
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=bool(training_defaults.get("gradient_checkpointing", True)),
)
lora_cfg = model_cfg.get("lora", {})
peft_cfg = LoraConfig(
r=as_int(lora_cfg.get("r"), 64),
lora_alpha=as_int(lora_cfg.get("alpha"), 128),
lora_dropout=as_float(lora_cfg.get("dropout"), 0.05),
bias=as_text(lora_cfg.get("bias")) or "none",
task_type="CAUSAL_LM",
target_modules=lora_cfg.get("target_modules"),
)
model = get_peft_model(model, peft_cfg)
model.print_trainable_parameters()
return model, tokenizer
class WeightedLossCollator:
def __init__(self, tokenizer: AutoTokenizer, model: Any) -> None:
self.base = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
model=model,
label_pad_token_id=-100,
pad_to_multiple_of=8,
)
def __call__(self, features: list[Dict[str, Any]]) -> Dict[str, Any]:
weights = [float(feature.pop("loss_weight", 1.0)) for feature in features]
batch = self.base(features)
batch["loss_weight"] = torch.tensor(weights, dtype=torch.float32)
return batch
class WeightedLossTrainer(Trainer):
def _get_train_sampler(self):
if self.train_dataset is None:
return None
if "loss_weight" not in self.train_dataset.column_names:
return super()._get_train_sampler()
weights = self.train_dataset["loss_weight"]
if not weights:
return super()._get_train_sampler()
weight_tensor = torch.tensor(weights, dtype=torch.double)
return WeightedRandomSampler(
weights=weight_tensor,
num_samples=len(weight_tensor),
replacement=True,
)
def compute_loss(
self,
model: Any,
inputs: Dict[str, Any],
return_outputs: bool = False,
num_items_in_batch: Optional[torch.Tensor] = None,
):
loss_weight = inputs.pop("loss_weight", None)
labels = inputs.get("labels")
if labels is None:
return super().compute_loss(
model=model,
inputs=inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
model_inputs = {k: v for k, v in inputs.items() if k != "labels"}
outputs = model(**model_inputs)
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
token_losses = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
reduction="none",
).view(shift_labels.size())
token_mask = shift_labels.ne(-100).float()
seq_den = token_mask.sum(dim=1).clamp(min=1.0)
seq_loss = (token_losses * token_mask).sum(dim=1) / seq_den
if loss_weight is not None:
normalized = loss_weight.to(seq_loss.device).float().clamp(min=0.05)
loss = (seq_loss * normalized).sum() / normalized.sum()
else:
loss = seq_loss.mean()
if return_outputs:
return loss, outputs
return loss
def build_training_args(
output_dir: Path,
training_cfg: Dict[str, Any],
use_bf16: bool,
has_eval_split: bool,
) -> TrainingArguments:
output_dir.mkdir(parents=True, exist_ok=True)
use_cuda = torch.cuda.is_available()
bf16_runtime = bool(use_cuda and use_bf16)
fp16_runtime = bool(use_cuda and not bf16_runtime)
return TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=as_float(training_cfg.get("num_train_epochs"), 1.0),
per_device_train_batch_size=as_int(training_cfg.get("per_device_train_batch_size"), 1),
per_device_eval_batch_size=as_int(training_cfg.get("per_device_eval_batch_size"), 1),
gradient_accumulation_steps=as_int(training_cfg.get("gradient_accumulation_steps"), 1),
learning_rate=as_float(training_cfg.get("learning_rate"), 2e-5),
weight_decay=as_float(training_cfg.get("weight_decay"), 0.0),
warmup_ratio=as_float(training_cfg.get("warmup_ratio"), 0.0),
lr_scheduler_type=as_text(training_cfg.get("lr_scheduler_type")) or "cosine",
max_grad_norm=as_float(training_cfg.get("max_grad_norm"), 1.0),
gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)),
logging_steps=as_int(training_cfg.get("logging_steps"), 10),
save_steps=as_int(training_cfg.get("save_steps"), 500),
save_total_limit=as_int(training_cfg.get("save_total_limit"), 3),
dataloader_num_workers=as_int(training_cfg.get("dataloader_num_workers"), 0),
seed=as_int(training_cfg.get("seed"), 17),
bf16=bf16_runtime,
fp16=fp16_runtime,
remove_unused_columns=False,
report_to="none",
evaluation_strategy="steps" if has_eval_split else "no",
eval_steps=as_int(training_cfg.get("eval_steps"), 500) if has_eval_split else None,
)
def push_folder(
api: HfApi,
repo_id: str,
folder_path: Path,
commit_message: str,
path_in_repo: Optional[str] = None,
) -> None:
kwargs: Dict[str, Any] = {
"repo_id": repo_id,
"repo_type": "model",
"folder_path": str(folder_path),
"commit_message": commit_message,
}
if path_in_repo:
kwargs["path_in_repo"] = path_in_repo
api.upload_folder(**kwargs)
def extract_final_eval_loss(stage_reports: List[Dict[str, Any]]) -> Optional[float]:
for report in reversed(stage_reports):
eval_metrics = report.get("eval_metrics")
if not isinstance(eval_metrics, dict):
continue
value = eval_metrics.get("eval_loss")
if value is None:
continue
try:
return float(value)
except (TypeError, ValueError):
continue
return None
def release_model_memory(model: Any) -> None:
try:
model.to("cpu")
except Exception:
pass
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def run_post_eval(
cfg: Dict[str, Any],
config_path: Path,
output_root: Path,
final_adapter_dir: Path,
) -> Optional[Dict[str, Any]]:
post_cfg = cfg.get("post_eval", {})
if not as_bool(post_cfg.get("enabled"), False):
return None
eval_script = DEFAULT_EVAL_SCRIPT
if not eval_script.exists():
raise FileNotFoundError(f"Post-eval enabled but eval script is missing: {eval_script}")
data_cfg = cfg.get("data", {})
eval_file = Path(
as_text(post_cfg.get("eval_file"))
or as_text(data_cfg.get("default_validation_file"))
or "data/releases/v1/test.parquet"
)
if not eval_file.exists():
raise FileNotFoundError(f"Post-eval file not found: {eval_file}")
output_json = Path(as_text(post_cfg.get("output_json")) or str(output_root / "post_eval_report.json"))
base_model = as_text(cfg.get("model", {}).get("base_model"))
if not base_model:
raise ValueError("model.base_model is required for post-eval.")
cmd = [
sys.executable,
str(eval_script),
"--config",
str(config_path),
"--base-model",
base_model,
"--adapter-path",
str(final_adapter_dir),
"--eval-file",
str(eval_file),
"--max-samples",
str(as_int(post_cfg.get("max_samples"), 300)),
"--k",
str(as_int(post_cfg.get("k"), 4)),
"--max-new-tokens",
str(as_int(post_cfg.get("max_new_tokens"), 256)),
"--temperature",
str(as_float(post_cfg.get("temperature"), 0.7)),
"--top-p",
str(as_float(post_cfg.get("top_p"), 0.95)),
"--seed",
str(as_int(post_cfg.get("seed"), as_int(cfg.get("global", {}).get("seed"), 17))),
"--output-json",
str(output_json),
]
print(f"Running post-training eval: {' '.join(cmd)}")
completed = subprocess.run(cmd, check=False)
if completed.returncode != 0:
raise RuntimeError(f"Post-training evaluation failed with exit code {completed.returncode}.")
if not output_json.exists():
raise FileNotFoundError(f"Post-eval report was not created: {output_json}")
report = json.loads(output_json.read_text(encoding="utf-8"))
return {
"enabled": True,
"report_path": str(output_json),
"report": report,
"command": cmd,
}
def evaluate_quality_gate(
stage_reports: List[Dict[str, Any]],
post_eval_result: Optional[Dict[str, Any]],
gate_cfg: Dict[str, Any],
) -> Dict[str, Any]:
enabled = as_bool(gate_cfg.get("enabled"), False)
result: Dict[str, Any] = {
"enabled": enabled,
"passed": True,
"violations": [],
"checks": [],
}
if not enabled:
return result
violations: List[str] = []
checks: List[Dict[str, Any]] = []
final_eval_loss = extract_final_eval_loss(stage_reports)
max_final_eval_loss = gate_cfg.get("max_final_eval_loss")
if max_final_eval_loss is not None:
threshold = as_float(max_final_eval_loss, 0.0)
checks.append(
{
"name": "max_final_eval_loss",
"actual": final_eval_loss,
"threshold": threshold,
}
)
if final_eval_loss is None:
violations.append("Final stage eval_loss is missing for max_final_eval_loss gate.")
elif final_eval_loss > threshold:
violations.append(
f"Final eval_loss {final_eval_loss:.4f} exceeds threshold {threshold:.4f}."
)
report: Optional[Dict[str, Any]] = None
if isinstance(post_eval_result, dict):
loaded = post_eval_result.get("report")
if isinstance(loaded, dict):
report = loaded
require_post_eval = as_bool(gate_cfg.get("require_post_eval"), False)
if report is None:
if require_post_eval:
violations.append("Quality gate requires post-eval metrics, but post-eval report is missing.")
else:
evaluated_rows = as_int(report.get("evaluated_rows"), 0)
min_rows = as_int(gate_cfg.get("min_evaluated_rows"), 0)
checks.append(
{
"name": "min_evaluated_rows",
"actual": evaluated_rows,
"threshold": min_rows,
}
)
if evaluated_rows < min_rows:
violations.append(
f"Post-eval rows {evaluated_rows} is below minimum required {min_rows}."
)
min_pass_at_1_raw = gate_cfg.get("min_pass_at_1")
if min_pass_at_1_raw is not None:
min_pass_at_1 = as_float(min_pass_at_1_raw, 0.0)
pass_at_1 = as_float(report.get("pass_at_1"), 0.0)
checks.append(
{
"name": "min_pass_at_1",
"actual": pass_at_1,
"threshold": min_pass_at_1,
}
)
if pass_at_1 < min_pass_at_1:
violations.append(
f"pass@1 {pass_at_1:.4f} is below threshold {min_pass_at_1:.4f}."
)
min_pass_at_k_raw = gate_cfg.get("min_pass_at_k")
if min_pass_at_k_raw is not None:
min_pass_at_k = as_float(min_pass_at_k_raw, 0.0)
pass_at_k = as_float(report.get("pass_at_k"), 0.0)
checks.append(
{
"name": "min_pass_at_k",
"actual": pass_at_k,
"threshold": min_pass_at_k,
}
)
if pass_at_k < min_pass_at_k:
violations.append(
f"pass@k {pass_at_k:.4f} is below threshold {min_pass_at_k:.4f}."
)
family_requirements = gate_cfg.get("required_family_pass_at_k", {})
family_metrics = report.get("family_metrics", {})
if isinstance(family_requirements, dict):
for family, threshold_raw in family_requirements.items():
threshold = as_float(threshold_raw, 0.0)
actual = None
if isinstance(family_metrics, dict):
family_row = family_metrics.get(family)
if isinstance(family_row, dict):
try:
actual = float(family_row.get("pass_at_k"))
except (TypeError, ValueError):
actual = None
checks.append(
{
"name": f"family_pass_at_k:{family}",
"actual": actual,
"threshold": threshold,
}
)
if actual is None:
violations.append(f"Missing pass@k metric for required family '{family}'.")
elif actual < threshold:
violations.append(
f"Family '{family}' pass@k {actual:.4f} is below threshold {threshold:.4f}."
)
result["violations"] = violations
result["checks"] = checks
result["passed"] = len(violations) == 0
return result
def main() -> None:
args = parse_args()
cfg = load_config(args.config)
apply_overrides(cfg, args)
seed = as_int(cfg.get("global", {}).get("seed"), 17)
set_seed(seed)
output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "runs/math-conjecture-sota")
output_root.mkdir(parents=True, exist_ok=True)
token, username = resolve_auth(cfg)
repo_id = resolve_repo_id(cfg, username=username, output_root=output_root)
push_to_hub_requested = bool(cfg.get("hub", {}).get("push_to_hub", False))
if args.dry_run and push_to_hub_requested:
print("Dry-run enabled. Disabling push_to_hub for this run.")
push_to_hub_requested = push_to_hub_requested and not args.dry_run
if push_to_hub_requested:
if token is None:
raise ValueError("Hub push requested but token is missing.")
if repo_id is None:
raise ValueError("Hub push requested but repo_id is missing.")
if args.dry_run:
tokenizer = build_tokenizer(cfg["model"])
model = None
else:
model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {}))
if torch.cuda.is_available():
print("Compute mode: GPU")
else:
print("Compute mode: CPU fallback (no CUDA detected)")
data_cfg = cfg["data"]
stage_reports: List[Dict[str, Any]] = []
start_stage = max(1, args.start_stage)
stages = cfg["stages"]
end_stage = len(stages)
if args.max_stages is not None:
if args.max_stages <= 0:
raise ValueError("--max-stages must be positive.")
end_stage = min(end_stage, start_stage + args.max_stages - 1)
for index in range(start_stage, end_stage + 1):
stage = stages[index - 1]
stage_name = as_text(stage.get("name")) or f"stage_{index:02d}"
stage_slug = f"{index:02d}_{stage_name.replace(' ', '_')}"
stage_output_dir = output_root / stage_slug
print(f"[stage {index}] Starting: {stage_name}")
split_files = stage_split_files(stage, data_cfg)
raw = load_dataset("parquet", data_files=split_files)
train_rows_before = len(raw["train"])
valid_rows_before = len(raw["validation"])
filters = stage.get("filters", {})
raw["train"] = apply_filters(raw["train"], filters)
raw["validation"] = apply_filters(raw["validation"], filters)
train_rows_after_filter = len(raw["train"])
valid_rows_after_filter = len(raw["validation"])
raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples"))
raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples"))
train_rows_selected = len(raw["train"])
valid_rows_selected = len(raw["validation"])
print(
f"[stage {index}] rows train: {train_rows_before} -> {train_rows_after_filter} -> {train_rows_selected}; "
f"validation: {valid_rows_before} -> {valid_rows_after_filter} -> {valid_rows_selected}"
)
if len(raw["train"]) == 0:
raise ValueError(f"Stage {stage_slug} has zero train rows after filtering.")
if args.dry_run:
sample_row = raw["train"][0]
_ = build_prompt_text(sample_row, tokenizer, data_cfg)
_ = build_answer_block(sample_row, data_cfg)
stage_reports.append(
{
"stage_index": index,
"stage_name": stage_name,
"stage_slug": stage_slug,
"mode": "dry_run",
"train_rows_before_filter": train_rows_before,
"validation_rows_before_filter": valid_rows_before,
"train_rows_after_filter": train_rows_after_filter,
"validation_rows_after_filter": valid_rows_after_filter,
"train_rows_selected": train_rows_selected,
"validation_rows_selected": valid_rows_selected,
}
)
print(f"[stage {index}] Dry-run checks passed.")
continue
tokenized = tokenize_datasets(raw, tokenizer, data_cfg)
train_dataset = tokenized["train"]
eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None
merged_training = dict(cfg.get("training_defaults", {}))
merged_training.update(stage.get("training", {}))
merged_training["seed"] = seed
training_args = build_training_args(
output_dir=stage_output_dir,
training_cfg=merged_training,
use_bf16=bool(cfg["model"].get("use_bf16", True)),
has_eval_split=eval_dataset is not None,
)
collator = WeightedLossCollator(tokenizer=tokenizer, model=model)
trainer = WeightedLossTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
data_collator=collator,
)
train_result = trainer.train()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
eval_metrics = None
if eval_dataset is not None:
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
trainer.save_model(str(stage_output_dir))
tokenizer.save_pretrained(str(stage_output_dir))
stage_reports.append(
{
"stage_index": index,
"stage_name": stage_name,
"output_dir": str(stage_output_dir),
"train_rows_before_filter": train_rows_before,
"validation_rows_before_filter": valid_rows_before,
"train_rows_after_filter": train_rows_after_filter,
"validation_rows_after_filter": valid_rows_after_filter,
"train_rows_selected": train_rows_selected,
"validation_rows_selected": valid_rows_selected,
"train_rows": len(train_dataset),
"eval_rows": len(eval_dataset) if eval_dataset is not None else 0,
"train_metrics": train_result.metrics,
"eval_metrics": eval_metrics,
}
)
print(
f"[stage {index}] Completed: train_rows={len(train_dataset)} "
f"eval_rows={len(eval_dataset) if eval_dataset is not None else 0} output={stage_output_dir}"
)
if args.dry_run:
summary = {
"mode": "dry_run",
"config_path": str(args.config),
"seed": seed,
"start_stage": start_stage,
"end_stage": end_stage,
"stages_ran": stage_reports,
}
summary_path = output_root / "dry_run_summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
print("Dry-run complete. No training or model push was performed.")
print(f"Dry-run summary: {summary_path}")
return
final_dir = output_root / "final_adapter"
final_dir.mkdir(parents=True, exist_ok=True)
assert model is not None
model.save_pretrained(str(final_dir))
tokenizer.save_pretrained(str(final_dir))
release_model_memory(model)
del model
post_eval_result = run_post_eval(
cfg=cfg,
config_path=args.config,
output_root=output_root,
final_adapter_dir=final_dir,
)
quality_gate = evaluate_quality_gate(
stage_reports=stage_reports,
post_eval_result=post_eval_result,
gate_cfg=cfg.get("quality_gate", {}),
)
push_to_hub_performed = push_to_hub_requested
push_block_reason: Optional[str] = None
if push_to_hub_requested and not quality_gate.get("passed", True):
push_to_hub_performed = False
push_block_reason = "quality_gate_failed"
print("Quality gate failed; skipping hub push for this run.")
summary: Dict[str, Any] = {
"config_path": str(args.config),
"repo_id": repo_id,
"seed": seed,
"stages_ran": stage_reports,
"final_adapter_dir": str(final_dir),
"quality_gate": quality_gate,
"push": {
"requested": bool(push_to_hub_requested),
"performed": bool(push_to_hub_performed),
"block_reason": push_block_reason,
},
}
if post_eval_result is not None:
report = post_eval_result.get("report", {})
summary["post_eval"] = {
"report_path": post_eval_result.get("report_path"),
"evaluated_rows": report.get("evaluated_rows"),
"k": report.get("k"),
"pass_at_1": report.get("pass_at_1"),
"pass_at_k": report.get("pass_at_k"),
"exact_at_k": report.get("exact_at_k"),
"composite_score": report.get("composite_score"),
}
summary_path = output_root / "training_summary.json"
summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8")
if push_to_hub_performed and repo_id is not None and token is not None:
api = HfApi(token=token)
api.create_repo(
repo_id=repo_id,
repo_type="model",
private=bool(cfg.get("hub", {}).get("private", False)),
exist_ok=True,
)
commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter."
push_folder(api, repo_id, final_dir, commit_message=commit_message)
if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)):
for report in stage_reports:
stage_dir_raw = report.get("output_dir")
if not stage_dir_raw:
continue
stage_dir = Path(stage_dir_raw)
path_in_repo = f"checkpoints/{stage_dir.name}"
push_folder(
api,
repo_id,
stage_dir,
commit_message=f"Upload stage checkpoint {report.get('stage_name', stage_dir.name)}",
path_in_repo=path_in_repo,
)
api.upload_file(
path_or_fileobj=str(summary_path),
path_in_repo="training_summary.json",
repo_id=repo_id,
repo_type="model",
commit_message="Upload training summary for SOTA curriculum run.",
)
if post_eval_result is not None and post_eval_result.get("report_path"):
api.upload_file(
path_or_fileobj=str(post_eval_result["report_path"]),
path_in_repo="post_eval_report.json",
repo_id=repo_id,
repo_type="model",
commit_message="Upload post-training evaluation report.",
)
print(f"Pushed training artifacts to https://huggingface.co/{repo_id}")
print(f"Training complete. Final adapter: {final_dir}")
print(f"Training summary: {summary_path}")
if __name__ == "__main__":
main()