gemma-mini-code-agent / train_colab.py
Abhay557's picture
Fix TRL v1.x: remove max_seq_length from SFTConfig, use processing_class
099c050 verified
Raw
History Blame Contribute Delete
8.59 kB
"""
========================================
COLAB CODING AGENT TRAINING SCRIPT
Fine-tune Gemma-3-1B-IT as a Mini Claude Code
Optimized for Google Colab T4 GPU (16GB VRAM)
========================================
INSTRUCTIONS:
1. Open https://colab.research.google.com
2. Change runtime to GPU (Runtime > Change runtime type > T4 GPU)
3. Run the install cell below
4. Authenticate with HuggingFace (Gemma requires license acceptance)
5. Run the training script
"""
# ============================================================
# CELL 1: Install Dependencies
# ============================================================
# !pip install -q transformers trl peft datasets accelerate bitsandbytes huggingface_hub
# ============================================================
# CELL 2: Hugging Face Login
# ============================================================
# from huggingface_hub import notebook_login
# notebook_login()
# # IMPORTANT: Visit https://huggingface.co/google/gemma-3-1b-it
# # and ACCEPT the license before running training!
# ============================================================
# CELL 3: Training Script (copy everything below)
# ============================================================
import torch
import gc
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, SFTConfig
print("=" * 70)
print(" MINI CODING AGENT - Fine-tune Gemma-3-1B-IT")
print(" Target: ~1B params | Dataset: Coding instruction pairs")
print("=" * 70)
# ========================= CONFIG =========================
MODEL_ID = "google/gemma-3-1b-it"
# --- PICK ONE DATASET ---
# Option A: Magicoder 75K (smaller, faster, proven recipe)
DATASET_NAME = "ise-uiuc/Magicoder-OSS-Instruct-75K"
# Option B: OpenCodeInstruct 5M (higher quality, use subset)
# DATASET_NAME = "nvidia/OpenCodeInstruct"
OUTPUT_DIR = "./gemma-code-agent"
HUB_MODEL_ID = "YOUR_USERNAME/gemma-3-1b-code-agent" # <-- CHANGE THIS!
# Training hyperparameters
MAX_SEQ_LENGTH = 1024 # Safe for Colab T4; use 2048 on A100
NUM_EPOCHS = 2
BATCH_SIZE = 1
GRAD_ACCUM = 16 # Effective batch = 16
LEARNING_RATE = 5e-5 # Higher for LoRA adapters
WARMUP_STEPS = 50
# LoRA config (only trains ~1.5% of parameters)
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LORA_MODULES = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
]
# Dataset subset size (reduce for faster Colab runs)
MAX_SAMPLES = 50000
# ===========================================================
# Step 1: Load Tokenizer & SET MAX LENGTH HERE
print("\n[1/7] Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# FIX for TRL v1.x: set max length on tokenizer, not in SFTConfig
tokenizer.model_max_length = MAX_SEQ_LENGTH
print(f" Vocab size: {len(tokenizer)}, max_length: {tokenizer.model_max_length}")
# Step 2: Load Model (4-bit quantization for Colab T4)
print("\n[2/7] Loading model with 4-bit quantization (NF4)...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
print(f" Model loaded on: {next(model.parameters()).device}")
# Step 3: Attach LoRA Adapters
print("\n[3/7] Attaching LoRA adapters...")
lora_cfg = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=LORA_MODULES,
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
# Step 4: Load and Format Dataset
print(f"\n[4/7] Loading dataset: {DATASET_NAME}...")
ds = load_dataset(DATASET_NAME, split="train")
total_available = len(ds)
use_samples = min(MAX_SAMPLES, total_available)
ds = ds.shuffle(seed=42).select(range(use_samples))
print(f" Using {use_samples:,} / {total_available:,} samples")
def format_magicoder(example):
return {
"messages": [
{"role": "user", "content": example["problem"]},
{"role": "assistant", "content": example["solution"]},
]
}
def format_opencode(example):
return {
"messages": [
{"role": "user", "content": example["input"]},
{"role": "assistant", "content": example["output"]},
]
}
if "Magicoder" in DATASET_NAME:
ds = ds.map(format_magicoder, remove_columns=ds.column_names)
else:
ds = ds.map(format_opencode, remove_columns=ds.column_names)
print(f" Dataset ready: {len(ds):,} examples")
print(" Sample:")
print(f" User: {ds[0]['messages'][0]['content'][:80]}...")
print(f" Assistant: {ds[0]['messages'][1]['content'][:80]}...")
# Step 5: Setup Training
print("\n[5/7] Configuring trainer...")
# FIX for TRL v1.x: remove max_seq_length from SFTConfig
args = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
warmup_steps=WARMUP_STEPS,
lr_scheduler_type="cosine",
logging_steps=10,
save_steps=500,
save_total_limit=2,
bf16=torch.cuda.is_bf16_supported(),
fp16=not torch.cuda.is_bf16_supported(),
gradient_checkpointing=True,
push_to_hub=False,
hub_model_id=HUB_MODEL_ID,
report_to="none",
dataloader_num_workers=2,
remove_unused_columns=False,
# max_seq_length=MAX_SEQ_LENGTH, # <-- REMOVED: not supported in TRL v1.x
)
# FIX for TRL v1.x: use processing_class instead of tokenizer
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=ds,
processing_class=tokenizer, # <-- FIX: was 'tokenizer=' in old TRL
)
# Step 6: Train!
print("\n[6/7] Starting training...")
print(f" Epochs: {NUM_EPOCHS} | Batch: {BATCH_SIZE} x {GRAD_ACCUM}")
print(f" LR: {LEARNING_RATE} | Warmup: {WARMUP_STEPS} | Max length: {MAX_SEQ_LENGTH}")
print("-" * 70)
trainer.train()
# Step 7: Save and Merge
print("\n[7/7] Saving model...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(" Merging LoRA adapters into base model...")
merged_model = model.merge_and_unload()
merged_model.save_pretrained(f"{OUTPUT_DIR}-merged")
tokenizer.save_pretrained(f"{OUTPUT_DIR}-merged")
print("\n" + "=" * 70)
print(" TRAINING COMPLETE!")
print(f" LoRA model: ./{OUTPUT_DIR}/")
print(f" Merged model: ./{OUTPUT_DIR}-merged/")
print("=" * 70)
gc.collect()
torch.cuda.empty_cache()
# ============================================================
# CELL 4: Inference / Test the Coding Agent
# ============================================================
def chat_with_agent(prompt: str, max_new_tokens: int = 512) -> str:
"""Send a coding task to your fine-tuned agent and get a response."""
messages = [{"role": "user", "content": prompt}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
return_tensors="pt",
add_generation_prompt=True,
return_dict=True,
).to(model.device)
with torch.no_grad():
outputs = merged_model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id,
)
response = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[-1]:],
skip_special_tokens=True
)
return response
test_prompts = [
"Write a Python function to find the longest common subsequence of two strings.",
"Create a function that checks if a linked list has a cycle using Floyd's algorithm.",
"Write a Python script that fetches weather data from a public API and prints the temperature.",
"Implement a simple LRU cache in Python using a dictionary and a doubly linked list.",
]
print("\n" + "=" * 70)
print(" TESTING CODING AGENT")
print("=" * 70)
for i, prompt in enumerate(test_prompts, 1):
print(f"\n--- Test {i} ---")
print(f"User: {prompt}")
print(f"\nAgent: {chat_with_agent(prompt)[:500]}...")
print("-" * 70)