"""Stage 1: JEPA Pretraining on GPU (full 371M model).""" import yaml import torch import os import time from model.vlm import VLJEPAModel from model.tokenizer import BPETokenizer from data.dataset import CaptionDataset from torch.utils.data import DataLoader with open("configs/default.yaml") as f: config = yaml.safe_load(f) # Full model config for RTX 3090 (24GB) config["train_stage1"]["batch_size"] = 8 config["train_stage1"]["max_epochs"] = 20 device = torch.device("cuda") print(f"Device: {device}") print(f"GPU: {torch.cuda.get_device_name()}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") tokenizer = BPETokenizer(vocab_size=config["decoder"]["vocab_size"]) tokenizer.load("checkpoints/tokenizer.json") print(f"Tokenizer: {len(tokenizer)} tokens") dataset = CaptionDataset( image_dir="data/flickr8k/Images", captions_file="data/flickr8k/captions.txt", tokenizer=tokenizer, img_size=config["vision"]["img_size"], ) loader = DataLoader( dataset, batch_size=config["train_stage1"]["batch_size"], shuffle=True, num_workers=4, pin_memory=True, ) print(f"Dataset: {len(dataset)} samples, {len(loader)} batches") model = VLJEPAModel(config).to(device) params = model.count_parameters() for k, v in params.items(): print(f" {k}: {v:,}") # Y-Encoder gets slower learning rate y_params = list(model.y_encoder.parameters()) other_params = [ p for n, p in model.named_parameters() if not n.startswith("y_encoder") and not n.startswith("decoder") and p.requires_grad ] lr = config["train_stage1"]["learning_rate"] y_lr = lr * config["y_encoder"]["lr_multiplier"] optimizer = torch.optim.AdamW([ {"params": other_params, "lr": lr}, {"params": y_params, "lr": y_lr}, ], weight_decay=0.01) model.train() os.makedirs("checkpoints", exist_ok=True) max_epochs = config["train_stage1"]["max_epochs"] start = time.time() for epoch in range(max_epochs): total_loss = 0 n = 0 for batch in loader: images = batch["image"].to(device) cap_ids = batch["caption_ids"].to(device) cap_mask = batch["caption_mask"].to(device) output = model.forward_stage1(images, None, None, cap_ids, cap_mask) loss = output["loss"] optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() total_loss += loss.item() n += 1 avg = total_loss / n elapsed = time.time() - start gpu_mem = torch.cuda.max_memory_allocated() / 1e9 print(f"Epoch {epoch+1}/{max_epochs}: loss={avg:.4f} | {elapsed:.0f}s | GPU mem: {gpu_mem:.1f}GB", flush=True) if (epoch + 1) % 5 == 0: ckpt = f"checkpoints/stage1_epoch{epoch+1}.pt" torch.save({"epoch": epoch + 1, "model_state_dict": model.state_dict(), "loss": avg}, ckpt) print(f" Saved {ckpt}", flush=True) # Final save torch.save({"epoch": max_epochs, "model_state_dict": model.state_dict(), "loss": avg}, "checkpoints/stage1_final.pt") total_time = time.time() - start print(f"\nTraining complete. Final loss: {avg:.4f}") print(f"Total time: {total_time:.0f}s ({total_time/60:.1f} min)")