arcisvlm / scripts /train_stage3_ddp.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
22.6 kB
#!/usr/bin/env python3
"""
Stage 3: Domain Fine-Tuning — DDP Training on 8x A100.
Loads Stage 2 checkpoint and fine-tunes on domain-specific data
(COCO detection, VisDrone, MOT, UCF-Crime, ActivityNet, synthetic RTSP)
with very low learning rate.
Usage:
torchrun --nproc_per_node=8 scripts/train_stage3_ddp.py \\
--config configs/scale_1.3b.yaml \\
--stage2_ckpt checkpoints/stage2_final.pt
torchrun --nproc_per_node=8 scripts/train_stage3_ddp.py \\
--config configs/scale_1.3b.yaml \\
--stage2_ckpt checkpoints/stage2_final.pt \\
--resume checkpoints/stage3_epoch3.pt
"""
import argparse
import math
import os
import sys
import time
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
import yaml
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.vlm import VLJEPAModel
from model.tokenizer import BPETokenizer
# ---------------------------------------------------------------------------
# Dataset helpers
# ---------------------------------------------------------------------------
def build_stage3_dataset(config: dict, tokenizer) -> Dataset:
"""
Build Stage 3 domain fine-tuning dataset (detection, surveillance, temporal).
Raises RuntimeError if no real data is found.
"""
img_size = config["vision"]["img_size"]
vocab_size = config["decoder"]["vocab_size"]
# Try loading real data from local JSONL files (downloaded by download_datasets.py)
jsonl_dir = "data/downloads/stage3"
if os.path.exists(jsonl_dir):
try:
import json
class LocalJSONLDataset(Dataset):
"""Load domain fine-tuning data from local JSONL files."""
def __init__(self, jsonl_dir, tokenizer, img_size=448, max_q=64, max_a=128):
self.samples = []
self.tokenizer = tokenizer
self.max_q = max_q
self.max_a = max_a
self.img_size = img_size
self.vocab_size = tokenizer.vocab_size if hasattr(tokenizer, 'vocab_size') else 32768
# Load all JSONL files — ONLY keep samples with valid images
skipped = 0
for fname in sorted(os.listdir(jsonl_dir)):
if fname.endswith('.jsonl'):
fpath = os.path.join(jsonl_dir, fname)
with open(fpath) as f:
for line in f:
try:
item = json.loads(line.strip())
# Filter: must have valid image_path
img_path = item.get("image_path")
if img_path and os.path.exists(img_path):
self.samples.append(item)
else:
skipped += 1
except json.JSONDecodeError:
continue
if skipped > 0:
print(f" [Stage 3 data] Skipped {skipped} samples without images")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
item = self.samples[idx]
# Extract question and answer from various formats
question = ""
answer = ""
# LLaVA-Instruct format: conversations list
if "conversations" in item:
convos = item["conversations"]
if isinstance(convos, list) and len(convos) >= 2:
question = convos[0].get("value", "") if isinstance(convos[0], dict) else str(convos[0])
answer = convos[1].get("value", "") if isinstance(convos[1], dict) else str(convos[1])
# VQAv2/GQA format
if not question:
question = item.get("question", item.get("text", "What do you see?"))
if not answer:
answer = item.get("answer", item.get("multiple_choice_answer", ""))
if not answer and "answers" in item:
answers = item["answers"]
if isinstance(answers, list) and answers:
if isinstance(answers[0], dict):
answer = answers[0].get("answer", "")
else:
answer = str(answers[0])
if not answer:
answer = "unknown"
# Load real image — crash if missing or corrupted
image_path = item.get("image_path")
if image_path and os.path.exists(image_path):
try:
from PIL import Image as PILImage
from torchvision import transforms
pil_img = PILImage.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize((self.img_size, self.img_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
])
image = transform(pil_img)
except Exception as e:
raise FileNotFoundError(f"Image not found or corrupted: {image_path}. Original error: {e}")
else:
raise FileNotFoundError(
f"Image not found: {image_path}. Data may be corrupted or incomplete."
)
# Tokenize
q_ids = self.tokenizer.encode(str(question))
a_ids = self.tokenizer.encode(str(answer))
# Pad/truncate
q_ids = self._pad(q_ids, self.max_q)
a_ids = self._pad(a_ids, self.max_a)
q_tensor = torch.tensor(q_ids, dtype=torch.long)
a_tensor = torch.tensor(a_ids, dtype=torch.long)
return {
"image": image,
"question_ids": q_tensor,
"question_mask": (q_tensor != self.tokenizer.pad_id).long(),
"answer_ids": a_tensor,
"answer_mask": (a_tensor != self.tokenizer.pad_id).long(),
}
def _pad(self, ids, max_len):
if len(ids) > max_len:
return ids[:max_len]
return ids + [self.tokenizer.pad_id] * (max_len - len(ids))
dataset = LocalJSONLDataset(jsonl_dir, tokenizer, img_size)
if len(dataset) > 100:
print(f" [REAL DATA] Local JSONL: {len(dataset)} samples from {jsonl_dir}")
return dataset
except Exception as e:
print(f" [WARN] Local JSONL loading failed: {e}")
# Try loading from data/multi_dataset.py — wrap with tokenization
try:
from data.multi_dataset import build_stage3_dataset as _build
raw_dataset = _build(config, tokenizer)
if len(raw_dataset) > 0:
# Wrap raw dataset to tokenize on-the-fly if needed
class TokenizedWrapper(Dataset):
"""Wraps a dataset returning raw strings and adds tokenization."""
def __init__(self, ds, tokenizer, img_size=448, max_q=64, max_a=128):
self.ds = ds
self.tokenizer = tokenizer
self.img_size = img_size
self.max_q = max_q
self.max_a = max_a
self.vocab_size = tokenizer.vocab_size if hasattr(tokenizer, 'vocab_size') else 32768
def __len__(self):
return len(self.ds)
def __getitem__(self, idx):
item = self.ds[idx]
# If already tokenized, return as-is
if "question_ids" in item:
return item
# Otherwise tokenize raw text
image = item.get("image")
if image is None:
# Skip text-only samples — return a valid dummy that training can handle
# (filtered by the DataLoader collate, or produces minimal loss)
import torch
pad_id = getattr(self.tokenizer, 'pad_id', 0)
return {
"image": torch.zeros(3, self.img_size, self.img_size),
"question_ids": torch.full((self.max_q,), pad_id, dtype=torch.long),
"question_mask": torch.zeros(self.max_q, dtype=torch.long),
"answer_ids": torch.full((self.max_a,), pad_id, dtype=torch.long),
"answer_mask": torch.zeros(self.max_a, dtype=torch.long),
}
question = str(item.get("question", "What do you see?"))
answer = str(item.get("answer", "unknown"))
q_ids = self.tokenizer.encode(question)
a_ids = self.tokenizer.encode(answer)
# Pad/truncate
pad_id = getattr(self.tokenizer, 'pad_id', 0)
q_ids = (q_ids[:self.max_q] + [pad_id] * self.max_q)[:self.max_q]
a_ids = (a_ids[:self.max_a] + [pad_id] * self.max_a)[:self.max_a]
q_tensor = torch.tensor(q_ids, dtype=torch.long)
a_tensor = torch.tensor(a_ids, dtype=torch.long)
return {
"image": image,
"question_ids": q_tensor,
"question_mask": (q_tensor != pad_id).long(),
"answer_ids": a_tensor,
"answer_mask": (a_tensor != pad_id).long(),
}
dataset = TokenizedWrapper(raw_dataset, tokenizer, img_size)
print(f" [REAL DATA] multi_dataset: {len(dataset)} samples (with tokenization wrapper)")
return dataset
except (ImportError, Exception) as e:
print(f" [WARN] multi_dataset loading failed: {e}")
raise RuntimeError(
"FATAL: No Stage 3 training data found.\n"
"Download real data first: python3 scripts/download_all_data.py --stage 3\n"
"Required: data/downloads/stage3/ with JSONL files"
)
# ---------------------------------------------------------------------------
# LR scheduler
# ---------------------------------------------------------------------------
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
"""Linear warmup then cosine decay."""
def __init__(self, optimizer, warmup_steps: int, total_steps: int,
min_lr: float = 1e-7, last_epoch: int = -1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
scale = step / max(1, self.warmup_steps)
return [base_lr * scale for base_lr in self.base_lrs]
else:
progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
return [self.min_lr + (base_lr - self.min_lr) * cosine for base_lr in self.base_lrs]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def setup_distributed():
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
return local_rank
def cleanup():
if dist.is_initialized():
dist.destroy_process_group()
def is_rank0():
return not dist.is_initialized() or dist.get_rank() == 0
def log(msg: str):
if is_rank0():
print(msg, flush=True)
def save_checkpoint(model, optimizer, scheduler, epoch, global_step, loss, path):
if not is_rank0():
return
os.makedirs(os.path.dirname(path), exist_ok=True)
state_dict = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
torch.save({
"epoch": epoch,
"global_step": global_step,
"model_state_dict": state_dict,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss": loss,
}, path)
log(f" Checkpoint saved: {path}")
def push_checkpoints():
"""Push checkpoints to GitHub LFS. Disabled during training to avoid git lock issues.
Call scripts/push_checkpoints.py manually after training completes."""
pass # Disabled — run push_checkpoints.py separately after training
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Stage 3 DDP: Domain Fine-Tuning")
parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
parser.add_argument("--stage2_ckpt", type=str, default="checkpoints/stage2_final.pt",
help="Path to Stage 2 checkpoint")
parser.add_argument("--resume", type=str, default=None,
help="Path to Stage 3 checkpoint to resume from")
args = parser.parse_args()
# ---- Distributed setup ----
local_rank = setup_distributed()
world_size = dist.get_world_size()
global_rank = dist.get_rank()
device = torch.device(f"cuda:{local_rank}")
# ---- Config ----
with open(args.config) as f:
config = yaml.safe_load(f)
stage_cfg = config["train_stage3"]
per_gpu_batch = stage_cfg["batch_size"] // world_size
grad_accum = stage_cfg.get("gradient_accumulation", 4)
max_epochs = stage_cfg["max_epochs"]
lr = stage_cfg["learning_rate"] # Very low: 5e-6
warmup_steps = stage_cfg["warmup_steps"]
grad_clip = stage_cfg["gradient_clip"]
lb_weight = stage_cfg["load_balance_weight"]
log("=" * 70)
log("ArcisVLM — Stage 3: Domain Fine-Tuning (DDP)")
log("=" * 70)
log(f" World size: {world_size}")
log(f" Global batch: {stage_cfg['batch_size']}")
log(f" Per-GPU batch: {per_gpu_batch}")
log(f" Gradient accumulation:{grad_accum}")
log(f" Effective batch: {stage_cfg['batch_size'] * grad_accum}")
log(f" Max epochs: {max_epochs}")
log(f" Learning rate: {lr}")
log(f" Warmup steps: {warmup_steps}")
log(f" Load balance weight: {lb_weight}")
log(f" Precision: {stage_cfg.get('precision', 'bf16')}")
# ---- Tokenizer ----
tokenizer = BPETokenizer(vocab_size=config["decoder"]["vocab_size"])
for tok_path in ["checkpoints/tokenizer_32k.json", "checkpoints/tokenizer.json"]:
if os.path.exists(tok_path):
tokenizer.load(tok_path)
log(f" Tokenizer: {len(tokenizer)} tokens (from {tok_path})")
break
else:
log(" [WARN] No tokenizer found — using untrained tokenizer")
# ---- Dataset ----
dataset = build_stage3_dataset(config, tokenizer)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=global_rank, shuffle=True)
loader = DataLoader(
dataset,
batch_size=per_gpu_batch,
sampler=sampler,
num_workers=4,
pin_memory=True,
drop_last=True,
)
log(f" Dataset: {len(dataset)} samples, {len(loader)} batches/GPU")
# ---- Model + Stage 2 checkpoint ----
model = VLJEPAModel(config).to(device)
if os.path.exists(args.stage2_ckpt):
ckpt = torch.load(args.stage2_ckpt, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
log(f" Loaded Stage 2 ckpt: {args.stage2_ckpt} (epoch {ckpt['epoch']}, loss {ckpt['loss']:.4f})")
else:
log(f" [WARN] Stage 2 checkpoint not found: {args.stage2_ckpt} — training from scratch")
# Stage 3: all parameters trainable (unfrozen) but at very low LR
model.unfreeze_all()
if is_rank0():
params = model.count_parameters()
for k, v in params.items():
log(f" {k}: {v:,}")
# ---- Optimizer ----
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
# ---- Scheduler ----
steps_per_epoch = math.ceil(len(loader) / grad_accum)
total_steps = max_epochs * steps_per_epoch
scheduler = CosineWarmupScheduler(optimizer, warmup_steps=warmup_steps, total_steps=total_steps)
# ---- Mixed precision ----
use_bf16 = stage_cfg.get("precision", "bf16") == "bf16"
scaler = torch.amp.GradScaler("cuda", enabled=(not use_bf16))
autocast_dtype = torch.bfloat16 if use_bf16 else torch.float16
# ---- Resume ----
start_epoch = 0
global_step = 0
if args.resume and os.path.exists(args.resume):
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
if "scheduler_state_dict" in ckpt:
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
start_epoch = ckpt["epoch"]
global_step = ckpt.get("global_step", start_epoch * steps_per_epoch)
log(f" Resumed from {args.resume} (epoch {start_epoch}, loss {ckpt['loss']:.4f})")
# ---- DDP wrap ----
model = DDP(model, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
# ---- Training loop ----
model.train()
os.makedirs("checkpoints", exist_ok=True)
for epoch in range(start_epoch, max_epochs):
sampler.set_epoch(epoch)
epoch_loss = 0.0
epoch_decode_loss = 0.0
epoch_lb_loss = 0.0
epoch_steps = 0
epoch_start = time.time()
optimizer.zero_grad(set_to_none=True)
for batch_idx, batch in enumerate(loader):
images = batch["image"].to(device, non_blocking=True)
q_ids = batch["question_ids"].to(device, non_blocking=True)
q_mask = batch["question_mask"].to(device, non_blocking=True)
a_ids = batch["answer_ids"].to(device, non_blocking=True)
with torch.amp.autocast("cuda", dtype=autocast_dtype):
output = model.module.forward_stage2(
images=images,
query_ids=q_ids,
query_padding_mask=q_mask,
answer_ids=a_ids,
load_balance_weight=lb_weight,
)
loss = output["loss"] / grad_accum
if use_bf16:
loss.backward()
else:
scaler.scale(loss).backward()
epoch_loss += output["loss"].item()
epoch_decode_loss += output["decode_loss"].item()
epoch_lb_loss += output["load_balance_loss"].item()
epoch_steps += 1
# Optimizer step every grad_accum batches
if (batch_idx + 1) % grad_accum == 0:
if use_bf16:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
else:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
if global_step % 50 == 0:
current_lr = scheduler.get_last_lr()[0]
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
log(f" [Step {global_step}] loss={output['loss'].item():.4f} "
f"decode={output['decode_loss'].item():.4f} "
f"lb={output['load_balance_loss'].item():.4f} "
f"lr={current_lr:.2e} GPU={gpu_mem:.1f}GB")
# ---- Epoch summary ----
metrics = torch.tensor([epoch_loss, epoch_decode_loss, epoch_lb_loss, epoch_steps],
device=device, dtype=torch.float64)
dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
avg_loss = (metrics[0] / metrics[3]).item()
avg_decode = (metrics[1] / metrics[3]).item()
avg_lb = (metrics[2] / metrics[3]).item()
epoch_time = time.time() - epoch_start
current_lr = scheduler.get_last_lr()[0]
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
log(f"\nEpoch {epoch + 1}/{max_epochs}: loss={avg_loss:.4f} decode={avg_decode:.4f} "
f"lb={avg_lb:.4f} lr={current_lr:.2e} time={epoch_time:.0f}s GPU={gpu_mem:.1f}GB")
# ---- Checkpoint every epoch ----
ckpt_path = f"checkpoints/stage3_epoch{epoch + 1}.pt"
save_checkpoint(model, optimizer, scheduler, epoch + 1, global_step, avg_loss, ckpt_path)
push_checkpoints()
dist.barrier()
# ---- Final checkpoint ----
save_checkpoint(model, optimizer, scheduler, max_epochs, global_step, avg_loss,
"checkpoints/stage3_final.pt")
push_checkpoints()
log("\n" + "=" * 70)
log(f"Stage 3 complete. Final loss: {avg_loss:.4f} decode: {avg_decode:.4f}")
log("=" * 70)
cleanup()
if __name__ == "__main__":
main()