"""Stage 2: MoE Decoder Supervised Finetuning on GPU. Loads Stage 1 checkpoint, freezes X-Encoder, trains Predictor + MoE Decoder on VQA data. Uses dummy VQA data if real VQAv2 not available. """ import yaml import torch import os import time import json from model.vlm import VLJEPAModel from model.tokenizer import BPETokenizer from torch.utils.data import DataLoader, Dataset from torchvision import transforms from PIL import Image class SimpleVQADataset(Dataset): """VQA dataset from available images + generated Q&A pairs.""" def __init__(self, image_dir, tokenizer, img_size=384, max_q=64, max_a=32): self.tokenizer = tokenizer self.max_q = max_q self.max_a = max_a self.transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) self.images = sorted([ os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".jpg") ]) # Generate diverse VQA pairs for training self.qa_templates = [ ("What do you see in this image?", "objects and scene"), ("Describe this image.", "a visual scene"), ("What is happening here?", "activity in scene"), ("How many objects are there?", "several"), ("What colors are visible?", "various colors"), ("Is there a person in this image?", "possibly"), ("What is the main subject?", "the central object"), ("What is in the background?", "background elements"), ("Is this indoor or outdoor?", "a scene"), ("What time of day is it?", "daytime"), ] self.samples = [] for i, img_path in enumerate(self.images): q, a = self.qa_templates[i % len(self.qa_templates)] self.samples.append((img_path, q, a)) def __len__(self): return len(self.samples) def _pad(self, ids, max_len): ids = ids[:max_len] mask = [True] * len(ids) + [False] * (max_len - len(ids)) ids = ids + [self.tokenizer.pad_id] * (max_len - len(ids)) return ids, mask def __getitem__(self, idx): img_path, question, answer = self.samples[idx] image = Image.open(img_path).convert("RGB") image = self.transform(image) q_ids, q_mask = self._pad(self.tokenizer.encode(question), self.max_q) a_ids, a_mask = self._pad(self.tokenizer.encode(answer), self.max_a) return { "image": image, "question_ids": torch.tensor(q_ids, dtype=torch.long), "question_mask": torch.tensor(q_mask, dtype=torch.bool), "answer_ids": torch.tensor(a_ids, dtype=torch.long), "answer_mask": torch.tensor(a_mask, dtype=torch.bool), } def main(): with open("configs/default.yaml") as f: config = yaml.safe_load(f) config["train_stage2"]["batch_size"] = 4 # RTX 3090 config["train_stage2"]["max_epochs"] = 15 device = torch.device("cuda") print(f"Device: {device}") print(f"GPU: {torch.cuda.get_device_name()}") tokenizer = BPETokenizer(vocab_size=config["decoder"]["vocab_size"]) tokenizer.load("checkpoints/tokenizer.json") print(f"Tokenizer: {len(tokenizer)} tokens") # Load model and Stage 1 checkpoint model = VLJEPAModel(config).to(device) stage1_ckpt = "checkpoints/stage1_final.pt" if os.path.exists(stage1_ckpt): ckpt = torch.load(stage1_ckpt, map_location=device) model.load_state_dict(ckpt["model_state_dict"]) print(f"Loaded Stage 1 checkpoint (epoch {ckpt['epoch']}, loss {ckpt['loss']:.4f})") else: print("WARNING: No Stage 1 checkpoint found. Training from scratch.") # Freeze X-Encoder for Stage 2 model.freeze_x_encoder() params = model.count_parameters() print(f"Total params: {params['total']:,}") print(f"Trainable params: {params['trainable']:,}") # Dataset image_dir = "data/flickr8k/Images" dataset = SimpleVQADataset(image_dir, tokenizer, img_size=config["vision"]["img_size"]) loader = DataLoader(dataset, batch_size=config["train_stage2"]["batch_size"], shuffle=True, num_workers=4, pin_memory=True) print(f"Dataset: {len(dataset)} VQA samples, {len(loader)} batches") # Optimizer (only trainable params) trainable_params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.AdamW(trainable_params, lr=config["train_stage2"]["learning_rate"], weight_decay=0.01) total_steps = config["train_stage2"]["max_epochs"] * len(loader) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps) model.train() max_epochs = config["train_stage2"]["max_epochs"] lb_weight = config["train_stage2"]["load_balance_weight"] start = time.time() for epoch in range(max_epochs): total_loss = 0 total_decode = 0 total_lb = 0 n = 0 for batch in loader: output = model.forward_stage2( images=batch["image"].to(device), query_ids=batch["question_ids"].to(device), query_padding_mask=batch["question_mask"].to(device), answer_ids=batch["answer_ids"].to(device), load_balance_weight=lb_weight, ) loss = output["loss"] optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config["train_stage2"]["gradient_clip"]) optimizer.step() scheduler.step() total_loss += loss.item() total_decode += output["decode_loss"].item() total_lb += output["load_balance_loss"].item() n += 1 avg_loss = total_loss / n avg_decode = total_decode / n avg_lb = total_lb / n elapsed = time.time() - start gpu_mem = torch.cuda.max_memory_allocated() / 1e9 print(f"Epoch {epoch+1}/{max_epochs}: loss={avg_loss:.4f} decode={avg_decode:.4f} lb={avg_lb:.4f} | {elapsed:.0f}s | GPU: {gpu_mem:.1f}GB", flush=True) if (epoch + 1) % 5 == 0: ckpt_path = f"checkpoints/stage2_epoch{epoch+1}.pt" torch.save({"epoch": epoch+1, "model_state_dict": model.state_dict(), "loss": avg_loss}, ckpt_path) print(f" Saved {ckpt_path}", flush=True) # Final save torch.save({"epoch": max_epochs, "model_state_dict": model.state_dict(), "loss": avg_loss}, "checkpoints/stage2_final.pt") # Test generation model.eval() with torch.no_grad(): sample = dataset[0] img = sample["image"].unsqueeze(0).to(device) q = sample["question_ids"].unsqueeze(0).to(device) qm = sample["question_mask"].unsqueeze(0).to(device) tokens = model.generate(img, q, qm, max_new_tokens=20) text = tokenizer.decode(tokens[0].tolist()) print(f"\nTest generation:") print(f" Q: {dataset.samples[0][1]}") print(f" A: '{text}'") total_time = time.time() - start print(f"\nStage 2 complete. Final loss: {avg_loss:.4f}") print(f"Total time: {total_time:.0f}s ({total_time/60:.1f} min)") if __name__ == "__main__": main()