moos124 commited on
Commit
e96fb55
·
verified ·
1 Parent(s): ba076fc

Upload train_code_reasoning.py

Browse files
Files changed (1) hide show
  1. train_code_reasoning.py +221 -0
train_code_reasoning.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "transformers",
4
+ # "trl",
5
+ # "datasets",
6
+ # "accelerate",
7
+ # "torch",
8
+ # "trackio",
9
+ # "huggingface_hub",
10
+ # "peft",
11
+ # ]
12
+ # ///
13
+
14
+ import os
15
+ import random
16
+ from datasets import load_dataset, concatenate_datasets
17
+ from transformers import AutoTokenizer
18
+ from trl import SFTTrainer, SFTConfig
19
+ from peft import LoraConfig, TaskType
20
+ import trackio
21
+
22
+ # Configuration - smaller model to fit in A10G 24GB VRAM comfortably
23
+ MODEL_ID = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
24
+ HUB_MODEL_ID = "moos124/code-reasoning-0.5b"
25
+ OUTPUT_DIR = "./code-reasoning-0.5b"
26
+
27
+ # Initialize Trackio
28
+ trackio.init(project="code-reasoning-ft", name="qwen2.5-coder-0.5b-code-reasoning")
29
+
30
+ # Load tokenizer
31
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
32
+
33
+ print("Loading and preparing datasets...")
34
+
35
+ all_datasets = []
36
+
37
+ # ============= DATASET 1: CodeAlpaca (Python code instructions) =============
38
+ try:
39
+ codealpaca = load_dataset("sahil2801/CodeAlpaca-20k", split="train")
40
+ def format_codealpaca(ex):
41
+ instruction = ex["instruction"]
42
+ inp = ex.get("input", "")
43
+ output = ex["output"]
44
+ if inp and str(inp).strip():
45
+ user_content = f"{instruction}\n\nInput: {inp}"
46
+ else:
47
+ user_content = instruction
48
+ return {"messages": [
49
+ {"role": "user", "content": user_content},
50
+ {"role": "assistant", "content": output}
51
+ ]}
52
+ codealpaca = codealpaca.map(format_codealpaca, remove_columns=codealpaca.column_names)
53
+ if len(codealpaca) > 15000:
54
+ codealpaca = codealpaca.select(range(15000))
55
+ all_datasets.append(codealpaca)
56
+ print(f"CodeAlpaca: {len(codealpaca)} examples")
57
+ except Exception as e:
58
+ print(f"CodeAlpaca: skipped ({e})")
59
+
60
+ # ============= DATASET 2: Python Code Instructions (18k Alpaca style) =============
61
+ try:
62
+ pycode = load_dataset("iamtarun/python_code_instructions_18k_alpaca", split="train")
63
+ def format_pycode(ex):
64
+ instruction = ex["instruction"]
65
+ inp = ex.get("input", "")
66
+ output = ex["output"]
67
+ if inp and str(inp).strip():
68
+ user_content = f"{instruction}\n\nInput: {inp}"
69
+ else:
70
+ user_content = instruction
71
+ return {"messages": [
72
+ {"role": "user", "content": user_content},
73
+ {"role": "assistant", "content": output}
74
+ ]}
75
+ pycode = pycode.map(format_pycode, remove_columns=pycode.column_names)
76
+ if len(pycode) > 15000:
77
+ pycode = pycode.select(range(15000))
78
+ all_datasets.append(pycode)
79
+ print(f"Python Code 18k: {len(pycode)} examples")
80
+ except Exception as e:
81
+ print(f"Python Code 18k: skipped ({e})")
82
+
83
+ # ============= DATASET 3: Code instructions 120k Alpaca =============
84
+ try:
85
+ code120k = load_dataset("iamtarun/code_instructions_120k_alpaca", split="train")
86
+ def format_code120k(ex):
87
+ instruction = ex["instruction"]
88
+ inp = ex.get("input", "")
89
+ output = ex["output"]
90
+ if inp and str(inp).strip():
91
+ user_content = f"{instruction}\n\nInput: {inp}"
92
+ else:
93
+ user_content = instruction
94
+ return {"messages": [
95
+ {"role": "user", "content": user_content},
96
+ {"role": "assistant", "content": output}
97
+ ]}
98
+ code120k = code120k.map(format_code120k, remove_columns=code120k.column_names)
99
+ if len(code120k) > 20000:
100
+ indices = random.sample(range(len(code120k)), 20000)
101
+ code120k = code120k.select(indices)
102
+ all_datasets.append(code120k)
103
+ print(f"Code 120k (sampled): {len(code120k)} examples")
104
+ except Exception as e:
105
+ print(f"Code 120k: skipped ({e})")
106
+
107
+ # ============= DATASET 4: Code Contests (competitive programming / reasoning) =============
108
+ try:
109
+ contests = load_dataset("deepmind/code_contests", split="train")
110
+ def format_contest(ex):
111
+ desc = ex["description"]
112
+ sols = ex.get("solutions", {}).get("solution", [])
113
+ if sols:
114
+ sol = sols[0]
115
+ else:
116
+ sol = ""
117
+ return {"messages": [
118
+ {"role": "user", "content": f"Solve this competitive programming problem:\n\n{desc}"},
119
+ {"role": "assistant", "content": sol}
120
+ ]}
121
+ contests = contests.map(format_contest, remove_columns=contests.column_names)
122
+ if len(contests) > 5000:
123
+ contests = contests.select(range(5000))
124
+ all_datasets.append(contests)
125
+ print(f"Code Contests: {len(contests)} examples")
126
+ except Exception as e:
127
+ print(f"Code Contests: skipped ({e})")
128
+
129
+ # ============= DATASET 5: Orca Math (math reasoning with CoT) =============
130
+ try:
131
+ orca_math = load_dataset("microsoft/orca-math-word-problems-200k", split="train")
132
+ def format_orca(ex):
133
+ return {"messages": [
134
+ {"role": "user", "content": ex["question"]},
135
+ {"role": "assistant", "content": ex["answer"]}
136
+ ]}
137
+ orca_math = orca_math.map(format_orca, remove_columns=orca_math.column_names)
138
+ if len(orca_math) > 10000:
139
+ orca_math = orca_math.select(range(10000))
140
+ all_datasets.append(orca_math)
141
+ print(f"Orca Math: {len(orca_math)} examples")
142
+ except Exception as e:
143
+ print(f"Orca Math: skipped ({e})")
144
+
145
+ # ============= DATASET 6: Capybara (general reasoning / multi-turn) =============
146
+ try:
147
+ capybara = load_dataset("trl-lib/Capybara", split="train")
148
+ def format_capybara(ex):
149
+ return {"messages": ex["messages"]}
150
+ capybara = capybara.map(format_capybara, remove_columns=capybara.column_names)
151
+ if len(capybara) > 10000:
152
+ capybara = capybara.select(range(10000))
153
+ all_datasets.append(capybara)
154
+ print(f"Capybara: {len(capybara)} examples")
155
+ except Exception as e:
156
+ print(f"Capybara: skipped ({e})")
157
+
158
+ # Combine all datasets
159
+ train_dataset = concatenate_datasets(all_datasets).shuffle(seed=42)
160
+ print(f"\nTotal training examples: {len(train_dataset)}")
161
+
162
+ # LoRA config for memory efficiency
163
+ peft_config = LoraConfig(
164
+ r=32,
165
+ lora_alpha=16,
166
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
167
+ lora_dropout=0.05,
168
+ bias="none",
169
+ task_type=TaskType.CAUSAL_LM,
170
+ )
171
+
172
+ # Training configuration
173
+ # A10G-LARGE: 12 vCPU / 46GB RAM / 24GB GPU
174
+ # 0.5B model with LoRA + bf16 + grad checkpointing fits easily in 24GB
175
+ training_args = SFTConfig(
176
+ output_dir=OUTPUT_DIR,
177
+ hub_model_id=HUB_MODEL_ID,
178
+ push_to_hub=True,
179
+ num_train_epochs=2,
180
+ per_device_train_batch_size=4,
181
+ gradient_accumulation_steps=4,
182
+ learning_rate=1e-4,
183
+ warmup_steps=300,
184
+ lr_scheduler_type="cosine",
185
+ bf16=True,
186
+ gradient_checkpointing=True,
187
+ logging_strategy="steps",
188
+ logging_steps=10,
189
+ logging_first_step=True,
190
+ save_strategy="steps",
191
+ save_steps=10,
192
+ packing=False,
193
+ dataset_num_proc=4,
194
+ disable_tqdm=True,
195
+ report_to=["trackio"],
196
+ seed=42,
197
+ hub_strategy="checkpoint",
198
+ )
199
+
200
+ print("\nInitializing SFTTrainer...")
201
+ trainer = SFTTrainer(
202
+ model=MODEL_ID,
203
+ train_dataset=train_dataset,
204
+ args=training_args,
205
+ processing_class=tokenizer,
206
+ peft_config=peft_config,
207
+ )
208
+
209
+ print("Starting training...")
210
+ trainer.train()
211
+
212
+ print("Saving final model...")
213
+ trainer.save_model(OUTPUT_DIR)
214
+
215
+ # Merge LoRA weights and push full model
216
+ from peft import AutoPeftModelForCausalLM
217
+ model = AutoPeftModelForCausalLM.from_pretrained(OUTPUT_DIR)
218
+ merged = model.merge_and_unload()
219
+ merged.push_to_hub(HUB_MODEL_ID, commit_message="Merged LoRA after code+reasoning fine-tuning")
220
+
221
+ print("Training complete! Model pushed to", HUB_MODEL_ID)