codearena-rl / finetune_models.py
havinashpatil
Finalizing CodeArena RL Benchmark: frontend improvements, GRPO training scripts, and cleaned environment
03a7eb9
#!/usr/bin/env python3
"""
Fine-tune models on the XCoder-80K dataset using TRL.
Models:
- meta-llama/Llama-2-7b-hf (maps to llama3.2:latest in Ollama)
- google/gemma-7b (maps to gemma3:4b - adjusted)
- google/gemma-2b (maps to gemma3:1b - adjusted)
- LLaVA (multimodal - skipped for text-only fine-tuning)
Dataset: banksy235/XCoder-80K
Fine-tuning approaches:
1. SFT (Supervised Fine-Tuning) - simple and effective
2. DPO (Direct Preference Optimization) - if preference data available
3. GRPO (Group Relative Policy Optimization) - for RL environments
"""
import os
import json
import argparse
import logging
from pathlib import Path
from typing import Optional
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
)
from peft import get_peft_model, LoraConfig, TaskType
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Model registry - maps available models to HF model IDs
MODEL_REGISTRY = {
"llama3.2": "meta-llama/Llama-2-7b-hf",
"gemma3:4b": "google/gemma-7b",
"gemma3:1b": "google/gemma-2b",
}
XCODER_DATASET = "banksy235/XCoder-80K"
def load_xcoder_dataset(split: str = "train", max_samples: Optional[int] = None):
"""Load XCoder-80K dataset from Hugging Face."""
logger.info(f"Loading {XCODER_DATASET} ({split} split)...")
try:
ds = load_dataset(XCODER_DATASET, split=split)
if max_samples:
ds = ds.select(range(min(max_samples, len(ds))))
logger.info(f"Loaded {len(ds)} examples")
return ds
except Exception as e:
logger.error(f"Failed to load dataset: {e}")
raise
def prepare_dataset_for_sft(dataset, tokenizer, max_length: int = 2048):
"""Prepare dataset for SFT (Supervised Fine-Tuning)."""
logger.info("Preparing dataset for SFT...")
def tokenize_function(examples):
"""Tokenize function for the dataset."""
# Assuming dataset has 'code' and/or 'text' fields
texts = []
for i in range(len(examples.get("code", []))):
# Try different field combinations
if "code" in examples:
code = examples["code"][i]
if "comment" in examples:
text = f"{examples['comment'][i]}\n{code}"
elif "problem" in examples:
text = f"{examples['problem'][i]}\n{code}"
else:
text = code
elif "text" in examples:
text = examples["text"][i]
else:
# Fallback: concatenate all string fields
text = " ".join([str(v) for k, v in examples.items() if isinstance(v, list) and i < len(v)])
texts.append(text)
# Tokenize
encodings = tokenizer(
texts,
max_length=max_length,
truncation=True,
padding="max_length",
return_tensors=None,
)
return encodings
# Apply tokenization
tokenized_ds = dataset.map(
tokenize_function,
batched=True,
batch_size=32,
remove_columns=dataset.column_names,
)
logger.info(f"Prepared {len(tokenized_ds)} samples")
return tokenized_ds
def setup_lora(model, lora_rank: int = 8, lora_alpha: int = 16):
"""Setup LoRA (Low-Rank Adaptation) for efficient fine-tuning."""
logger.info(f"Setting up LoRA (rank={lora_rank}, alpha={lora_alpha})...")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=0.1,
bias="none",
target_modules=["q_proj", "v_proj"], # Common for causal LM
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model
def finetune_model(
model_name: str,
output_dir: str = "./finetuned_models",
num_epochs: int = 3,
batch_size: int = 4,
learning_rate: float = 2e-4,
max_samples: Optional[int] = None,
use_lora: bool = True,
use_gradient_checkpointing: bool = True,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""Fine-tune a model on the XCoder-80K dataset."""
# Validate model
if model_name not in MODEL_REGISTRY:
logger.error(f"Model {model_name} not found. Available: {list(MODEL_REGISTRY.keys())}")
return False
hf_model_id = MODEL_REGISTRY[model_name]
output_model_dir = Path(output_dir) / model_name.replace(":", "_")
output_model_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"\n{'='*60}")
logger.info(f"Fine-tuning: {model_name}")
logger.info(f"HF Model: {hf_model_id}")
logger.info(f"Output: {output_model_dir}")
logger.info(f"Device: {device}")
logger.info(f"{'='*60}\n")
# Load dataset
dataset = load_xcoder_dataset(split="train", max_samples=max_samples)
# Load tokenizer and model
logger.info(f"Loading {hf_model_id}...")
tokenizer = AutoTokenizer.from_pretrained(hf_model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
hf_model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else "cpu",
)
if use_gradient_checkpointing:
model.gradient_checkpointing_enable()
# Setup LoRA if requested
if use_lora:
model = setup_lora(model)
# Prepare dataset
train_dataset = prepare_dataset_for_sft(dataset, tokenizer)
# Training arguments
training_args = TrainingArguments(
output_dir=str(output_model_dir),
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=0.01,
warmup_steps=500,
logging_steps=100,
save_steps=500,
save_total_limit=2,
gradient_accumulation_steps=2,
gradient_checkpointing=use_gradient_checkpointing,
fp16=device == "cuda",
optim="paged_adamw_8bit" if device == "cuda" else "adamw_torch",
report_to=["tensorboard"],
)
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
# Train
logger.info("Starting training...")
try:
trainer.train()
logger.info(f"✓ Training completed successfully")
logger.info(f"Model saved to: {output_model_dir}")
# Save final model and tokenizer
model.save_pretrained(str(output_model_dir / "final"))
tokenizer.save_pretrained(str(output_model_dir / "final"))
# Save metadata
metadata = {
"model_name": model_name,
"hf_model_id": hf_model_id,
"dataset": XCODER_DATASET,
"training_args": training_args.to_dict(),
"num_epochs": num_epochs,
"batch_size": batch_size,
"learning_rate": learning_rate,
}
with open(output_model_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
return True
except Exception as e:
logger.error(f"Training failed: {e}")
return False
def main():
parser = argparse.ArgumentParser(description="Fine-tune models on XCoder-80K dataset")
parser.add_argument(
"--model",
type=str,
default="llama3.2",
choices=list(MODEL_REGISTRY.keys()),
help="Model to fine-tune",
)
parser.add_argument(
"--all-models",
action="store_true",
help="Fine-tune all available models sequentially",
)
parser.add_argument(
"--output-dir",
type=str,
default="./finetuned_models",
help="Output directory for fine-tuned models",
)
parser.add_argument(
"--num-epochs",
type=int,
default=3,
help="Number of training epochs",
)
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Training batch size",
)
parser.add_argument(
"--learning-rate",
type=float,
default=2e-4,
help="Learning rate",
)
parser.add_argument(
"--max-samples",
type=int,
default=None,
help="Maximum number of samples to use (None = all)",
)
parser.add_argument(
"--no-lora",
action="store_true",
help="Disable LoRA (full fine-tuning instead)",
)
parser.add_argument(
"--no-gradient-checkpointing",
action="store_true",
help="Disable gradient checkpointing",
)
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
if args.all_models:
results = {}
for model_name in MODEL_REGISTRY.keys():
success = finetune_model(
model_name=model_name,
output_dir=args.output_dir,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
max_samples=args.max_samples,
use_lora=not args.no_lora,
use_gradient_checkpointing=not args.no_gradient_checkpointing,
device=device,
)
results[model_name] = "✓ Success" if success else "✗ Failed"
logger.info("\n" + "="*60)
logger.info("FINE-TUNING RESULTS")
logger.info("="*60)
for model, status in results.items():
logger.info(f"{model}: {status}")
else:
success = finetune_model(
model_name=args.model,
output_dir=args.output_dir,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
max_samples=args.max_samples,
use_lora=not args.no_lora,
use_gradient_checkpointing=not args.no_gradient_checkpointing,
device=device,
)
if success:
logger.info("\n✓ Fine-tuning completed successfully!")
logger.info(f"Output directory: {args.output_dir}")
else:
logger.error("\n✗ Fine-tuning failed!")
if __name__ == "__main__":
main()