| | |
| | """Fine-tune DeepSeek-Math models on the conjecture-solution corpus.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import json |
| | import os |
| | from pathlib import Path |
| | from typing import Any, Dict, Optional, Tuple |
| |
|
| | import torch |
| | import yaml |
| | from datasets import Dataset, DatasetDict, load_dataset |
| | from huggingface_hub import HfApi |
| | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | BitsAndBytesConfig, |
| | DataCollatorForSeq2Seq, |
| | Trainer, |
| | TrainingArguments, |
| | set_seed, |
| | ) |
| |
|
| | DEFAULT_CONFIG_PATH = Path("model_development/configs/deepseek_math.yaml") |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | description="Supervised fine-tuning (LoRA/QLoRA) for DeepSeek-Math models." |
| | ) |
| | parser.add_argument( |
| | "--config", |
| | type=Path, |
| | default=DEFAULT_CONFIG_PATH, |
| | help="YAML config path.", |
| | ) |
| | parser.add_argument("--base-model", type=str, default=None, help="Override model.base_model.") |
| | parser.add_argument("--output-dir", type=Path, default=None, help="Override training.output_dir.") |
| | parser.add_argument("--max-train-samples", type=int, default=None, help="Optional train subset.") |
| | parser.add_argument("--max-eval-samples", type=int, default=None, help="Optional eval subset.") |
| | parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.") |
| | parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.") |
| | parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.") |
| | parser.add_argument( |
| | "--resume-from-checkpoint", |
| | type=str, |
| | default=None, |
| | help="Path to checkpoint for resume.", |
| | ) |
| | parser.add_argument( |
| | "--credentials-path", |
| | type=Path, |
| | default=None, |
| | help="Override credentials.path.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def as_text(value: Any) -> str: |
| | if value is None: |
| | return "" |
| | if isinstance(value, str): |
| | return value.strip() |
| | return str(value).strip() |
| |
|
| |
|
| | def load_config(path: Path) -> Dict[str, Any]: |
| | if not path.exists(): |
| | raise FileNotFoundError(f"Config not found: {path}") |
| | cfg = yaml.safe_load(path.read_text(encoding="utf-8")) |
| | if not isinstance(cfg, dict): |
| | raise ValueError(f"Invalid config format: {path}") |
| | for key in ("model", "data", "training"): |
| | if key not in cfg or not isinstance(cfg[key], dict): |
| | raise ValueError(f"Config missing section: {key}") |
| | cfg.setdefault("hub", {}) |
| | cfg.setdefault("credentials", {}) |
| | return cfg |
| |
|
| |
|
| | def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None: |
| | if args.base_model: |
| | cfg["model"]["base_model"] = args.base_model |
| | if args.output_dir is not None: |
| | cfg["training"]["output_dir"] = str(args.output_dir) |
| | if args.max_train_samples is not None: |
| | cfg["data"]["max_train_samples"] = args.max_train_samples |
| | if args.max_eval_samples is not None: |
| | cfg["data"]["max_eval_samples"] = args.max_eval_samples |
| | if args.repo_id: |
| | cfg.setdefault("hub", {})["repo_id"] = args.repo_id |
| | if args.credentials_path is not None: |
| | cfg.setdefault("credentials", {})["path"] = str(args.credentials_path) |
| | if args.push_to_hub and args.no_push_to_hub: |
| | raise ValueError("Cannot set both --push-to-hub and --no-push-to-hub.") |
| | if args.push_to_hub: |
| | cfg.setdefault("hub", {})["push_to_hub"] = True |
| | if args.no_push_to_hub: |
| | cfg.setdefault("hub", {})["push_to_hub"] = False |
| |
|
| |
|
| | def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]: |
| | token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None |
| | username = as_text(os.environ.get("HF_USERNAME")) or None |
| |
|
| | cred_path = as_text(cfg.get("credentials", {}).get("path")) |
| | if cred_path: |
| | path = Path(cred_path) |
| | if path.exists(): |
| | data = json.loads(path.read_text(encoding="utf-8")) |
| | if token is None: |
| | token = as_text(data.get("key")) or None |
| | if username is None: |
| | username = as_text(data.get("username")) or None |
| | return token, username |
| |
|
| |
|
| | def load_raw_datasets(data_cfg: Dict[str, Any]) -> DatasetDict: |
| | train_path = Path(as_text(data_cfg.get("train_file"))) |
| | valid_path = Path(as_text(data_cfg.get("validation_file"))) |
| | if not train_path.exists(): |
| | raise FileNotFoundError(f"Missing train split: {train_path}") |
| | if not valid_path.exists(): |
| | raise FileNotFoundError(f"Missing validation split: {valid_path}") |
| |
|
| | files = {"train": str(train_path), "validation": str(valid_path)} |
| | return load_dataset("parquet", data_files=files) |
| |
|
| |
|
| | def maybe_select(dataset: Dataset, max_samples: Optional[int]) -> Dataset: |
| | if max_samples is None: |
| | return dataset |
| | if max_samples <= 0: |
| | raise ValueError("max_samples must be positive.") |
| | if max_samples >= len(dataset): |
| | return dataset |
| | return dataset.select(range(max_samples)) |
| |
|
| |
|
| | def stringify_structured(value: Any) -> str: |
| | if value is None: |
| | return "" |
| | if isinstance(value, str): |
| | text = value.strip() |
| | if not text: |
| | return "" |
| | try: |
| | parsed = json.loads(text) |
| | except json.JSONDecodeError: |
| | return text |
| | return json.dumps(parsed, ensure_ascii=False, sort_keys=True) |
| | return json.dumps(value, ensure_ascii=False, sort_keys=True) |
| |
|
| |
|
| | def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: |
| | prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" |
| | prompt = as_text(row.get(prompt_field)) |
| | if not prompt: |
| | prompt = "Solve the math task." |
| |
|
| | meta_fields = [ |
| | ("task_type", "Task type"), |
| | ("family", "Family"), |
| | ("difficulty", "Difficulty"), |
| | ("source_dataset", "Source"), |
| | ("status_as_of", "Status as of"), |
| | ] |
| | meta_lines = [] |
| | for key, label in meta_fields: |
| | value = as_text(row.get(key)) |
| | if value: |
| | meta_lines.append(f"{label}: {value}") |
| | tags = row.get("topic_tags") |
| | if isinstance(tags, list) and tags: |
| | tag_text = ", ".join(as_text(tag) for tag in tags if as_text(tag)) |
| | if tag_text: |
| | meta_lines.append(f"Tags: {tag_text}") |
| |
|
| | if not meta_lines: |
| | return prompt |
| | return f"{prompt}\n\nMetadata:\n" + "\n".join(meta_lines) |
| |
|
| |
|
| | def build_answer_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: |
| | target_field = as_text(data_cfg.get("target_field")) or "target" |
| | final_answer_field = as_text(data_cfg.get("final_answer_field")) or "final_answer" |
| | proof_field = as_text(data_cfg.get("proof_field")) or "proof_formal" |
| |
|
| | sections = [] |
| | target_text = stringify_structured(row.get(target_field)) |
| | if target_text: |
| | sections.append(f"Structured target:\n{target_text}") |
| |
|
| | final_answer = stringify_structured(row.get(final_answer_field)) |
| | if final_answer: |
| | sections.append(f"Final answer:\n{final_answer}") |
| |
|
| | proof_text = stringify_structured(row.get(proof_field)) |
| | if proof_text: |
| | sections.append(f"Formal proof snippet:\n{proof_text}") |
| |
|
| | if not sections: |
| | sections.append("No structured target provided.") |
| | return "\n\n".join(sections).strip() |
| |
|
| |
|
| | def build_prompt_text( |
| | row: Dict[str, Any], |
| | tokenizer: AutoTokenizer, |
| | data_cfg: Dict[str, Any], |
| | ) -> str: |
| | system_prompt = as_text(data_cfg.get("system_prompt")) |
| | if not system_prompt: |
| | system_prompt = ( |
| | "You are a rigorous mathematical reasoning assistant focused on " |
| | "unsolved conjectures. Produce checkable reasoning." |
| | ) |
| | user_block = build_user_block(row, data_cfg) |
| | if getattr(tokenizer, "chat_template", None): |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": user_block}, |
| | ] |
| | return tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n" |
| |
|
| |
|
| | def tokenize_datasets( |
| | raw: DatasetDict, |
| | tokenizer: AutoTokenizer, |
| | data_cfg: Dict[str, Any], |
| | ) -> DatasetDict: |
| | max_len = int(data_cfg.get("max_seq_length", 2048)) |
| | if max_len < 64: |
| | raise ValueError("data.max_seq_length must be at least 64.") |
| |
|
| | eos = tokenizer.eos_token or "" |
| | remove_columns = raw["train"].column_names |
| |
|
| | def _tokenize(row: Dict[str, Any]) -> Dict[str, Any]: |
| | prompt_text = build_prompt_text(row, tokenizer, data_cfg) |
| | answer_text = build_answer_block(row, data_cfg) |
| | full_text = f"{prompt_text}{answer_text}{eos}" |
| |
|
| | prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] |
| | full_enc = tokenizer( |
| | full_text, |
| | add_special_tokens=False, |
| | truncation=True, |
| | max_length=max_len, |
| | ) |
| | input_ids = full_enc["input_ids"] |
| | attention_mask = full_enc["attention_mask"] |
| |
|
| | if not input_ids: |
| | fallback = tokenizer.eos_token_id |
| | if fallback is None: |
| | fallback = tokenizer.pad_token_id |
| | if fallback is None: |
| | fallback = 0 |
| | input_ids = [fallback] |
| | attention_mask = [1] |
| | labels = [fallback] |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "labels": labels, |
| | } |
| |
|
| | prompt_len = min(len(prompt_ids), len(input_ids)) |
| | labels = [-100] * prompt_len + input_ids[prompt_len:] |
| | if prompt_len >= len(input_ids): |
| | labels[-1] = input_ids[-1] |
| |
|
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "labels": labels, |
| | } |
| |
|
| | tokenized = raw.map( |
| | _tokenize, |
| | remove_columns=remove_columns, |
| | desc="Tokenizing prompt/answer pairs", |
| | ) |
| | tokenized = tokenized.filter( |
| | lambda row: any(token != -100 for token in row["labels"]), |
| | desc="Dropping prompt-only rows", |
| | ) |
| | return tokenized |
| |
|
| |
|
| | def build_model_and_tokenizer( |
| | model_cfg: Dict[str, Any], |
| | training_cfg: Dict[str, Any], |
| | ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: |
| | base_model = as_text(model_cfg.get("base_model")) |
| | if not base_model: |
| | raise ValueError("model.base_model is required.") |
| |
|
| | use_bf16 = bool(model_cfg.get("use_bf16", True)) |
| | dtype = torch.bfloat16 if use_bf16 else torch.float16 |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained( |
| | base_model, |
| | trust_remote_code=bool(model_cfg.get("trust_remote_code", False)), |
| | use_fast=True, |
| | ) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token |
| | if tokenizer.pad_token is None: |
| | tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) |
| |
|
| | model_kwargs: Dict[str, Any] = { |
| | "trust_remote_code": bool(model_cfg.get("trust_remote_code", False)), |
| | "torch_dtype": dtype, |
| | } |
| | attn_impl = as_text(model_cfg.get("attn_implementation")) |
| | if attn_impl: |
| | model_kwargs["attn_implementation"] = attn_impl |
| |
|
| | load_in_4bit = bool(model_cfg.get("load_in_4bit", True)) |
| | if load_in_4bit: |
| | if not torch.cuda.is_available(): |
| | raise RuntimeError("4-bit loading requested but CUDA is not available.") |
| | model_kwargs["quantization_config"] = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4", |
| | bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)), |
| | bnb_4bit_compute_dtype=dtype, |
| | ) |
| | model_kwargs["device_map"] = "auto" |
| |
|
| | model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs) |
| | if tokenizer.pad_token_id is not None: |
| | model.config.pad_token_id = tokenizer.pad_token_id |
| | model.config.use_cache = False |
| |
|
| | if load_in_4bit: |
| | model = prepare_model_for_kbit_training( |
| | model, |
| | use_gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)), |
| | ) |
| |
|
| | lora_cfg = model_cfg.get("lora", {}) |
| | peft_cfg = LoraConfig( |
| | r=int(lora_cfg.get("r", 64)), |
| | lora_alpha=int(lora_cfg.get("alpha", 128)), |
| | lora_dropout=float(lora_cfg.get("dropout", 0.05)), |
| | bias=as_text(lora_cfg.get("bias")) or "none", |
| | task_type="CAUSAL_LM", |
| | target_modules=lora_cfg.get("target_modules"), |
| | ) |
| | model = get_peft_model(model, peft_cfg) |
| | model.print_trainable_parameters() |
| | return model, tokenizer |
| |
|
| |
|
| | def build_training_args( |
| | cfg: Dict[str, Any], |
| | has_eval_split: bool, |
| | ) -> TrainingArguments: |
| | model_cfg = cfg["model"] |
| | training_cfg = cfg["training"] |
| |
|
| | use_bf16 = bool(model_cfg.get("use_bf16", True)) |
| | output_dir = Path(as_text(training_cfg.get("output_dir"))) |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | return TrainingArguments( |
| | output_dir=str(output_dir), |
| | num_train_epochs=float(training_cfg.get("num_train_epochs", 1)), |
| | per_device_train_batch_size=int(training_cfg.get("per_device_train_batch_size", 1)), |
| | per_device_eval_batch_size=int(training_cfg.get("per_device_eval_batch_size", 1)), |
| | gradient_accumulation_steps=int(training_cfg.get("gradient_accumulation_steps", 1)), |
| | learning_rate=float(training_cfg.get("learning_rate", 2e-5)), |
| | weight_decay=float(training_cfg.get("weight_decay", 0.0)), |
| | warmup_ratio=float(training_cfg.get("warmup_ratio", 0.0)), |
| | lr_scheduler_type=as_text(training_cfg.get("lr_scheduler_type")) or "cosine", |
| | max_grad_norm=float(training_cfg.get("max_grad_norm", 1.0)), |
| | gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)), |
| | logging_steps=int(training_cfg.get("logging_steps", 10)), |
| | save_steps=int(training_cfg.get("save_steps", 250)), |
| | save_total_limit=int(training_cfg.get("save_total_limit", 3)), |
| | dataloader_num_workers=int(training_cfg.get("dataloader_num_workers", 0)), |
| | seed=int(training_cfg.get("seed", 17)), |
| | bf16=use_bf16, |
| | fp16=not use_bf16, |
| | remove_unused_columns=False, |
| | report_to="none", |
| | evaluation_strategy="steps" if has_eval_split else "no", |
| | eval_steps=int(training_cfg.get("eval_steps", 250)) if has_eval_split else None, |
| | ) |
| |
|
| |
|
| | def resolve_repo_id( |
| | cfg: Dict[str, Any], |
| | username: Optional[str], |
| | ) -> Optional[str]: |
| | repo_id = as_text(cfg.get("hub", {}).get("repo_id")) |
| | if repo_id: |
| | return repo_id |
| | if not username: |
| | return None |
| | output_dir = Path(as_text(cfg["training"].get("output_dir"))) |
| | return f"{username}/{output_dir.name}" |
| |
|
| |
|
| | def push_output_to_hub( |
| | output_dir: Path, |
| | repo_id: str, |
| | token: str, |
| | private: bool, |
| | commit_message: str, |
| | ) -> None: |
| | api = HfApi(token=token) |
| | api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) |
| | api.upload_folder( |
| | repo_id=repo_id, |
| | repo_type="model", |
| | folder_path=str(output_dir), |
| | commit_message=commit_message, |
| | ) |
| |
|
| |
|
| | def save_resolved_config( |
| | cfg: Dict[str, Any], |
| | output_dir: Path, |
| | config_path: Path, |
| | ) -> None: |
| | serializable = json.loads(json.dumps(cfg)) |
| | serializable["resolved_from"] = str(config_path) |
| | out_path = output_dir / "resolved_training_config.json" |
| | out_path.write_text(json.dumps(serializable, ensure_ascii=True, indent=2), encoding="utf-8") |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | cfg = load_config(args.config) |
| | apply_overrides(cfg, args) |
| |
|
| | training_cfg = cfg["training"] |
| | seed = int(training_cfg.get("seed", 17)) |
| | set_seed(seed) |
| |
|
| | token, username = resolve_auth(cfg) |
| | push_to_hub = bool(cfg.get("hub", {}).get("push_to_hub", False)) |
| | repo_id = resolve_repo_id(cfg, username) |
| | if push_to_hub: |
| | if token is None: |
| | raise ValueError( |
| | "Hub push requested but no token found. Set HF_TOKEN or credentials.path." |
| | ) |
| | if repo_id is None: |
| | raise ValueError( |
| | "Hub push requested but repo_id is empty and username is unavailable." |
| | ) |
| |
|
| | model, tokenizer = build_model_and_tokenizer(cfg["model"], training_cfg) |
| |
|
| | raw = load_raw_datasets(cfg["data"]) |
| | raw["train"] = maybe_select(raw["train"], cfg["data"].get("max_train_samples")) |
| | raw["validation"] = maybe_select(raw["validation"], cfg["data"].get("max_eval_samples")) |
| |
|
| | tokenized = tokenize_datasets(raw, tokenizer, cfg["data"]) |
| | train_dataset = tokenized["train"] |
| | eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None |
| |
|
| | training_args = build_training_args(cfg, has_eval_split=eval_dataset is not None) |
| | data_collator = DataCollatorForSeq2Seq( |
| | tokenizer=tokenizer, |
| | model=model, |
| | label_pad_token_id=-100, |
| | pad_to_multiple_of=8, |
| | ) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | eval_dataset=eval_dataset, |
| | tokenizer=tokenizer, |
| | data_collator=data_collator, |
| | ) |
| |
|
| | train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) |
| | trainer.log_metrics("train", train_result.metrics) |
| | trainer.save_metrics("train", train_result.metrics) |
| | trainer.save_state() |
| |
|
| | if eval_dataset is not None: |
| | eval_metrics = trainer.evaluate() |
| | trainer.log_metrics("eval", eval_metrics) |
| | trainer.save_metrics("eval", eval_metrics) |
| |
|
| | trainer.save_model(training_args.output_dir) |
| | tokenizer.save_pretrained(training_args.output_dir) |
| |
|
| | output_dir = Path(training_args.output_dir) |
| | save_resolved_config(cfg, output_dir, args.config) |
| |
|
| | if push_to_hub and repo_id is not None and token is not None: |
| | commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload fine-tuned model." |
| | private = bool(cfg.get("hub", {}).get("private", False)) |
| | push_output_to_hub(output_dir, repo_id, token, private, commit_message) |
| | print(f"Pushed model artifacts to https://huggingface.co/{repo_id}") |
| |
|
| | print(f"Training finished. Output saved to: {output_dir}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|