arcisvlm / scripts /train_stage1_gpu.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
3.19 kB
"""Stage 1: JEPA Pretraining on GPU (full 371M model)."""
import yaml
import torch
import os
import time
from model.vlm import VLJEPAModel
from model.tokenizer import BPETokenizer
from data.dataset import CaptionDataset
from torch.utils.data import DataLoader
with open("configs/default.yaml") as f:
config = yaml.safe_load(f)
# Full model config for RTX 3090 (24GB)
config["train_stage1"]["batch_size"] = 8
config["train_stage1"]["max_epochs"] = 20
device = torch.device("cuda")
print(f"Device: {device}")
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
tokenizer = BPETokenizer(vocab_size=config["decoder"]["vocab_size"])
tokenizer.load("checkpoints/tokenizer.json")
print(f"Tokenizer: {len(tokenizer)} tokens")
dataset = CaptionDataset(
image_dir="data/flickr8k/Images",
captions_file="data/flickr8k/captions.txt",
tokenizer=tokenizer,
img_size=config["vision"]["img_size"],
)
loader = DataLoader(
dataset,
batch_size=config["train_stage1"]["batch_size"],
shuffle=True,
num_workers=4,
pin_memory=True,
)
print(f"Dataset: {len(dataset)} samples, {len(loader)} batches")
model = VLJEPAModel(config).to(device)
params = model.count_parameters()
for k, v in params.items():
print(f" {k}: {v:,}")
# Y-Encoder gets slower learning rate
y_params = list(model.y_encoder.parameters())
other_params = [
p for n, p in model.named_parameters()
if not n.startswith("y_encoder") and not n.startswith("decoder") and p.requires_grad
]
lr = config["train_stage1"]["learning_rate"]
y_lr = lr * config["y_encoder"]["lr_multiplier"]
optimizer = torch.optim.AdamW([
{"params": other_params, "lr": lr},
{"params": y_params, "lr": y_lr},
], weight_decay=0.01)
model.train()
os.makedirs("checkpoints", exist_ok=True)
max_epochs = config["train_stage1"]["max_epochs"]
start = time.time()
for epoch in range(max_epochs):
total_loss = 0
n = 0
for batch in loader:
images = batch["image"].to(device)
cap_ids = batch["caption_ids"].to(device)
cap_mask = batch["caption_mask"].to(device)
output = model.forward_stage1(images, None, None, cap_ids, cap_mask)
loss = output["loss"]
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
n += 1
avg = total_loss / n
elapsed = time.time() - start
gpu_mem = torch.cuda.max_memory_allocated() / 1e9
print(f"Epoch {epoch+1}/{max_epochs}: loss={avg:.4f} | {elapsed:.0f}s | GPU mem: {gpu_mem:.1f}GB", flush=True)
if (epoch + 1) % 5 == 0:
ckpt = f"checkpoints/stage1_epoch{epoch+1}.pt"
torch.save({"epoch": epoch + 1, "model_state_dict": model.state_dict(), "loss": avg}, ckpt)
print(f" Saved {ckpt}", flush=True)
# Final save
torch.save({"epoch": max_epochs, "model_state_dict": model.state_dict(), "loss": avg}, "checkpoints/stage1_final.pt")
total_time = time.time() - start
print(f"\nTraining complete. Final loss: {avg:.4f}")
print(f"Total time: {total_time:.0f}s ({total_time/60:.1f} min)")