#!/usr/bin/env python3 """Fine-tune DeepSeek-Math models on the conjecture-solution corpus.""" from __future__ import annotations import argparse import json import os from pathlib import Path from typing import Any, Dict, Optional, Tuple import torch import yaml from datasets import Dataset, DatasetDict, load_dataset from huggingface_hub import HfApi from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorForSeq2Seq, Trainer, TrainingArguments, set_seed, ) DEFAULT_CONFIG_PATH = Path("model_development/configs/deepseek_math.yaml") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Supervised fine-tuning (LoRA/QLoRA) for DeepSeek-Math models." ) parser.add_argument( "--config", type=Path, default=DEFAULT_CONFIG_PATH, help="YAML config path.", ) parser.add_argument("--base-model", type=str, default=None, help="Override model.base_model.") parser.add_argument("--output-dir", type=Path, default=None, help="Override training.output_dir.") parser.add_argument("--max-train-samples", type=int, default=None, help="Optional train subset.") parser.add_argument("--max-eval-samples", type=int, default=None, help="Optional eval subset.") parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.") parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.") parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.") parser.add_argument( "--resume-from-checkpoint", type=str, default=None, help="Path to checkpoint for resume.", ) parser.add_argument( "--credentials-path", type=Path, default=None, help="Override credentials.path.", ) return parser.parse_args() def as_text(value: Any) -> str: if value is None: return "" if isinstance(value, str): return value.strip() return str(value).strip() def load_config(path: Path) -> Dict[str, Any]: if not path.exists(): raise FileNotFoundError(f"Config not found: {path}") cfg = yaml.safe_load(path.read_text(encoding="utf-8")) if not isinstance(cfg, dict): raise ValueError(f"Invalid config format: {path}") for key in ("model", "data", "training"): if key not in cfg or not isinstance(cfg[key], dict): raise ValueError(f"Config missing section: {key}") cfg.setdefault("hub", {}) cfg.setdefault("credentials", {}) return cfg def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None: if args.base_model: cfg["model"]["base_model"] = args.base_model if args.output_dir is not None: cfg["training"]["output_dir"] = str(args.output_dir) if args.max_train_samples is not None: cfg["data"]["max_train_samples"] = args.max_train_samples if args.max_eval_samples is not None: cfg["data"]["max_eval_samples"] = args.max_eval_samples if args.repo_id: cfg.setdefault("hub", {})["repo_id"] = args.repo_id if args.credentials_path is not None: cfg.setdefault("credentials", {})["path"] = str(args.credentials_path) if args.push_to_hub and args.no_push_to_hub: raise ValueError("Cannot set both --push-to-hub and --no-push-to-hub.") if args.push_to_hub: cfg.setdefault("hub", {})["push_to_hub"] = True if args.no_push_to_hub: cfg.setdefault("hub", {})["push_to_hub"] = False def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]: token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None username = as_text(os.environ.get("HF_USERNAME")) or None cred_path = as_text(cfg.get("credentials", {}).get("path")) if cred_path: path = Path(cred_path) if path.exists(): data = json.loads(path.read_text(encoding="utf-8")) if token is None: token = as_text(data.get("key")) or None if username is None: username = as_text(data.get("username")) or None return token, username def load_raw_datasets(data_cfg: Dict[str, Any]) -> DatasetDict: train_path = Path(as_text(data_cfg.get("train_file"))) valid_path = Path(as_text(data_cfg.get("validation_file"))) if not train_path.exists(): raise FileNotFoundError(f"Missing train split: {train_path}") if not valid_path.exists(): raise FileNotFoundError(f"Missing validation split: {valid_path}") files = {"train": str(train_path), "validation": str(valid_path)} return load_dataset("parquet", data_files=files) def maybe_select(dataset: Dataset, max_samples: Optional[int]) -> Dataset: if max_samples is None: return dataset if max_samples <= 0: raise ValueError("max_samples must be positive.") if max_samples >= len(dataset): return dataset return dataset.select(range(max_samples)) def stringify_structured(value: Any) -> str: if value is None: return "" if isinstance(value, str): text = value.strip() if not text: return "" try: parsed = json.loads(text) except json.JSONDecodeError: return text return json.dumps(parsed, ensure_ascii=False, sort_keys=True) return json.dumps(value, ensure_ascii=False, sort_keys=True) def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" prompt = as_text(row.get(prompt_field)) if not prompt: prompt = "Solve the math task." meta_fields = [ ("task_type", "Task type"), ("family", "Family"), ("difficulty", "Difficulty"), ("source_dataset", "Source"), ("status_as_of", "Status as of"), ] meta_lines = [] for key, label in meta_fields: value = as_text(row.get(key)) if value: meta_lines.append(f"{label}: {value}") tags = row.get("topic_tags") if isinstance(tags, list) and tags: tag_text = ", ".join(as_text(tag) for tag in tags if as_text(tag)) if tag_text: meta_lines.append(f"Tags: {tag_text}") if not meta_lines: return prompt return f"{prompt}\n\nMetadata:\n" + "\n".join(meta_lines) def build_answer_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: target_field = as_text(data_cfg.get("target_field")) or "target" final_answer_field = as_text(data_cfg.get("final_answer_field")) or "final_answer" proof_field = as_text(data_cfg.get("proof_field")) or "proof_formal" sections = [] target_text = stringify_structured(row.get(target_field)) if target_text: sections.append(f"Structured target:\n{target_text}") final_answer = stringify_structured(row.get(final_answer_field)) if final_answer: sections.append(f"Final answer:\n{final_answer}") proof_text = stringify_structured(row.get(proof_field)) if proof_text: sections.append(f"Formal proof snippet:\n{proof_text}") if not sections: sections.append("No structured target provided.") return "\n\n".join(sections).strip() def build_prompt_text( row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any], ) -> str: system_prompt = as_text(data_cfg.get("system_prompt")) if not system_prompt: system_prompt = ( "You are a rigorous mathematical reasoning assistant focused on " "unsolved conjectures. Produce checkable reasoning." ) user_block = build_user_block(row, data_cfg) if getattr(tokenizer, "chat_template", None): messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_block}, ] return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n" def tokenize_datasets( raw: DatasetDict, tokenizer: AutoTokenizer, data_cfg: Dict[str, Any], ) -> DatasetDict: max_len = int(data_cfg.get("max_seq_length", 2048)) if max_len < 64: raise ValueError("data.max_seq_length must be at least 64.") eos = tokenizer.eos_token or "" remove_columns = raw["train"].column_names def _tokenize(row: Dict[str, Any]) -> Dict[str, Any]: prompt_text = build_prompt_text(row, tokenizer, data_cfg) answer_text = build_answer_block(row, data_cfg) full_text = f"{prompt_text}{answer_text}{eos}" prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] full_enc = tokenizer( full_text, add_special_tokens=False, truncation=True, max_length=max_len, ) input_ids = full_enc["input_ids"] attention_mask = full_enc["attention_mask"] if not input_ids: fallback = tokenizer.eos_token_id if fallback is None: fallback = tokenizer.pad_token_id if fallback is None: fallback = 0 input_ids = [fallback] attention_mask = [1] labels = [fallback] return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } prompt_len = min(len(prompt_ids), len(input_ids)) labels = [-100] * prompt_len + input_ids[prompt_len:] if prompt_len >= len(input_ids): labels[-1] = input_ids[-1] return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } tokenized = raw.map( _tokenize, remove_columns=remove_columns, desc="Tokenizing prompt/answer pairs", ) tokenized = tokenized.filter( lambda row: any(token != -100 for token in row["labels"]), desc="Dropping prompt-only rows", ) return tokenized def build_model_and_tokenizer( model_cfg: Dict[str, Any], training_cfg: Dict[str, Any], ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: base_model = as_text(model_cfg.get("base_model")) if not base_model: raise ValueError("model.base_model is required.") use_bf16 = bool(model_cfg.get("use_bf16", True)) dtype = torch.bfloat16 if use_bf16 else torch.float16 tokenizer = AutoTokenizer.from_pretrained( base_model, trust_remote_code=bool(model_cfg.get("trust_remote_code", False)), use_fast=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) model_kwargs: Dict[str, Any] = { "trust_remote_code": bool(model_cfg.get("trust_remote_code", False)), "torch_dtype": dtype, } attn_impl = as_text(model_cfg.get("attn_implementation")) if attn_impl: model_kwargs["attn_implementation"] = attn_impl load_in_4bit = bool(model_cfg.get("load_in_4bit", True)) if load_in_4bit: if not torch.cuda.is_available(): raise RuntimeError("4-bit loading requested but CUDA is not available.") model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4", bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)), bnb_4bit_compute_dtype=dtype, ) model_kwargs["device_map"] = "auto" model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs) if tokenizer.pad_token_id is not None: model.config.pad_token_id = tokenizer.pad_token_id model.config.use_cache = False if load_in_4bit: model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)), ) lora_cfg = model_cfg.get("lora", {}) peft_cfg = LoraConfig( r=int(lora_cfg.get("r", 64)), lora_alpha=int(lora_cfg.get("alpha", 128)), lora_dropout=float(lora_cfg.get("dropout", 0.05)), bias=as_text(lora_cfg.get("bias")) or "none", task_type="CAUSAL_LM", target_modules=lora_cfg.get("target_modules"), ) model = get_peft_model(model, peft_cfg) model.print_trainable_parameters() return model, tokenizer def build_training_args( cfg: Dict[str, Any], has_eval_split: bool, ) -> TrainingArguments: model_cfg = cfg["model"] training_cfg = cfg["training"] use_bf16 = bool(model_cfg.get("use_bf16", True)) output_dir = Path(as_text(training_cfg.get("output_dir"))) output_dir.mkdir(parents=True, exist_ok=True) return TrainingArguments( output_dir=str(output_dir), num_train_epochs=float(training_cfg.get("num_train_epochs", 1)), per_device_train_batch_size=int(training_cfg.get("per_device_train_batch_size", 1)), per_device_eval_batch_size=int(training_cfg.get("per_device_eval_batch_size", 1)), gradient_accumulation_steps=int(training_cfg.get("gradient_accumulation_steps", 1)), learning_rate=float(training_cfg.get("learning_rate", 2e-5)), weight_decay=float(training_cfg.get("weight_decay", 0.0)), warmup_ratio=float(training_cfg.get("warmup_ratio", 0.0)), lr_scheduler_type=as_text(training_cfg.get("lr_scheduler_type")) or "cosine", max_grad_norm=float(training_cfg.get("max_grad_norm", 1.0)), gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)), logging_steps=int(training_cfg.get("logging_steps", 10)), save_steps=int(training_cfg.get("save_steps", 250)), save_total_limit=int(training_cfg.get("save_total_limit", 3)), dataloader_num_workers=int(training_cfg.get("dataloader_num_workers", 0)), seed=int(training_cfg.get("seed", 17)), bf16=use_bf16, fp16=not use_bf16, remove_unused_columns=False, report_to="none", evaluation_strategy="steps" if has_eval_split else "no", eval_steps=int(training_cfg.get("eval_steps", 250)) if has_eval_split else None, ) def resolve_repo_id( cfg: Dict[str, Any], username: Optional[str], ) -> Optional[str]: repo_id = as_text(cfg.get("hub", {}).get("repo_id")) if repo_id: return repo_id if not username: return None output_dir = Path(as_text(cfg["training"].get("output_dir"))) return f"{username}/{output_dir.name}" def push_output_to_hub( output_dir: Path, repo_id: str, token: str, private: bool, commit_message: str, ) -> None: api = HfApi(token=token) api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) api.upload_folder( repo_id=repo_id, repo_type="model", folder_path=str(output_dir), commit_message=commit_message, ) def save_resolved_config( cfg: Dict[str, Any], output_dir: Path, config_path: Path, ) -> None: serializable = json.loads(json.dumps(cfg)) serializable["resolved_from"] = str(config_path) out_path = output_dir / "resolved_training_config.json" out_path.write_text(json.dumps(serializable, ensure_ascii=True, indent=2), encoding="utf-8") def main() -> None: args = parse_args() cfg = load_config(args.config) apply_overrides(cfg, args) training_cfg = cfg["training"] seed = int(training_cfg.get("seed", 17)) set_seed(seed) token, username = resolve_auth(cfg) push_to_hub = bool(cfg.get("hub", {}).get("push_to_hub", False)) repo_id = resolve_repo_id(cfg, username) if push_to_hub: if token is None: raise ValueError( "Hub push requested but no token found. Set HF_TOKEN or credentials.path." ) if repo_id is None: raise ValueError( "Hub push requested but repo_id is empty and username is unavailable." ) model, tokenizer = build_model_and_tokenizer(cfg["model"], training_cfg) raw = load_raw_datasets(cfg["data"]) raw["train"] = maybe_select(raw["train"], cfg["data"].get("max_train_samples")) raw["validation"] = maybe_select(raw["validation"], cfg["data"].get("max_eval_samples")) tokenized = tokenize_datasets(raw, tokenizer, cfg["data"]) train_dataset = tokenized["train"] eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None training_args = build_training_args(cfg, has_eval_split=eval_dataset is not None) data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, model=model, label_pad_token_id=-100, pad_to_multiple_of=8, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator, ) train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() if eval_dataset is not None: eval_metrics = trainer.evaluate() trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) trainer.save_model(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir) output_dir = Path(training_args.output_dir) save_resolved_config(cfg, output_dir, args.config) if push_to_hub and repo_id is not None and token is not None: commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload fine-tuned model." private = bool(cfg.get("hub", {}).get("private", False)) push_output_to_hub(output_dir, repo_id, token, private, commit_message) print(f"Pushed model artifacts to https://huggingface.co/{repo_id}") print(f"Training finished. Output saved to: {output_dir}") if __name__ == "__main__": main()