# /// script # dependencies = [ # "transformers", # "trl", # "datasets", # "accelerate", # "torch", # "trackio", # "huggingface_hub", # ] # /// import os import random from datasets import load_dataset, concatenate_datasets from transformers import AutoTokenizer from trl import SFTTrainer, SFTConfig import trackio # Configuration MODEL_ID = "Qwen/Qwen2.5-Coder-1.5B-Instruct" HUB_MODEL_ID = "moos124/code-reasoning-1.5b" OUTPUT_DIR = "./code-reasoning-1.5b" # Initialize Trackio trackio.init(project="code-reasoning-ft", name="qwen2.5-coder-1.5b-code-reasoning") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) print("Loading and preparing datasets...") all_datasets = [] # ============= DATASET 1: CodeAlpaca (Python code instructions) ============= try: codealpaca = load_dataset("sahil2801/CodeAlpaca-20k", split="train") def format_codealpaca(ex): instruction = ex["instruction"] inp = ex.get("input", "") output = ex["output"] if inp and str(inp).strip(): user_content = f"{instruction}\n\nInput: {inp}" else: user_content = instruction return {"messages": [ {"role": "user", "content": user_content}, {"role": "assistant", "content": output} ]} codealpaca = codealpaca.map(format_codealpaca, remove_columns=codealpaca.column_names) if len(codealpaca) > 15000: codealpaca = codealpaca.select(range(15000)) all_datasets.append(codealpaca) print(f"CodeAlpaca: {len(codealpaca)} examples") except Exception as e: print(f"CodeAlpaca: skipped ({e})") # ============= DATASET 2: Python Code Instructions (18k Alpaca style) ============= try: pycode = load_dataset("iamtarun/python_code_instructions_18k_alpaca", split="train") def format_pycode(ex): instruction = ex["instruction"] inp = ex.get("input", "") output = ex["output"] if inp and str(inp).strip(): user_content = f"{instruction}\n\nInput: {inp}" else: user_content = instruction return {"messages": [ {"role": "user", "content": user_content}, {"role": "assistant", "content": output} ]} pycode = pycode.map(format_pycode, remove_columns=pycode.column_names) if len(pycode) > 15000: pycode = pycode.select(range(15000)) all_datasets.append(pycode) print(f"Python Code 18k: {len(pycode)} examples") except Exception as e: print(f"Python Code 18k: skipped ({e})") # ============= DATASET 3: Code instructions 120k Alpaca ============= try: code120k = load_dataset("iamtarun/code_instructions_120k_alpaca", split="train") def format_code120k(ex): instruction = ex["instruction"] inp = ex.get("input", "") output = ex["output"] if inp and str(inp).strip(): user_content = f"{instruction}\n\nInput: {inp}" else: user_content = instruction return {"messages": [ {"role": "user", "content": user_content}, {"role": "assistant", "content": output} ]} code120k = code120k.map(format_code120k, remove_columns=code120k.column_names) if len(code120k) > 20000: indices = random.sample(range(len(code120k)), 20000) code120k = code120k.select(indices) all_datasets.append(code120k) print(f"Code 120k (sampled): {len(code120k)} examples") except Exception as e: print(f"Code 120k: skipped ({e})") # ============= DATASET 4: Code Contests (competitive programming / reasoning) ============= try: contests = load_dataset("deepmind/code_contests", split="train") def format_contest(ex): desc = ex["description"] sols = ex.get("solutions", {}).get("solution", []) if sols: sol = sols[0] else: sol = "" return {"messages": [ {"role": "user", "content": f"Solve this competitive programming problem:\n\n{desc}"}, {"role": "assistant", "content": sol} ]} contests = contests.map(format_contest, remove_columns=contests.column_names) if len(contests) > 5000: contests = contests.select(range(5000)) all_datasets.append(contests) print(f"Code Contests: {len(contests)} examples") except Exception as e: print(f"Code Contests: skipped ({e})") # ============= DATASET 5: Orca Math (math reasoning with CoT) ============= try: orca_math = load_dataset("microsoft/orca-math-word-problems-200k", split="train") def format_orca(ex): return {"messages": [ {"role": "user", "content": ex["question"]}, {"role": "assistant", "content": ex["answer"]} ]} orca_math = orca_math.map(format_orca, remove_columns=orca_math.column_names) if len(orca_math) > 10000: orca_math = orca_math.select(range(10000)) all_datasets.append(orca_math) print(f"Orca Math: {len(orca_math)} examples") except Exception as e: print(f"Orca Math: skipped ({e})") # ============= DATASET 6: Capybara (general reasoning / multi-turn) ============= try: capybara = load_dataset("trl-lib/Capybara", split="train") def format_capybara(ex): return {"messages": ex["messages"]} capybara = capybara.map(format_capybara, remove_columns=capybara.column_names) if len(capybara) > 10000: capybara = capybara.select(range(10000)) all_datasets.append(capybara) print(f"Capybara: {len(capybara)} examples") except Exception as e: print(f"Capybara: skipped ({e})") # Combine all datasets train_dataset = concatenate_datasets(all_datasets).shuffle(seed=42) print(f"\nTotal training examples: {len(train_dataset)}") # Training configuration training_args = SFTConfig( output_dir=OUTPUT_DIR, hub_model_id=HUB_MODEL_ID, push_to_hub=True, num_train_epochs=2, per_device_train_batch_size=2, gradient_accumulation_steps=8, learning_rate=5e-5, warmup_steps=300, lr_scheduler_type="cosine", bf16=True, gradient_checkpointing=True, logging_strategy="steps", logging_steps=10, logging_first_step=True, save_strategy="steps", save_steps=10, packing=False, dataset_num_proc=4, disable_tqdm=True, report_to=["trackio"], seed=42, hub_strategy="checkpoint", ) print("\nInitializing SFTTrainer...") trainer = SFTTrainer( model=MODEL_ID, train_dataset=train_dataset, args=training_args, processing_class=tokenizer, ) print("Starting training...") trainer.train() print("Saving final model...") trainer.save_model(OUTPUT_DIR) trainer.push_to_hub(commit_message="Final model after code+reasoning fine-tuning") print("Training complete! Model pushed to", HUB_MODEL_ID)