#!/usr/bin/env python3 """ HyperNetwork Meta-Training — SHINE-style two-phase DDP training. Phase 1 (MSE fitting): Train HyperNetwork to reconstruct prototype LoRA adapters from conditioning vectors. Loss = MSE(generated_params, prototype_params) + calibration_loss. Phase 2 (End-to-end): Fine-tune HyperNetwork on task loss: condition -> HyperNetwork -> LoRA params -> inject into decoder -> decode -> task loss (decode_loss + calibration). Usage: # Phase 1: MSE fitting to prototype LoRAs torchrun --nproc_per_node=8 scripts/train_hypernetwork.py \ --config configs/default.yaml \ --hn_config configs/hypernetwork.yaml \ --stage3_ckpt checkpoints/stage3_final.pt \ --prototype_dir checkpoints/prototypes \ --phase 1 # Phase 2: End-to-end on task loss torchrun --nproc_per_node=8 scripts/train_hypernetwork.py \ --config configs/default.yaml \ --hn_config configs/hypernetwork.yaml \ --stage3_ckpt checkpoints/stage3_final.pt \ --resume checkpoints/hypernetwork_phase1_final.pt \ --phase 2 References: - SHINE: Scalable Hypernetwork Internalization (arXiv: 2602.28901) - HypeLoRA: Calibrated LoRA with uncertainty (arXiv: 2603.19278) - HyperVLA: Mother spawns compact Child policies (arXiv: 2510.04898) """ import argparse import math import os import sys import time import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler import yaml sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from model.hypernetwork import HyperNetwork from model.condition_encoder import ConditionEncoder from model.lora import LoRAConfig, LoRAInjector from model.vlm import VLJEPAModel from model.tokenizer import BPETokenizer # --------------------------------------------------------------------------- # Dataset helpers # --------------------------------------------------------------------------- class PrototypeDataset(Dataset): """ Dataset of prototype LoRA adapters for Phase 1 MSE training. Each sample contains: - A conditioning vector (camera_id, scene_descriptor, query_embedding) - The corresponding prototype LoRA parameter vector (flat .pt file) If no real prototypes exist, generates dummy data for testing. """ def __init__( self, prototype_dir: str, cond_encoder_cfg: dict, lora_param_count: int, num_dummy: int = 2000, ): self.lora_param_count = lora_param_count self.samples: list[dict] = [] scene_input_dim = cond_encoder_cfg.get("scene_input_dim", 2048) query_input_dim = cond_encoder_cfg.get("query_input_dim", 2048) n_cameras = cond_encoder_cfg.get("n_cameras", 2048) # Try loading real prototypes if prototype_dir and os.path.isdir(prototype_dir): pt_files = sorted( f for f in os.listdir(prototype_dir) if f.endswith(".pt") ) for fname in pt_files: fpath = os.path.join(prototype_dir, fname) try: proto = torch.load(fpath, map_location="cpu", weights_only=True) # Expect dict with 'lora_params' and optionally 'condition' if isinstance(proto, dict) and "lora_params" in proto: lora_params = proto["lora_params"].flatten() if lora_params.numel() == lora_param_count: sample = { "lora_params": lora_params.float(), "camera_id": proto.get("camera_id", torch.randint(0, n_cameras, (1,)).item()), "scene_descriptor": proto.get("scene_descriptor", torch.randn(scene_input_dim)), "query_embedding": proto.get("query_embedding", torch.randn(query_input_dim)), } self.samples.append(sample) elif isinstance(proto, torch.Tensor): flat = proto.flatten() if flat.numel() == lora_param_count: self.samples.append({ "lora_params": flat.float(), "camera_id": torch.randint(0, n_cameras, (1,)).item(), "scene_descriptor": torch.randn(scene_input_dim), "query_embedding": torch.randn(query_input_dim), }) except Exception: continue if len(self.samples) > 0: return raise RuntimeError( f"FATAL: No prototype LoRA files found in '{prototype_dir}'.\n" "Train prototypes first: python3 scripts/train_prototype_loras.py\n" "Required: .pt files with 'lora_params' key and matching param count" ) def __len__(self): return len(self.samples) def __getitem__(self, idx): s = self.samples[idx] return { "lora_params": s["lora_params"], "camera_id": torch.tensor(s["camera_id"], dtype=torch.long), "scene_descriptor": s["scene_descriptor"].float(), "query_embedding": s["query_embedding"].float(), } def _require_e2e_dataset(): """Phase 2 end-to-end training requires real data.""" raise RuntimeError( "FATAL: Phase 2 end-to-end training requires real data.\n" "Download real data first: python3 scripts/download_all_data.py --stage 3\n" "Required: JSONL files with image_path, question, answer, and conditioning fields" ) # --------------------------------------------------------------------------- # LR scheduler # --------------------------------------------------------------------------- class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): """Linear warmup then cosine decay.""" def __init__(self, optimizer, warmup_steps: int, total_steps: int, min_lr: float = 1e-7, last_epoch: int = -1): self.warmup_steps = warmup_steps self.total_steps = total_steps self.min_lr = min_lr super().__init__(optimizer, last_epoch) def get_lr(self): step = self.last_epoch if step < self.warmup_steps: scale = step / max(1, self.warmup_steps) return [base_lr * scale for base_lr in self.base_lrs] else: progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps) cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) return [self.min_lr + (base_lr - self.min_lr) * cosine for base_lr in self.base_lrs] # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def setup_distributed(): dist.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) return local_rank def cleanup(): if dist.is_initialized(): dist.destroy_process_group() def is_rank0(): return not dist.is_initialized() or dist.get_rank() == 0 def log(msg: str): if is_rank0(): print(msg, flush=True) def save_checkpoint(state: dict, path: str): if not is_rank0(): return os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(state, path) log(f" Checkpoint saved: {path}") def push_to_hf(ckpt_path: str, repo: str): """Push a single checkpoint to HuggingFace Hub (rank 0 only).""" if not is_rank0(): return try: from huggingface_hub import HfApi, create_repo token = os.environ.get("HF_TOKEN") if not token: log(" [WARN] HF_TOKEN not set — skipping HuggingFace push") return api = HfApi(token=token) create_repo(repo, repo_type="model", exist_ok=True, token=token) api.upload_file( path_or_fileobj=ckpt_path, path_in_repo=ckpt_path, repo_id=repo, repo_type="model", token=token, ) log(f" Pushed to HuggingFace: {repo}/{ckpt_path}") except ImportError: log(" [WARN] huggingface_hub not installed — skipping push") except Exception as e: log(f" [WARN] HuggingFace push failed: {e}") # --------------------------------------------------------------------------- # Phase 1: MSE fitting to prototype LoRAs # --------------------------------------------------------------------------- def train_phase1(args, config, hn_config): """Phase 1: Train HyperNetwork to reconstruct prototype LoRA adapters.""" # ---- Distributed setup ---- local_rank = setup_distributed() world_size = dist.get_world_size() global_rank = dist.get_rank() device = torch.device(f"cuda:{local_rank}") # ---- Config ---- lora_cfg = hn_config["lora"] hn_cfg = hn_config["hypernetwork"] ce_cfg = hn_config["condition_encoder"] train_cfg = hn_config["train_hypernetwork"] lora_config = LoRAConfig( rank=lora_cfg["rank"], alpha=lora_cfg["alpha"], dropout=lora_cfg["dropout"], targets=tuple(lora_cfg["targets"]), ) decoder_cfg = config["decoder"] num_decoder_blocks = decoder_cfg["num_blocks"] decoder_embed_dim = decoder_cfg["hidden_dim"] per_gpu_batch = max(1, train_cfg["mse_batch_size"] // world_size) grad_accum = max(1, train_cfg["mse_batch_size"] // (per_gpu_batch * world_size)) max_epochs = train_cfg["mse_epochs"] lr = train_cfg["mse_lr"] calibration_weight = train_cfg["calibration_weight"] warmup_steps = max_epochs * 2 # Small warmup for MSE phase log("=" * 70) log("ArcisVLM — HyperNetwork Phase 1: MSE Fitting (DDP)") log("=" * 70) log(f" World size: {world_size}") log(f" Global batch: {train_cfg['mse_batch_size']}") log(f" Per-GPU batch: {per_gpu_batch}") log(f" Gradient accumulation:{grad_accum}") log(f" Effective batch: {per_gpu_batch * world_size * grad_accum}") log(f" Max epochs: {max_epochs}") log(f" Learning rate: {lr}") log(f" Calibration weight: {calibration_weight}") # ---- Build models ---- # ConditionEncoder scene_input_dim = ce_cfg.get("scene_input_dim", config["predictor"]["embed_dim"]) query_input_dim = ce_cfg.get("query_input_dim", config["predictor"]["embed_dim"]) cond_encoder = ConditionEncoder( n_cameras=ce_cfg["n_cameras"], camera_dim=ce_cfg["camera_dim"], scene_input_dim=scene_input_dim, scene_dim=ce_cfg["scene_dim"], query_input_dim=query_input_dim, query_dim=ce_cfg["query_dim"], out_dim=hn_cfg["cond_dim"], ).to(device) # HyperNetwork hypernetwork = HyperNetwork( cond_dim=hn_cfg["cond_dim"], hidden_dim=hn_cfg["hidden_dim"], lora_config=lora_config, num_decoder_blocks=num_decoder_blocks, decoder_embed_dim=decoder_embed_dim, ).to(device) lora_param_count = hypernetwork.num_generated_params log(f" LoRA param count: {lora_param_count:,}") log(f" HyperNetwork params: {hypernetwork.num_own_params:,}") log(f" CondEncoder params: {sum(p.numel() for p in cond_encoder.parameters()):,}") # ---- Dataset ---- cond_enc_data_cfg = { "n_cameras": ce_cfg["n_cameras"], "scene_input_dim": scene_input_dim, "query_input_dim": query_input_dim, } dataset = PrototypeDataset( prototype_dir=args.prototype_dir, cond_encoder_cfg=cond_enc_data_cfg, lora_param_count=lora_param_count, ) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=global_rank, shuffle=True) loader = DataLoader( dataset, batch_size=per_gpu_batch, sampler=sampler, num_workers=4, pin_memory=True, drop_last=True, ) log(f" Dataset: {len(dataset)} prototypes, {len(loader)} batches/GPU") if len(dataset) < 100: log(" [WARN] Using dummy prototypes — provide real ones via --prototype_dir") # ---- Optimizer ---- params = list(cond_encoder.parameters()) + list(hypernetwork.parameters()) optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=0.01) # ---- Scheduler ---- steps_per_epoch = max(1, len(loader) // grad_accum) total_steps = max_epochs * steps_per_epoch scheduler = CosineWarmupScheduler(optimizer, warmup_steps=warmup_steps, total_steps=total_steps) # ---- Mixed precision ---- use_bf16 = True scaler = torch.amp.GradScaler("cuda", enabled=False) autocast_dtype = torch.bfloat16 # ---- Resume ---- start_epoch = 0 global_step = 0 if args.resume and os.path.exists(args.resume): ckpt = torch.load(args.resume, map_location=device, weights_only=False) cond_encoder.load_state_dict(ckpt["cond_encoder_state_dict"]) hypernetwork.load_state_dict(ckpt["hypernetwork_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) if "scheduler_state_dict" in ckpt: scheduler.load_state_dict(ckpt["scheduler_state_dict"]) start_epoch = ckpt["epoch"] global_step = ckpt.get("global_step", start_epoch * steps_per_epoch) log(f" Resumed from {args.resume} (epoch {start_epoch}, loss {ckpt['loss']:.4f})") # ---- DDP wrap ---- cond_encoder = DDP(cond_encoder, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) hypernetwork = DDP(hypernetwork, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) # ---- Training loop ---- cond_encoder.train() hypernetwork.train() os.makedirs(args.output_dir, exist_ok=True) for epoch in range(start_epoch, max_epochs): sampler.set_epoch(epoch) epoch_mse = 0.0 epoch_cal = 0.0 epoch_total = 0.0 epoch_steps = 0 epoch_start = time.time() optimizer.zero_grad(set_to_none=True) for batch_idx, batch in enumerate(loader): camera_id = batch["camera_id"].to(device, non_blocking=True) scene_desc = batch["scene_descriptor"].to(device, non_blocking=True) query_emb = batch["query_embedding"].to(device, non_blocking=True) target_params = batch["lora_params"].to(device, non_blocking=True) with torch.amp.autocast("cuda", dtype=autocast_dtype): # Condition encoding condition = cond_encoder(camera_id, scene_desc, query_emb) # HyperNetwork forward pred_params, sigma = hypernetwork(condition) # MSE loss: match prototype LoRA params mse_loss = F.mse_loss(pred_params, target_params) # Calibration loss (HypeLoRA): sigma should predict MSE per_sample_mse = (pred_params - target_params).pow(2).mean(dim=-1, keepdim=True) cal_loss = F.mse_loss(sigma, per_sample_mse.detach()) loss = (mse_loss + calibration_weight * cal_loss) / grad_accum loss.backward() epoch_mse += mse_loss.item() epoch_cal += cal_loss.item() epoch_total += (mse_loss + calibration_weight * cal_loss).item() epoch_steps += 1 # Optimizer step every grad_accum batches if (batch_idx + 1) % grad_accum == 0: torch.nn.utils.clip_grad_norm_(params, 1.0) optimizer.step() scheduler.step() optimizer.zero_grad(set_to_none=True) global_step += 1 if global_step % 50 == 0: current_lr = scheduler.get_last_lr()[0] gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 confidence = hypernetwork.module.compute_confidence(sigma).mean().item() log(f" [Step {global_step}] mse={mse_loss.item():.6f} " f"cal={cal_loss.item():.6f} " f"conf={confidence:.3f} " f"lr={current_lr:.2e} GPU={gpu_mem:.1f}GB") # ---- Epoch summary ---- metrics = torch.tensor([epoch_mse, epoch_cal, epoch_total, epoch_steps], device=device, dtype=torch.float64) dist.all_reduce(metrics, op=dist.ReduceOp.SUM) avg_mse = (metrics[0] / metrics[3]).item() avg_cal = (metrics[1] / metrics[3]).item() avg_total = (metrics[2] / metrics[3]).item() epoch_time = time.time() - epoch_start current_lr = scheduler.get_last_lr()[0] gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 log(f"\nEpoch {epoch + 1}/{max_epochs}: mse={avg_mse:.6f} cal={avg_cal:.6f} " f"total={avg_total:.6f} lr={current_lr:.2e} time={epoch_time:.0f}s GPU={gpu_mem:.1f}GB") # ---- Checkpoint every 10 epochs ---- if (epoch + 1) % 10 == 0 or (epoch + 1) == max_epochs: ckpt_path = os.path.join(args.output_dir, f"hypernetwork_phase1_epoch{epoch + 1}.pt") hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict() ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict() save_checkpoint({ "epoch": epoch + 1, "global_step": global_step, "hypernetwork_state_dict": hn_state, "cond_encoder_state_dict": ce_state, "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": avg_total, "mse_loss": avg_mse, "cal_loss": avg_cal, "phase": 1, "lora_config": { "rank": lora_config.rank, "alpha": lora_config.alpha, "dropout": lora_config.dropout, "targets": list(lora_config.targets), }, }, ckpt_path) if args.hf_push: push_to_hf(ckpt_path, args.hf_push) dist.barrier() # ---- Final checkpoint ---- final_path = os.path.join(args.output_dir, "hypernetwork_phase1_final.pt") hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict() ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict() save_checkpoint({ "epoch": max_epochs, "global_step": global_step, "hypernetwork_state_dict": hn_state, "cond_encoder_state_dict": ce_state, "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": avg_total, "mse_loss": avg_mse, "cal_loss": avg_cal, "phase": 1, "lora_config": { "rank": lora_config.rank, "alpha": lora_config.alpha, "dropout": lora_config.dropout, "targets": list(lora_config.targets), }, }, final_path) if args.hf_push: push_to_hf(final_path, args.hf_push) log("\n" + "=" * 70) log(f"Phase 1 complete. Final MSE: {avg_mse:.6f} Cal: {avg_cal:.6f}") log("=" * 70) cleanup() # --------------------------------------------------------------------------- # Phase 2: End-to-end task loss # --------------------------------------------------------------------------- def train_phase2(args, config, hn_config): """Phase 2: Fine-tune HyperNetwork end-to-end on decode loss + calibration.""" # ---- Distributed setup ---- local_rank = setup_distributed() world_size = dist.get_world_size() global_rank = dist.get_rank() device = torch.device(f"cuda:{local_rank}") # ---- Config ---- lora_cfg = hn_config["lora"] hn_cfg = hn_config["hypernetwork"] ce_cfg = hn_config["condition_encoder"] train_cfg = hn_config["train_hypernetwork"] lora_config = LoRAConfig( rank=lora_cfg["rank"], alpha=lora_cfg["alpha"], dropout=lora_cfg["dropout"], targets=tuple(lora_cfg["targets"]), ) decoder_cfg = config["decoder"] num_decoder_blocks = decoder_cfg["num_blocks"] decoder_embed_dim = decoder_cfg["hidden_dim"] per_gpu_batch = max(1, train_cfg["e2e_batch_size"] // world_size) grad_accum = max(1, train_cfg["e2e_batch_size"] // (per_gpu_batch * world_size)) max_epochs = train_cfg["e2e_epochs"] lr = train_cfg["e2e_lr"] calibration_weight = train_cfg["calibration_weight"] warmup_steps = max_epochs # Minimal warmup for fine-tuning phase log("=" * 70) log("ArcisVLM — HyperNetwork Phase 2: End-to-End (DDP)") log("=" * 70) log(f" World size: {world_size}") log(f" Global batch: {train_cfg['e2e_batch_size']}") log(f" Per-GPU batch: {per_gpu_batch}") log(f" Gradient accumulation:{grad_accum}") log(f" Effective batch: {per_gpu_batch * world_size * grad_accum}") log(f" Max epochs: {max_epochs}") log(f" Learning rate: {lr}") log(f" Calibration weight: {calibration_weight}") # ---- Build VLM backbone (frozen) ---- vlm = VLJEPAModel(config).to(device) if os.path.exists(args.stage3_ckpt): ckpt = torch.load(args.stage3_ckpt, map_location=device, weights_only=False) vlm.load_state_dict(ckpt["model_state_dict"]) log(f" Loaded Stage 3 ckpt: {args.stage3_ckpt}") else: log(f" [WARN] Stage 3 checkpoint not found: {args.stage3_ckpt} — using random weights") # Freeze VLM backbone — only HyperNetwork + ConditionEncoder are trainable for param in vlm.parameters(): param.requires_grad = False vlm.eval() # ---- Build HyperNetwork + ConditionEncoder ---- scene_input_dim = ce_cfg.get("scene_input_dim", config["predictor"]["embed_dim"]) query_input_dim = ce_cfg.get("query_input_dim", config["predictor"]["embed_dim"]) cond_encoder = ConditionEncoder( n_cameras=ce_cfg["n_cameras"], camera_dim=ce_cfg["camera_dim"], scene_input_dim=scene_input_dim, scene_dim=ce_cfg["scene_dim"], query_input_dim=query_input_dim, query_dim=ce_cfg["query_dim"], out_dim=hn_cfg["cond_dim"], ).to(device) hypernetwork = HyperNetwork( cond_dim=hn_cfg["cond_dim"], hidden_dim=hn_cfg["hidden_dim"], lora_config=lora_config, num_decoder_blocks=num_decoder_blocks, decoder_embed_dim=decoder_embed_dim, ).to(device) log(f" LoRA param count: {hypernetwork.num_generated_params:,}") log(f" HyperNetwork params: {hypernetwork.num_own_params:,}") # ---- Load Phase 1 checkpoint ---- phase1_path = args.resume or os.path.join(args.output_dir, "hypernetwork_phase1_final.pt") if os.path.exists(phase1_path): ckpt = torch.load(phase1_path, map_location=device, weights_only=False) hypernetwork.load_state_dict(ckpt["hypernetwork_state_dict"]) cond_encoder.load_state_dict(ckpt["cond_encoder_state_dict"]) log(f" Loaded Phase 1 ckpt: {phase1_path} (loss {ckpt.get('loss', -1):.4f})") else: log(f" [WARN] Phase 1 checkpoint not found: {phase1_path} — training from scratch") # ---- Tokenizer ---- tokenizer = BPETokenizer(vocab_size=config["decoder"]["vocab_size"]) for tok_path in ["checkpoints/tokenizer_32k.json", "checkpoints/tokenizer.json"]: if os.path.exists(tok_path): tokenizer.load(tok_path) log(f" Tokenizer: {len(tokenizer)} tokens (from {tok_path})") break else: log(" [WARN] No tokenizer found — using untrained tokenizer") # ---- Dataset ---- (real data required, no dummy fallback) _require_e2e_dataset() dataset = None # unreachable — _require_e2e_dataset always raises sampler = DistributedSampler(dataset, num_replicas=world_size, rank=global_rank, shuffle=True) loader = DataLoader( dataset, batch_size=per_gpu_batch, sampler=sampler, num_workers=4, pin_memory=True, drop_last=True, ) log(f" Dataset: {len(dataset)} samples, {len(loader)} batches/GPU") # ---- Optimizer (only HyperNetwork + ConditionEncoder) ---- trainable_params = list(cond_encoder.parameters()) + list(hypernetwork.parameters()) optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=0.01) # ---- Scheduler ---- steps_per_epoch = max(1, len(loader) // grad_accum) total_steps = max_epochs * steps_per_epoch scheduler = CosineWarmupScheduler(optimizer, warmup_steps=warmup_steps, total_steps=total_steps) # ---- Mixed precision ---- use_bf16 = True autocast_dtype = torch.bfloat16 # ---- Resume Phase 2 checkpoint ---- start_epoch = 0 global_step = 0 if args.resume and os.path.exists(args.resume): ckpt = torch.load(args.resume, map_location=device, weights_only=False) if ckpt.get("phase") == 2: cond_encoder.load_state_dict(ckpt["cond_encoder_state_dict"]) hypernetwork.load_state_dict(ckpt["hypernetwork_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) if "scheduler_state_dict" in ckpt: scheduler.load_state_dict(ckpt["scheduler_state_dict"]) start_epoch = ckpt["epoch"] global_step = ckpt.get("global_step", start_epoch * steps_per_epoch) log(f" Resumed Phase 2 from {args.resume} (epoch {start_epoch}, loss {ckpt['loss']:.4f})") # ---- DDP wrap (HyperNetwork + ConditionEncoder only) ---- cond_encoder = DDP(cond_encoder, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) hypernetwork = DDP(hypernetwork, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=False) # ---- LoRA injector ---- injector = LoRAInjector(lora_config, num_decoder_blocks, decoder_embed_dim) # ---- Training loop ---- cond_encoder.train() hypernetwork.train() os.makedirs(args.output_dir, exist_ok=True) for epoch in range(start_epoch, max_epochs): sampler.set_epoch(epoch) epoch_decode = 0.0 epoch_cal = 0.0 epoch_total = 0.0 epoch_steps = 0 epoch_start = time.time() optimizer.zero_grad(set_to_none=True) for batch_idx, batch in enumerate(loader): images = batch["image"].to(device, non_blocking=True) q_ids = batch["question_ids"].to(device, non_blocking=True) q_mask = batch["question_mask"].to(device, non_blocking=True) a_ids = batch["answer_ids"].to(device, non_blocking=True) camera_id = batch["camera_id"].to(device, non_blocking=True) scene_desc = batch["scene_descriptor"].to(device, non_blocking=True) query_emb = batch["query_embedding"].to(device, non_blocking=True) with torch.amp.autocast("cuda", dtype=autocast_dtype): # Step 1: Generate conditioning vector condition = cond_encoder(camera_id, scene_desc, query_emb) # Step 2: HyperNetwork generates LoRA params lora_params, sigma = hypernetwork(condition) # Step 3: Inject LoRA into decoder and run forward # Process each sample in the batch (LoRA params are per-sample) batch_decode_loss = torch.tensor(0.0, device=device) B = images.shape[0] for i in range(B): # Create LoRA layers from this sample's params sample_params = lora_params[i] lora_layers = injector.create_lora_layers( sample_params.detach() if not sample_params.requires_grad else sample_params, device=device, ) # Inject LoRA into the (frozen-weights) decoder vlm.decoder.apply_lora(lora_layers) # Forward through VLM with LoRA active with torch.no_grad(): visual_tokens = vlm.x_encoder(images[i:i+1]) pred_embeds = vlm.predictor(visual_tokens, q_ids[i:i+1], q_mask[i:i+1]) logits, decode_loss = vlm.decoder(pred_embeds, a_ids[i:i+1]) batch_decode_loss = batch_decode_loss + decode_loss # Clean up LoRA for next sample vlm.decoder.clear_lora() batch_decode_loss = batch_decode_loss / B # Calibration loss: sigma should correlate with decode loss # Higher decode loss -> higher sigma (more uncertain) target_sigma = batch_decode_loss.detach().unsqueeze(0).expand(B, 1) cal_loss = F.mse_loss(sigma, target_sigma) loss = (batch_decode_loss + calibration_weight * cal_loss) / grad_accum loss.backward() epoch_decode += batch_decode_loss.item() epoch_cal += cal_loss.item() epoch_total += (batch_decode_loss + calibration_weight * cal_loss).item() epoch_steps += 1 # Optimizer step every grad_accum batches if (batch_idx + 1) % grad_accum == 0: torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) optimizer.step() scheduler.step() optimizer.zero_grad(set_to_none=True) global_step += 1 if global_step % 50 == 0: current_lr = scheduler.get_last_lr()[0] gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 confidence = hypernetwork.module.compute_confidence(sigma).mean().item() log(f" [Step {global_step}] decode={batch_decode_loss.item():.4f} " f"cal={cal_loss.item():.6f} " f"conf={confidence:.3f} " f"lr={current_lr:.2e} GPU={gpu_mem:.1f}GB") # ---- Epoch summary ---- metrics = torch.tensor([epoch_decode, epoch_cal, epoch_total, epoch_steps], device=device, dtype=torch.float64) dist.all_reduce(metrics, op=dist.ReduceOp.SUM) avg_decode = (metrics[0] / metrics[3]).item() avg_cal = (metrics[1] / metrics[3]).item() avg_total = (metrics[2] / metrics[3]).item() epoch_time = time.time() - epoch_start current_lr = scheduler.get_last_lr()[0] gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 log(f"\nEpoch {epoch + 1}/{max_epochs}: decode={avg_decode:.4f} cal={avg_cal:.6f} " f"total={avg_total:.4f} lr={current_lr:.2e} time={epoch_time:.0f}s GPU={gpu_mem:.1f}GB") # ---- Checkpoint every epoch ---- ckpt_path = os.path.join(args.output_dir, f"hypernetwork_phase2_epoch{epoch + 1}.pt") hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict() ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict() save_checkpoint({ "epoch": epoch + 1, "global_step": global_step, "hypernetwork_state_dict": hn_state, "cond_encoder_state_dict": ce_state, "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": avg_total, "decode_loss": avg_decode, "cal_loss": avg_cal, "phase": 2, "lora_config": { "rank": lora_config.rank, "alpha": lora_config.alpha, "dropout": lora_config.dropout, "targets": list(lora_config.targets), }, }, ckpt_path) if args.hf_push: push_to_hf(ckpt_path, args.hf_push) dist.barrier() # ---- Final checkpoint ---- final_path = os.path.join(args.output_dir, "hypernetwork_phase2_final.pt") hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict() ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict() save_checkpoint({ "epoch": max_epochs, "global_step": global_step, "hypernetwork_state_dict": hn_state, "cond_encoder_state_dict": ce_state, "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "loss": avg_total, "decode_loss": avg_decode, "cal_loss": avg_cal, "phase": 2, "lora_config": { "rank": lora_config.rank, "alpha": lora_config.alpha, "dropout": lora_config.dropout, "targets": list(lora_config.targets), }, }, final_path) if args.hf_push: push_to_hf(final_path, args.hf_push) log("\n" + "=" * 70) log(f"Phase 2 complete. Final decode: {avg_decode:.4f} cal: {avg_cal:.6f}") log("=" * 70) cleanup() # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="HyperNetwork DDP: SHINE-style Meta-Training") parser.add_argument("--config", type=str, required=True, help="Path to main YAML config (e.g., configs/default.yaml)") parser.add_argument("--stage3_ckpt", type=str, default="checkpoints/stage3_final.pt", help="Path to Stage 3 checkpoint (frozen VLM backbone for Phase 2)") parser.add_argument("--hn_config", type=str, default="configs/hypernetwork.yaml", help="Path to HyperNetwork YAML config") parser.add_argument("--prototype_dir", type=str, default="checkpoints/prototypes", help="Directory with prototype LoRA .pt files (Phase 1)") parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint to resume from") parser.add_argument("--output_dir", type=str, default="checkpoints", help="Output directory for checkpoints") parser.add_argument("--hf_push", type=str, default=None, help="HuggingFace repo to push checkpoints (e.g., hardiksa/arcisvlm)") parser.add_argument("--phase", type=int, required=True, choices=[1, 2], help="Training phase: 1=MSE fitting, 2=end-to-end") args = parser.parse_args() # ---- Load configs ---- with open(args.config) as f: config = yaml.safe_load(f) with open(args.hn_config) as f: hn_config = yaml.safe_load(f) if args.phase == 1: train_phase1(args, config, hn_config) else: train_phase2(args, config, hn_config) if __name__ == "__main__": main()