arcisvlm / scripts /train_stage2_gpu.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
7.37 kB
"""Stage 2: MoE Decoder Supervised Finetuning on GPU.
Loads Stage 1 checkpoint, freezes X-Encoder, trains Predictor + MoE Decoder on VQA data.
Uses dummy VQA data if real VQAv2 not available.
"""
import yaml
import torch
import os
import time
import json
from model.vlm import VLJEPAModel
from model.tokenizer import BPETokenizer
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
class SimpleVQADataset(Dataset):
"""VQA dataset from available images + generated Q&A pairs."""
def __init__(self, image_dir, tokenizer, img_size=384, max_q=64, max_a=32):
self.tokenizer = tokenizer
self.max_q = max_q
self.max_a = max_a
self.transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
self.images = sorted([
os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".jpg")
])
# Generate diverse VQA pairs for training
self.qa_templates = [
("What do you see in this image?", "objects and scene"),
("Describe this image.", "a visual scene"),
("What is happening here?", "activity in scene"),
("How many objects are there?", "several"),
("What colors are visible?", "various colors"),
("Is there a person in this image?", "possibly"),
("What is the main subject?", "the central object"),
("What is in the background?", "background elements"),
("Is this indoor or outdoor?", "a scene"),
("What time of day is it?", "daytime"),
]
self.samples = []
for i, img_path in enumerate(self.images):
q, a = self.qa_templates[i % len(self.qa_templates)]
self.samples.append((img_path, q, a))
def __len__(self):
return len(self.samples)
def _pad(self, ids, max_len):
ids = ids[:max_len]
mask = [True] * len(ids) + [False] * (max_len - len(ids))
ids = ids + [self.tokenizer.pad_id] * (max_len - len(ids))
return ids, mask
def __getitem__(self, idx):
img_path, question, answer = self.samples[idx]
image = Image.open(img_path).convert("RGB")
image = self.transform(image)
q_ids, q_mask = self._pad(self.tokenizer.encode(question), self.max_q)
a_ids, a_mask = self._pad(self.tokenizer.encode(answer), self.max_a)
return {
"image": image,
"question_ids": torch.tensor(q_ids, dtype=torch.long),
"question_mask": torch.tensor(q_mask, dtype=torch.bool),
"answer_ids": torch.tensor(a_ids, dtype=torch.long),
"answer_mask": torch.tensor(a_mask, dtype=torch.bool),
}
def main():
with open("configs/default.yaml") as f:
config = yaml.safe_load(f)
config["train_stage2"]["batch_size"] = 4 # RTX 3090
config["train_stage2"]["max_epochs"] = 15
device = torch.device("cuda")
print(f"Device: {device}")
print(f"GPU: {torch.cuda.get_device_name()}")
tokenizer = BPETokenizer(vocab_size=config["decoder"]["vocab_size"])
tokenizer.load("checkpoints/tokenizer.json")
print(f"Tokenizer: {len(tokenizer)} tokens")
# Load model and Stage 1 checkpoint
model = VLJEPAModel(config).to(device)
stage1_ckpt = "checkpoints/stage1_final.pt"
if os.path.exists(stage1_ckpt):
ckpt = torch.load(stage1_ckpt, map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
print(f"Loaded Stage 1 checkpoint (epoch {ckpt['epoch']}, loss {ckpt['loss']:.4f})")
else:
print("WARNING: No Stage 1 checkpoint found. Training from scratch.")
# Freeze X-Encoder for Stage 2
model.freeze_x_encoder()
params = model.count_parameters()
print(f"Total params: {params['total']:,}")
print(f"Trainable params: {params['trainable']:,}")
# Dataset
image_dir = "data/flickr8k/Images"
dataset = SimpleVQADataset(image_dir, tokenizer, img_size=config["vision"]["img_size"])
loader = DataLoader(dataset, batch_size=config["train_stage2"]["batch_size"],
shuffle=True, num_workers=4, pin_memory=True)
print(f"Dataset: {len(dataset)} VQA samples, {len(loader)} batches")
# Optimizer (only trainable params)
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=config["train_stage2"]["learning_rate"], weight_decay=0.01)
total_steps = config["train_stage2"]["max_epochs"] * len(loader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps)
model.train()
max_epochs = config["train_stage2"]["max_epochs"]
lb_weight = config["train_stage2"]["load_balance_weight"]
start = time.time()
for epoch in range(max_epochs):
total_loss = 0
total_decode = 0
total_lb = 0
n = 0
for batch in loader:
output = model.forward_stage2(
images=batch["image"].to(device),
query_ids=batch["question_ids"].to(device),
query_padding_mask=batch["question_mask"].to(device),
answer_ids=batch["answer_ids"].to(device),
load_balance_weight=lb_weight,
)
loss = output["loss"]
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config["train_stage2"]["gradient_clip"])
optimizer.step()
scheduler.step()
total_loss += loss.item()
total_decode += output["decode_loss"].item()
total_lb += output["load_balance_loss"].item()
n += 1
avg_loss = total_loss / n
avg_decode = total_decode / n
avg_lb = total_lb / n
elapsed = time.time() - start
gpu_mem = torch.cuda.max_memory_allocated() / 1e9
print(f"Epoch {epoch+1}/{max_epochs}: loss={avg_loss:.4f} decode={avg_decode:.4f} lb={avg_lb:.4f} | {elapsed:.0f}s | GPU: {gpu_mem:.1f}GB", flush=True)
if (epoch + 1) % 5 == 0:
ckpt_path = f"checkpoints/stage2_epoch{epoch+1}.pt"
torch.save({"epoch": epoch+1, "model_state_dict": model.state_dict(), "loss": avg_loss}, ckpt_path)
print(f" Saved {ckpt_path}", flush=True)
# Final save
torch.save({"epoch": max_epochs, "model_state_dict": model.state_dict(), "loss": avg_loss}, "checkpoints/stage2_final.pt")
# Test generation
model.eval()
with torch.no_grad():
sample = dataset[0]
img = sample["image"].unsqueeze(0).to(device)
q = sample["question_ids"].unsqueeze(0).to(device)
qm = sample["question_mask"].unsqueeze(0).to(device)
tokens = model.generate(img, q, qm, max_new_tokens=20)
text = tokenizer.decode(tokens[0].tolist())
print(f"\nTest generation:")
print(f" Q: {dataset.samples[0][1]}")
print(f" A: '{text}'")
total_time = time.time() - start
print(f"\nStage 2 complete. Final loss: {avg_loss:.4f}")
print(f"Total time: {total_time:.0f}s ({total_time/60:.1f} min)")
if __name__ == "__main__":
main()