stmasson commited on
Commit
285df1b
ยท
verified ยท
1 Parent(s): e81fa0f

Upload scripts/train_alizee_v2_stage3_merge.py with huggingface_hub

Browse files
scripts/train_alizee_v2_stage3_merge.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "peft>=0.14.0",
5
+ # "transformers>=4.48.0",
6
+ # "accelerate>=0.35.0",
7
+ # "torch>=2.2.0",
8
+ # "huggingface_hub>=0.25.0",
9
+ # ]
10
+ # ///
11
+
12
+ """
13
+ Stage 3: Adapter Merging and Final Model Publication
14
+
15
+ Merges LoRA adapters from Stage 1 (and optionally Stage 2) into the base model
16
+ and pushes the final merged model as alizee-coder-devstral-2-small.
17
+
18
+ Options:
19
+ 1. Merge Stage 1 only (if skipping DPO)
20
+ 2. Merge Stage 1 + Stage 2 (if DPO was applied)
21
+ """
22
+
23
+ import os
24
+ import torch
25
+ from peft import PeftModel, AutoPeftModelForCausalLM
26
+ from transformers import AutoTokenizer, AutoModelForCausalLM
27
+ from huggingface_hub import HfApi, create_repo
28
+
29
+ # Configuration
30
+ BASE_MODEL = "mistralai/Devstral-Small-2505"
31
+ STAGE1_MODEL = "stmasson/alizee-coder-devstral-2-small-stage1"
32
+ STAGE2_MODEL = "stmasson/alizee-coder-devstral-2-small-stage2" # Optional
33
+ FINAL_REPO = "stmasson/alizee-coder-devstral-2-small"
34
+
35
+ # Set this based on whether you ran Stage 2
36
+ USE_STAGE2 = os.environ.get("USE_STAGE2", "false").lower() == "true"
37
+
38
+ print("=" * 60)
39
+ print("Stage 3: Adapter Merging and Final Model Publication")
40
+ print("=" * 60)
41
+ print(f"Base model: {BASE_MODEL}")
42
+ print(f"Stage 1 adapter: {STAGE1_MODEL}")
43
+ print(f"Stage 2 adapter: {STAGE2_MODEL if USE_STAGE2 else 'SKIPPED'}")
44
+ print(f"Final output: {FINAL_REPO}")
45
+ print("=" * 60)
46
+
47
+ # Determine which model to merge
48
+ source_model = STAGE2_MODEL if USE_STAGE2 else STAGE1_MODEL
49
+ print(f"\n๐Ÿ”„ Loading model from: {source_model}")
50
+
51
+ # Load tokenizer
52
+ print("\n๐Ÿ“ Loading tokenizer...")
53
+ tokenizer = AutoTokenizer.from_pretrained(source_model, trust_remote_code=True)
54
+
55
+ # Load PEFT model and merge
56
+ print("\n๐Ÿ”— Loading and merging adapters...")
57
+ print(" This may take several minutes for a 24B model...")
58
+
59
+ # Method 1: If the model was saved with adapters
60
+ try:
61
+ model = AutoPeftModelForCausalLM.from_pretrained(
62
+ source_model,
63
+ torch_dtype=torch.bfloat16,
64
+ device_map="auto",
65
+ trust_remote_code=True,
66
+ )
67
+ print(" Merging LoRA weights into base model...")
68
+ model = model.merge_and_unload()
69
+ print(" โœ“ Adapters merged successfully")
70
+ except Exception as e:
71
+ print(f" AutoPeftModel failed: {e}")
72
+ print(" Trying alternative loading method...")
73
+
74
+ # Method 2: Load base + adapters separately
75
+ base_model = AutoModelForCausalLM.from_pretrained(
76
+ BASE_MODEL,
77
+ torch_dtype=torch.bfloat16,
78
+ device_map="auto",
79
+ trust_remote_code=True,
80
+ )
81
+
82
+ model = PeftModel.from_pretrained(
83
+ base_model,
84
+ source_model,
85
+ torch_dtype=torch.bfloat16,
86
+ )
87
+ print(" Merging LoRA weights into base model...")
88
+ model = model.merge_and_unload()
89
+ print(" โœ“ Adapters merged successfully")
90
+
91
+ # Create output repository
92
+ print(f"\n๐Ÿ“ Creating repository: {FINAL_REPO}")
93
+ api = HfApi()
94
+ try:
95
+ create_repo(FINAL_REPO, repo_type="model", exist_ok=True)
96
+ except Exception as e:
97
+ print(f" Repository exists or error: {e}")
98
+
99
+ # Push to Hub
100
+ print(f"\n๐Ÿ’พ Pushing merged model to Hub...")
101
+ print(" This will take a while for a 24B model...")
102
+
103
+ model.push_to_hub(
104
+ FINAL_REPO,
105
+ commit_message="Alizee-Coder-Devstral-2-Small: Reasoning-enhanced coding model",
106
+ safe_serialization=True,
107
+ )
108
+
109
+ tokenizer.push_to_hub(
110
+ FINAL_REPO,
111
+ commit_message="Add tokenizer",
112
+ )
113
+
114
+ # Create model card
115
+ print("\n๐Ÿ“„ Creating model card...")
116
+ model_card = f"""---
117
+ license: apache-2.0
118
+ base_model: {BASE_MODEL}
119
+ tags:
120
+ - code
121
+ - reasoning
122
+ - devstral
123
+ - fine-tuned
124
+ - qlora
125
+ datasets:
126
+ - nvidia/OpenCodeReasoning
127
+ - bigcode/starcoderdata
128
+ - {"RLHFlow/CodeUltraFeedback-standard" if USE_STAGE2 else ""}
129
+ pipeline_tag: text-generation
130
+ ---
131
+
132
+ # Alizee-Coder-Devstral-2-Small
133
+
134
+ A reasoning-enhanced coding model fine-tuned from [stmasson/alizee-coder-devstral-1-small](https://huggingface.co/stmasson/alizee-coder-devstral-1-small).
135
+
136
+ ## Training Pipeline
137
+
138
+ This model was trained using a {"three" if USE_STAGE2 else "two"}-stage approach:
139
+
140
+ ### Stage 1: Reasoning Distillation via SFT
141
+ - **Dataset**: nvidia/OpenCodeReasoning (736K samples) + bigcode/starcoderdata (15% mix)
142
+ - **Method**: QLoRA (r=64, alpha=128)
143
+ - **Config**: lr=5e-5, batch_size=256, epochs=2, warmup=5%, cosine scheduler
144
+ - **Context**: 32K tokens
145
+
146
+ {"### Stage 2: Light DPO Refresh" if USE_STAGE2 else ""}
147
+ {f"- **Dataset**: RLHFlow/CodeUltraFeedback-standard" if USE_STAGE2 else ""}
148
+ {f"- **Method**: Conservative DPO (beta=0.1, lr=5e-6)" if USE_STAGE2 else ""}
149
+ {f"- **Purpose**: Restore alignment after reasoning SFT" if USE_STAGE2 else ""}
150
+
151
+ ### Stage 3: Adapter Merging
152
+ - Merged LoRA adapters into base model
153
+ - Full precision model for inference
154
+
155
+ ## Model Details
156
+
157
+ | Parameter | Value |
158
+ |-----------|-------|
159
+ | **Base Model** | mistralai/Devstral-Small-2505 |
160
+ | **Parameters** | ~24B |
161
+ | **Architecture** | Mistral Small |
162
+ | **Context Length** | 32K tokens |
163
+ | **Training Data** | ~860K samples (736K reasoning + 124K coding) |
164
+
165
+ ## Usage
166
+
167
+ ```python
168
+ from transformers import AutoModelForCausalLM, AutoTokenizer
169
+
170
+ model = AutoModelForCausalLM.from_pretrained(
171
+ "stmasson/alizee-coder-devstral-2-small",
172
+ torch_dtype="auto",
173
+ device_map="auto",
174
+ )
175
+ tokenizer = AutoTokenizer.from_pretrained("stmasson/alizee-coder-devstral-2-small")
176
+
177
+ messages = [
178
+ {{"role": "user", "content": "Solve this problem step by step: Find the longest palindromic substring."}}
179
+ ]
180
+
181
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
182
+ outputs = model.generate(inputs, max_new_tokens=2048)
183
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
184
+ ```
185
+
186
+ ## Training Methodology
187
+
188
+ Based on NVIDIA's OpenCodeReasoning research findings:
189
+ - Performance improves linearly from 25K to 736K samples
190
+ - Execution filtering is crucial - only solutions that pass tests
191
+ - Data mixing (85% reasoning, 15% coding) preserves capabilities
192
+ - QLoRA enables efficient training of 24B models
193
+
194
+ ## Citation
195
+
196
+ ```bibtex
197
+ @misc{{alizee-coder-v2,
198
+ author = {{stmasson}},
199
+ title = {{Alizee-Coder-Devstral-2-Small}},
200
+ year = {{2025}},
201
+ publisher = {{Hugging Face}},
202
+ howpublished = {{\\url{{https://huggingface.co/stmasson/alizee-coder-devstral-2-small}}}}
203
+ }}
204
+ ```
205
+ """
206
+
207
+ api.upload_file(
208
+ path_or_fileobj=model_card.encode(),
209
+ path_in_repo="README.md",
210
+ repo_id=FINAL_REPO,
211
+ repo_type="model",
212
+ )
213
+
214
+ print("\n" + "=" * 60)
215
+ print("โœ… Stage 3 Complete!")
216
+ print(f" Final model: https://huggingface.co/{FINAL_REPO}")
217
+ print("=" * 60)
218
+ print("\n๐ŸŽฏ Recommended next steps:")
219
+ print(" 1. Evaluate on LiveCodeBench, HumanEval, SWE-Bench")
220
+ print(" 2. Compare with v1 baseline")
221
+ print(" 3. Test reasoning quality on sample problems")
222
+ print("=" * 60)