NorthernTribe-Research commited on
Commit
90dacf5
·
verified ·
1 Parent(s): 6c608fd

Rename model repo target to math-conjecture-model and upload pipeline.

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ runs/
2
+ merged/
3
+ cache/
4
+ __pycache__/
5
+ scripts/__pycache__/
README.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ library_name: transformers
5
+ pipeline_tag: text-generation
6
+ tags:
7
+ - mathematics
8
+ - conjecture-reasoning
9
+ - deepseek-math
10
+ - lora
11
+ base_model:
12
+ - deepseek-ai/deepseek-math-7b-instruct
13
+ - deepseek-ai/deepseek-math-v2
14
+ datasets:
15
+ - NorthernTribe-Research/math-conjecture-training-corpus
16
+ ---
17
+
18
+ # Model Development (DeepSeek-Math)
19
+
20
+ This folder contains the fine-tuning pipeline for building a conjecture-solution
21
+ model from the merged dataset in `data/releases/v1/`.
22
+
23
+ ## What is included
24
+
25
+ - `configs/deepseek_math.yaml`: preset for `DeepSeek-Math`
26
+ - `configs/deepseek_math_v2.yaml`: preset for `DeepSeek-Math-V2`
27
+ - `scripts/train_sft.py`: LoRA/QLoRA supervised fine-tuning + optional Hub push
28
+ - `scripts/merge_and_push.py`: optional adapter merge into full weights + Hub push
29
+ - `requirements.txt`: model-training dependencies
30
+
31
+ ## Setup
32
+
33
+ ```bash
34
+ .venv/bin/python -m pip install -r model_development/requirements.txt
35
+ ```
36
+
37
+ ## Fine-tune DeepSeek-Math
38
+
39
+ ```bash
40
+ .venv/bin/python model_development/scripts/train_sft.py \
41
+ --config model_development/configs/deepseek_math.yaml
42
+ ```
43
+
44
+ ## Fine-tune DeepSeek-Math-V2
45
+
46
+ ```bash
47
+ .venv/bin/python model_development/scripts/train_sft.py \
48
+ --config model_development/configs/deepseek_math_v2.yaml
49
+ ```
50
+
51
+ ## Important notes
52
+
53
+ - Both presets point to `data/releases/v1/train.parquet` and
54
+ `data/releases/v1/validation.parquet`.
55
+ - If your exact v2 checkpoint id differs, update `model.base_model` in
56
+ `model_development/configs/deepseek_math_v2.yaml`.
57
+ - Hub auth uses `HF_TOKEN` first, then `huggingface-api-key.json`.
58
+ - If `hub.repo_id` is empty, repo name defaults to
59
+ `<username>/<output_dir_name>`.
60
+
61
+ ## Optional: merge LoRA adapter into full model
62
+
63
+ ```bash
64
+ .venv/bin/python model_development/scripts/merge_and_push.py \
65
+ --adapter-path model_development/runs/deepseek-math-lora \
66
+ --output-dir model_development/merged/math-conjecture-model \
67
+ --push-to-hub \
68
+ --repo-id NorthernTribe-Research/math-conjecture-model
69
+ ```
configs/deepseek_math.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_model: deepseek-ai/deepseek-math-7b-instruct
3
+ trust_remote_code: true
4
+ use_bf16: true
5
+ load_in_4bit: true
6
+ bnb_4bit_quant_type: nf4
7
+ bnb_4bit_use_double_quant: true
8
+ attn_implementation: null
9
+ lora:
10
+ r: 64
11
+ alpha: 128
12
+ dropout: 0.05
13
+ bias: none
14
+ target_modules:
15
+ - q_proj
16
+ - k_proj
17
+ - v_proj
18
+ - o_proj
19
+ - gate_proj
20
+ - up_proj
21
+ - down_proj
22
+
23
+ data:
24
+ train_file: data/releases/v1/train.parquet
25
+ validation_file: data/releases/v1/validation.parquet
26
+ prompt_field: prompt
27
+ target_field: target
28
+ final_answer_field: final_answer
29
+ proof_field: proof_formal
30
+ max_seq_length: 2048
31
+ max_train_samples: null
32
+ max_eval_samples: 2000
33
+ system_prompt: |
34
+ You are a rigorous mathematical reasoning assistant specialized in unsolved
35
+ conjectures. Produce clear, checkable reasoning and avoid claiming a full
36
+ proof unless it is explicitly available in the task context.
37
+
38
+ training:
39
+ output_dir: model_development/runs/deepseek-math-lora
40
+ num_train_epochs: 1
41
+ per_device_train_batch_size: 1
42
+ per_device_eval_batch_size: 1
43
+ gradient_accumulation_steps: 16
44
+ learning_rate: 2.0e-5
45
+ weight_decay: 0.01
46
+ warmup_ratio: 0.03
47
+ lr_scheduler_type: cosine
48
+ max_grad_norm: 1.0
49
+ gradient_checkpointing: true
50
+ logging_steps: 10
51
+ save_steps: 250
52
+ eval_steps: 250
53
+ save_total_limit: 3
54
+ dataloader_num_workers: 2
55
+ seed: 17
56
+
57
+ hub:
58
+ push_to_hub: true
59
+ repo_id: NorthernTribe-Research/math-conjecture-model
60
+ private: false
61
+ commit_message: Train DeepSeek-Math LoRA on conjecture corpus.
62
+
63
+ credentials:
64
+ path: huggingface-api-key.json
configs/deepseek_math_v2.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_model: deepseek-ai/deepseek-math-v2
3
+ trust_remote_code: true
4
+ use_bf16: true
5
+ load_in_4bit: true
6
+ bnb_4bit_quant_type: nf4
7
+ bnb_4bit_use_double_quant: true
8
+ attn_implementation: null
9
+ lora:
10
+ r: 64
11
+ alpha: 128
12
+ dropout: 0.05
13
+ bias: none
14
+ target_modules:
15
+ - q_proj
16
+ - k_proj
17
+ - v_proj
18
+ - o_proj
19
+ - gate_proj
20
+ - up_proj
21
+ - down_proj
22
+
23
+ data:
24
+ train_file: data/releases/v1/train.parquet
25
+ validation_file: data/releases/v1/validation.parquet
26
+ prompt_field: prompt
27
+ target_field: target
28
+ final_answer_field: final_answer
29
+ proof_field: proof_formal
30
+ max_seq_length: 2048
31
+ max_train_samples: null
32
+ max_eval_samples: 2000
33
+ system_prompt: |
34
+ You are a rigorous mathematical reasoning assistant specialized in unsolved
35
+ conjectures. Focus on conjecture-aware strategy, partial progress, and
36
+ precise formal statements.
37
+
38
+ training:
39
+ output_dir: model_development/runs/deepseek-math-v2-lora
40
+ num_train_epochs: 1
41
+ per_device_train_batch_size: 1
42
+ per_device_eval_batch_size: 1
43
+ gradient_accumulation_steps: 16
44
+ learning_rate: 2.0e-5
45
+ weight_decay: 0.01
46
+ warmup_ratio: 0.03
47
+ lr_scheduler_type: cosine
48
+ max_grad_norm: 1.0
49
+ gradient_checkpointing: true
50
+ logging_steps: 10
51
+ save_steps: 250
52
+ eval_steps: 250
53
+ save_total_limit: 3
54
+ dataloader_num_workers: 2
55
+ seed: 17
56
+
57
+ hub:
58
+ push_to_hub: true
59
+ repo_id: NorthernTribe-Research/math-conjecture-model
60
+ private: false
61
+ commit_message: Train DeepSeek-Math-V2 LoRA on conjecture corpus.
62
+
63
+ credentials:
64
+ path: huggingface-api-key.json
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.3.0
2
+ transformers>=4.48.0
3
+ accelerate>=1.1.0
4
+ datasets>=2.21.0
5
+ peft>=0.14.0
6
+ bitsandbytes>=0.45.0
7
+ huggingface_hub>=0.26.0
8
+ pyyaml>=6.0.2
scripts/merge_and_push.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Merge a LoRA adapter into a full model and optionally push to Hugging Face."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Optional, Tuple
11
+
12
+ import torch
13
+ from huggingface_hub import HfApi
14
+ from peft import AutoPeftModelForCausalLM
15
+ from transformers import AutoTokenizer
16
+
17
+
18
+ def parse_args() -> argparse.Namespace:
19
+ parser = argparse.ArgumentParser(
20
+ description="Merge a PEFT adapter into base weights and publish the merged model."
21
+ )
22
+ parser.add_argument(
23
+ "--adapter-path",
24
+ type=Path,
25
+ required=True,
26
+ help="Directory containing adapter_model.safetensors + adapter_config.json.",
27
+ )
28
+ parser.add_argument(
29
+ "--output-dir",
30
+ type=Path,
31
+ required=True,
32
+ help="Directory where merged weights are saved.",
33
+ )
34
+ parser.add_argument("--repo-id", type=str, default=None, help="Hub model repo id.")
35
+ parser.add_argument("--push-to-hub", action="store_true", help="Upload merged model to Hub.")
36
+ parser.add_argument("--private", action="store_true", help="Create private repo on Hub.")
37
+ parser.add_argument(
38
+ "--commit-message",
39
+ type=str,
40
+ default="Upload merged DeepSeek-Math conjecture model.",
41
+ )
42
+ parser.add_argument(
43
+ "--credentials-path",
44
+ type=Path,
45
+ default=Path("huggingface-api-key.json"),
46
+ help="Path to JSON credentials with {username, key}.",
47
+ )
48
+ parser.add_argument(
49
+ "--max-shard-size",
50
+ type=str,
51
+ default="5GB",
52
+ help="Shard size passed to save_pretrained.",
53
+ )
54
+ parser.add_argument(
55
+ "--trust-remote-code",
56
+ action="store_true",
57
+ help="Enable trust_remote_code for tokenizer/model loading.",
58
+ )
59
+ parser.add_argument(
60
+ "--bf16",
61
+ action="store_true",
62
+ help="Load adapter in bfloat16 before merge (default float16).",
63
+ )
64
+ return parser.parse_args()
65
+
66
+
67
+ def as_text(value: object) -> str:
68
+ if value is None:
69
+ return ""
70
+ if isinstance(value, str):
71
+ return value.strip()
72
+ return str(value).strip()
73
+
74
+
75
+ def resolve_auth(credentials_path: Path) -> Tuple[Optional[str], Optional[str]]:
76
+ token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
77
+ username = as_text(os.environ.get("HF_USERNAME")) or None
78
+ if credentials_path.exists():
79
+ data = json.loads(credentials_path.read_text(encoding="utf-8"))
80
+ if token is None:
81
+ token = as_text(data.get("key")) or None
82
+ if username is None:
83
+ username = as_text(data.get("username")) or None
84
+ return token, username
85
+
86
+
87
+ def merge_adapter(args: argparse.Namespace) -> None:
88
+ if not args.adapter_path.exists():
89
+ raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}")
90
+
91
+ dtype = torch.bfloat16 if args.bf16 else torch.float16
92
+ model = AutoPeftModelForCausalLM.from_pretrained(
93
+ str(args.adapter_path),
94
+ torch_dtype=dtype,
95
+ device_map="auto",
96
+ trust_remote_code=args.trust_remote_code,
97
+ )
98
+ merged = model.merge_and_unload()
99
+
100
+ tokenizer = AutoTokenizer.from_pretrained(
101
+ str(args.adapter_path),
102
+ trust_remote_code=args.trust_remote_code,
103
+ )
104
+
105
+ args.output_dir.mkdir(parents=True, exist_ok=True)
106
+ merged.save_pretrained(
107
+ str(args.output_dir),
108
+ safe_serialization=True,
109
+ max_shard_size=args.max_shard_size,
110
+ )
111
+ tokenizer.save_pretrained(str(args.output_dir))
112
+
113
+ print(f"Merged model saved to: {args.output_dir}")
114
+
115
+
116
+ def push_merged(args: argparse.Namespace, token: str, repo_id: str) -> None:
117
+ api = HfApi(token=token)
118
+ api.create_repo(repo_id=repo_id, repo_type="model", private=args.private, exist_ok=True)
119
+ api.upload_folder(
120
+ repo_id=repo_id,
121
+ repo_type="model",
122
+ folder_path=str(args.output_dir),
123
+ commit_message=args.commit_message,
124
+ )
125
+ print(f"Pushed merged model to https://huggingface.co/{repo_id}")
126
+
127
+
128
+ def main() -> None:
129
+ args = parse_args()
130
+ merge_adapter(args)
131
+
132
+ if not args.push_to_hub:
133
+ return
134
+
135
+ token, username = resolve_auth(args.credentials_path)
136
+ if token is None:
137
+ raise ValueError("Missing HF token. Set HF_TOKEN or provide credentials JSON.")
138
+ repo_id = as_text(args.repo_id)
139
+ if not repo_id:
140
+ if not username:
141
+ raise ValueError("repo_id missing and username unavailable.")
142
+ repo_id = f"{username}/{args.output_dir.name}"
143
+ push_merged(args, token=token, repo_id=repo_id)
144
+
145
+
146
+ if __name__ == "__main__":
147
+ main()
scripts/train_sft.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Fine-tune DeepSeek-Math models on the conjecture-solution corpus."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Optional, Tuple
11
+
12
+ import torch
13
+ import yaml
14
+ from datasets import Dataset, DatasetDict, load_dataset
15
+ from huggingface_hub import HfApi
16
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
17
+ from transformers import (
18
+ AutoModelForCausalLM,
19
+ AutoTokenizer,
20
+ BitsAndBytesConfig,
21
+ DataCollatorForSeq2Seq,
22
+ Trainer,
23
+ TrainingArguments,
24
+ set_seed,
25
+ )
26
+
27
+ DEFAULT_CONFIG_PATH = Path("model_development/configs/deepseek_math.yaml")
28
+
29
+
30
+ def parse_args() -> argparse.Namespace:
31
+ parser = argparse.ArgumentParser(
32
+ description="Supervised fine-tuning (LoRA/QLoRA) for DeepSeek-Math models."
33
+ )
34
+ parser.add_argument(
35
+ "--config",
36
+ type=Path,
37
+ default=DEFAULT_CONFIG_PATH,
38
+ help="YAML config path.",
39
+ )
40
+ parser.add_argument("--base-model", type=str, default=None, help="Override model.base_model.")
41
+ parser.add_argument("--output-dir", type=Path, default=None, help="Override training.output_dir.")
42
+ parser.add_argument("--max-train-samples", type=int, default=None, help="Optional train subset.")
43
+ parser.add_argument("--max-eval-samples", type=int, default=None, help="Optional eval subset.")
44
+ parser.add_argument("--repo-id", type=str, default=None, help="Override hub.repo_id.")
45
+ parser.add_argument("--push-to-hub", action="store_true", help="Force push enabled.")
46
+ parser.add_argument("--no-push-to-hub", action="store_true", help="Force push disabled.")
47
+ parser.add_argument(
48
+ "--resume-from-checkpoint",
49
+ type=str,
50
+ default=None,
51
+ help="Path to checkpoint for resume.",
52
+ )
53
+ parser.add_argument(
54
+ "--credentials-path",
55
+ type=Path,
56
+ default=None,
57
+ help="Override credentials.path.",
58
+ )
59
+ return parser.parse_args()
60
+
61
+
62
+ def as_text(value: Any) -> str:
63
+ if value is None:
64
+ return ""
65
+ if isinstance(value, str):
66
+ return value.strip()
67
+ return str(value).strip()
68
+
69
+
70
+ def load_config(path: Path) -> Dict[str, Any]:
71
+ if not path.exists():
72
+ raise FileNotFoundError(f"Config not found: {path}")
73
+ cfg = yaml.safe_load(path.read_text(encoding="utf-8"))
74
+ if not isinstance(cfg, dict):
75
+ raise ValueError(f"Invalid config format: {path}")
76
+ for key in ("model", "data", "training"):
77
+ if key not in cfg or not isinstance(cfg[key], dict):
78
+ raise ValueError(f"Config missing section: {key}")
79
+ cfg.setdefault("hub", {})
80
+ cfg.setdefault("credentials", {})
81
+ return cfg
82
+
83
+
84
+ def apply_overrides(cfg: Dict[str, Any], args: argparse.Namespace) -> None:
85
+ if args.base_model:
86
+ cfg["model"]["base_model"] = args.base_model
87
+ if args.output_dir is not None:
88
+ cfg["training"]["output_dir"] = str(args.output_dir)
89
+ if args.max_train_samples is not None:
90
+ cfg["data"]["max_train_samples"] = args.max_train_samples
91
+ if args.max_eval_samples is not None:
92
+ cfg["data"]["max_eval_samples"] = args.max_eval_samples
93
+ if args.repo_id:
94
+ cfg.setdefault("hub", {})["repo_id"] = args.repo_id
95
+ if args.credentials_path is not None:
96
+ cfg.setdefault("credentials", {})["path"] = str(args.credentials_path)
97
+ if args.push_to_hub and args.no_push_to_hub:
98
+ raise ValueError("Cannot set both --push-to-hub and --no-push-to-hub.")
99
+ if args.push_to_hub:
100
+ cfg.setdefault("hub", {})["push_to_hub"] = True
101
+ if args.no_push_to_hub:
102
+ cfg.setdefault("hub", {})["push_to_hub"] = False
103
+
104
+
105
+ def resolve_auth(cfg: Dict[str, Any]) -> Tuple[Optional[str], Optional[str]]:
106
+ token = as_text(os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")) or None
107
+ username = as_text(os.environ.get("HF_USERNAME")) or None
108
+
109
+ cred_path = as_text(cfg.get("credentials", {}).get("path"))
110
+ if cred_path:
111
+ path = Path(cred_path)
112
+ if path.exists():
113
+ data = json.loads(path.read_text(encoding="utf-8"))
114
+ if token is None:
115
+ token = as_text(data.get("key")) or None
116
+ if username is None:
117
+ username = as_text(data.get("username")) or None
118
+ return token, username
119
+
120
+
121
+ def load_raw_datasets(data_cfg: Dict[str, Any]) -> DatasetDict:
122
+ train_path = Path(as_text(data_cfg.get("train_file")))
123
+ valid_path = Path(as_text(data_cfg.get("validation_file")))
124
+ if not train_path.exists():
125
+ raise FileNotFoundError(f"Missing train split: {train_path}")
126
+ if not valid_path.exists():
127
+ raise FileNotFoundError(f"Missing validation split: {valid_path}")
128
+
129
+ files = {"train": str(train_path), "validation": str(valid_path)}
130
+ return load_dataset("parquet", data_files=files)
131
+
132
+
133
+ def maybe_select(dataset: Dataset, max_samples: Optional[int]) -> Dataset:
134
+ if max_samples is None:
135
+ return dataset
136
+ if max_samples <= 0:
137
+ raise ValueError("max_samples must be positive.")
138
+ if max_samples >= len(dataset):
139
+ return dataset
140
+ return dataset.select(range(max_samples))
141
+
142
+
143
+ def stringify_structured(value: Any) -> str:
144
+ if value is None:
145
+ return ""
146
+ if isinstance(value, str):
147
+ text = value.strip()
148
+ if not text:
149
+ return ""
150
+ try:
151
+ parsed = json.loads(text)
152
+ except json.JSONDecodeError:
153
+ return text
154
+ return json.dumps(parsed, ensure_ascii=False, sort_keys=True)
155
+ return json.dumps(value, ensure_ascii=False, sort_keys=True)
156
+
157
+
158
+ def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str:
159
+ prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt"
160
+ prompt = as_text(row.get(prompt_field))
161
+ if not prompt:
162
+ prompt = "Solve the math task."
163
+
164
+ meta_fields = [
165
+ ("task_type", "Task type"),
166
+ ("family", "Family"),
167
+ ("difficulty", "Difficulty"),
168
+ ("source_dataset", "Source"),
169
+ ("status_as_of", "Status as of"),
170
+ ]
171
+ meta_lines = []
172
+ for key, label in meta_fields:
173
+ value = as_text(row.get(key))
174
+ if value:
175
+ meta_lines.append(f"{label}: {value}")
176
+ tags = row.get("topic_tags")
177
+ if isinstance(tags, list) and tags:
178
+ tag_text = ", ".join(as_text(tag) for tag in tags if as_text(tag))
179
+ if tag_text:
180
+ meta_lines.append(f"Tags: {tag_text}")
181
+
182
+ if not meta_lines:
183
+ return prompt
184
+ return f"{prompt}\n\nMetadata:\n" + "\n".join(meta_lines)
185
+
186
+
187
+ def build_answer_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str:
188
+ target_field = as_text(data_cfg.get("target_field")) or "target"
189
+ final_answer_field = as_text(data_cfg.get("final_answer_field")) or "final_answer"
190
+ proof_field = as_text(data_cfg.get("proof_field")) or "proof_formal"
191
+
192
+ sections = []
193
+ target_text = stringify_structured(row.get(target_field))
194
+ if target_text:
195
+ sections.append(f"Structured target:\n{target_text}")
196
+
197
+ final_answer = stringify_structured(row.get(final_answer_field))
198
+ if final_answer:
199
+ sections.append(f"Final answer:\n{final_answer}")
200
+
201
+ proof_text = stringify_structured(row.get(proof_field))
202
+ if proof_text:
203
+ sections.append(f"Formal proof snippet:\n{proof_text}")
204
+
205
+ if not sections:
206
+ sections.append("No structured target provided.")
207
+ return "\n\n".join(sections).strip()
208
+
209
+
210
+ def build_prompt_text(
211
+ row: Dict[str, Any],
212
+ tokenizer: AutoTokenizer,
213
+ data_cfg: Dict[str, Any],
214
+ ) -> str:
215
+ system_prompt = as_text(data_cfg.get("system_prompt"))
216
+ if not system_prompt:
217
+ system_prompt = (
218
+ "You are a rigorous mathematical reasoning assistant focused on "
219
+ "unsolved conjectures. Produce checkable reasoning."
220
+ )
221
+ user_block = build_user_block(row, data_cfg)
222
+ if getattr(tokenizer, "chat_template", None):
223
+ messages = [
224
+ {"role": "system", "content": system_prompt},
225
+ {"role": "user", "content": user_block},
226
+ ]
227
+ return tokenizer.apply_chat_template(
228
+ messages,
229
+ tokenize=False,
230
+ add_generation_prompt=True,
231
+ )
232
+ return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n"
233
+
234
+
235
+ def tokenize_datasets(
236
+ raw: DatasetDict,
237
+ tokenizer: AutoTokenizer,
238
+ data_cfg: Dict[str, Any],
239
+ ) -> DatasetDict:
240
+ max_len = int(data_cfg.get("max_seq_length", 2048))
241
+ if max_len < 64:
242
+ raise ValueError("data.max_seq_length must be at least 64.")
243
+
244
+ eos = tokenizer.eos_token or ""
245
+ remove_columns = raw["train"].column_names
246
+
247
+ def _tokenize(row: Dict[str, Any]) -> Dict[str, Any]:
248
+ prompt_text = build_prompt_text(row, tokenizer, data_cfg)
249
+ answer_text = build_answer_block(row, data_cfg)
250
+ full_text = f"{prompt_text}{answer_text}{eos}"
251
+
252
+ prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
253
+ full_enc = tokenizer(
254
+ full_text,
255
+ add_special_tokens=False,
256
+ truncation=True,
257
+ max_length=max_len,
258
+ )
259
+ input_ids = full_enc["input_ids"]
260
+ attention_mask = full_enc["attention_mask"]
261
+
262
+ if not input_ids:
263
+ fallback = tokenizer.eos_token_id
264
+ if fallback is None:
265
+ fallback = tokenizer.pad_token_id
266
+ if fallback is None:
267
+ fallback = 0
268
+ input_ids = [fallback]
269
+ attention_mask = [1]
270
+ labels = [fallback]
271
+ return {
272
+ "input_ids": input_ids,
273
+ "attention_mask": attention_mask,
274
+ "labels": labels,
275
+ }
276
+
277
+ prompt_len = min(len(prompt_ids), len(input_ids))
278
+ labels = [-100] * prompt_len + input_ids[prompt_len:]
279
+ if prompt_len >= len(input_ids):
280
+ labels[-1] = input_ids[-1]
281
+
282
+ return {
283
+ "input_ids": input_ids,
284
+ "attention_mask": attention_mask,
285
+ "labels": labels,
286
+ }
287
+
288
+ tokenized = raw.map(
289
+ _tokenize,
290
+ remove_columns=remove_columns,
291
+ desc="Tokenizing prompt/answer pairs",
292
+ )
293
+ tokenized = tokenized.filter(
294
+ lambda row: any(token != -100 for token in row["labels"]),
295
+ desc="Dropping prompt-only rows",
296
+ )
297
+ return tokenized
298
+
299
+
300
+ def build_model_and_tokenizer(
301
+ model_cfg: Dict[str, Any],
302
+ training_cfg: Dict[str, Any],
303
+ ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
304
+ base_model = as_text(model_cfg.get("base_model"))
305
+ if not base_model:
306
+ raise ValueError("model.base_model is required.")
307
+
308
+ use_bf16 = bool(model_cfg.get("use_bf16", True))
309
+ dtype = torch.bfloat16 if use_bf16 else torch.float16
310
+
311
+ tokenizer = AutoTokenizer.from_pretrained(
312
+ base_model,
313
+ trust_remote_code=bool(model_cfg.get("trust_remote_code", False)),
314
+ use_fast=True,
315
+ )
316
+ if tokenizer.pad_token is None:
317
+ tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token
318
+ if tokenizer.pad_token is None:
319
+ tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
320
+
321
+ model_kwargs: Dict[str, Any] = {
322
+ "trust_remote_code": bool(model_cfg.get("trust_remote_code", False)),
323
+ "torch_dtype": dtype,
324
+ }
325
+ attn_impl = as_text(model_cfg.get("attn_implementation"))
326
+ if attn_impl:
327
+ model_kwargs["attn_implementation"] = attn_impl
328
+
329
+ load_in_4bit = bool(model_cfg.get("load_in_4bit", True))
330
+ if load_in_4bit:
331
+ if not torch.cuda.is_available():
332
+ raise RuntimeError("4-bit loading requested but CUDA is not available.")
333
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
334
+ load_in_4bit=True,
335
+ bnb_4bit_quant_type=as_text(model_cfg.get("bnb_4bit_quant_type")) or "nf4",
336
+ bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)),
337
+ bnb_4bit_compute_dtype=dtype,
338
+ )
339
+ model_kwargs["device_map"] = "auto"
340
+
341
+ model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
342
+ if tokenizer.pad_token_id is not None:
343
+ model.config.pad_token_id = tokenizer.pad_token_id
344
+ model.config.use_cache = False
345
+
346
+ if load_in_4bit:
347
+ model = prepare_model_for_kbit_training(
348
+ model,
349
+ use_gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)),
350
+ )
351
+
352
+ lora_cfg = model_cfg.get("lora", {})
353
+ peft_cfg = LoraConfig(
354
+ r=int(lora_cfg.get("r", 64)),
355
+ lora_alpha=int(lora_cfg.get("alpha", 128)),
356
+ lora_dropout=float(lora_cfg.get("dropout", 0.05)),
357
+ bias=as_text(lora_cfg.get("bias")) or "none",
358
+ task_type="CAUSAL_LM",
359
+ target_modules=lora_cfg.get("target_modules"),
360
+ )
361
+ model = get_peft_model(model, peft_cfg)
362
+ model.print_trainable_parameters()
363
+ return model, tokenizer
364
+
365
+
366
+ def build_training_args(
367
+ cfg: Dict[str, Any],
368
+ has_eval_split: bool,
369
+ ) -> TrainingArguments:
370
+ model_cfg = cfg["model"]
371
+ training_cfg = cfg["training"]
372
+
373
+ use_bf16 = bool(model_cfg.get("use_bf16", True))
374
+ output_dir = Path(as_text(training_cfg.get("output_dir")))
375
+ output_dir.mkdir(parents=True, exist_ok=True)
376
+
377
+ return TrainingArguments(
378
+ output_dir=str(output_dir),
379
+ num_train_epochs=float(training_cfg.get("num_train_epochs", 1)),
380
+ per_device_train_batch_size=int(training_cfg.get("per_device_train_batch_size", 1)),
381
+ per_device_eval_batch_size=int(training_cfg.get("per_device_eval_batch_size", 1)),
382
+ gradient_accumulation_steps=int(training_cfg.get("gradient_accumulation_steps", 1)),
383
+ learning_rate=float(training_cfg.get("learning_rate", 2e-5)),
384
+ weight_decay=float(training_cfg.get("weight_decay", 0.0)),
385
+ warmup_ratio=float(training_cfg.get("warmup_ratio", 0.0)),
386
+ lr_scheduler_type=as_text(training_cfg.get("lr_scheduler_type")) or "cosine",
387
+ max_grad_norm=float(training_cfg.get("max_grad_norm", 1.0)),
388
+ gradient_checkpointing=bool(training_cfg.get("gradient_checkpointing", True)),
389
+ logging_steps=int(training_cfg.get("logging_steps", 10)),
390
+ save_steps=int(training_cfg.get("save_steps", 250)),
391
+ save_total_limit=int(training_cfg.get("save_total_limit", 3)),
392
+ dataloader_num_workers=int(training_cfg.get("dataloader_num_workers", 0)),
393
+ seed=int(training_cfg.get("seed", 17)),
394
+ bf16=use_bf16,
395
+ fp16=not use_bf16,
396
+ remove_unused_columns=False,
397
+ report_to="none",
398
+ evaluation_strategy="steps" if has_eval_split else "no",
399
+ eval_steps=int(training_cfg.get("eval_steps", 250)) if has_eval_split else None,
400
+ )
401
+
402
+
403
+ def resolve_repo_id(
404
+ cfg: Dict[str, Any],
405
+ username: Optional[str],
406
+ ) -> Optional[str]:
407
+ repo_id = as_text(cfg.get("hub", {}).get("repo_id"))
408
+ if repo_id:
409
+ return repo_id
410
+ if not username:
411
+ return None
412
+ output_dir = Path(as_text(cfg["training"].get("output_dir")))
413
+ return f"{username}/{output_dir.name}"
414
+
415
+
416
+ def push_output_to_hub(
417
+ output_dir: Path,
418
+ repo_id: str,
419
+ token: str,
420
+ private: bool,
421
+ commit_message: str,
422
+ ) -> None:
423
+ api = HfApi(token=token)
424
+ api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
425
+ api.upload_folder(
426
+ repo_id=repo_id,
427
+ repo_type="model",
428
+ folder_path=str(output_dir),
429
+ commit_message=commit_message,
430
+ )
431
+
432
+
433
+ def save_resolved_config(
434
+ cfg: Dict[str, Any],
435
+ output_dir: Path,
436
+ config_path: Path,
437
+ ) -> None:
438
+ serializable = json.loads(json.dumps(cfg))
439
+ serializable["resolved_from"] = str(config_path)
440
+ out_path = output_dir / "resolved_training_config.json"
441
+ out_path.write_text(json.dumps(serializable, ensure_ascii=True, indent=2), encoding="utf-8")
442
+
443
+
444
+ def main() -> None:
445
+ args = parse_args()
446
+ cfg = load_config(args.config)
447
+ apply_overrides(cfg, args)
448
+
449
+ training_cfg = cfg["training"]
450
+ seed = int(training_cfg.get("seed", 17))
451
+ set_seed(seed)
452
+
453
+ token, username = resolve_auth(cfg)
454
+ push_to_hub = bool(cfg.get("hub", {}).get("push_to_hub", False))
455
+ repo_id = resolve_repo_id(cfg, username)
456
+ if push_to_hub:
457
+ if token is None:
458
+ raise ValueError(
459
+ "Hub push requested but no token found. Set HF_TOKEN or credentials.path."
460
+ )
461
+ if repo_id is None:
462
+ raise ValueError(
463
+ "Hub push requested but repo_id is empty and username is unavailable."
464
+ )
465
+
466
+ model, tokenizer = build_model_and_tokenizer(cfg["model"], training_cfg)
467
+
468
+ raw = load_raw_datasets(cfg["data"])
469
+ raw["train"] = maybe_select(raw["train"], cfg["data"].get("max_train_samples"))
470
+ raw["validation"] = maybe_select(raw["validation"], cfg["data"].get("max_eval_samples"))
471
+
472
+ tokenized = tokenize_datasets(raw, tokenizer, cfg["data"])
473
+ train_dataset = tokenized["train"]
474
+ eval_dataset = tokenized["validation"] if len(tokenized["validation"]) > 0 else None
475
+
476
+ training_args = build_training_args(cfg, has_eval_split=eval_dataset is not None)
477
+ data_collator = DataCollatorForSeq2Seq(
478
+ tokenizer=tokenizer,
479
+ model=model,
480
+ label_pad_token_id=-100,
481
+ pad_to_multiple_of=8,
482
+ )
483
+
484
+ trainer = Trainer(
485
+ model=model,
486
+ args=training_args,
487
+ train_dataset=train_dataset,
488
+ eval_dataset=eval_dataset,
489
+ tokenizer=tokenizer,
490
+ data_collator=data_collator,
491
+ )
492
+
493
+ train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
494
+ trainer.log_metrics("train", train_result.metrics)
495
+ trainer.save_metrics("train", train_result.metrics)
496
+ trainer.save_state()
497
+
498
+ if eval_dataset is not None:
499
+ eval_metrics = trainer.evaluate()
500
+ trainer.log_metrics("eval", eval_metrics)
501
+ trainer.save_metrics("eval", eval_metrics)
502
+
503
+ trainer.save_model(training_args.output_dir)
504
+ tokenizer.save_pretrained(training_args.output_dir)
505
+
506
+ output_dir = Path(training_args.output_dir)
507
+ save_resolved_config(cfg, output_dir, args.config)
508
+
509
+ if push_to_hub and repo_id is not None and token is not None:
510
+ commit_message = as_text(cfg.get("hub", {}).get("commit_message")) or "Upload fine-tuned model."
511
+ private = bool(cfg.get("hub", {}).get("private", False))
512
+ push_output_to_hub(output_dir, repo_id, token, private, commit_message)
513
+ print(f"Pushed model artifacts to https://huggingface.co/{repo_id}")
514
+
515
+ print(f"Training finished. Output saved to: {output_dir}")
516
+
517
+
518
+ if __name__ == "__main__":
519
+ main()