moos124 commited on
Commit
27e9bab
·
verified ·
1 Parent(s): 0bcf8cc

Upload train_code_reasoning.py

Browse files
Files changed (1) hide show
  1. train_code_reasoning.py +203 -0
train_code_reasoning.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "transformers",
4
+ # "trl",
5
+ # "datasets",
6
+ # "accelerate",
7
+ # "torch",
8
+ # "trackio",
9
+ # "huggingface_hub",
10
+ # ]
11
+ # ///
12
+
13
+ import os
14
+ import random
15
+ from datasets import load_dataset, concatenate_datasets
16
+ from transformers import AutoTokenizer
17
+ from trl import SFTTrainer, SFTConfig
18
+ import trackio
19
+
20
+ # Configuration
21
+ MODEL_ID = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
22
+ HUB_MODEL_ID = "moos124/code-reasoning-1.5b"
23
+ OUTPUT_DIR = "./code-reasoning-1.5b"
24
+
25
+ # Initialize Trackio
26
+ trackio.init(project="code-reasoning-ft", name="qwen2.5-coder-1.5b-code-reasoning")
27
+
28
+ # Load tokenizer
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
30
+
31
+ print("Loading and preparing datasets...")
32
+
33
+ all_datasets = []
34
+
35
+ # ============= DATASET 1: CodeAlpaca (Python code instructions) =============
36
+ try:
37
+ codealpaca = load_dataset("sahil2801/CodeAlpaca-20k", split="train")
38
+ def format_codealpaca(ex):
39
+ instruction = ex["instruction"]
40
+ inp = ex.get("input", "")
41
+ output = ex["output"]
42
+ if inp and str(inp).strip():
43
+ user_content = f"{instruction}\n\nInput: {inp}"
44
+ else:
45
+ user_content = instruction
46
+ return {"messages": [
47
+ {"role": "user", "content": user_content},
48
+ {"role": "assistant", "content": output}
49
+ ]}
50
+ codealpaca = codealpaca.map(format_codealpaca, remove_columns=codealpaca.column_names)
51
+ if len(codealpaca) > 15000:
52
+ codealpaca = codealpaca.select(range(15000))
53
+ all_datasets.append(codealpaca)
54
+ print(f"CodeAlpaca: {len(codealpaca)} examples")
55
+ except Exception as e:
56
+ print(f"CodeAlpaca: skipped ({e})")
57
+
58
+ # ============= DATASET 2: Python Code Instructions (18k Alpaca style) =============
59
+ try:
60
+ pycode = load_dataset("iamtarun/python_code_instructions_18k_alpaca", split="train")
61
+ def format_pycode(ex):
62
+ instruction = ex["instruction"]
63
+ inp = ex.get("input", "")
64
+ output = ex["output"]
65
+ if inp and str(inp).strip():
66
+ user_content = f"{instruction}\n\nInput: {inp}"
67
+ else:
68
+ user_content = instruction
69
+ return {"messages": [
70
+ {"role": "user", "content": user_content},
71
+ {"role": "assistant", "content": output}
72
+ ]}
73
+ pycode = pycode.map(format_pycode, remove_columns=pycode.column_names)
74
+ if len(pycode) > 15000:
75
+ pycode = pycode.select(range(15000))
76
+ all_datasets.append(pycode)
77
+ print(f"Python Code 18k: {len(pycode)} examples")
78
+ except Exception as e:
79
+ print(f"Python Code 18k: skipped ({e})")
80
+
81
+ # ============= DATASET 3: Code instructions 120k Alpaca =============
82
+ try:
83
+ code120k = load_dataset("iamtarun/code_instructions_120k_alpaca", split="train")
84
+ def format_code120k(ex):
85
+ instruction = ex["instruction"]
86
+ inp = ex.get("input", "")
87
+ output = ex["output"]
88
+ if inp and str(inp).strip():
89
+ user_content = f"{instruction}\n\nInput: {inp}"
90
+ else:
91
+ user_content = instruction
92
+ return {"messages": [
93
+ {"role": "user", "content": user_content},
94
+ {"role": "assistant", "content": output}
95
+ ]}
96
+ code120k = code120k.map(format_code120k, remove_columns=code120k.column_names)
97
+ if len(code120k) > 20000:
98
+ indices = random.sample(range(len(code120k)), 20000)
99
+ code120k = code120k.select(indices)
100
+ all_datasets.append(code120k)
101
+ print(f"Code 120k (sampled): {len(code120k)} examples")
102
+ except Exception as e:
103
+ print(f"Code 120k: skipped ({e})")
104
+
105
+ # ============= DATASET 4: Code Contests (competitive programming / reasoning) =============
106
+ try:
107
+ contests = load_dataset("deepmind/code_contests", split="train")
108
+ def format_contest(ex):
109
+ desc = ex["description"]
110
+ sols = ex.get("solutions", {}).get("solution", [])
111
+ if sols:
112
+ sol = sols[0]
113
+ else:
114
+ sol = ""
115
+ return {"messages": [
116
+ {"role": "user", "content": f"Solve this competitive programming problem:\n\n{desc}"},
117
+ {"role": "assistant", "content": sol}
118
+ ]}
119
+ contests = contests.map(format_contest, remove_columns=contests.column_names)
120
+ if len(contests) > 5000:
121
+ contests = contests.select(range(5000))
122
+ all_datasets.append(contests)
123
+ print(f"Code Contests: {len(contests)} examples")
124
+ except Exception as e:
125
+ print(f"Code Contests: skipped ({e})")
126
+
127
+ # ============= DATASET 5: Orca Math (math reasoning with CoT) =============
128
+ try:
129
+ orca_math = load_dataset("microsoft/orca-math-word-problems-200k", split="train")
130
+ def format_orca(ex):
131
+ return {"messages": [
132
+ {"role": "user", "content": ex["question"]},
133
+ {"role": "assistant", "content": ex["answer"]}
134
+ ]}
135
+ orca_math = orca_math.map(format_orca, remove_columns=orca_math.column_names)
136
+ if len(orca_math) > 10000:
137
+ orca_math = orca_math.select(range(10000))
138
+ all_datasets.append(orca_math)
139
+ print(f"Orca Math: {len(orca_math)} examples")
140
+ except Exception as e:
141
+ print(f"Orca Math: skipped ({e})")
142
+
143
+ # ============= DATASET 6: Capybara (general reasoning / multi-turn) =============
144
+ try:
145
+ capybara = load_dataset("trl-lib/Capybara", split="train")
146
+ def format_capybara(ex):
147
+ return {"messages": ex["messages"]}
148
+ capybara = capybara.map(format_capybara, remove_columns=capybara.column_names)
149
+ if len(capybara) > 10000:
150
+ capybara = capybara.select(range(10000))
151
+ all_datasets.append(capybara)
152
+ print(f"Capybara: {len(capybara)} examples")
153
+ except Exception as e:
154
+ print(f"Capybara: skipped ({e})")
155
+
156
+ # Combine all datasets
157
+ train_dataset = concatenate_datasets(all_datasets).shuffle(seed=42)
158
+ print(f"\nTotal training examples: {len(train_dataset)}")
159
+
160
+ # Training configuration
161
+ training_args = SFTConfig(
162
+ output_dir=OUTPUT_DIR,
163
+ hub_model_id=HUB_MODEL_ID,
164
+ push_to_hub=True,
165
+ num_train_epochs=2,
166
+ per_device_train_batch_size=4,
167
+ gradient_accumulation_steps=4,
168
+ learning_rate=5e-5,
169
+ max_seq_length=2048,
170
+ warmup_ratio=0.03,
171
+ lr_scheduler_type="cosine",
172
+ bf16=True,
173
+ gradient_checkpointing=True,
174
+ logging_strategy="steps",
175
+ logging_steps=10,
176
+ logging_first_step=True,
177
+ save_strategy="steps",
178
+ save_steps=10,
179
+ packing=True,
180
+ dataset_num_proc=4,
181
+ disable_tqdm=True,
182
+ report_to=["trackio"],
183
+ seed=42,
184
+ hub_strategy="checkpoint",
185
+ hub_always_push=True,
186
+ )
187
+
188
+ print("\nInitializing SFTTrainer...")
189
+ trainer = SFTTrainer(
190
+ model=MODEL_ID,
191
+ train_dataset=train_dataset,
192
+ args=training_args,
193
+ tokenizer=tokenizer,
194
+ )
195
+
196
+ print("Starting training...")
197
+ trainer.train()
198
+
199
+ print("Saving final model...")
200
+ trainer.save_model(OUTPUT_DIR)
201
+ trainer.push_to_hub(commit_message="Final model after code+reasoning fine-tuning")
202
+
203
+ print("Training complete! Model pushed to", HUB_MODEL_ID)