TorchLLM commited on
Commit
f35adfe
Β·
verified Β·
1 Parent(s): aeba8f5

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Logs
2
+ *.log
3
+ pipeline.log
4
+ training.log
5
+
6
+ # State files
7
+ pipeline_state.json
8
+ *.json.bak
9
+
10
+ # Environment / secrets
11
+ .env
12
+ .env.*
13
+ *.key
14
+ secrets.py
15
+
16
+ # Python cache
17
+ __pycache__/
18
+ *.py[cod]
19
+ *.pyo
20
+ .pytest_cache/
21
+
22
+ # Checkpoints (intermediate, not final)
23
+ checkpoint-*/
24
+ pretrain/
25
+ sft/
26
+ grpo/
27
+ pretrain_model/
28
+ sft_model/
29
+
30
+ # OS
31
+ .DS_Store
32
+ Thumbs.db
33
+
34
+ # Notebooks output
35
+ *.ipynb_checkpoints/
README.md CHANGED
@@ -1,3 +1,153 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - mixture-of-experts
7
+ - mixture-of-recursions
8
+ - causal-lm
9
+ - custom-architecture
10
+ - pytorch
11
+ base_model: Qwen/Qwen2.5-0.5B-Instruct
12
+ pipeline_tag: text-generation
13
+ ---
14
+
15
+ # HybridMoRMoE β€” Hybrid Mixture-of-Recursions & Mixture-of-Experts
16
+
17
+ A custom causal language model combining **Mixture-of-Recursions (MoR)** with **Mixture-of-Experts (MoE)** routing, built from scratch in PyTorch and trained via a three-stage pipeline (pre-training β†’ SFT β†’ GRPO).
18
+
19
+ ---
20
+
21
+ ## Architecture
22
+
23
+ | Hyperparameter | Value |
24
+ |---|---|
25
+ | Model type | `hybrid_mor_moe` |
26
+ | Hidden dim (`d_model`) | 576 |
27
+ | Feed-forward dim (`d_ff`) | 1536 |
28
+ | Attention heads | 8 |
29
+ | Base layers | 6 |
30
+ | Shared recursive blocks | 6 |
31
+ | Unique last layers | 2 |
32
+ | Total transformer depth | 30 |
33
+ | Number of experts | 4 |
34
+ | Experts per token | 1 |
35
+ | Max recursions | 3 |
36
+ | Router percentile | 0.70 |
37
+ | Sequence length | 4096 |
38
+ | Vocabulary size | 151,665 |
39
+ | Tokenizer | Qwen2Tokenizer (Qwen2.5 compatible) |
40
+
41
+ **Key design choices:**
42
+ - Shared weight blocks are recursively applied based on a learned complexity score
43
+ - A per-token MoE router selects which expert processes each position
44
+ - Auxiliary routing loss (`router_aux_loss_coef = 1e-4`) encourages load balance
45
+ - Chat template follows the ChatML (`<|im_start|>` / `<|im_end|>`) format
46
+
47
+ ---
48
+
49
+ ## Training Pipeline
50
+
51
+ The model was trained in three sequential stages on a single NVIDIA P100 (16 GB HBM2):
52
+
53
+ | Stage | Method | Notes |
54
+ |---|---|---|
55
+ | 1 | **Pre-training** | Causal LM on open-domain text |
56
+ | 2 | **SFT** (Supervised Fine-Tuning) | Instruction following with packing |
57
+ | 3 | **GRPO** (Group Relative Policy Optimisation) | Reinforcement learning from preference signal |
58
+
59
+ Training used FP16 precision throughout (P100 has no BF16 support).
60
+
61
+ ---
62
+
63
+ ## Usage
64
+
65
+ Because this model uses a **custom architecture** not registered in the Hugging Face Transformers library by default, you must load the modelling code alongside the weights.
66
+
67
+ ### Quick inference
68
+
69
+ ```python
70
+ import torch
71
+ from transformers import AutoTokenizer
72
+
73
+ # 1. Clone / download this repo
74
+ # 2. Make sure hybrid_mor_moe_training.py is on your Python path
75
+ # (it registers HybridMoRMoEForCausalLM & HybridMoRMoEConfig with AutoModel)
76
+
77
+ from hybrid_mor_moe_training import HybridMoRMoEConfig, HybridMoRMoEForCausalLM
78
+
79
+ model_path = "TorchLLM/HybridMoRMoE" # or local path
80
+
81
+ config = HybridMoRMoEConfig.from_pretrained(model_path)
82
+ model = HybridMoRMoEForCausalLM.from_pretrained(model_path, config=config)
83
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
84
+
85
+ model.eval()
86
+ device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ model.to(device)
88
+
89
+ messages = [
90
+ {"role": "user", "content": "Explain the difference between MoE and dense transformers."}
91
+ ]
92
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93
+ inputs = tokenizer(text, return_tensors="pt").to(device)
94
+
95
+ with torch.no_grad():
96
+ out = model.simple_generate(
97
+ inputs["input_ids"],
98
+ max_new_tokens=256,
99
+ temperature=0.7,
100
+ top_p=0.9,
101
+ eos_token_id=tokenizer.eos_token_id,
102
+ )
103
+
104
+ print(tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
105
+ ```
106
+
107
+ ### Environment setup
108
+
109
+ ```bash
110
+ pip install torch transformers trl datasets accelerate
111
+ ```
112
+
113
+ > **HF_TOKEN**: If you need to access gated datasets during re-training, export your token:
114
+ > ```bash
115
+ > export HF_TOKEN="your_token_here"
116
+ > ```
117
+ > Never hard-code tokens in source files.
118
+
119
+ ---
120
+
121
+ ## Repository Structure
122
+
123
+ ```
124
+ TorchLLM/HybridMoRMoE/
125
+ β”œβ”€β”€ config.json # Model architecture config
126
+ β”œβ”€β”€ generation_config.json # Default generation settings
127
+ β”œβ”€β”€ model.safetensors # Trained weights (SafeTensors format)
128
+ β”œβ”€β”€ tokenizer.json # Tokenizer vocabulary & rules
129
+ β”œβ”€β”€ tokenizer_config.json # Tokenizer metadata
130
+ β”œβ”€β”€ chat_template.jinja # ChatML chat template
131
+ └── hybrid_mor_moe_training.py # Full training pipeline source
132
+ ```
133
+
134
+ ---
135
+
136
+ ## Citation
137
+
138
+ If you use this model or training code in your research, please cite:
139
+
140
+ ```bibtex
141
+ @misc{hybridmormoe2025,
142
+ title = {HybridMoRMoE: Combining Mixture-of-Recursions and Mixture-of-Experts for Efficient Causal LM},
143
+ author = {TorchLLM},
144
+ year = {2025},
145
+ url = {https://huggingface.co/TorchLLM/HybridMoRMoE}
146
+ }
147
+ ```
148
+
149
+ ---
150
+
151
+ ## License
152
+
153
+ Apache 2.0 β€” see [LICENSE](https://www.apache.org/licenses/LICENSE-2.0) for details.
chat_template.jinja ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0]['role'] == 'system' %}
4
+ {{- messages[0]['content'] }}
5
+ {%- else %}
6
+ {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
7
+ {%- endif %}
8
+ {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
9
+ {%- for tool in tools %}
10
+ {{- "\n" }}
11
+ {{- tool | tojson }}
12
+ {%- endfor %}
13
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
14
+ {%- else %}
15
+ {%- if messages[0]['role'] == 'system' %}
16
+ {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
17
+ {%- else %}
18
+ {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
19
+ {%- endif %}
20
+ {%- endif %}
21
+ {%- for message in messages %}
22
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
23
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
24
+ {%- elif message.role == "assistant" %}
25
+ {{- '<|im_start|>' + message.role }}
26
+ {%- if message.content %}
27
+ {{- '\n' + message.content }}
28
+ {%- endif %}
29
+ {%- for tool_call in message.tool_calls %}
30
+ {%- if tool_call.function is defined %}
31
+ {%- set tool_call = tool_call.function %}
32
+ {%- endif %}
33
+ {{- '\n<tool_call>\n{"name": "' }}
34
+ {{- tool_call.name }}
35
+ {{- '", "arguments": ' }}
36
+ {{- tool_call.arguments | tojson }}
37
+ {{- '}\n</tool_call>' }}
38
+ {%- endfor %}
39
+ {{- '<|im_end|>\n' }}
40
+ {%- elif message.role == "tool" %}
41
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
42
+ {{- '<|im_start|>user' }}
43
+ {%- endif %}
44
+ {{- '\n<tool_response>\n' }}
45
+ {{- message.content }}
46
+ {{- '\n</tool_response>' }}
47
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
48
+ {{- '<|im_end|>\n' }}
49
+ {%- endif %}
50
+ {%- endif %}
51
+ {%- endfor %}
52
+ {%- if add_generation_prompt %}
53
+ {{- '<|im_start|>assistant\n' }}
54
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HybridMoRMoEForCausalLM"
4
+ ],
5
+ "complexity_hidden_dim": 64,
6
+ "d_ff": 1536,
7
+ "d_model": 576,
8
+ "dropout": 0.05,
9
+ "dtype": "float32",
10
+ "eos_token_id": 151645,
11
+ "max_recursions": 3,
12
+ "max_seq_len": 4096,
13
+ "model_size": "small",
14
+ "model_type": "hybrid_mor_moe",
15
+ "moe_aux_loss_coef": 0.0001,
16
+ "n_heads": 8,
17
+ "num_base_layers": 6,
18
+ "num_experts": 4,
19
+ "num_experts_per_tok": 1,
20
+ "num_hidden_layers": 30,
21
+ "num_recursions": 3,
22
+ "num_shared_blocks": 6,
23
+ "num_unique_last_layers": 2,
24
+ "pad_token_id": 151643,
25
+ "router_aux_loss_coef": 0.0001,
26
+ "router_percentile": 0.7,
27
+ "transformers_version": "5.4.0",
28
+ "use_cache": false,
29
+ "vocab_size": 151665
30
+ }
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": [
4
+ 151645
5
+ ],
6
+ "output_attentions": false,
7
+ "output_hidden_states": false,
8
+ "pad_token_id": 151643,
9
+ "transformers_version": "5.4.0"
10
+ }
hybrid_mor_moe_training.py ADDED
@@ -0,0 +1,1723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ HybridMoRMoE Full Training Pipeline β€” P100 x1 (SINGLE GPU, 16 GB)
4
+ ==================================================================
5
+ Optimised for NVIDIA P100 (Pascal, compute 6.0, FP16, 16 GB HBM2).
6
+ KEY CHANGES vs T4Γ—2 version
7
+ ────────────────────────────
8
+ β€’ Single-GPU path only β€” no DataParallel / multi-GPU branching.
9
+ β€’ FP16 forced everywhere (P100 has NO BF16 support).
10
+ β€’ Batch size 2 + grad-accum 8 β†’ eff batch 16 (P100 bandwidth > T4).
11
+ β€’ packing=True for SFT & pretrain β†’ ~2Γ— throughput on long-tail data.
12
+ β€’ Data volumes doubled:
13
+ pretrain_max_samples 200 K β†’ 400 K
14
+ sft_max_samples/dom 5 K β†’ 10 K
15
+ grpo_max_dataset 10 K β†’ 20 K
16
+ β€’ dataloader_num_workers 2 β†’ 4 (P100 hosts usually have β‰₯4 cores).
17
+ β€’ Save / eval frequency reduced to cut I/O overhead.
18
+ β€’ Sequence length stays 4096; RotaryEmbedding cache 8192.
19
+ β€’ OOM fallback in GRPO is more aggressive (batch 1, accum 16).
20
+ """
21
+ import gc
22
+ import inspect
23
+ import json
24
+ import logging
25
+ import math
26
+ import os
27
+ import re
28
+ import shutil
29
+ import sys
30
+ import time
31
+ import warnings
32
+ from dataclasses import dataclass
33
+ from typing import Dict, List, Optional
34
+ # ── Unbuffered I/O ──
35
+ os.environ["PYTHONUNBUFFERED"] = "1"
36
+ if hasattr(sys.stdout, "reconfigure"):
37
+ sys.stdout.reconfigure(line_buffering=True)
38
+ sys.stderr.reconfigure(line_buffering=True)
39
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
40
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
41
+ # os.environ["HF_TOKEN"] = os.getenv("HF_TOKEN", "") # Set via environment variable
42
+ os.environ["WANDB_DISABLED"] = "true"
43
+ os.environ["WANDB_MODE"] = "disabled"
44
+ os.environ["OMP_NUM_THREADS"] = "4"
45
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
46
+ warnings.filterwarnings("ignore")
47
+ os.environ["TQDM_DISABLE"] = "1"
48
+ import torch
49
+ import torch.nn as nn
50
+ import torch.nn.functional as F
51
+ import transformers
52
+ from transformers import (
53
+ AutoConfig,
54
+ AutoModelForCausalLM,
55
+ AutoTokenizer,
56
+ GenerationMixin,
57
+ PretrainedConfig,
58
+ PreTrainedModel,
59
+ Trainer,
60
+ TrainerCallback,
61
+ TrainingArguments,
62
+ )
63
+ import trl
64
+ from trl import GRPOConfig, GRPOTrainer, SFTConfig, SFTTrainer
65
+ from datasets import Dataset, load_dataset
66
+ def _ensure_clean_distributed_state():
67
+ try:
68
+ if torch.distributed.is_initialized():
69
+ try:
70
+ torch.distributed.get_world_size()
71
+ return
72
+ except (ValueError, RuntimeError):
73
+ try:
74
+ torch.distributed.destroy_process_group()
75
+ except Exception:
76
+ pass
77
+ except Exception:
78
+ pass
79
+ try:
80
+ from accelerate.state import PartialState
81
+ if hasattr(PartialState, '_shared_state') and PartialState._shared_state:
82
+ PartialState._shared_state.clear()
83
+ except Exception:
84
+ pass
85
+ _ensure_clean_distributed_state()
86
+ IS_KAGGLE = os.path.exists("/kaggle")
87
+ _log_dir = "/kaggle/working/hybrid_mor_moe_P100" if IS_KAGGLE else "./hybrid_mor_moe_P100"
88
+ os.makedirs(_log_dir, exist_ok=True)
89
+ logging.basicConfig(
90
+ level=logging.INFO,
91
+ format="%(asctime)s | %(levelname)-7s | %(message)s",
92
+ datefmt="%Y-%m-%d %H:%M:%S",
93
+ handlers=[
94
+ logging.StreamHandler(sys.stdout),
95
+ logging.FileHandler(os.path.join(_log_dir, "pipeline.log"), mode="a", encoding="utf-8"),
96
+ ],
97
+ )
98
+ logger = logging.getLogger("HybridMoRMoE_P100")
99
+ OUTPUT_STORAGE_LIMIT_GB = 12.0
100
+ OUTPUT_STORAGE_WARN_GB = 9.5
101
+ OUTPUT_STORAGE_CRITICAL_GB = 11.0
102
+ # ════════════════════════════════════════════════════════════════════════════
103
+ # Storage helpers (unchanged logic, lighter logging)
104
+ # ════════════════════════════════════════════════════════════════════════════
105
+ def get_dir_size_gb(path):
106
+ if not os.path.isdir(path):
107
+ return 0.0
108
+ total = 0
109
+ for dirpath, _, filenames in os.walk(path):
110
+ for f in filenames:
111
+ try:
112
+ total += os.path.getsize(os.path.join(dirpath, f))
113
+ except OSError:
114
+ pass
115
+ return total / (1024 ** 3)
116
+ def check_output_storage(output_dir) -> str:
117
+ used = get_dir_size_gb("/kaggle/working" if IS_KAGGLE else output_dir)
118
+ pct = used / OUTPUT_STORAGE_LIMIT_GB * 100
119
+ logger.info(f" Storage: {used:.2f} GB / {OUTPUT_STORAGE_LIMIT_GB:.0f} GB ({pct:.0f}%)")
120
+ if used >= OUTPUT_STORAGE_CRITICAL_GB:
121
+ logger.warning(f" !! CRITICAL: {used:.2f} GB !!")
122
+ return "critical"
123
+ if used >= OUTPUT_STORAGE_WARN_GB:
124
+ logger.warning(f" ! WARNING : {used:.2f} GB")
125
+ return "warn"
126
+ return "ok"
127
+ def _rmdir(path: str, reason: str = ""):
128
+ before = get_dir_size_gb(path)
129
+ shutil.rmtree(path, ignore_errors=True)
130
+ tag = f" [{reason}]" if reason else ""
131
+ logger.info(f" Removed{tag}: {path} (freed ~{before:.2f} GB)")
132
+ def emergency_cleanup(output_dir: str, level: str = "warn"):
133
+ base = "/kaggle/working" if IS_KAGGLE else output_dir
134
+ def _used():
135
+ return get_dir_size_gb(base)
136
+ logger.info(f" [Cleanup/{level}] Starting β€” current usage {_used():.2f} GB")
137
+ for subdir in ["pretrain", "sft", "grpo"]:
138
+ phase_dir = os.path.join(output_dir, subdir)
139
+ if not os.path.isdir(phase_dir):
140
+ continue
141
+ ckpts = sorted(
142
+ [d for d in os.listdir(phase_dir)
143
+ if d.startswith("checkpoint-") and os.path.isdir(os.path.join(phase_dir, d))],
144
+ key=lambda x: int(x.split("-")[-1]),
145
+ )
146
+ for ckpt in ckpts[:-1]:
147
+ _rmdir(os.path.join(phase_dir, ckpt), "T1-old-ckpt")
148
+ if level == "critical" and ckpts:
149
+ _rmdir(os.path.join(phase_dir, ckpts[-1]), "T1-latest-ckpt")
150
+ if _used() < OUTPUT_STORAGE_WARN_GB:
151
+ return "ok"
152
+ sft_done = os.path.isdir(os.path.join(output_dir, "sft_model"))
153
+ grpo_done = os.path.isdir(os.path.join(output_dir, "final_model"))
154
+ if os.path.isdir(p := os.path.join(output_dir, "pretrain_model")) and sft_done:
155
+ _rmdir(p, "T2-pretrain_model")
156
+ if os.path.isdir(p := os.path.join(output_dir, "sft_model")) and grpo_done:
157
+ _rmdir(p, "T2-sft_model")
158
+ if _used() < OUTPUT_STORAGE_WARN_GB:
159
+ return "ok"
160
+ if os.path.isdir(p := os.path.join(output_dir, "best_pretrain")):
161
+ _rmdir(p, "T3-best_pretrain")
162
+ used_after = _used()
163
+ return "ok" if used_after < OUTPUT_STORAGE_WARN_GB else (
164
+ "critical" if used_after >= OUTPUT_STORAGE_CRITICAL_GB else "warn")
165
+ def enforce_storage_limit(output_dir: str, action: str = "save"):
166
+ used = get_dir_size_gb("/kaggle/working" if IS_KAGGLE else output_dir)
167
+ if used >= OUTPUT_STORAGE_LIMIT_GB:
168
+ status = emergency_cleanup(output_dir, level="critical")
169
+ used_after = get_dir_size_gb("/kaggle/working" if IS_KAGGLE else output_dir)
170
+ if used_after >= OUTPUT_STORAGE_LIMIT_GB:
171
+ raise RuntimeError(f"[StorageGate] Cannot {action}: {used_after:.2f} GB used")
172
+ elif used >= OUTPUT_STORAGE_CRITICAL_GB:
173
+ emergency_cleanup(output_dir, level="critical")
174
+ elif used >= OUTPUT_STORAGE_WARN_GB:
175
+ emergency_cleanup(output_dir, level="warn")
176
+ # ════════════════════════════════════════════════════════════════════════════
177
+ # GPU setup β€” single P100 path
178
+ # ════════════════════════════════════════════════════════════════════════════
179
+ def setup_gpu():
180
+ if not torch.cuda.is_available():
181
+ logger.warning("No CUDA device. Running on CPU.")
182
+ return False, 0
183
+ num_gpus = torch.cuda.device_count()
184
+ for i in range(num_gpus):
185
+ props = torch.cuda.get_device_properties(i)
186
+ gpu_name = props.name
187
+ vram_gb = props.total_memory / 1e9
188
+ cc = torch.cuda.get_device_capability(i)
189
+ logger.info(f"GPU {i}: {gpu_name} | VRAM: {vram_gb:.1f} GB | Compute: {cc[0]}.{cc[1]}")
190
+ # P100 = compute 6.0, NO BF16, good FP16 throughput
191
+ torch.backends.cuda.matmul.allow_tf32 = False # Pascal has no TF32
192
+ torch.backends.cudnn.allow_tf32 = False
193
+ torch.backends.cudnn.benchmark = True
194
+ logger.info(f"Precision: FP16 (P100 β€” no BF16) | GPUs visible: {num_gpus}")
195
+ torch.cuda.set_per_process_memory_fraction(0.95, 0)
196
+ torch.cuda.empty_cache()
197
+ gc.collect()
198
+ return True, num_gpus
199
+ # ════════════════════════════════════════════════════════════════════════════
200
+ # Model presets
201
+ # ════════════════════════════════════════════════════════════════════════════
202
+ MODEL_PRESETS = {
203
+ "small": {
204
+ "d_model": 512, "n_heads": 8, "d_ff": 1408,
205
+ "num_base_layers": 4, "num_shared_blocks": 3,
206
+ "num_recursions": 2, "num_unique_last_layers": 1,
207
+ "num_experts": 4, "max_recursions": 2,
208
+ },
209
+ "medium": {
210
+ "d_model": 576, "n_heads": 8, "d_ff": 1536,
211
+ "num_base_layers": 6, "num_shared_blocks": 6,
212
+ "num_recursions": 3, "num_unique_last_layers": 2,
213
+ "num_experts": 4, "max_recursions": 3,
214
+ },
215
+ "large": {
216
+ "d_model": 1536, "n_heads": 16, "d_ff": 4096,
217
+ "num_base_layers": 8, "num_shared_blocks": 8,
218
+ "num_recursions": 3, "num_unique_last_layers": 3,
219
+ "num_experts": 8, "max_recursions": 3,
220
+ },
221
+ }
222
+ # ════════════════════════════════════════════════════════════════════════════
223
+ # Pipeline config β€” P100 optimised defaults
224
+ # ════════════════════════════════════════════════════════════════════════════
225
+ @dataclass
226
+ class PipelineConfig:
227
+ model_size: str = "medium"
228
+ max_seq_len: int = 4096
229
+ dropout: float = 0.05
230
+ num_gpus: int = 1 # ← single P100
231
+ sft_data_dir: str = "/kaggle/input/datasets/abhishek0706/sft-dataset"
232
+ pretrain_corpus: str = "./pretraining_corpus.jsonl"
233
+ tokenizer_path: str = "./hf_assets/tokenizer/Qwen2.5-0.5B-Instruct"
234
+ tokenizer_hf_id: str = "Qwen/Qwen2.5-0.5B-Instruct"
235
+ output_dir: str = "./hybrid_mor_moe_P100"
236
+ # ── Pretrain (doubled data) ──
237
+ pretrain_max_samples: int = 400_000 # was 200 K
238
+ pretrain_max_steps: int = 10_000
239
+ pretrain_batch_size: int = 2 # P100 16 GB can handle bs=2 @ 4096
240
+ pretrain_grad_accum: int = 8 # eff batch = 16
241
+ pretrain_lr: float = 3e-4
242
+ pretrain_warmup_steps: int = 500
243
+ pretrain_weight_decay: float = 0.1
244
+ pretrain_save_steps: int = 2500 # save less often β†’ faster
245
+ pretrain_eval_steps: int = 2500
246
+ pretrain_logging_steps: int = 50
247
+ pretrain_eval_split: float = 0.02
248
+ # ── SFT (doubled data, packing ON) ──
249
+ sft_max_samples_per_domain: int = 10_000 # was 5 K
250
+ sft_max_steps: int = 2000
251
+ sft_batch_size: int = 2
252
+ sft_grad_accum: int = 8 # eff batch = 16
253
+ sft_lr: float = 5e-4
254
+ sft_warmup_steps: int = 200
255
+ sft_weight_decay: float = 0.1
256
+ sft_max_grad_norm: float = 1.0
257
+ sft_save_steps: int = 1000
258
+ sft_eval_steps: int = 500
259
+ sft_logging_steps: int = 25
260
+ sft_eval_split: float = 0.05
261
+ # ── GRPO (doubled data) ──
262
+ grpo_max_steps: int = 1000
263
+ grpo_batch_size: int = 2
264
+ grpo_grad_accum: int = 8 # eff batch = 16
265
+ grpo_lr: float = 5e-6
266
+ grpo_warmup_steps: int = 50
267
+ grpo_weight_decay: float = 0.05
268
+ grpo_max_grad_norm: float = 0.5
269
+ grpo_num_generations: int = 2
270
+ grpo_max_completion_length: int = 192
271
+ grpo_max_prompt_length: int = 128
272
+ grpo_beta: float = 0.04
273
+ grpo_save_steps: int = 500
274
+ grpo_logging_steps: int = 25
275
+ grpo_max_dataset_size: int = 20_000 # was 10 K
276
+ save_total_limit: int = 2
277
+ dataloader_num_workers: int = 4 # was 2
278
+ inference_every_steps: int = 1000
279
+ skip_pretrain: bool = True
280
+ skip_sft: bool = True
281
+ def adjust_config_for_model_size(cfg: PipelineConfig):
282
+ """Tune batch / seq sizes per model preset for P100 16 GB."""
283
+ if cfg.model_size == "large":
284
+ cfg.max_seq_len = 512
285
+ cfg.pretrain_batch_size = 1
286
+ cfg.pretrain_grad_accum = 16
287
+ cfg.sft_batch_size = 1
288
+ cfg.sft_grad_accum = 16
289
+ cfg.grpo_batch_size = 1
290
+ cfg.grpo_grad_accum = 16
291
+ cfg.grpo_num_generations = 2
292
+ cfg.grpo_max_completion_length = 256
293
+ cfg.grpo_max_prompt_length = 256
294
+ elif cfg.model_size == "medium":
295
+ # P100 16 GB VRAM budget for GRPO (294M model):
296
+ # Model FP16: ~600 MB
297
+ # Optimizer FP32: ~2.4 GB
298
+ # Gradients: ~600 MB
299
+ # Base overhead: ~3.6 GB β†’ leaves ~12 GB for activations + logits
300
+ #
301
+ # GRPO scoring forward pass (with grads) over batch Γ— seq Γ— 151K vocab is
302
+ # the bottleneck. Accelerate's convert_to_fp32 doubles logits memory.
303
+ # Keep total tokens LOW: prompt=128 + completion=192 = 320 total.
304
+ cfg.max_seq_len = 4096
305
+ cfg.pretrain_batch_size = 2
306
+ cfg.pretrain_grad_accum = 8
307
+ cfg.sft_batch_size = 2
308
+ cfg.sft_grad_accum = 8
309
+ cfg.grpo_batch_size = 1
310
+ cfg.grpo_grad_accum = 16 # eff batch = 16
311
+ cfg.grpo_num_generations = 2
312
+ cfg.grpo_max_completion_length = 192 # conservative: 128+192=320 total
313
+ cfg.grpo_max_prompt_length = 128
314
+ else: # small β€” more room, but 152K vocab still limits GRPO
315
+ cfg.max_seq_len = 4096
316
+ cfg.pretrain_batch_size = 4
317
+ cfg.pretrain_grad_accum = 4
318
+ cfg.sft_batch_size = 4
319
+ cfg.sft_grad_accum = 4
320
+ cfg.grpo_batch_size = 1
321
+ cfg.grpo_grad_accum = 16
322
+ cfg.grpo_max_completion_length = 256
323
+ cfg.grpo_max_prompt_length = 192
324
+ eff_sft = cfg.sft_batch_size * cfg.sft_grad_accum
325
+ eff_grpo = cfg.grpo_batch_size * cfg.grpo_grad_accum
326
+ logger.info(f"P100 config β€” {cfg.model_size} model: seq={cfg.max_seq_len}")
327
+ logger.info(f" Per-device batch : SFT={cfg.sft_batch_size}, GRPO={cfg.grpo_batch_size}")
328
+ logger.info(f" Grad accum : SFT={cfg.sft_grad_accum}, GRPO={cfg.grpo_grad_accum}")
329
+ logger.info(f" Effective batch : SFT={eff_sft}, GRPO={eff_grpo}")
330
+ logger.info(f" GRPO seq lengths : prompt={cfg.grpo_max_prompt_length}, "
331
+ f"completion={cfg.grpo_max_completion_length}")
332
+ return cfg
333
+ # ════════════════════════════════════════════════════════════════════════════
334
+ # Model Architecture (identical to original β€” kept for self-containedness)
335
+ # ════════════════════════════════════════════════════════════════════════════
336
+ class HybridMoRMoEConfig(PretrainedConfig):
337
+ model_type = "hybrid_mor_moe"
338
+ model_size: str = "medium"
339
+ d_model: int = 576
340
+ n_heads: int = 8
341
+ d_ff: int = 1536
342
+ vocab_size: int = 151936
343
+ max_seq_len: int = 4096
344
+ dropout: float = 0.05
345
+ num_base_layers: int = 4
346
+ num_shared_blocks: int = 4
347
+ num_recursions: int = 2
348
+ max_recursions: int = 2
349
+ num_unique_last_layers: int = 2
350
+ router_percentile: float = 0.7
351
+ num_experts: int = 4
352
+ num_experts_per_tok: int = 1
353
+ router_aux_loss_coef: float = 0.0001
354
+ moe_aux_loss_coef: float = 0.0001
355
+ complexity_hidden_dim: int = 64
356
+ complexity_threshold_easy: float = 0.3
357
+ complexity_threshold_hard: float = 0.7
358
+ think_budget_easy: int = 12
359
+ think_budget_medium: int = 48
360
+ think_budget_hard: int = 96
361
+ def __init__(self, **kwargs):
362
+ model_size = kwargs.get("model_size", "small")
363
+ if model_size in MODEL_PRESETS:
364
+ for k, v in MODEL_PRESETS[model_size].items():
365
+ if k not in kwargs:
366
+ kwargs[k] = v
367
+ super().__init__(**kwargs)
368
+ self.model_size = model_size
369
+ n_rec = min(self.num_recursions, self.max_recursions)
370
+ self.num_hidden_layers = (
371
+ self.num_base_layers
372
+ + n_rec * self.num_shared_blocks
373
+ + n_rec * self.num_unique_last_layers
374
+ )
375
+ class RotaryEmbedding(nn.Module):
376
+ def __init__(self, dim, max_seq_len=8192):
377
+ super().__init__()
378
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
379
+ self.register_buffer("inv_freq", inv_freq)
380
+ self._set_cos_sin_cache(max_seq_len)
381
+ def _set_cos_sin_cache(self, seq_len):
382
+ t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
383
+ freqs = torch.outer(t, self.inv_freq)
384
+ emb = torch.cat((freqs, freqs), dim=-1)
385
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
386
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
387
+ self.max_seq_len_cached = seq_len
388
+ def forward(self, seq_len, device):
389
+ if seq_len > self.max_seq_len_cached:
390
+ self._set_cos_sin_cache(seq_len)
391
+ return self.cos_cached[:seq_len].to(device), self.sin_cached[:seq_len].to(device)
392
+ def apply_rotary_emb(q, k, cos, sin):
393
+ def rotate_half(x):
394
+ x1, x2 = x.chunk(2, dim=-1)
395
+ return torch.cat((-x2, x1), dim=-1)
396
+ seq_len = q.shape[2]
397
+ cos = cos[:seq_len].unsqueeze(0).unsqueeze(0)
398
+ sin = sin[:seq_len].unsqueeze(0).unsqueeze(0)
399
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
400
+ class MultiHeadAttention(nn.Module):
401
+ def __init__(self, d_model, n_heads, dropout=0.05):
402
+ super().__init__()
403
+ self.n_heads = n_heads
404
+ self.d_k = d_model // n_heads
405
+ self.d_model = d_model
406
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
407
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
408
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
409
+ self.o_proj = nn.Linear(d_model, d_model, bias=False)
410
+ self.attn_dropout_p = dropout
411
+ def forward(self, x, mask=None, cos=None, sin=None, past_key_value=None, use_cache=False):
412
+ B, T, C = x.shape
413
+ q = self.q_proj(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
414
+ k = self.k_proj(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
415
+ v = self.v_proj(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
416
+ if cos is not None and sin is not None:
417
+ q, k = apply_rotary_emb(q, k, cos, sin)
418
+ if past_key_value is not None:
419
+ past_k, past_v = past_key_value
420
+ k = torch.cat([past_k, k], dim=2)
421
+ v = torch.cat([past_v, v], dim=2)
422
+ new_cache = (k, v) if use_cache else None
423
+ dropout_p = self.attn_dropout_p if self.training else 0.0
424
+ attn_out = F.scaled_dot_product_attention(
425
+ q, k, v, attn_mask=None, dropout_p=dropout_p,
426
+ is_causal=(past_key_value is None),
427
+ )
428
+ output = self.o_proj(attn_out.transpose(1, 2).contiguous().view(B, T, C))
429
+ return output, new_cache
430
+ class Expert(nn.Module):
431
+ def __init__(self, d_model, d_ff, dropout=0.05):
432
+ super().__init__()
433
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
434
+ self.w3 = nn.Linear(d_model, d_ff, bias=False)
435
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
436
+ self.dropout = nn.Dropout(dropout)
437
+ def forward(self, x):
438
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
439
+ class MoELayer(nn.Module):
440
+ def __init__(self, d_model, d_ff, num_experts, top_k, dropout=0.05):
441
+ super().__init__()
442
+ self.num_experts, self.top_k = num_experts, top_k
443
+ self.experts = nn.ModuleList([Expert(d_model, d_ff, dropout) for _ in range(num_experts)])
444
+ self.gate = nn.Linear(d_model, num_experts, bias=False)
445
+ def forward(self, x):
446
+ B, T, C = x.shape
447
+ xf = x.reshape(-1, C)
448
+ gp = F.softmax(self.gate(xf), dim=-1)
449
+ tp, ti = torch.topk(gp, self.top_k, dim=-1)
450
+ tp = tp / (tp.sum(dim=-1, keepdim=True) + 1e-8)
451
+ out = torch.zeros_like(xf)
452
+ for i in range(self.num_experts):
453
+ m = (ti == i).any(dim=-1)
454
+ if m.any():
455
+ eo = self.experts[i](xf[m])
456
+ w = (tp[m] * (ti[m] == i).float()).sum(dim=-1, keepdim=True)
457
+ out[m] += w * eo
458
+ aux_loss = (gp.mean(0) ** 2).sum() * self.num_experts
459
+ return out.view(B, T, C), aux_loss
460
+ class PercentileRouter(nn.Module):
461
+ def __init__(self, d_model, percentile=0.7):
462
+ super().__init__()
463
+ self.percentile = percentile
464
+ self.router = nn.Linear(d_model, 1)
465
+ def forward(self, x, mask=None):
466
+ device = x.device
467
+ raw = self.router(x).squeeze(-1).clamp(-50.0, 50.0)
468
+ scores = torch.softmax(raw, dim=-1)
469
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
470
+ scores = torch.where(
471
+ torch.isnan(scores) | torch.isinf(scores),
472
+ torch.ones_like(scores) / max(scores.shape[-1], 1), scores,
473
+ )
474
+ if mask is not None:
475
+ am = mask.bool().to(device)
476
+ if am.shape != scores.shape:
477
+ if am.shape[0] == scores.shape[0] and am.shape[-1] >= scores.shape[-1]:
478
+ am = am[..., -scores.shape[-1]:]
479
+ else:
480
+ am = torch.ones_like(scores, dtype=torch.bool, device=device)
481
+ else:
482
+ am = torch.ones_like(scores, dtype=torch.bool, device=device)
483
+ active = scores[am]
484
+ if active.numel() > 0:
485
+ thr = torch.quantile(active.float(), self.percentile)
486
+ sel = (scores >= thr) & am
487
+ else:
488
+ sel = am
489
+ zl = torch.logsumexp(scores[am].float(), dim=0) ** 2 if am.any() else torch.tensor(0.0, device=device)
490
+ return sel, scores, zl
491
+ class TransformerBlock(nn.Module):
492
+ def __init__(self, d_model, n_heads, d_ff, dropout, use_moe=False, num_experts=8, top_k=2):
493
+ super().__init__()
494
+ self.use_moe = use_moe
495
+ self.ln1 = nn.RMSNorm(d_model)
496
+ self.attn = MultiHeadAttention(d_model, n_heads, dropout)
497
+ self.ln2 = nn.RMSNorm(d_model)
498
+ if use_moe:
499
+ self.ffn = MoELayer(d_model, d_ff, num_experts, top_k, dropout)
500
+ else:
501
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
502
+ self.w3 = nn.Linear(d_model, d_ff, bias=False)
503
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
504
+ self.ffn_dropout = nn.Dropout(dropout)
505
+ def _ffn(self, x):
506
+ if self.use_moe:
507
+ return self.ffn(self.ln2(x))
508
+ else:
509
+ h = self.ln2(x)
510
+ return self.w2(self.ffn_dropout(F.silu(self.w1(h)) * self.w3(h))), None
511
+ def forward(self, x, mask=None, cos=None, sin=None, past_key_value=None, use_cache=False):
512
+ attn_out, new_cache = self.attn(self.ln1(x), mask, cos, sin, past_key_value, use_cache)
513
+ x = x + attn_out
514
+ fo, ml = self._ffn(x)
515
+ return x + fo, ml, new_cache
516
+ class ComplexityScorer(nn.Module):
517
+ def __init__(self, d_model, hidden_dim=128):
518
+ super().__init__()
519
+ self.pool_proj = nn.Linear(d_model, hidden_dim)
520
+ self.scorer = nn.Sequential(
521
+ nn.RMSNorm(hidden_dim), nn.GELU(),
522
+ nn.Linear(hidden_dim, hidden_dim), nn.GELU(),
523
+ nn.Linear(hidden_dim, 1),
524
+ )
525
+ def forward(self, hidden_states, attention_mask=None):
526
+ if attention_mask is not None:
527
+ m = attention_mask.unsqueeze(-1).float()
528
+ pooled = (hidden_states * m).sum(1) / m.sum(1).clamp(min=1)
529
+ else:
530
+ pooled = hidden_states.mean(dim=1)
531
+ return torch.sigmoid(self.scorer(self.pool_proj(pooled)).squeeze(-1))
532
+ class HybridMoRMoEForCausalLM(PreTrainedModel, GenerationMixin):
533
+ config_class = HybridMoRMoEConfig
534
+ base_model_prefix = "model"
535
+ supports_gradient_checkpointing = True
536
+ _supports_sdpa = False
537
+ _no_split_modules = []
538
+ @classmethod
539
+ def _can_set_experts_implementation(cls) -> bool:
540
+ return False
541
+ def __init__(self, config: HybridMoRMoEConfig):
542
+ super().__init__(config)
543
+ self.config = config
544
+ self.gradient_checkpointing = False
545
+ self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
546
+ self.rotary_emb = RotaryEmbedding(config.d_model // config.n_heads, config.max_seq_len * 2)
547
+ self.base_layers = nn.ModuleList([
548
+ TransformerBlock(config.d_model, config.n_heads, config.d_ff, config.dropout, False)
549
+ for _ in range(config.num_base_layers)
550
+ ])
551
+ self.shared_blocks = nn.ModuleList([
552
+ TransformerBlock(config.d_model, config.n_heads, config.d_ff, config.dropout,
553
+ True, config.num_experts, config.num_experts_per_tok)
554
+ for _ in range(config.num_shared_blocks)
555
+ ])
556
+ self.routers = nn.ModuleList([
557
+ PercentileRouter(config.d_model, config.router_percentile)
558
+ for _ in range(config.num_recursions)
559
+ ])
560
+ self.unique_last_layers = nn.ModuleList([
561
+ nn.ModuleList([
562
+ TransformerBlock(config.d_model, config.n_heads, config.d_ff, config.dropout, False)
563
+ for _ in range(config.num_unique_last_layers)
564
+ ])
565
+ for _ in range(config.num_recursions)
566
+ ])
567
+ self.complexity_scorer = ComplexityScorer(config.d_model, config.complexity_hidden_dim)
568
+ self.ln_f = nn.RMSNorm(config.d_model)
569
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
570
+ self._num_kv_layers = self._count_kv_layers()
571
+ self.post_init()
572
+ def _count_kv_layers(self):
573
+ count = len(self.base_layers)
574
+ n_rec = min(self.config.num_recursions, len(self.routers))
575
+ for ri in range(n_rec):
576
+ count += len(self.shared_blocks) + len(self.unique_last_layers[ri])
577
+ return count
578
+ def _set_gradient_checkpointing(self, enable=True, gradient_checkpointing_func=None):
579
+ self.gradient_checkpointing = enable
580
+ def _init_weights(self, module):
581
+ std = 0.02
582
+ if isinstance(module, nn.Linear):
583
+ torch.nn.init.normal_(module.weight, std=std)
584
+ if module.bias is not None:
585
+ torch.nn.init.zeros_(module.bias)
586
+ elif isinstance(module, nn.Embedding):
587
+ torch.nn.init.normal_(module.weight, std=std)
588
+ def get_input_embeddings(self):
589
+ return self.token_embedding
590
+ def set_input_embeddings(self, value):
591
+ self.token_embedding = value
592
+ def get_output_embeddings(self):
593
+ return self.lm_head
594
+ def set_output_embeddings(self, new_embeddings):
595
+ self.lm_head = new_embeddings
596
+ def forward(self, input_ids=None, attention_mask=None, labels=None,
597
+ past_key_values=None, use_cache=False, return_dict=True, **kwargs):
598
+ from transformers.cache_utils import DynamicCache
599
+ device = input_ids.device
600
+ B, seq_len = input_ids.shape
601
+ input_ids = input_ids.clamp(0, self.config.vocab_size - 1)
602
+ x = self.token_embedding(input_ids)
603
+ _input_is_dynamic_cache = isinstance(past_key_values, DynamicCache)
604
+ if _input_is_dynamic_cache:
605
+ if past_key_values.get_seq_length() > 0:
606
+ past_key_values = [
607
+ (past_key_values.key_cache[i], past_key_values.value_cache[i])
608
+ for i in range(len(past_key_values.key_cache))
609
+ ]
610
+ else:
611
+ past_key_values = None
612
+ past_length = 0
613
+ if (past_key_values is not None and isinstance(past_key_values, (list, tuple))
614
+ and len(past_key_values) > 0 and past_key_values[0] is not None):
615
+ past_length = past_key_values[0][0].shape[2]
616
+ total_len = past_length + seq_len
617
+ cos, sin = self.rotary_emb(total_len, device)
618
+ cos = cos[past_length:total_len]
619
+ sin = sin[past_length:total_len]
620
+ new_past_key_values = []
621
+ layer_idx = 0
622
+ use_ckpt = self.gradient_checkpointing and self.training and not use_cache
623
+ for layer in self.base_layers:
624
+ past_kv = past_key_values[layer_idx] if past_key_values and layer_idx < len(past_key_values) else None
625
+ if use_ckpt:
626
+ x, _, new_cache = torch.utils.checkpoint.checkpoint(
627
+ layer, x, attention_mask, cos, sin, past_kv, use_cache, use_reentrant=False)
628
+ else:
629
+ x, _, new_cache = layer(x, attention_mask, cos, sin, past_kv, use_cache)
630
+ new_past_key_values.append(new_cache)
631
+ layer_idx += 1
632
+ router_losses, moe_losses = [], []
633
+ n_rec = min(self.config.num_recursions, len(self.routers))
634
+ for ri in range(n_rec):
635
+ sel, _, zl = self.routers[ri](x, attention_mask)
636
+ router_losses.append(zl)
637
+ for blk in self.shared_blocks:
638
+ past_kv = past_key_values[layer_idx] if past_key_values and layer_idx < len(past_key_values) else None
639
+ if use_ckpt:
640
+ xb, ml, new_cache = torch.utils.checkpoint.checkpoint(
641
+ blk, x, attention_mask, cos, sin, past_kv, use_cache, use_reentrant=False)
642
+ else:
643
+ xb, ml, new_cache = blk(x, attention_mask, cos, sin, past_kv, use_cache)
644
+ x = torch.where(sel.unsqueeze(-1), xb, x)
645
+ new_past_key_values.append(new_cache)
646
+ layer_idx += 1
647
+ if ml is not None:
648
+ moe_losses.append(ml)
649
+ for layer in self.unique_last_layers[ri]:
650
+ past_kv = past_key_values[layer_idx] if past_key_values and layer_idx < len(past_key_values) else None
651
+ if use_ckpt:
652
+ x, _, new_cache = torch.utils.checkpoint.checkpoint(
653
+ layer, x, attention_mask, cos, sin, past_kv, use_cache, use_reentrant=False)
654
+ else:
655
+ x, _, new_cache = layer(x, attention_mask, cos, sin, past_kv, use_cache)
656
+ new_past_key_values.append(new_cache)
657
+ layer_idx += 1
658
+ x = self.ln_f(x)
659
+ logits = self.lm_head(x)
660
+ # In-place cleanup β€” avoids allocating copies of the huge logits tensor
661
+ logits.nan_to_num_(nan=0.0, posinf=100.0, neginf=-100.0)
662
+ logits.clamp_(-100.0, 100.0)
663
+ loss = None
664
+ if labels is not None:
665
+ cl = labels.clone()
666
+ v = cl != -100
667
+ cl[v] = cl[v].clamp(0, self.config.vocab_size - 1)
668
+ sl = logits[..., :-1, :].contiguous()
669
+ tl = cl[..., 1:].contiguous()
670
+ loss = F.cross_entropy(sl.view(-1, sl.size(-1)), tl.view(-1), ignore_index=-100)
671
+ if router_losses:
672
+ loss = loss + self.config.router_aux_loss_coef * torch.stack(router_losses).mean()
673
+ if moe_losses:
674
+ loss = loss + self.config.moe_aux_loss_coef * torch.stack(moe_losses).mean()
675
+ output_cache = tuple(new_past_key_values) if use_cache else None
676
+ if return_dict:
677
+ from transformers.modeling_outputs import CausalLMOutputWithPast
678
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=output_cache)
679
+ return (loss, logits, output_cache) if loss is not None else (logits, output_cache)
680
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
681
+ from transformers.cache_utils import DynamicCache
682
+ has_past = False
683
+ if past_key_values is not None:
684
+ if isinstance(past_key_values, DynamicCache):
685
+ has_past = past_key_values.get_seq_length() > 0
686
+ elif isinstance(past_key_values, (list, tuple)) and len(past_key_values) > 0:
687
+ has_past = past_key_values[0] is not None
688
+ if has_past:
689
+ input_ids = input_ids[:, -1:]
690
+ return {"input_ids": input_ids, "attention_mask": attention_mask,
691
+ "past_key_values": past_key_values, "use_cache": True}
692
+ @torch.no_grad()
693
+ def simple_generate(self, input_ids, max_new_tokens=256, temperature=0.7,
694
+ top_k=50, top_p=0.9, pad_token_id=0, eos_token_id=None, use_cache=True):
695
+ self.eval()
696
+ gen_model = self.module if hasattr(self, 'module') else self
697
+ generated = input_ids.clone()
698
+ past_key_values = None
699
+ for _ in range(max_new_tokens):
700
+ current_input = generated[:, -1:] if (past_key_values is not None and use_cache) else generated
701
+ outputs = gen_model.forward(current_input, past_key_values=past_key_values,
702
+ use_cache=use_cache, return_dict=True)
703
+ if use_cache:
704
+ past_key_values = outputs.past_key_values
705
+ next_logits = outputs.logits[:, -1, :].float() / max(temperature, 1e-8)
706
+ if top_k > 0:
707
+ v, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
708
+ next_logits[next_logits < v[..., -1, None]] = float("-inf")
709
+ if top_p < 1.0:
710
+ sorted_logits, sorted_idx = torch.sort(next_logits, descending=True)
711
+ cumsum = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
712
+ remove = cumsum > top_p
713
+ remove[..., 1:] = remove[..., :-1].clone()
714
+ remove[..., 0] = 0
715
+ next_logits[remove.scatter(1, sorted_idx, remove)] = float("-inf")
716
+ probs = F.softmax(next_logits, dim=-1).clamp(min=0.0)
717
+ if torch.isnan(probs).any() or probs.sum(dim=-1).min() < 1e-8:
718
+ probs = torch.ones_like(probs) / probs.shape[-1]
719
+ next_token = torch.multinomial(probs, num_samples=1)
720
+ generated = torch.cat([generated, next_token], dim=1)
721
+ if eos_token_id is not None and (next_token == eos_token_id).all():
722
+ break
723
+ self.train()
724
+ return generated
725
+ AutoConfig.register("hybrid_mor_moe", HybridMoRMoEConfig)
726
+ AutoModelForCausalLM.register(HybridMoRMoEConfig, HybridMoRMoEForCausalLM)
727
+ setattr(transformers, "HybridMoRMoEForCausalLM", HybridMoRMoEForCausalLM)
728
+ # ════════════════════════════════════════════════════════════════════════════
729
+ # Checkpoint Utilities
730
+ # ════════════════════════════════════════════════════════════════════════════
731
+ def find_latest_checkpoint(output_dir):
732
+ if not os.path.isdir(output_dir):
733
+ return None
734
+ checkpoints = [d for d in os.listdir(output_dir)
735
+ if d.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, d))]
736
+ if not checkpoints:
737
+ return None
738
+ checkpoints.sort(key=lambda x: int(x.split("-")[-1]))
739
+ latest = os.path.join(output_dir, checkpoints[-1])
740
+ logger.info(f" Found checkpoint to resume from: {latest}")
741
+ return latest
742
+ def cleanup_checkpoints(output_dir, keep_last=0):
743
+ if not os.path.isdir(output_dir):
744
+ return
745
+ checkpoints = sorted(
746
+ [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")
747
+ and os.path.isdir(os.path.join(output_dir, d))],
748
+ key=lambda x: int(x.split("-")[-1]),
749
+ )
750
+ to_remove = checkpoints[:-keep_last] if keep_last > 0 else checkpoints
751
+ for ckpt in to_remove:
752
+ path = os.path.join(output_dir, ckpt)
753
+ shutil.rmtree(path, ignore_errors=True)
754
+ logger.info(f" Cleaned up checkpoint: {path}")
755
+ # ════════════════════════════════════════════════════════════════════════════
756
+ # Robust Model Loading
757
+ # ════════════════════════════════════════════════════════════════════════════
758
+ def load_checkpoint_robust(config, checkpoint_dir, device="cpu"):
759
+ from safetensors.torch import load_file as safetensors_load
760
+ model = HybridMoRMoEForCausalLM(config)
761
+ if not os.path.isdir(checkpoint_dir):
762
+ raise FileNotFoundError(f"Checkpoint dir not found: {checkpoint_dir}")
763
+ sf_files = sorted(f for f in os.listdir(checkpoint_dir) if f.endswith(".safetensors"))
764
+ if sf_files:
765
+ ckpt_state = {}
766
+ for sf in sf_files:
767
+ ckpt_state.update(safetensors_load(os.path.join(checkpoint_dir, sf), device="cpu"))
768
+ else:
769
+ pt_bin = os.path.join(checkpoint_dir, "pytorch_model.bin")
770
+ if not os.path.isfile(pt_bin):
771
+ raise FileNotFoundError(f"No .safetensors or pytorch_model.bin in {checkpoint_dir}")
772
+ ckpt_state = torch.load(pt_bin, map_location="cpu", weights_only=False)
773
+ model_state = model.state_dict()
774
+ loaded, skipped_unexpected, partial_loaded = 0, 0, 0
775
+ for key, ckpt_param in ckpt_state.items():
776
+ if key not in model_state:
777
+ skipped_unexpected += 1
778
+ continue
779
+ model_param = model_state[key]
780
+ if ckpt_param.shape == model_param.shape:
781
+ model_state[key] = ckpt_param
782
+ loaded += 1
783
+ else:
784
+ slices = tuple(
785
+ slice(0, min(cs, ms))
786
+ for cs, ms in zip(ckpt_param.shape, model_param.shape)
787
+ )
788
+ model_state[key][slices] = ckpt_param[slices]
789
+ partial_loaded += 1
790
+ logger.info(f" [load] Partial copy {key}: ckpt={list(ckpt_param.shape)} β†’ model={list(model_param.shape)}")
791
+ missing = [k for k in model_state if k not in ckpt_state]
792
+ model.load_state_dict(model_state, strict=True)
793
+ logger.info(f" [load] Loaded: {loaded} | Partial: {partial_loaded} | "
794
+ f"Unexpected (skipped): {skipped_unexpected} | Missing (random init): {len(missing)}")
795
+ model.to(device)
796
+ return model
797
+ # ════════════════════════════════════════════════════════════════════════════
798
+ # Pipeline State Manager
799
+ # ════════════════════════════════════════════════════════════════════════════
800
+ class PipelineStateManager:
801
+ def __init__(self, output_dir: str):
802
+ self.path = os.path.join(output_dir, "pipeline_state.json")
803
+ self._state = self._load()
804
+ def _load(self) -> dict:
805
+ if os.path.exists(self.path):
806
+ try:
807
+ with open(self.path, "r") as f:
808
+ state = json.load(f)
809
+ logger.info(f" [Checkpoint] Loaded pipeline state: completed={state.get('completed_phases', [])}")
810
+ return state
811
+ except Exception as e:
812
+ logger.warning(f" [Checkpoint] Could not read pipeline_state.json: {e}")
813
+ return {"completed_phases": [], "best_eval_loss": {}, "phase_steps": {}}
814
+ def _save(self):
815
+ self._state["last_updated"] = time.strftime("%Y-%m-%d %H:%M:%S")
816
+ tmp = self.path + ".tmp"
817
+ with open(tmp, "w") as f:
818
+ json.dump(self._state, f, indent=2)
819
+ os.replace(tmp, self.path)
820
+ def mark_complete(self, phase: str, best_eval_loss: float = None, steps: int = None):
821
+ if phase not in self._state["completed_phases"]:
822
+ self._state["completed_phases"].append(phase)
823
+ if best_eval_loss is not None:
824
+ self._state["best_eval_loss"][phase] = round(best_eval_loss, 6)
825
+ if steps is not None:
826
+ self._state["phase_steps"][phase] = steps
827
+ self._save()
828
+ logger.info(f" [Checkpoint] Phase '{phase}' marked complete")
829
+ def is_complete(self, phase: str) -> bool:
830
+ return phase in self._state["completed_phases"]
831
+ def get_best_loss(self, phase: str) -> Optional[float]:
832
+ return self._state["best_eval_loss"].get(phase)
833
+ def summary(self) -> str:
834
+ done = self._state.get("completed_phases", [])
835
+ losses = self._state.get("best_eval_loss", {})
836
+ parts = []
837
+ for p in done:
838
+ l = losses.get(p)
839
+ parts.append(f"{p}(loss={l:.4f})" if l else p)
840
+ return "Completed: " + (", ".join(parts) if parts else "none")
841
+ # ════════════════════════════════════════════════════════════════════════════
842
+ # Callbacks
843
+ # ════════════════════════════════════════════════════════════════════════════
844
+ class BestModelCallback(TrainerCallback):
845
+ def __init__(self, output_dir: str, phase: str, tokenizer):
846
+ self.best_dir = os.path.join(output_dir, f"best_{phase}")
847
+ self.phase = phase
848
+ self.tokenizer = tokenizer
849
+ self.best_loss = float("inf")
850
+ def on_evaluate(self, args, state, control, model=None, metrics=None, **kwargs):
851
+ if metrics is None or model is None:
852
+ return
853
+ eval_loss = metrics.get("eval_loss")
854
+ if eval_loss is None:
855
+ return
856
+ if eval_loss < self.best_loss:
857
+ self.best_loss = eval_loss
858
+ save_model = model.module if hasattr(model, "module") else model
859
+ save_model.save_pretrained(self.best_dir)
860
+ self.tokenizer.save_pretrained(self.best_dir)
861
+ meta = {"step": state.global_step, "eval_loss": round(eval_loss, 6),
862
+ "phase": self.phase, "saved_at": time.strftime("%Y-%m-%d %H:%M:%S")}
863
+ with open(os.path.join(self.best_dir, "best_checkpoint_meta.json"), "w") as f:
864
+ json.dump(meta, f, indent=2)
865
+ logger.info(f" [BestModel/{self.phase}] New best eval_loss={eval_loss:.4f} @ step {state.global_step}")
866
+ class StorageMonitorCallback(TrainerCallback):
867
+ def __init__(self, output_dir, check_every_steps=200):
868
+ self.output_dir = output_dir
869
+ self.check_every_steps = check_every_steps
870
+ def on_step_end(self, args, state, control, **kwargs):
871
+ if state.global_step <= 0 or state.global_step % self.check_every_steps != 0:
872
+ return
873
+ status = check_output_storage(self.output_dir)
874
+ if status == "critical":
875
+ result = emergency_cleanup(self.output_dir, level="critical")
876
+ if result == "critical":
877
+ control.should_training_stop = True
878
+ elif status == "warn":
879
+ emergency_cleanup(self.output_dir, level="warn")
880
+ class ValidationLoggerCallback(TrainerCallback):
881
+ def __init__(self, phase=""):
882
+ self.phase = phase
883
+ self.eval_history = []
884
+ self.best_eval_loss = float("inf")
885
+ self.best_step = 0
886
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
887
+ if metrics is None:
888
+ return
889
+ step = state.global_step
890
+ eval_loss = metrics.get("eval_loss")
891
+ if eval_loss is None:
892
+ return
893
+ self.eval_history.append((step, eval_loss))
894
+ is_best = eval_loss < self.best_eval_loss
895
+ if is_best:
896
+ self.best_eval_loss = eval_loss
897
+ self.best_step = step
898
+ try:
899
+ ppl_str = f" | ppl={math.exp(eval_loss):.2f}"
900
+ except OverflowError:
901
+ ppl_str = ""
902
+ best_str = " BEST" if is_best else f" (best={self.best_eval_loss:.4f} @{self.best_step})"
903
+ logger.info(f" [{self.phase}] Step {step}: eval_loss={eval_loss:.4f}{ppl_str}{best_str}")
904
+ def on_train_end(self, args, state, control, **kwargs):
905
+ if self.eval_history:
906
+ logger.info(f" [{self.phase}] Summary: best={self.best_eval_loss:.4f} @step {self.best_step}")
907
+ class PrintProgressCallback(TrainerCallback):
908
+ def __init__(self, phase=""):
909
+ self.phase = phase
910
+ self.start_time = None
911
+ def on_train_begin(self, args, state, control, **kwargs):
912
+ self.start_time = time.time()
913
+ print(f"\n{'='*70}", flush=True)
914
+ print(f"[{self.phase}] Training started | max_steps={args.max_steps} | "
915
+ f"lr={args.learning_rate} | gpu=P100", flush=True)
916
+ print(f"{'='*70}", flush=True)
917
+ def on_log(self, args, state, control, logs=None, **kwargs):
918
+ if logs is None or state.global_step == 0:
919
+ return
920
+ elapsed = time.time() - self.start_time
921
+ steps_done = state.global_step
922
+ steps_total = args.max_steps if args.max_steps > 0 else state.max_steps
923
+ speed = steps_done / elapsed if elapsed > 0 else 0
924
+ eta = (steps_total - steps_done) / speed if speed > 0 else 0
925
+ loss = logs.get("loss", logs.get("train_loss"))
926
+ lr = logs.get("learning_rate")
927
+ parts = [f"[{self.phase}] step {steps_done}/{steps_total}"]
928
+ if loss is not None: parts.append(f"loss={loss:.4f}")
929
+ if lr is not None: parts.append(f"lr={lr:.2e}")
930
+ parts.append(f"{speed:.2f} it/s")
931
+ parts.append(f"eta={eta/60:.1f}m")
932
+ print(" | ".join(parts), flush=True)
933
+ def on_train_end(self, args, state, control, **kwargs):
934
+ elapsed = time.time() - self.start_time
935
+ print(f"[{self.phase}] Done: {state.global_step} steps in {elapsed/60:.1f}m", flush=True)
936
+ class PipelineCallback(TrainerCallback):
937
+ def __init__(self, model, tokenizer, eval_prompts, phase="", eval_every=1000, max_new_tokens=256):
938
+ self.model = model
939
+ self.tokenizer = tokenizer
940
+ self.eval_prompts = eval_prompts
941
+ self.phase = phase
942
+ self.eval_every = eval_every
943
+ self.max_new_tokens = max_new_tokens
944
+ self.start_time = None
945
+ def on_train_begin(self, args, state, control, **kwargs):
946
+ self.start_time = time.time()
947
+ def on_step_end(self, args, state, control, **kwargs):
948
+ if state.global_step > 0 and state.global_step % 500 == 0 and torch.cuda.is_available():
949
+ mem = torch.cuda.memory_allocated(0) / 1e9
950
+ logger.info(f" [{self.phase}] Step {state.global_step} | GPU 0: {mem:.1f}GB")
951
+ if state.global_step > 0 and state.global_step % self.eval_every == 0:
952
+ self._run_inference(state.global_step)
953
+ def on_train_end(self, args, state, control, **kwargs):
954
+ elapsed = time.time() - self.start_time
955
+ logger.info(f"{self.phase} complete: {state.global_step} steps in {elapsed/3600:.2f}h")
956
+ self._run_inference(state.global_step, final=True)
957
+ @torch.no_grad()
958
+ def _run_inference(self, step, final=False):
959
+ self.model.eval()
960
+ device = next(self.model.parameters()).device
961
+ tag = "FINAL" if final else f"Step {step}"
962
+ logger.info(f"\n--- {self.phase} Inference @ {tag} ---")
963
+ prompts_dict = self.eval_prompts if isinstance(self.eval_prompts, dict) else {"general": self.eval_prompts}
964
+ for domain, prompts in prompts_dict.items():
965
+ show = prompts if final else prompts[:1]
966
+ for prompt in show:
967
+ formatted = f"User: {prompt}\n\nAssistant:"
968
+ inputs = self.tokenizer(formatted, return_tensors="pt", truncation=True,
969
+ max_length=self.max_new_tokens).to(device)
970
+ try:
971
+ outputs = self.model.simple_generate(
972
+ input_ids=inputs["input_ids"], max_new_tokens=self.max_new_tokens,
973
+ temperature=0.7, top_k=50, top_p=0.9,
974
+ pad_token_id=self.tokenizer.pad_token_id,
975
+ eos_token_id=self.tokenizer.eos_token_id,
976
+ )
977
+ response = self.tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:],
978
+ skip_special_tokens=True)
979
+ except Exception as e:
980
+ response = f"[Error: {str(e)[:80]}]"
981
+ logger.info(f" [{domain}] Q: {prompt[:80]}")
982
+ logger.info(f" [{domain}] A: {response[:250]}")
983
+ self.model.train()
984
+ # ════════════════════════════════════════════════════════════════════════════
985
+ # Eval Prompts
986
+ # ════════════════════════════════════════════════════════════════════════════
987
+ EVAL_PROMPTS = {
988
+ "math": [
989
+ "Martha has 18 crayons. She lost half of them, so she bought a new set of 20 crayons. How many crayons in total does Martha have after the purchase?",
990
+ "The four-digit numeral 3AA1 is divisible by 9. What digit does A represent?",
991
+ "Find the remainder when 2^100 is divided by 7.",
992
+ "Solve: 3x + 7 = 22. What is x?",
993
+ ],
994
+ "coding": [
995
+ "Given an array of integers, implement insertion sort in Python.",
996
+ "Write a Python function to find the longest common subsequence of two strings.",
997
+ ],
998
+ "conversation": [
999
+ "What are the key differences between renewable and non-renewable energy sources?",
1000
+ "What is the difference between machine learning and deep learning?",
1001
+ ],
1002
+ "reasoning": [
1003
+ "A bat and a ball cost $1.10 in total. The bat costs $1 more than the ball. How much does the ball cost?",
1004
+ "If P implies Q, and Q is false, what can we say about P?",
1005
+ ],
1006
+ "greetings": [
1007
+ "Hello! How are you today?",
1008
+ "Hi, can you help me with something?",
1009
+ ],
1010
+ }
1011
+ # ════════════════════════════════════════════════════════════════════════════
1012
+ # Dataset Loading β€” more data everywhere
1013
+ # ════════════════════════════════════════════════════════════════════════════
1014
+ def load_pretrain_dataset(cfg, tokenizer):
1015
+ logger.info("Loading pretraining corpus...")
1016
+ corpus_path = cfg.pretrain_corpus
1017
+ if IS_KAGGLE and not os.path.exists(corpus_path):
1018
+ corpus_path = "/kaggle/input/pretraining-corpus/pretraining_corpus.jsonl"
1019
+ if not os.path.exists(corpus_path):
1020
+ texts = ["Mathematics studies numbers and shapes."] * 1000
1021
+ ds = Dataset.from_dict({"text": texts})
1022
+ else:
1023
+ ds = load_dataset("json", data_files=corpus_path, split="train")
1024
+ if len(ds) > cfg.pretrain_max_samples:
1025
+ ds = ds.shuffle(seed=42).select(range(cfg.pretrain_max_samples))
1026
+ split = ds.train_test_split(test_size=cfg.pretrain_eval_split, seed=42)
1027
+ logger.info(f"Pretrain: {len(split['train']):,} train | {len(split['test']):,} eval")
1028
+ return split["train"], split["test"]
1029
+ def load_sft_dataset(cfg):
1030
+ logger.info("Loading SFT dataset...")
1031
+ data_dir = cfg.sft_data_dir
1032
+ if IS_KAGGLE and not os.path.isdir(data_dir):
1033
+ for cand in ["/kaggle/input/datasets/abhishek0706/sft-dataset",
1034
+ "/kaggle/input/datasets/abhishekgandhiau/sft-dataset-v1",
1035
+ "/kaggle/input/datasets/abhishekgandhi0706/sft-dataset",
1036
+ "/kaggle/input/sft-dataset/SFT_dataset", "/kaggle/input/sft-dataset"]:
1037
+ if os.path.isdir(cand):
1038
+ data_dir = cand
1039
+ break
1040
+ domain_files = {"math": "math_records.jsonl", "coding": "coding_records.jsonl",
1041
+ "conversation": "conversation_records.jsonl", "reasoning": "reasoning_records.jsonl",
1042
+ "greetings": "greetings_records.jsonl"}
1043
+ use_all = {"greetings"}
1044
+ all_records = []
1045
+ for domain, filename in domain_files.items():
1046
+ filepath = os.path.join(data_dir, filename)
1047
+ if not os.path.exists(filepath):
1048
+ continue
1049
+ records = []
1050
+ with open(filepath, "r", encoding="utf-8") as f:
1051
+ for line in f:
1052
+ line = line.strip()
1053
+ if not line:
1054
+ continue
1055
+ try:
1056
+ records.append(json.loads(line))
1057
+ except json.JSONDecodeError:
1058
+ continue
1059
+ total = len(records)
1060
+ if domain not in use_all and len(records) > cfg.sft_max_samples_per_domain:
1061
+ import random; random.seed(42); random.shuffle(records)
1062
+ records = records[:cfg.sft_max_samples_per_domain]
1063
+ logger.info(f" {domain}: {len(records):,}/{total:,}")
1064
+ for r in records:
1065
+ r["domain"] = domain
1066
+ all_records.extend(records)
1067
+ if not all_records:
1068
+ all_records = [{"prompt": "What is 2+2?", "thinking": "4", "answer": "4", "domain": "math"}] * 100
1069
+ import random; random.seed(42); random.shuffle(all_records)
1070
+ logger.info(f"Total SFT: {len(all_records):,}")
1071
+ return Dataset.from_list(all_records)
1072
+ def format_sft_text(example):
1073
+ p = example.get("prompt", "")
1074
+ t = example.get("thinking", "")
1075
+ a = example.get("answer", "")
1076
+ if len(t) > 3000:
1077
+ t = t[:1500] + " ... " + t[-1500:]
1078
+ if len(a) > 2000:
1079
+ a = a[:2000]
1080
+ return {"text": f"User: {p}\n\nAssistant: <think>{t}</think>\n<answer>{a}</answer>"}
1081
+ def create_grpo_dataset(sft_dataset, cfg):
1082
+ def fmt(ex):
1083
+ return {"prompt": f"User: {ex.get('prompt','')}\n\nAssistant:", "solution": ex.get("answer", "")}
1084
+ ds = sft_dataset.map(fmt)
1085
+ if len(ds) > cfg.grpo_max_dataset_size:
1086
+ ds = ds.shuffle(seed=42).select(range(cfg.grpo_max_dataset_size))
1087
+ logger.info(f"GRPO dataset: {len(ds):,}")
1088
+ return ds
1089
+ # ═══════════════════════════════════���════════════════════════════════════════
1090
+ # Reward Functions
1091
+ # ════════════════════════════════════════════════════════════════════════════
1092
+ def format_reward_func(completions, **kwargs):
1093
+ rewards = []
1094
+ for c in completions:
1095
+ text = " ".join(m.get("content", "") for m in c if isinstance(m, dict)) if isinstance(c, list) else str(c)
1096
+ r = 0.0
1097
+ has_think = bool(re.search(r"<think>.*?</think>", text, re.DOTALL))
1098
+ has_answer = bool(re.search(r"<answer>.*?</answer>", text, re.DOTALL))
1099
+ if has_think and has_answer:
1100
+ r += 1.0
1101
+ if text.find("<think>") < text.find("<answer>"):
1102
+ r += 0.5
1103
+ elif has_think or has_answer:
1104
+ r += 0.3
1105
+ rewards.append(r)
1106
+ return rewards
1107
+ def length_reward_func(completions, **kwargs):
1108
+ rewards = []
1109
+ for c in completions:
1110
+ text = " ".join(m.get("content", "") for m in c if isinstance(m, dict)) if isinstance(c, list) else str(c)
1111
+ w = len(text.split())
1112
+ rewards.append(1.0 if 20 <= w <= 200 else 0.1 if w < 10 else 0.4 if w > 300 else 0.7)
1113
+ return rewards
1114
+ def reasoning_quality_reward_func(completions, **kwargs):
1115
+ rewards = []
1116
+ for c in completions:
1117
+ text = " ".join(m.get("content", "") for m in c if isinstance(m, dict)) if isinstance(c, list) else str(c)
1118
+ r = 0.0
1119
+ m = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
1120
+ if m:
1121
+ reasoning = m.group(1).strip()
1122
+ w = len(reasoning.split())
1123
+ if w >= 10: r += 0.5
1124
+ if w >= 30: r += 0.3
1125
+ indicators = ["step", "first", "then", "therefore", "because", "since", "thus", "let me", "we can"]
1126
+ r += min(sum(1 for s in indicators if s in reasoning.lower()) * 0.1, 0.5)
1127
+ rewards.append(r)
1128
+ return rewards
1129
+ def repetition_penalty_reward_func(completions, **kwargs):
1130
+ rewards = []
1131
+ for c in completions:
1132
+ text = " ".join(m.get("content", "") for m in c if isinstance(m, dict)) if isinstance(c, list) else str(c)
1133
+ if len(text.strip()) < 5:
1134
+ rewards.append(0.0); continue
1135
+ r = 1.0
1136
+ words = text.lower().split()
1137
+ if len(words) >= 4:
1138
+ fg = [tuple(words[i:i+4]) for i in range(len(words)-3)]
1139
+ rr = 1.0 - len(set(fg)) / len(fg)
1140
+ if rr > 0.6: r -= 0.5
1141
+ elif rr > 0.4: r -= 0.3
1142
+ rewards.append(max(r, 0.0))
1143
+ return rewards
1144
+ def correctness_reward_func(completions, **kwargs):
1145
+ rewards = []
1146
+ solutions = kwargs.get("solution", [None] * len(completions))
1147
+ for i, c in enumerate(completions):
1148
+ text = " ".join(m.get("content", "") for m in c if isinstance(m, dict)) if isinstance(c, list) else str(c)
1149
+ gt = solutions[i] if i < len(solutions) and solutions[i] else None
1150
+ if not gt:
1151
+ rewards.append(0.0); continue
1152
+ am = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
1153
+ if not am:
1154
+ rewards.append(0.0); continue
1155
+ pred = am.group(1).strip()
1156
+ gt_t = str(gt).strip()
1157
+ try:
1158
+ if abs(float(pred.replace(",", "")) - float(gt_t.replace(",", ""))) < 1e-6:
1159
+ rewards.append(1.0); continue
1160
+ except Exception:
1161
+ pass
1162
+ pt = set(pred.lower().split()); gt_s = set(gt_t.lower().split())
1163
+ if pt and gt_s:
1164
+ o = pt & gt_s; p = len(o) / len(pt); r = len(o) / len(gt_s)
1165
+ rewards.append(min(2 * p * r / (p + r), 1.0) if (p + r) > 0 else 0.0)
1166
+ else:
1167
+ rewards.append(0.0)
1168
+ return rewards
1169
+ # ════════════════════════════════════════════════════════════════════════════
1170
+ # Phase 0: Pretraining
1171
+ # ════════════════════════════════════════════════════════════════════════════
1172
+ def run_pretraining(model, tokenizer, cfg, device, state_mgr):
1173
+ logger.info("\n" + "=" * 70 + "\nPHASE 0: PRETRAINING\n" + "=" * 70)
1174
+ pretrain_train_ds, pretrain_eval_ds = load_pretrain_dataset(cfg, tokenizer)
1175
+ pretrain_path = os.path.join(cfg.output_dir, "pretrain")
1176
+ _ensure_clean_distributed_state()
1177
+ sft_sig = inspect.signature(SFTConfig.__init__)
1178
+ sft_params = set(sft_sig.parameters.keys())
1179
+ sft_args = dict(
1180
+ output_dir=pretrain_path,
1181
+ per_device_train_batch_size=cfg.pretrain_batch_size,
1182
+ per_device_eval_batch_size=cfg.pretrain_batch_size * 2,
1183
+ gradient_accumulation_steps=cfg.pretrain_grad_accum,
1184
+ learning_rate=cfg.pretrain_lr,
1185
+ max_steps=cfg.pretrain_max_steps,
1186
+ warmup_steps=cfg.pretrain_warmup_steps,
1187
+ weight_decay=cfg.pretrain_weight_decay,
1188
+ logging_steps=cfg.pretrain_logging_steps,
1189
+ save_steps=cfg.pretrain_save_steps,
1190
+ eval_strategy="steps",
1191
+ eval_steps=cfg.pretrain_eval_steps,
1192
+ save_total_limit=cfg.save_total_limit,
1193
+ bf16=False, fp16=True, # ← P100: FP16 only
1194
+ report_to="none",
1195
+ gradient_checkpointing=True,
1196
+ lr_scheduler_type="cosine",
1197
+ dataloader_num_workers=cfg.dataloader_num_workers,
1198
+ dataloader_pin_memory=True,
1199
+ load_best_model_at_end=False,
1200
+ disable_tqdm=True,
1201
+ )
1202
+ opt = {}
1203
+ if "max_seq_length" in sft_params: opt["max_seq_length"] = cfg.max_seq_len
1204
+ if "dataset_text_field" in sft_params: opt["dataset_text_field"] = "text"
1205
+ if "packing" in sft_params: opt["packing"] = True # ← SPEED: packing ON
1206
+ sft_config = SFTConfig(**sft_args, **opt)
1207
+ model._set_gradient_checkpointing(True)
1208
+ best_cb = BestModelCallback(cfg.output_dir, "pretrain", tokenizer)
1209
+ val_cb = ValidationLoggerCallback("Pretrain")
1210
+ cbs = [PrintProgressCallback("Pretrain"),
1211
+ PipelineCallback(model, tokenizer, EVAL_PROMPTS, "Pretrain", cfg.pretrain_save_steps, 128),
1212
+ StorageMonitorCallback(cfg.output_dir), val_cb, best_cb]
1213
+ tk = dict(model=model, args=sft_config, train_dataset=pretrain_train_ds,
1214
+ eval_dataset=pretrain_eval_ds, callbacks=cbs)
1215
+ ti = set(inspect.signature(SFTTrainer.__init__).parameters.keys())
1216
+ tk["processing_class" if "processing_class" in ti else "tokenizer"] = tokenizer
1217
+ trainer = SFTTrainer(**tk)
1218
+ logger.info(f"Pretrain: single P100 | eff_batch={cfg.pretrain_batch_size * cfg.pretrain_grad_accum} | packing=ON")
1219
+ resume = find_latest_checkpoint(pretrain_path)
1220
+ trainer.train(resume_from_checkpoint=resume)
1221
+ enforce_storage_limit(cfg.output_dir, "save pretrain_model")
1222
+ save_path = os.path.join(cfg.output_dir, "pretrain_model")
1223
+ trainer.save_model(save_path); tokenizer.save_pretrained(save_path)
1224
+ cleanup_checkpoints(pretrain_path, keep_last=0)
1225
+ state_mgr.mark_complete("pretrain", best_eval_loss=val_cb.best_eval_loss,
1226
+ steps=trainer.state.global_step)
1227
+ del trainer; torch.cuda.empty_cache(); gc.collect()
1228
+ return model
1229
+ # ════════════════════════════════════════════════════════════════════════════
1230
+ # Phase 1: SFT
1231
+ # ════════════════════════════════════════════════════════════════════════════
1232
+ def run_sft(model, tokenizer, cfg, device, state_mgr):
1233
+ logger.info("\n" + "=" * 70 + "\nPHASE 1: SFT\n" + "=" * 70)
1234
+ raw_ds = load_sft_dataset(cfg)
1235
+ split = raw_ds.train_test_split(test_size=cfg.sft_eval_split, seed=42)
1236
+ train_ds = split["train"].map(format_sft_text, remove_columns=split["train"].column_names)
1237
+ eval_ds = split["test"].map(format_sft_text, remove_columns=split["test"].column_names)
1238
+ logger.info(f"SFT Train: {len(train_ds):,} | Eval: {len(eval_ds):,}")
1239
+ sft_path = os.path.join(cfg.output_dir, "sft")
1240
+ _ensure_clean_distributed_state()
1241
+ sft_sig = inspect.signature(SFTConfig.__init__)
1242
+ sft_params = set(sft_sig.parameters.keys())
1243
+ sft_args = dict(
1244
+ output_dir=sft_path,
1245
+ per_device_train_batch_size=cfg.sft_batch_size,
1246
+ per_device_eval_batch_size=cfg.sft_batch_size * 2,
1247
+ gradient_accumulation_steps=cfg.sft_grad_accum,
1248
+ learning_rate=cfg.sft_lr,
1249
+ max_steps=cfg.sft_max_steps,
1250
+ warmup_steps=cfg.sft_warmup_steps,
1251
+ weight_decay=cfg.sft_weight_decay,
1252
+ max_grad_norm=cfg.sft_max_grad_norm,
1253
+ logging_steps=cfg.sft_logging_steps,
1254
+ save_steps=cfg.sft_save_steps,
1255
+ eval_strategy="steps",
1256
+ eval_steps=cfg.sft_eval_steps,
1257
+ save_total_limit=cfg.save_total_limit,
1258
+ bf16=False, fp16=True, # ← P100: FP16 only
1259
+ report_to="none",
1260
+ gradient_checkpointing=True,
1261
+ lr_scheduler_type="cosine",
1262
+ load_best_model_at_end=False,
1263
+ dataloader_num_workers=cfg.dataloader_num_workers,
1264
+ dataloader_pin_memory=True,
1265
+ disable_tqdm=True,
1266
+ )
1267
+ opt = {}
1268
+ if "max_seq_length" in sft_params: opt["max_seq_length"] = cfg.max_seq_len
1269
+ if "dataset_text_field" in sft_params: opt["dataset_text_field"] = "text"
1270
+ if "packing" in sft_params: opt["packing"] = True # ← SPEED: packing ON
1271
+ sft_config = SFTConfig(**sft_args, **opt)
1272
+ model._set_gradient_checkpointing(True)
1273
+ best_cb = BestModelCallback(cfg.output_dir, "sft", tokenizer)
1274
+ val_cb = ValidationLoggerCallback("SFT")
1275
+ cbs = [PrintProgressCallback("SFT"),
1276
+ PipelineCallback(model, tokenizer, EVAL_PROMPTS, "SFT", cfg.inference_every_steps, cfg.max_seq_len),
1277
+ StorageMonitorCallback(cfg.output_dir), val_cb, best_cb]
1278
+ tk = dict(model=model, args=sft_config, train_dataset=train_ds, eval_dataset=eval_ds, callbacks=cbs)
1279
+ ti = set(inspect.signature(SFTTrainer.__init__).parameters.keys())
1280
+ tk["processing_class" if "processing_class" in ti else "tokenizer"] = tokenizer
1281
+ trainer = SFTTrainer(**tk)
1282
+ logger.info(f"SFT: single P100 | eff_batch={cfg.sft_batch_size * cfg.sft_grad_accum} | "
1283
+ f"max_steps={cfg.sft_max_steps} | seq={cfg.max_seq_len} | packing=ON")
1284
+ resume = find_latest_checkpoint(sft_path)
1285
+ trainer.train(resume_from_checkpoint=resume)
1286
+ enforce_storage_limit(cfg.output_dir, "save sft_model")
1287
+ save_path = os.path.join(cfg.output_dir, "sft_model")
1288
+ trainer.save_model(save_path); tokenizer.save_pretrained(save_path)
1289
+ cleanup_checkpoints(sft_path, keep_last=0)
1290
+ pretrain_model_path = os.path.join(cfg.output_dir, "pretrain_model")
1291
+ if os.path.isdir(pretrain_model_path):
1292
+ shutil.rmtree(pretrain_model_path, ignore_errors=True)
1293
+ state_mgr.mark_complete("sft", best_eval_loss=val_cb.best_eval_loss,
1294
+ steps=trainer.state.global_step)
1295
+ check_output_storage(cfg.output_dir)
1296
+ del trainer; torch.cuda.empty_cache(); gc.collect()
1297
+ return model, raw_ds
1298
+
1299
+
1300
+ # ════════════════════════════════════════════════════════════════════════════
1301
+ # Helper: monkey-patch TRL's create_model_from_path to prevent auto ref
1302
+ # model creation for custom model types (would fail with empty _name_or_path)
1303
+ # ════════════════════════════════════════════════════════════════════════════
1304
+ def _patch_trl_no_ref_model():
1305
+ """
1306
+ Returns a context-manager-like pair (patch, unpatch) that replaces
1307
+ trl.trainer.grpo_trainer.create_model_from_path with a no-op so that
1308
+ GRPOTrainer.__init__ skips automatic ref-model creation.
1309
+ """
1310
+ import trl.trainer.grpo_trainer as _grpo_mod
1311
+ _orig = getattr(_grpo_mod, "create_model_from_path", None)
1312
+ def _noop(*_args, **_kwargs):
1313
+ return None
1314
+ def patch():
1315
+ if _orig is not None:
1316
+ _grpo_mod.create_model_from_path = _noop
1317
+ def unpatch():
1318
+ if _orig is not None:
1319
+ _grpo_mod.create_model_from_path = _orig
1320
+ return patch, unpatch
1321
+
1322
+
1323
+ def _nuclear_gpu_cleanup(model, device):
1324
+ """Move model to CPU, purge ALL GPU state, move model back."""
1325
+ model.cpu()
1326
+ gc.collect()
1327
+ torch.cuda.empty_cache()
1328
+ if torch.cuda.is_available():
1329
+ torch.cuda.synchronize()
1330
+ gc.collect()
1331
+ if torch.cuda.is_available():
1332
+ free_mem = torch.cuda.mem_get_info(0)[0] / 1e9
1333
+ total_mem = torch.cuda.mem_get_info(0)[1] / 1e9
1334
+ logger.info(f" After cleanup: {free_mem:.1f} / {total_mem:.1f} GB free")
1335
+ model.to(device)
1336
+ model._set_gradient_checkpointing(True)
1337
+ torch.cuda.empty_cache()
1338
+ _ensure_clean_distributed_state()
1339
+
1340
+
1341
+ # ════════════════════════════════════════════════════════════════════════════
1342
+ # Phase 2: GRPO β€” single GPU, robust OOM fallback
1343
+ # ════════════════════════════════════════════════════════════════════════════
1344
+ def run_grpo(model, tokenizer, raw_sft_ds, cfg, device, state_mgr,
1345
+ ref_model_path: str = None):
1346
+ logger.info("\n" + "=" * 70 + "\nPHASE 2: GRPO\n" + "=" * 70)
1347
+ grpo_ds = create_grpo_dataset(raw_sft_ds, cfg)
1348
+ tokenizer.padding_side = "left"
1349
+ grpo_path = os.path.join(cfg.output_dir, "grpo")
1350
+ _ensure_clean_distributed_state()
1351
+ grpo_sig = inspect.signature(GRPOConfig.__init__)
1352
+ grpo_params = set(grpo_sig.parameters.keys())
1353
+ base_args = dict(
1354
+ output_dir=grpo_path,
1355
+ per_device_train_batch_size=cfg.grpo_batch_size,
1356
+ gradient_accumulation_steps=cfg.grpo_grad_accum,
1357
+ learning_rate=cfg.grpo_lr,
1358
+ max_steps=cfg.grpo_max_steps,
1359
+ logging_steps=cfg.grpo_logging_steps,
1360
+ save_steps=cfg.grpo_save_steps,
1361
+ warmup_steps=cfg.grpo_warmup_steps,
1362
+ weight_decay=cfg.grpo_weight_decay,
1363
+ max_grad_norm=cfg.grpo_max_grad_norm,
1364
+ bf16=False, fp16=True, # ← P100: FP16 only
1365
+ report_to="none",
1366
+ save_total_limit=cfg.save_total_limit,
1367
+ gradient_checkpointing=True,
1368
+ lr_scheduler_type="cosine",
1369
+ dataloader_pin_memory=True,
1370
+ disable_tqdm=True,
1371
+ )
1372
+ opt = {}
1373
+ if "num_generations" in grpo_params: opt["num_generations"] = cfg.grpo_num_generations
1374
+ if "max_completion_length" in grpo_params: opt["max_completion_length"] = cfg.grpo_max_completion_length
1375
+ if "max_prompt_length" in grpo_params: opt["max_prompt_length"] = cfg.grpo_max_prompt_length
1376
+ if "beta" in grpo_params: opt["beta"] = cfg.grpo_beta
1377
+ if "remove_unused_columns" in grpo_params: opt["remove_unused_columns"] = False
1378
+ grpo_config = GRPOConfig(**base_args, **opt)
1379
+ model._set_gradient_checkpointing(True)
1380
+
1381
+ # ── Validate ref_model_path ─────────────────────────────────────────────
1382
+ if ref_model_path is None:
1383
+ ref_model_path = os.path.join(cfg.output_dir, "sft_model")
1384
+ if not os.path.isdir(ref_model_path):
1385
+ if os.path.isfile(ref_model_path):
1386
+ ref_model_path = os.path.dirname(ref_model_path)
1387
+ else:
1388
+ raise FileNotFoundError(f"Reference model path does not exist: {ref_model_path}")
1389
+ logger.info(f"Reference model weights dir: {ref_model_path}")
1390
+
1391
+ # ── FIX: Set _name_or_path so TRL's get_config_model_id() returns a
1392
+ # valid path. Without this, custom model types get an empty string
1393
+ # which makes HuggingFace Hub validation fail. ────────────────────────
1394
+ model.config._name_or_path = ref_model_path
1395
+
1396
+ # ── Reference model β€” check TRL API, load only if accepted ──────────
1397
+ gi = set(inspect.signature(GRPOTrainer.__init__).parameters.keys())
1398
+ ref_model_accepted = "ref_model" in gi
1399
+ ref_model = None
1400
+
1401
+ if ref_model_accepted:
1402
+ # Older TRL: we can pass ref_model directly β†’ load it
1403
+ ref_device = torch.device(device)
1404
+ ref_model = load_checkpoint_robust(model.config, ref_model_path, device=ref_device)
1405
+ ref_model.eval()
1406
+ for p in ref_model.parameters():
1407
+ p.requires_grad = False
1408
+ torch.cuda.empty_cache()
1409
+ logger.info(f"Reference model loaded on {ref_device} (beta={cfg.grpo_beta})")
1410
+ else:
1411
+ # Newer TRL: auto-creates ref_model internally.
1412
+ # Since _name_or_path is now set, TRL can auto-load via
1413
+ # create_model_from_path. But to save VRAM we set beta=0
1414
+ # and monkey-patch to skip the ref model entirely.
1415
+ grpo_config.beta = 0.0
1416
+ logger.info(" ref_model not accepted as param β€” set beta=0.0 "
1417
+ "(reward-only GRPO, no KL penalty)")
1418
+
1419
+ best_cb = BestModelCallback(cfg.output_dir, "grpo", tokenizer)
1420
+ cbs = [PrintProgressCallback("GRPO"),
1421
+ PipelineCallback(model, tokenizer, EVAL_PROMPTS, "GRPO",
1422
+ cfg.inference_every_steps, cfg.grpo_max_completion_length),
1423
+ StorageMonitorCallback(cfg.output_dir, 200), best_cb]
1424
+ reward_funcs = [format_reward_func, length_reward_func, reasoning_quality_reward_func,
1425
+ repetition_penalty_reward_func, correctness_reward_func]
1426
+
1427
+ gk = dict(model=model, args=grpo_config, train_dataset=grpo_ds,
1428
+ reward_funcs=reward_funcs, callbacks=cbs)
1429
+ if ref_model_accepted and ref_model is not None:
1430
+ gk["ref_model"] = ref_model
1431
+ if "processing_class" in gi:
1432
+ gk["processing_class"] = tokenizer
1433
+ elif "tokenizer" in gi:
1434
+ gk["tokenizer"] = tokenizer
1435
+ else:
1436
+ gk["processing_class"] = tokenizer
1437
+
1438
+ # ── Build GRPOTrainer β€” for newer TRL that auto-creates ref model,
1439
+ # we need the monkey-patch when ref_model param is NOT accepted ───────
1440
+ patch_fn, unpatch_fn = _patch_trl_no_ref_model()
1441
+ if not ref_model_accepted:
1442
+ patch_fn()
1443
+ try:
1444
+ grpo_trainer = GRPOTrainer(**gk)
1445
+ finally:
1446
+ unpatch_fn()
1447
+
1448
+ # Ensure ref_model is None if we patched (belt and suspenders)
1449
+ if not ref_model_accepted:
1450
+ if hasattr(grpo_trainer, 'ref_model') and grpo_trainer.ref_model is not None:
1451
+ grpo_trainer.ref_model.cpu()
1452
+ del grpo_trainer.ref_model
1453
+ grpo_trainer.ref_model = None
1454
+ torch.cuda.empty_cache()
1455
+
1456
+ logger.info(f"GRPO: single P100 | eff_batch={cfg.grpo_batch_size * cfg.grpo_grad_accum} | "
1457
+ f"prompt={cfg.grpo_max_prompt_length} | completion={cfg.grpo_max_completion_length}")
1458
+ resume = find_latest_checkpoint(grpo_path)
1459
+ grpo_succeeded = False
1460
+
1461
+ # ── Attempt 1: normal settings ──
1462
+ try:
1463
+ grpo_trainer.train(resume_from_checkpoint=resume)
1464
+ grpo_succeeded = True
1465
+ except torch.cuda.OutOfMemoryError:
1466
+ logger.warning("OOM during GRPO attempt 1! Performing full GPU memory reset...")
1467
+
1468
+ # --- Free trainer and ref model completely ---
1469
+ try:
1470
+ if hasattr(grpo_trainer, 'ref_model') and grpo_trainer.ref_model is not None:
1471
+ grpo_trainer.ref_model.cpu()
1472
+ del grpo_trainer
1473
+ except Exception:
1474
+ pass
1475
+ if ref_model is not None:
1476
+ ref_model.cpu()
1477
+ del ref_model
1478
+ ref_model = None
1479
+ gk.pop("ref_model", None)
1480
+
1481
+ # --- Nuclear GPU cleanup ---
1482
+ _nuclear_gpu_cleanup(model, device)
1483
+
1484
+ # ── Attempt 2: conservative (prompt=96, completion=128, total=224) ──
1485
+ logger.info(" Fallback attempt 2: bs=1, prompt=96, completion=128, beta=0.0")
1486
+ fallback_args = {**base_args}
1487
+ fallback_args["per_device_train_batch_size"] = 1
1488
+ fallback_args["gradient_accumulation_steps"] = 16
1489
+ fallback_opt = {}
1490
+ if "num_generations" in grpo_params: fallback_opt["num_generations"] = 2
1491
+ if "max_completion_length" in grpo_params: fallback_opt["max_completion_length"] = 128
1492
+ if "max_prompt_length" in grpo_params: fallback_opt["max_prompt_length"] = 96
1493
+ if "beta" in grpo_params: fallback_opt["beta"] = 0.0
1494
+ if "remove_unused_columns" in grpo_params: fallback_opt["remove_unused_columns"] = False
1495
+ grpo_config2 = GRPOConfig(**fallback_args, **fallback_opt)
1496
+ # Force beta=0 even if param name changed
1497
+ grpo_config2.beta = 0.0
1498
+
1499
+ gk["model"] = model
1500
+ gk["args"] = grpo_config2
1501
+ gk["reward_funcs"] = [format_reward_func, length_reward_func, correctness_reward_func]
1502
+ gk["callbacks"] = [PrintProgressCallback("GRPO-fallback2"),
1503
+ StorageMonitorCallback(cfg.output_dir, 200), best_cb]
1504
+
1505
+ try:
1506
+ # Monkey-patch to prevent TRL from auto-creating ref model
1507
+ patch_fn()
1508
+ try:
1509
+ grpo_trainer = GRPOTrainer(**gk)
1510
+ finally:
1511
+ unpatch_fn()
1512
+ # Ensure no ref model lingering
1513
+ if hasattr(grpo_trainer, 'ref_model') and grpo_trainer.ref_model is not None:
1514
+ grpo_trainer.ref_model.cpu()
1515
+ del grpo_trainer.ref_model
1516
+ grpo_trainer.ref_model = None
1517
+ torch.cuda.empty_cache()
1518
+ grpo_trainer.args.beta = 0.0
1519
+ grpo_trainer.train()
1520
+ grpo_succeeded = True
1521
+
1522
+ except torch.cuda.OutOfMemoryError:
1523
+ logger.error("OOM on attempt 2! Trying minimal config (prompt=64, completion=64)...")
1524
+ try:
1525
+ if hasattr(grpo_trainer, 'ref_model') and grpo_trainer.ref_model is not None:
1526
+ grpo_trainer.ref_model.cpu()
1527
+ del grpo_trainer
1528
+ except Exception:
1529
+ pass
1530
+
1531
+ _nuclear_gpu_cleanup(model, device)
1532
+
1533
+ # ── Attempt 3: absolute minimum ──
1534
+ fallback_args3 = {**base_args}
1535
+ fallback_args3["per_device_train_batch_size"] = 1
1536
+ fallback_args3["gradient_accumulation_steps"] = 16
1537
+ fallback_args3["max_steps"] = 500
1538
+ fallback_opt3 = {}
1539
+ if "num_generations" in grpo_params: fallback_opt3["num_generations"] = 2
1540
+ if "max_completion_length" in grpo_params: fallback_opt3["max_completion_length"] = 64
1541
+ if "max_prompt_length" in grpo_params: fallback_opt3["max_prompt_length"] = 64
1542
+ if "beta" in grpo_params: fallback_opt3["beta"] = 0.0
1543
+ if "remove_unused_columns" in grpo_params: fallback_opt3["remove_unused_columns"] = False
1544
+ grpo_config3 = GRPOConfig(**fallback_args3, **fallback_opt3)
1545
+ grpo_config3.beta = 0.0
1546
+
1547
+ gk["model"] = model
1548
+ gk["args"] = grpo_config3
1549
+ gk["callbacks"] = [PrintProgressCallback("GRPO-fallback3"),
1550
+ StorageMonitorCallback(cfg.output_dir, 200), best_cb]
1551
+
1552
+ try:
1553
+ patch_fn()
1554
+ try:
1555
+ grpo_trainer = GRPOTrainer(**gk)
1556
+ finally:
1557
+ unpatch_fn()
1558
+ if hasattr(grpo_trainer, 'ref_model') and grpo_trainer.ref_model is not None:
1559
+ grpo_trainer.ref_model.cpu()
1560
+ del grpo_trainer.ref_model
1561
+ grpo_trainer.ref_model = None
1562
+ torch.cuda.empty_cache()
1563
+ grpo_trainer.args.beta = 0.0
1564
+ grpo_trainer.train()
1565
+ grpo_succeeded = True
1566
+ except torch.cuda.OutOfMemoryError:
1567
+ logger.error("OOM on attempt 3! Skipping GRPO β€” saving current model as final.")
1568
+ grpo_succeeded = False
1569
+
1570
+ # ── Save final model ──
1571
+ final_path = os.path.join(cfg.output_dir, "final_model")
1572
+ os.makedirs(final_path, exist_ok=True)
1573
+ enforce_storage_limit(cfg.output_dir, "save final_model")
1574
+ if grpo_succeeded:
1575
+ grpo_trainer.save_model(final_path)
1576
+ tokenizer.save_pretrained(final_path)
1577
+ state_mgr.mark_complete("grpo", steps=grpo_trainer.state.global_step)
1578
+ try:
1579
+ del grpo_trainer
1580
+ except Exception:
1581
+ pass
1582
+ else:
1583
+ # Save the SFT model as "final" since GRPO couldn't run
1584
+ save_model = model.module if hasattr(model, "module") else model
1585
+ save_model.save_pretrained(final_path)
1586
+ tokenizer.save_pretrained(final_path)
1587
+ state_mgr.mark_complete("grpo", steps=0)
1588
+ cleanup_checkpoints(grpo_path, keep_last=0)
1589
+ local_sft_model = os.path.join(cfg.output_dir, "sft_model")
1590
+ if os.path.isdir(local_sft_model):
1591
+ shutil.rmtree(local_sft_model, ignore_errors=True)
1592
+ check_output_storage(cfg.output_dir)
1593
+ torch.cuda.empty_cache(); gc.collect()
1594
+ return model
1595
+ # ════════════════════════════════════════════════════════════════════════════
1596
+ # Final Evaluation
1597
+ # ════════════════════════════════════════════════════════════════════════════
1598
+ def run_final_eval(model, tokenizer, cfg, device):
1599
+ logger.info("\n" + "=" * 70 + "\nFINAL EVALUATION\n" + "=" * 70)
1600
+ model.eval()
1601
+ test_prompts = {
1602
+ "math": ["A store sells notebooks for $3 each. Buy 5+ get 20% off. How much do 7 cost?",
1603
+ "What is the sum of all integers from 1 to 100?"],
1604
+ "coding": ["Write a Python function to compute factorial using recursion."],
1605
+ "conversation": ["Explain the greenhouse effect and its role in climate change."],
1606
+ "reasoning": ["If it rains, the ground gets wet. The ground is wet. Did it necessarily rain?"],
1607
+ "greetings": ["Good morning! What can you do?"],
1608
+ }
1609
+ for domain, prompts in test_prompts.items():
1610
+ logger.info(f"\n--- [{domain.upper()}] ---")
1611
+ for prompt in prompts:
1612
+ formatted = f"User: {prompt}\n\nAssistant:"
1613
+ inputs = tokenizer(formatted, return_tensors="pt", truncation=True,
1614
+ max_length=cfg.max_seq_len).to(device)
1615
+ with torch.no_grad():
1616
+ outputs = model.simple_generate(
1617
+ inputs["input_ids"], max_new_tokens=cfg.max_seq_len,
1618
+ temperature=0.7, top_p=0.9, eos_token_id=tokenizer.eos_token_id)
1619
+ response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
1620
+ logger.info(f" Q: {prompt}")
1621
+ logger.info(f" A: {response}")
1622
+ # ════════════════════════════════════════════════════════════════════════════
1623
+ # Main
1624
+ # ════════════════════════════════════════════════════════════════════════════
1625
+ def main():
1626
+ train_start = time.time()
1627
+ logger.info("\n" + "=" * 70)
1628
+ logger.info("HybridMoRMoE Full Pipeline β€” P100 x1 (SINGLE GPU, 16 GB)")
1629
+ logger.info("=" * 70)
1630
+ has_gpu, num_gpus = setup_gpu()
1631
+ device = "cuda" if has_gpu else "cpu"
1632
+ cfg = PipelineConfig()
1633
+ cfg.model_size = os.environ.get("MODEL_SIZE", "medium")
1634
+ cfg.skip_pretrain = os.environ.get("SKIP_PRETRAIN", "1") == "1"
1635
+ cfg.skip_sft = os.environ.get("SKIP_SFT", "1") == "1"
1636
+ cfg.num_gpus = 1 # always 1 for P100
1637
+ if IS_KAGGLE:
1638
+ cfg.output_dir = "/kaggle/working/hybrid_mor_moe_P100"
1639
+ cfg.sft_data_dir = "/kaggle/input/datasets/abhishekgandhiau/sft-dataset-v1"
1640
+ cfg.pretrain_corpus = "/kaggle/input/pretraining-corpus/pretraining_corpus.jsonl"
1641
+ cfg.tokenizer_path = "/kaggle/input/qwen-tokenizer/Qwen2.5-0.5B-Instruct"
1642
+ cfg = adjust_config_for_model_size(cfg)
1643
+ os.makedirs(cfg.output_dir, exist_ok=True)
1644
+ state_mgr = PipelineStateManager(cfg.output_dir)
1645
+ logger.info(f" [Checkpoint] {state_mgr.summary()}")
1646
+ if state_mgr.is_complete("pretrain"):
1647
+ cfg.skip_pretrain = True
1648
+ logger.info(" [Checkpoint] pretrain already done β†’ skip")
1649
+ if state_mgr.is_complete("sft"):
1650
+ cfg.skip_sft = True
1651
+ logger.info(" [Checkpoint] sft already done β†’ skip")
1652
+ logger.info(f"Model: {cfg.model_size} | GPU: P100 x1 | Seq: {cfg.max_seq_len} | SFT steps: {cfg.sft_max_steps}")
1653
+ logger.info(f"Data: pretrain={cfg.pretrain_max_samples//1000}K sft/dom={cfg.sft_max_samples_per_domain//1000}K "
1654
+ f"grpo={cfg.grpo_max_dataset_size//1000}K")
1655
+ logger.info(f"Skip pretrain: {cfg.skip_pretrain} | Skip SFT: {cfg.skip_sft}")
1656
+ check_output_storage(cfg.output_dir)
1657
+ # ── Tokenizer ──
1658
+ if os.path.isdir(cfg.tokenizer_path):
1659
+ tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path, trust_remote_code=True, local_files_only=True)
1660
+ else:
1661
+ tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_hf_id, trust_remote_code=True)
1662
+ if tokenizer.pad_token is None:
1663
+ tokenizer.pad_token = tokenizer.eos_token
1664
+ tokenizer.padding_side = "right"
1665
+ model_config = HybridMoRMoEConfig(
1666
+ model_size=cfg.model_size, max_seq_len=cfg.max_seq_len, dropout=cfg.dropout)
1667
+ model_config.vocab_size = len(tokenizer)
1668
+ # ── Pre-trained SFT model path (Kaggle input) ──
1669
+ INPUT_SFT_MODEL_DIR = "/kaggle/input/models/abhishekgandhiau/hybrid-mor-moe/transformers/default/1"
1670
+ pretrain_model_path = os.path.join(cfg.output_dir, "pretrain_model")
1671
+ raw_sft_ds = None
1672
+ ref_model_path_for_grpo = None
1673
+ # ── Model loading ──
1674
+ if cfg.skip_sft and os.path.isdir(INPUT_SFT_MODEL_DIR):
1675
+ logger.info(f"Loading existing SFT model: {INPUT_SFT_MODEL_DIR}")
1676
+ model = load_checkpoint_robust(model_config, INPUT_SFT_MODEL_DIR, device=device)
1677
+ raw_sft_ds = load_sft_dataset(cfg)
1678
+ ref_model_path_for_grpo = INPUT_SFT_MODEL_DIR
1679
+ elif cfg.skip_pretrain and os.path.isdir(pretrain_model_path):
1680
+ logger.info(f"Loading existing pretrain model: {pretrain_model_path}")
1681
+ model = load_checkpoint_robust(model_config, pretrain_model_path, device=device)
1682
+ else:
1683
+ model = HybridMoRMoEForCausalLM(model_config)
1684
+ model.to(device)
1685
+ total_params = sum(p.numel() for p in model.parameters())
1686
+ logger.info(f"Model: {total_params:,} params ({total_params/1e6:.1f}M)")
1687
+ with torch.no_grad():
1688
+ test_ids = torch.randint(0, model_config.vocab_size, (2, 32), device=device)
1689
+ test_out = model(test_ids, labels=test_ids, return_dict=True)
1690
+ logger.info(f"Forward pass OK, loss={test_out.loss.item():.4f}")
1691
+ del test_ids, test_out; torch.cuda.empty_cache()
1692
+ # ── Phase 0: Pretrain ──
1693
+ if not cfg.skip_pretrain:
1694
+ model = run_pretraining(model, tokenizer, cfg, device, state_mgr)
1695
+ else:
1696
+ logger.info("\nPHASE 0: PRETRAINING β€” SKIPPED")
1697
+ ckpt_dir = os.path.join(cfg.output_dir, "pretrain")
1698
+ if os.path.isdir(ckpt_dir):
1699
+ cleanup_checkpoints(ckpt_dir, keep_last=0)
1700
+ # ── Phase 1: SFT ──
1701
+ if not cfg.skip_sft:
1702
+ model, raw_sft_ds = run_sft(model, tokenizer, cfg, device, state_mgr)
1703
+ ref_model_path_for_grpo = os.path.join(cfg.output_dir, "sft_model")
1704
+ else:
1705
+ logger.info("\nPHASE 1: SFT β€” SKIPPED")
1706
+ if raw_sft_ds is None:
1707
+ raw_sft_ds = load_sft_dataset(cfg)
1708
+ # ── Phase 2: GRPO ──
1709
+ model = run_grpo(model, tokenizer, raw_sft_ds, cfg, device, state_mgr,
1710
+ ref_model_path=ref_model_path_for_grpo)
1711
+ run_final_eval(model, tokenizer, cfg, device)
1712
+ check_output_storage(cfg.output_dir)
1713
+ total_time = time.time() - train_start
1714
+ logger.info("\n" + "=" * 70)
1715
+ logger.info("PIPELINE COMPLETE!")
1716
+ logger.info(f" Model: {cfg.model_size} ({total_params/1e6:.1f}M) | GPU: P100 x1")
1717
+ logger.info(f" Wall time: {total_time/3600:.2f}h")
1718
+ logger.info(f" {state_mgr.summary()}")
1719
+ logger.info(f" Final model: {os.path.join(cfg.output_dir, 'final_model')}")
1720
+ logger.info("=" * 70)
1721
+ return model, tokenizer
1722
+ if __name__ == "__main__":
1723
+ main()
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ac8da7c6d9840b73e97532c92ca59936ebeb2917f26df401e3eafe0a91553dd
3
+ size 1176801876
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fd169731d2cbde95e10bf356d66d5997fd885dd8dbb6fb4684da3f23b2585d8
3
+ size 11421892
tokenizer_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<|im_start|>",
10
+ "<|im_end|>",
11
+ "<|object_ref_start|>",
12
+ "<|object_ref_end|>",
13
+ "<|box_start|>",
14
+ "<|box_end|>",
15
+ "<|quad_start|>",
16
+ "<|quad_end|>",
17
+ "<|vision_start|>",
18
+ "<|vision_end|>",
19
+ "<|vision_pad|>",
20
+ "<|image_pad|>",
21
+ "<|video_pad|>"
22
+ ],
23
+ "is_local": false,
24
+ "model_max_length": 131072,
25
+ "pad_token": "<|endoftext|>",
26
+ "split_special_tokens": false,
27
+ "tokenizer_class": "Qwen2Tokenizer",
28
+ "unk_token": null
29
+ }