GIS-Coder-7B / train_7b.py
RhodWeo's picture
Add train_7b.py
0fe8cc8 verified
"""
GIS-Coder 7B: Production QLoRA SFT Training Script
====================================================
Fine-tunes Qwen2.5-Coder-7B-Instruct for GIS code generation.
Hardware requirements:
- Minimum: 1x A10G (24GB) or 1x RTX 4090 (24GB)
- Recommended: 1x A100 (80GB) for faster training + larger batch
- Also works on: H100, L40S, RTX 3090
Training recipe based on:
- CFD fine-tuning (arxiv:2504.09602): QLoRA, r=16, 88.7% accuracy on domain tasks
- MapCoder-Lite (arxiv:2509.17489): Qwen2.5-Coder-7B as best backbone for code LoRA
- LoRA Without Regret: target all-linear layers, lr=2e-4 for LoRA
Usage:
# Single GPU
python train_7b.py
# Multi-GPU with accelerate
accelerate launch --num_processes 2 train_7b.py
# With custom settings
python train_7b.py --epochs 5 --lr 1e-4 --lora_r 32 --max_length 4096
"""
import os
import argparse
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTConfig, SFTTrainer
def parse_args():
parser = argparse.ArgumentParser(description="Train GIS-Coder 7B")
parser.add_argument("--model_id", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct")
parser.add_argument("--dataset_id", type=str, default="RhodWeo/gis-code-instructions")
parser.add_argument("--hub_model_id", type=str, default="RhodWeo/GIS-Coder-7B")
parser.add_argument("--output_dir", type=str, default="./gis-coder-7b-output")
# Training hyperparameters
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate (2e-4 for LoRA)")
parser.add_argument("--batch_size", type=int, default=2, help="Per-device batch size")
parser.add_argument("--grad_accum", type=int, default=8, help="Gradient accumulation steps")
parser.add_argument("--max_length", type=int, default=4096, help="Max sequence length")
parser.add_argument("--warmup_ratio", type=float, default=0.1)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--scheduler", type=str, default="cosine")
# LoRA hyperparameters
parser.add_argument("--lora_r", type=int, default=32, help="LoRA rank")
parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha")
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--target_modules", type=str, default="all-linear",
help="Target modules (all-linear or comma-separated list)")
# Quantization
parser.add_argument("--no_quantize", action="store_true", help="Disable 4-bit quantization (full fp16)")
parser.add_argument("--use_flash_attn", action="store_true", help="Use Flash Attention 2")
# Tracking
parser.add_argument("--use_trackio", action="store_true", help="Enable Trackio monitoring")
parser.add_argument("--trackio_project", type=str, default="gis-coder-7b")
return parser.parse_args()
def main():
args = parse_args()
# ─── Trackio (optional) ────────────────────────────────────────────────
if args.use_trackio:
import trackio
trackio.init(
project=args.trackio_project,
config=vars(args),
)
# ─── Dataset ───────────────────────────────────────────────────────────
print(f"Loading dataset: {args.dataset_id}")
dataset = load_dataset(args.dataset_id, data_files="data/train.jsonl", split="train")
print(f" {len(dataset)} examples, columns: {dataset.column_names}")
# ─── Model ─────────────────────────────────────────────────────────────
print(f"Loading model: {args.model_id}")
model_kwargs = {
"trust_remote_code": True,
"attn_implementation": "flash_attention_2" if args.use_flash_attn else "eager",
}
if not args.no_quantize:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_kwargs["quantization_config"] = bnb_config
model_kwargs["dtype"] = torch.bfloat16
else:
model_kwargs["dtype"] = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
device_map="auto",
**model_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
if not args.no_quantize:
model = prepare_model_for_kbit_training(model)
print(f" Parameters: {model.num_parameters()/1e9:.2f}B")
# ─── LoRA ──────────────────────────────────────────────────────────────
target = args.target_modules
if target != "all-linear":
target = target.split(",")
peft_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=target,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
print(f" LoRA: r={args.lora_r}, alpha={args.lora_alpha}, targets={target}")
# ─── Training Config ───────────────────────────────────────────────────
training_args = SFTConfig(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
lr_scheduler_type=args.scheduler,
warmup_ratio=args.warmup_ratio,
weight_decay=args.weight_decay,
gradient_checkpointing=True,
bf16=True,
max_length=args.max_length,
logging_steps=1,
logging_first_step=True,
logging_strategy="steps",
disable_tqdm=True,
report_to="trackio" if args.use_trackio else "none",
save_strategy="epoch",
save_total_limit=3,
push_to_hub=True,
hub_model_id=args.hub_model_id,
hub_strategy="every_save",
dataloader_num_workers=4,
seed=42,
)
# ─── Trainer ───────────────────────────────────────────────────────────
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
args=training_args,
train_dataset=dataset,
peft_config=peft_config,
)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f" Trainable: {trainable:,} ({trainable/total*100:.2f}%)")
# ─── Train ─────────────────────────────────────────────────────────────
eff_bs = args.batch_size * args.grad_accum
print(f"\n{'='*60}")
print(f"TRAINING: {args.model_id}")
print(f" Dataset: {len(dataset)} examples")
print(f" Method: {'QLoRA' if not args.no_quantize else 'LoRA'} (r={args.lora_r})")
print(f" LR: {args.lr}, Epochs: {args.epochs}, Eff. batch: {eff_bs}")
print(f" Max length: {args.max_length}")
print(f" Push to: {args.hub_model_id}")
print(f"{'='*60}\n")
result = trainer.train()
# ─── Save ──────────────────────────────────────────────────────────────
print("\nSaving final model...")
trainer.save_model(os.path.join(args.output_dir, "final"))
trainer.push_to_hub(commit_message="GIS-Coder 7B β€” final after training")
m = result.metrics
print(f"\nDone! Loss: {m.get('train_loss','?')}, Time: {m.get('train_runtime',0):.0f}s")
print(f"Model: https://huggingface.co/{args.hub_model_id}")
if args.use_trackio:
import trackio
trackio.log({"final_loss": m.get("train_loss", 0), "runtime": m.get("train_runtime", 0)})
trackio.finish()
if __name__ == "__main__":
main()