#!/usr/bin/env python3 """Multi-stage curriculum SFT for advancing the conjecture math model.""" from __future__ import annotations import argparse import gc import json import os import subprocess import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import torch import yaml from datasets import Dataset, DatasetDict, load_dataset from huggingface_hub import HfApi from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from torch.utils.data import WeightedRandomSampler from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorForSeq2Seq, Trainer, TrainingArguments, set_seed, ) SCRIPT_ROOT = Path(__file__).resolve().parents[1] DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml" DEFAULT_EVAL_SCRIPT = Path(__file__).resolve().with_name("eval_sota.py") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Train DeepSeek-Math with a multi-stage SOTA curriculum recipe." ) parser.add_argument( "--config", type=Path, default=DEFAULT_CONFIG_PATH, help="Path to multi-stage YAML config.", ) parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.") parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.") parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.") parser.add_argument( "--run-post-eval", action="store_true", help="Force post-training evaluation enabled.", ) parser.add_argument( "--no-post-eval", action="store_true", help="Force post-training evaluation disabled.", ) parser.add_argument( "--skip-quality-gate", action="store_true", help="Disable quality gate checks for this run.", ) parser.add_argument( "--start-stage", type=int, default=1, help="1-based stage index to start from.", ) parser.add_argument( "--max-stages", type=int, default=None, help="Optional number of stages to run from --start-stage.", ) parser.add_argument( "--credentials-path", type=Path, default=None, help="Override credentials.path.", ) parser.add_argument( "--dry-run", action="store_true", help="Validate data/filter/tokenization stages without running training or pushing.", ) return parser.parse_args() def as_text(value: Any) -> str: if value is None: return "" if isinstance(value, str): return value.strip() return str(value).strip() def as_float(value: Any, default: float) -> float: if value is None: return default try: return float(value) except (TypeError, ValueError): return default def as_int(value: Any, default: int) -> int: if value is None: return default try: return int(value) except (TypeError, ValueError): return default def as_bool(value: Any, default: bool = False) -> bool: if value is None: return default if isinstance(value, bool): return value text = as_text(value).lower() if text in {"1", "true", "yes", "y", "on"}: return True if text in {"0", "false", "no", "n", "off"}: return False return default def load_config(path: Path) -> Dict[str, Any]: if not path.exists(): raise FileNotFoundError(f"Config not found: {path}") cfg = yaml.safe_load(path.read_text(encoding="utf-8")) if not isinstance(cfg, dict): raise ValueError(f"Invalid config format: {path}") for key in ("model", "data", "stages"): if key not in cfg: raise ValueError(f"Missing config section: {key}") if not isinstance(cfg["stages"], list) or not cfg["stages"]: raise ValueError("Config must contain at least one stage in stages[].") cfg.setdefault("global", {}) cfg.setdefault("training_defaults", {}) cfg.setdefault("hub", {}) cfg.setdefault("credentials", {}) cfg.setdefault("post_eval", {}) cfg.setdefault("quality_gate", {}) return cfg def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None: if args.repo_id: cfg.setdefault("hub", {})["repo_id"] = args.repo_id if args.credentials_path is not None: cfg.setdefault("credentials", {})["path"] = str(args.credentials_path) if args.push_to_hub and args.no_push_to_hub: raise ValueError("Cannot set both --push-to-hub and --no-push-to-hub.") if args.push_to_hub: cfg.setdefault("hub", {})["push_to_hub"] = True if args.no_push_to_hub: cfg.setdefault("hub", {})["push_to_hub"] = False if args.run_post_eval and args.no_post_eval: raise ValueError("Cannot set both --run-post-eval and --no-post-eval.") if args.run_post_eval: cfg.setdefault("post_eval", {})["enabled"] = True if args.no_post_eval: cfg.setdefault("post_eval", {})["enabled"] = False if args.skip_quality_gate: cfg.setdefault("quality_gate", {})["enabled"] = False def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]: token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None username = as_text(os.environ.get("HF_USERNAME")) or None cred_path = as_text(cfg.get("credentials", {}).get("path")) if cred_path: path = Path(cred_path) if path.exists(): data = json.loads(path.read_text(encoding="utf-8")) if token is None: for key in ("token", "key", "api_key", "hf_token"): candidate = as_text(data.get(key)) if candidate: token = candidate break if username is None: for key in ("username", "user", "owner"): candidate = as_text(data.get(key)) if candidate: username = candidate break return token, username def resolve_repo_id(cfg: Dict[str, Any], username: Optional[str], output_root: Path) -> Optional[str]: repo_id = as_text(cfg.get("hub", {}).get("repo_id")) if repo_id: return repo_id if not username: return None return f"{username}/{output_root.name}" def stringify_structured(value: Any) -> str: if value is None: return "" if isinstance(value, str): text = value.strip() if not text: return "" try: parsed = json.loads(text) except json.JSONDecodeError: return text return json.dumps(parsed, ensure_ascii=False, sort_keys=True) return json.dumps(value, ensure_ascii=False, sort_keys=True) def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" prompt = as_text(row.get(prompt_field)) if not prompt: prompt = "Solve the math task." meta_fields = [ ("task_type", "Task type"), ("family", "Family"), ("difficulty", "Difficulty"), ("source_dataset", "Source"), ("status_as_of", "Status as of"), ] meta_lines = [] for key, label in meta_fields: value = as_text(row.get(key)) if value: meta_lines.append(f"{label}: {value}") tags = row.get("topic_tags") if isinstance(tags, list) and tags: tag_text = ", ".join(as_text(tag) for tag in tags if as_text(tag)) if tag_text: meta_lines.append(f"Tags: {tag_text}") if not meta_lines: return prompt return f"{prompt}\n\nMetadata:\n" + "\n".join(meta_lines) def build_answer_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: target_field = as_text(data_cfg.get("target_field")) or "target" final_answer_field = as_text(data_cfg.get("final_answer_field")) or "final_answer" proof_field = as_text(data_cfg.get("proof_field")) or "proof_formal" sections = [] target_text = stringify_structured(row.get(target_field)) if target_text: sections.append(f"Structured target:\n{target_text}") final_answer = stringify_structured(row.get(final_answer_field)) if final_answer: sections.append(f"Final answer:\n{final_answer}") proof_text = stringify_structured(row.get(proof_field)) if proof_text: sections.append(f"Formal proof snippet:\n{proof_text}") if not sections: sections.append("No structured target provided.") return "\n\n".join(sections).strip() def build_prompt_text(row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> str: system_prompt = as_text(data_cfg.get("system_prompt")) if not system_prompt: system_prompt = ( "You are a rigorous mathematical reasoning assistant focused on unsolved " "conjectures. Produce checkable reasoning." ) user_block = build_user_block(row, data_cfg) if getattr(tokenizer, "chat_template", None): messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_block}, ] return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n" def compute_loss_weight(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> float: sample_weight_field = as_text(data_cfg.get("sample_weight_field")) or "sample_weight" base = as_float(row.get(sample_weight_field), 1.0) family = as_text(row.get("family")) family_boost = data_cfg.get("family_boost", {}) if isinstance(family_boost, dict): base *= as_float(family_boost.get(family), 1.0) min_w = as_float(data_cfg.get("min_loss_weight"), 0.1) max_w = as_float(data_cfg.get("max_loss_weight"), 8.0) if min_w > max_w: min_w, max_w = max_w, min_w return max(min_w, min(max_w, base)) def stage_split_files(stage_cfg: Dict[str, Any], data_cfg: Dict[str, Any]) -> Dict[str, str]: train_file = as_text(stage_cfg.get("train_file")) or as_text(data_cfg.get("default_train_file")) valid_file = as_text(stage_cfg.get("validation_file")) or as_text(data_cfg.get("default_validation_file")) train_path = Path(train_file) valid_path = Path(valid_file) if not train_path.exists(): raise FileNotFoundError(f"Missing train split for stage: {train_path}") if not valid_path.exists(): raise FileNotFoundError(f"Missing validation split for stage: {valid_path}") return {"train": str(train_path), "validation": str(valid_path)} def apply_filters(dataset: Dataset, filter_cfg: Dict[str, Any]) -> Dataset: if not filter_cfg: return dataset include_families = set(filter_cfg.get("include_families", []) or []) exclude_families = set(filter_cfg.get("exclude_families", []) or []) include_task_types = set(filter_cfg.get("include_task_types", []) or []) source_datasets = set(filter_cfg.get("source_datasets", []) or []) require_conjecture_id = bool(filter_cfg.get("require_conjecture_id", False)) min_sample_weight = filter_cfg.get("min_sample_weight") min_sample_weight = as_float(min_sample_weight, 0.0) if min_sample_weight is not None else None def _keep(row: Dict[str, Any]) -> bool: family = as_text(row.get("family")) if include_families and family not in include_families: return False if exclude_families and family in exclude_families: return False if include_task_types: task_type = as_text(row.get("task_type")) if task_type not in include_task_types: return False if source_datasets: source = as_text(row.get("source_dataset")) if source not in source_datasets: return False if require_conjecture_id: conjecture_id = as_text(row.get("conjecture_id")) if not conjecture_id or conjecture_id.lower() == "null": return False if min_sample_weight is not None: sample_weight = as_float(row.get("sample_weight"), 0.0) if sample_weight < min_sample_weight: return False return True return dataset.filter(_keep, desc="Applying stage filters") def maybe_select(dataset: Dataset, max_samples: Optional[int]) -> Dataset: if max_samples is None: return dataset if max_samples <= 0: raise ValueError("max_samples must be positive.") if max_samples >= len(dataset): return dataset return dataset.select(range(max_samples)) def tokenize_datasets(raw: DatasetDict, tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> DatasetDict: max_len = as_int(data_cfg.get("max_seq_length"), 2048) if max_len < 64: raise ValueError("data.max_seq_length must be >= 64") eos = tokenizer.eos_token or "" remove_columns = raw["train"].column_names def _tokenize(row: Dict[str, Any]) -> Dict[str, Any]: prompt_text = build_prompt_text(row, tokenizer, data_cfg) answer_text = build_answer_block(row, data_cfg) full_text = f"{prompt_text}{answer_text}{eos}" prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] full_enc = tokenizer( full_text, add_special_tokens=False, truncation=True, max_length=max_len, ) input_ids = full_enc["input_ids"] attention_mask = full_enc["attention_mask"] if not input_ids: fallback = tokenizer.eos_token_id if fallback is None: fallback = tokenizer.pad_token_id if fallback is None: fallback = 0 input_ids = [fallback] attention_mask = [1] labels = [fallback] else: prompt_len = min(len(prompt_ids), len(input_ids)) labels = [-100] * prompt_len + input_ids[prompt_len:] if prompt_len >= len(input_ids): labels[-1] = input_ids[-1] loss_weight = compute_loss_weight(row, data_cfg) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "loss_weight": float(loss_weight), } tokenized = raw.map( _tokenize, remove_columns=remove_columns, desc="Tokenizing prompt/answer pairs", ) tokenized = tokenized.filter( lambda row: any(token != -100 for token in row["labels"]), desc="Dropping prompt-only rows", ) return tokenized def build_tokenizer(model_cfg: Dict[str, Any]) -> AutoTokenizer: base_model = as_text(model_cfg.get("base_model")) if not base_model: raise ValueError("model.base_model is required.") tokenizer = AutoTokenizer.from_pretrained( base_model, trust_remote_code=bool(model_cfg.get("trust_remote_code", False)), use_fast=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) return tokenizer def build_model_and_tokenizer(model_cfg: Dict[str, Any], training_defaults: Dict[str, Any]) -> Tuple[Any, AutoTokenizer]: base_model = as_text(model_cfg.get("base_model")) if not base_model: raise ValueError("model.base_model is required.") use_cuda = torch.cuda.is_available() requested_bf16 = bool(model_cfg.get("use_bf16", True)) if use_cuda: dtype = torch.bfloat16 if requested_bf16 else torch.float16 else: dtype = torch.float32 tokenizer = build_tokenizer(model_cfg) model_kwargs: Dict[str, Any] = { "trust_remote_code": bool(model_cfg.get("trust_remote_code", False)), "torch_dtype": dtype, } attn_impl = as_text(model_cfg.get("attn_implementation")) if attn_impl: model_kwargs["attn_implementation"] = attn_impl requested_load_in_4bit = bool(model_cfg.get("load_in_4bit", True)) load_in_4bit = requested_load_in_4bit and use_cuda if requested_load_in_4bit and not load_in_4bit: print("CUDA unavailable. Disabling 4-bit loading and using full-precision CPU fallback.") if load_in_4bit: model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4", bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)), bnb_4bit_compute_dtype=dtype, ) model_kwargs["device_map"] = "auto" model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs) if tokenizer.pad_token_id is not None: model.config.pad_token_id = tokenizer.pad_token_id model.config.use_cache = False if load_in_4bit: model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=bool(training_defaults.get("gradient_checkpointing", True)), ) lora_cfg = model_cfg.get("lora", {}) peft_cfg = LoraConfig( r=as_int(lora_cfg.get("r"), 64), lora_alpha=as_int(lora_cfg.get("alpha"), 128), lora_dropout=as_float(lora_cfg.get("dropout"), 0.05), bias=as_text(lora_cfg.get("bias")) or "none", task_type="CAUSAL_LM", target_modules=lora_cfg.get("target_modules"), ) model = get_peft_model(model, peft_cfg) model.print_trainable_parameters() return model, tokenizer class WeightedLossCollator: def __init__(self, tokenizer: AutoTokenizer, model: Any) -> None: self.base = DataCollatorForSeq2Seq( tokenizer=tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8, ) def __call__(self, features: list[Dict[str, Any]]) -> Dict[str, Any]: weights = [float(feature.pop("loss_weight", 1.0)) for feature in features] batch = self.base(features) batch["loss_weight"] = torch.tensor(weights, dtype=torch.float32) return batch class WeightedLossTrainer(Trainer): def _get_train_sampler(self): if self.train_dataset is None: return None if "loss_weight" not in self.train_dataset.column_names: return super()._get_train_sampler() weights = self.train_dataset["loss_weight"] if not weights: return super()._get_train_sampler() weight_tensor = torch.tensor(weights, dtype=torch.double) return WeightedRandomSampler( weights=weight_tensor, num_samples=len(weight_tensor), replacement=True, ) def compute_loss( self, model: Any, inputs: Dict[str, Any], return_outputs: bool = False, num_items_in_batch: Optional[torch.Tensor] = None, ): loss_weight = inputs.pop("loss_weight", None) labels = inputs.get("labels") if labels is None: return super().compute_loss( model=model, inputs=inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch, ) model_inputs = {k: v for k, v in inputs.items() if k != "labels"} outputs = model(**model_inputs) logits = outputs.logits shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() token_losses = torch.nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, reduction="none", ).view(shift_labels.size()) token_mask = shift_labels.ne(-100).float() seq_den = token_mask.sum(dim=1).clamp(min=1.0) seq_loss = (token_losses * token_mask).sum(dim=1) / seq_den if loss_weight is not None: normalized = loss_weight.to(seq_loss.device).float().clamp(min=0.05) loss = (seq_loss * normalized).sum() / normalized.sum() else: loss = seq_loss.mean() if return_outputs: return loss, outputs return loss def build_training_args( output_dir: Path, training_cfg: Dict[str, Any], use_bf16: bool, has_eval_split: bool, ) -> TrainingArguments: output_dir.mkdir(parents=True, exist_ok=True) use_cuda = torch.cuda.is_available() bf16_runtime = bool(use_cuda and use_bf16) fp16_runtime = bool(use_cuda and not bf16_runtime) return TrainingArguments( output_dir=str(output_dir), num_train_epochs=as_float(training_cfg.get("num_train_epochs"), 1.0), per_device_train_batch_size=as_int(training_cfg.get("per_device_train_batch_size"), 1), per_device_eval_batch_size=as_int(training_cfg.get("per_device_eval_batch_size"), 1), gradient_accumulation_steps=as_int(training_cfg.get("gradient_accumulation_steps"), 1), learning_rate=as_float(training_cfg.get("learning_rate"), 2e-5), weight_decay=as_float(training_cfg.get("weight_decay"), 0.0), warmup_ratio=as_float(training_cfg.get("warmup_ratio"), 0.0), lr_scheduler_type=as_text(training_cfg.get("lr_scheduler_type")) or "cosine", max_grad_norm=as_float(training_cfg.get("max_grad_norm"), 1.0), gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)), logging_steps=as_int(training_cfg.get("logging_steps"), 10), save_steps=as_int(training_cfg.get("save_steps"), 500), save_total_limit=as_int(training_cfg.get("save_total_limit"), 3), dataloader_num_workers=as_int(training_cfg.get("dataloader_num_workers"), 0), seed=as_int(training_cfg.get("seed"), 17), bf16=bf16_runtime, fp16=fp16_runtime, remove_unused_columns=False, report_to="none", evaluation_strategy="steps" if has_eval_split else "no", eval_steps=as_int(training_cfg.get("eval_steps"), 500) if has_eval_split else None, ) def push_folder( api: HfApi, repo_id: str, folder_path: Path, commit_message: str, path_in_repo: Optional[str] = None, ) -> None: kwargs: Dict[str, Any] = { "repo_id": repo_id, "repo_type": "model", "folder_path": str(folder_path), "commit_message": commit_message, } if path_in_repo: kwargs["path_in_repo"] = path_in_repo api.upload_folder(**kwargs) def extract_final_eval_loss(stage_reports: List[Dict[str, Any]]) -> Optional[float]: for report in reversed(stage_reports): eval_metrics = report.get("eval_metrics") if not isinstance(eval_metrics, dict): continue value = eval_metrics.get("eval_loss") if value is None: continue try: return float(value) except (TypeError, ValueError): continue return None def release_model_memory(model: Any) -> None: try: model.to("cpu") except Exception: pass if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def run_post_eval( cfg: Dict[str, Any], config_path: Path, output_root: Path, final_adapter_dir: Path, ) -> Optional[Dict[str, Any]]: post_cfg = cfg.get("post_eval", {}) if not as_bool(post_cfg.get("enabled"), False): return None eval_script = DEFAULT_EVAL_SCRIPT if not eval_script.exists(): raise FileNotFoundError(f"Post-eval enabled but eval script is missing: {eval_script}") data_cfg = cfg.get("data", {}) eval_file = Path( as_text(post_cfg.get("eval_file")) or as_text(data_cfg.get("default_validation_file")) or "data/releases/v1/test.parquet" ) if not eval_file.exists(): raise FileNotFoundError(f"Post-eval file not found: {eval_file}") output_json = Path(as_text(post_cfg.get("output_json")) or str(output_root / "post_eval_report.json")) base_model = as_text(cfg.get("model", {}).get("base_model")) if not base_model: raise ValueError("model.base_model is required for post-eval.") cmd = [ sys.executable, str(eval_script), "--config", str(config_path), "--base-model", base_model, "--adapter-path", str(final_adapter_dir), "--eval-file", str(eval_file), "--max-samples", str(as_int(post_cfg.get("max_samples"), 300)), "--k", str(as_int(post_cfg.get("k"), 4)), "--max-new-tokens", str(as_int(post_cfg.get("max_new_tokens"), 256)), "--temperature", str(as_float(post_cfg.get("temperature"), 0.7)), "--top-p", str(as_float(post_cfg.get("top_p"), 0.95)), "--seed", str(as_int(post_cfg.get("seed"), as_int(cfg.get("global", {}).get("seed"), 17))), "--output-json", str(output_json), ] print(f"Running post-training eval: {' '.join(cmd)}") completed = subprocess.run(cmd, check=False) if completed.returncode != 0: raise RuntimeError(f"Post-training evaluation failed with exit code {completed.returncode}.") if not output_json.exists(): raise FileNotFoundError(f"Post-eval report was not created: {output_json}") report = json.loads(output_json.read_text(encoding="utf-8")) return { "enabled": True, "report_path": str(output_json), "report": report, "command": cmd, } def evaluate_quality_gate( stage_reports: List[Dict[str, Any]], post_eval_result: Optional[Dict[str, Any]], gate_cfg: Dict[str, Any], ) -> Dict[str, Any]: enabled = as_bool(gate_cfg.get("enabled"), False) result: Dict[str, Any] = { "enabled": enabled, "passed": True, "violations": [], "checks": [], } if not enabled: return result violations: List[str] = [] checks: List[Dict[str, Any]] = [] final_eval_loss = extract_final_eval_loss(stage_reports) max_final_eval_loss = gate_cfg.get("max_final_eval_loss") if max_final_eval_loss is not None: threshold = as_float(max_final_eval_loss, 0.0) checks.append( { "name": "max_final_eval_loss", "actual": final_eval_loss, "threshold": threshold, } ) if final_eval_loss is None: violations.append("Final stage eval_loss is missing for max_final_eval_loss gate.") elif final_eval_loss > threshold: violations.append( f"Final eval_loss {final_eval_loss:.4f} exceeds threshold {threshold:.4f}." ) report: Optional[Dict[str, Any]] = None if isinstance(post_eval_result, dict): loaded = post_eval_result.get("report") if isinstance(loaded, dict): report = loaded require_post_eval = as_bool(gate_cfg.get("require_post_eval"), False) if report is None: if require_post_eval: violations.append("Quality gate requires post-eval metrics, but post-eval report is missing.") else: evaluated_rows = as_int(report.get("evaluated_rows"), 0) min_rows = as_int(gate_cfg.get("min_evaluated_rows"), 0) checks.append( { "name": "min_evaluated_rows", "actual": evaluated_rows, "threshold": min_rows, } ) if evaluated_rows < min_rows: violations.append( f"Post-eval rows {evaluated_rows} is below minimum required {min_rows}." ) min_pass_at_1_raw = gate_cfg.get("min_pass_at_1") if min_pass_at_1_raw is not None: min_pass_at_1 = as_float(min_pass_at_1_raw, 0.0) pass_at_1 = as_float(report.get("pass_at_1"), 0.0) checks.append( { "name": "min_pass_at_1", "actual": pass_at_1, "threshold": min_pass_at_1, } ) if pass_at_1 < min_pass_at_1: violations.append( f"pass@1 {pass_at_1:.4f} is below threshold {min_pass_at_1:.4f}." ) min_pass_at_k_raw = gate_cfg.get("min_pass_at_k") if min_pass_at_k_raw is not None: min_pass_at_k = as_float(min_pass_at_k_raw, 0.0) pass_at_k = as_float(report.get("pass_at_k"), 0.0) checks.append( { "name": "min_pass_at_k", "actual": pass_at_k, "threshold": min_pass_at_k, } ) if pass_at_k < min_pass_at_k: violations.append( f"pass@k {pass_at_k:.4f} is below threshold {min_pass_at_k:.4f}." ) family_requirements = gate_cfg.get("required_family_pass_at_k", {}) family_metrics = report.get("family_metrics", {}) if isinstance(family_requirements, dict): for family, threshold_raw in family_requirements.items(): threshold = as_float(threshold_raw, 0.0) actual = None if isinstance(family_metrics, dict): family_row = family_metrics.get(family) if isinstance(family_row, dict): try: actual = float(family_row.get("pass_at_k")) except (TypeError, ValueError): actual = None checks.append( { "name": f"family_pass_at_k:{family}", "actual": actual, "threshold": threshold, } ) if actual is None: violations.append(f"Missing pass@k metric for required family '{family}'.") elif actual < threshold: violations.append( f"Family '{family}' pass@k {actual:.4f} is below threshold {threshold:.4f}." ) result["violations"] = violations result["checks"] = checks result["passed"] = len(violations) == 0 return result def main() -> None: args = parse_args() cfg = load_config(args.config) apply_overrides(cfg, args) seed = as_int(cfg.get("global", {}).get("seed"), 17) set_seed(seed) output_root = Path(as_text(cfg.get("global", {}).get("output_root")) or "runs/math-conjecture-sota") output_root.mkdir(parents=True, exist_ok=True) token, username = resolve_auth(cfg) repo_id = resolve_repo_id(cfg, username=username, output_root=output_root) push_to_hub_requested = bool(cfg.get("hub", {}).get("push_to_hub", False)) if args.dry_run and push_to_hub_requested: print("Dry-run enabled. Disabling push_to_hub for this run.") push_to_hub_requested = push_to_hub_requested and not args.dry_run if push_to_hub_requested: if token is None: raise ValueError("Hub push requested but token is missing.") if repo_id is None: raise ValueError("Hub push requested but repo_id is missing.") if args.dry_run: tokenizer = build_tokenizer(cfg["model"]) model = None else: model, tokenizer = build_model_and_tokenizer(cfg["model"], cfg.get("training_defaults", {})) if torch.cuda.is_available(): print("Compute mode: GPU") else: print("Compute mode: CPU fallback (no CUDA detected)") data_cfg = cfg["data"] stage_reports: List[Dict[str, Any]] = [] start_stage = max(1, args.start_stage) stages = cfg["stages"] end_stage = len(stages) if args.max_stages is not None: if args.max_stages <= 0: raise ValueError("--max-stages must be positive.") end_stage = min(end_stage, start_stage + args.max_stages - 1) for index in range(start_stage, end_stage + 1): stage = stages[index - 1] stage_name = as_text(stage.get("name")) or f"stage_{index:02d}" stage_slug = f"{index:02d}_{stage_name.replace(' ', '_')}" stage_output_dir = output_root / stage_slug print(f"[stage {index}] Starting: {stage_name}") split_files = stage_split_files(stage, data_cfg) raw = load_dataset("parquet", data_files=split_files) train_rows_before = len(raw["train"]) valid_rows_before = len(raw["validation"]) filters = stage.get("filters", {}) raw["train"] = apply_filters(raw["train"], filters) raw["validation"] = apply_filters(raw["validation"], filters) train_rows_after_filter = len(raw["train"]) valid_rows_after_filter = len(raw["validation"]) raw["train"] = maybe_select(raw["train"], stage.get("max_train_samples")) raw["validation"] = maybe_select(raw["validation"], stage.get("max_eval_samples")) train_rows_selected = len(raw["train"]) valid_rows_selected = len(raw["validation"]) print( f"[stage {index}] rows train: {train_rows_before} -> {train_rows_after_filter} -> {train_rows_selected}; " f"validation: {valid_rows_before} -> {valid_rows_after_filter} -> {valid_rows_selected}" ) if len(raw["train"]) == 0: raise ValueError(f"Stage {stage_slug} has zero train rows after filtering.") if args.dry_run: sample_row = raw["train"][0] _ = build_prompt_text(sample_row, tokenizer, data_cfg) _ = build_answer_block(sample_row, data_cfg) stage_reports.append( { "stage_index": index, "stage_name": stage_name, "stage_slug": stage_slug, "mode": "dry_run", "train_rows_before_filter": train_rows_before, "validation_rows_before_filter": valid_rows_before, "train_rows_after_filter": train_rows_after_filter, "validation_rows_after_filter": valid_rows_after_filter, "train_rows_selected": train_rows_selected, "validation_rows_selected": valid_rows_selected, } ) print(f"[stage {index}] Dry-run checks passed.") continue tokenized = tokenize_datasets(raw, tokenizer, data_cfg) train_dataset = tokenized["train"] eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None merged_training = dict(cfg.get("training_defaults", {})) merged_training.update(stage.get("training", {})) merged_training["seed"] = seed training_args = build_training_args( output_dir=stage_output_dir, training_cfg=merged_training, use_bf16=bool(cfg["model"].get("use_bf16", True)), has_eval_split=eval_dataset is not None, ) collator = WeightedLossCollator(tokenizer=tokenizer, model=model) trainer = WeightedLossTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=collator, ) train_result = trainer.train() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() eval_metrics = None if eval_dataset is not None: eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) trainer.save_model(str(stage_output_dir)) tokenizer.save_pretrained(str(stage_output_dir)) stage_reports.append( { "stage_index": index, "stage_name": stage_name, "output_dir": str(stage_output_dir), "train_rows_before_filter": train_rows_before, "validation_rows_before_filter": valid_rows_before, "train_rows_after_filter": train_rows_after_filter, "validation_rows_after_filter": valid_rows_after_filter, "train_rows_selected": train_rows_selected, "validation_rows_selected": valid_rows_selected, "train_rows": len(train_dataset), "eval_rows": len(eval_dataset) if eval_dataset is not None else 0, "train_metrics": train_result.metrics, "eval_metrics": eval_metrics, } ) print( f"[stage {index}] Completed: train_rows={len(train_dataset)} " f"eval_rows={len(eval_dataset) if eval_dataset is not None else 0} output={stage_output_dir}" ) if args.dry_run: summary = { "mode": "dry_run", "config_path": str(args.config), "seed": seed, "start_stage": start_stage, "end_stage": end_stage, "stages_ran": stage_reports, } summary_path = output_root / "dry_run_summary.json" summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8") print("Dry-run complete. No training or model push was performed.") print(f"Dry-run summary: {summary_path}") return final_dir = output_root / "final_adapter" final_dir.mkdir(parents=True, exist_ok=True) assert model is not None model.save_pretrained(str(final_dir)) tokenizer.save_pretrained(str(final_dir)) release_model_memory(model) del model post_eval_result = run_post_eval( cfg=cfg, config_path=args.config, output_root=output_root, final_adapter_dir=final_dir, ) quality_gate = evaluate_quality_gate( stage_reports=stage_reports, post_eval_result=post_eval_result, gate_cfg=cfg.get("quality_gate", {}), ) push_to_hub_performed = push_to_hub_requested push_block_reason: Optional[str] = None if push_to_hub_requested and not quality_gate.get("passed", True): push_to_hub_performed = False push_block_reason = "quality_gate_failed" print("Quality gate failed; skipping hub push for this run.") summary: Dict[str, Any] = { "config_path": str(args.config), "repo_id": repo_id, "seed": seed, "stages_ran": stage_reports, "final_adapter_dir": str(final_dir), "quality_gate": quality_gate, "push": { "requested": bool(push_to_hub_requested), "performed": bool(push_to_hub_performed), "block_reason": push_block_reason, }, } if post_eval_result is not None: report = post_eval_result.get("report", {}) summary["post_eval"] = { "report_path": post_eval_result.get("report_path"), "evaluated_rows": report.get("evaluated_rows"), "k": report.get("k"), "pass_at_1": report.get("pass_at_1"), "pass_at_k": report.get("pass_at_k"), "exact_at_k": report.get("exact_at_k"), "composite_score": report.get("composite_score"), } summary_path = output_root / "training_summary.json" summary_path.write_text(json.dumps(summary, ensure_ascii=True, indent=2), encoding="utf-8") if push_to_hub_performed and repo_id is not None and token is not None: api = HfApi(token=token) api.create_repo( repo_id=repo_id, repo_type="model", private=bool(cfg.get("hub", {}).get("private", False)), exist_ok=True, ) commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload SOTA curriculum adapter." push_folder(api, repo_id, final_dir, commit_message=commit_message) if bool(cfg.get("hub", {}).get("upload_stage_checkpoints", False)): for report in stage_reports: stage_dir_raw = report.get("output_dir") if not stage_dir_raw: continue stage_dir = Path(stage_dir_raw) path_in_repo = f"checkpoints/{stage_dir.name}" push_folder( api, repo_id, stage_dir, commit_message=f"Upload stage checkpoint {report.get('stage_name', stage_dir.name)}", path_in_repo=path_in_repo, ) api.upload_file( path_or_fileobj=str(summary_path), path_in_repo="training_summary.json", repo_id=repo_id, repo_type="model", commit_message="Upload training summary for SOTA curriculum run.", ) if post_eval_result is not None and post_eval_result.get("report_path"): api.upload_file( path_or_fileobj=str(post_eval_result["report_path"]), path_in_repo="post_eval_report.json", repo_id=repo_id, repo_type="model", commit_message="Upload post-training evaluation report.", ) print(f"Pushed training artifacts to https://huggingface.co/{repo_id}") print(f"Training complete. Final adapter: {final_dir}") print(f"Training summary: {summary_path}") if __name__ == "__main__": main()