arcisvlm / scripts /train_hypernetwork.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
35.8 kB
#!/usr/bin/env python3
"""
HyperNetwork Meta-Training — SHINE-style two-phase DDP training.
Phase 1 (MSE fitting):
Train HyperNetwork to reconstruct prototype LoRA adapters from conditioning
vectors. Loss = MSE(generated_params, prototype_params) + calibration_loss.
Phase 2 (End-to-end):
Fine-tune HyperNetwork on task loss: condition -> HyperNetwork -> LoRA params
-> inject into decoder -> decode -> task loss (decode_loss + calibration).
Usage:
# Phase 1: MSE fitting to prototype LoRAs
torchrun --nproc_per_node=8 scripts/train_hypernetwork.py \
--config configs/default.yaml \
--hn_config configs/hypernetwork.yaml \
--stage3_ckpt checkpoints/stage3_final.pt \
--prototype_dir checkpoints/prototypes \
--phase 1
# Phase 2: End-to-end on task loss
torchrun --nproc_per_node=8 scripts/train_hypernetwork.py \
--config configs/default.yaml \
--hn_config configs/hypernetwork.yaml \
--stage3_ckpt checkpoints/stage3_final.pt \
--resume checkpoints/hypernetwork_phase1_final.pt \
--phase 2
References:
- SHINE: Scalable Hypernetwork Internalization (arXiv: 2602.28901)
- HypeLoRA: Calibrated LoRA with uncertainty (arXiv: 2603.19278)
- HyperVLA: Mother spawns compact Child policies (arXiv: 2510.04898)
"""
import argparse
import math
import os
import sys
import time
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
import yaml
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from model.hypernetwork import HyperNetwork
from model.condition_encoder import ConditionEncoder
from model.lora import LoRAConfig, LoRAInjector
from model.vlm import VLJEPAModel
from model.tokenizer import BPETokenizer
# ---------------------------------------------------------------------------
# Dataset helpers
# ---------------------------------------------------------------------------
class PrototypeDataset(Dataset):
"""
Dataset of prototype LoRA adapters for Phase 1 MSE training.
Each sample contains:
- A conditioning vector (camera_id, scene_descriptor, query_embedding)
- The corresponding prototype LoRA parameter vector (flat .pt file)
If no real prototypes exist, generates dummy data for testing.
"""
def __init__(
self,
prototype_dir: str,
cond_encoder_cfg: dict,
lora_param_count: int,
num_dummy: int = 2000,
):
self.lora_param_count = lora_param_count
self.samples: list[dict] = []
scene_input_dim = cond_encoder_cfg.get("scene_input_dim", 2048)
query_input_dim = cond_encoder_cfg.get("query_input_dim", 2048)
n_cameras = cond_encoder_cfg.get("n_cameras", 2048)
# Try loading real prototypes
if prototype_dir and os.path.isdir(prototype_dir):
pt_files = sorted(
f for f in os.listdir(prototype_dir) if f.endswith(".pt")
)
for fname in pt_files:
fpath = os.path.join(prototype_dir, fname)
try:
proto = torch.load(fpath, map_location="cpu", weights_only=True)
# Expect dict with 'lora_params' and optionally 'condition'
if isinstance(proto, dict) and "lora_params" in proto:
lora_params = proto["lora_params"].flatten()
if lora_params.numel() == lora_param_count:
sample = {
"lora_params": lora_params.float(),
"camera_id": proto.get("camera_id", torch.randint(0, n_cameras, (1,)).item()),
"scene_descriptor": proto.get("scene_descriptor", torch.randn(scene_input_dim)),
"query_embedding": proto.get("query_embedding", torch.randn(query_input_dim)),
}
self.samples.append(sample)
elif isinstance(proto, torch.Tensor):
flat = proto.flatten()
if flat.numel() == lora_param_count:
self.samples.append({
"lora_params": flat.float(),
"camera_id": torch.randint(0, n_cameras, (1,)).item(),
"scene_descriptor": torch.randn(scene_input_dim),
"query_embedding": torch.randn(query_input_dim),
})
except Exception:
continue
if len(self.samples) > 0:
return
raise RuntimeError(
f"FATAL: No prototype LoRA files found in '{prototype_dir}'.\n"
"Train prototypes first: python3 scripts/train_prototype_loras.py\n"
"Required: .pt files with 'lora_params' key and matching param count"
)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
s = self.samples[idx]
return {
"lora_params": s["lora_params"],
"camera_id": torch.tensor(s["camera_id"], dtype=torch.long),
"scene_descriptor": s["scene_descriptor"].float(),
"query_embedding": s["query_embedding"].float(),
}
def _require_e2e_dataset():
"""Phase 2 end-to-end training requires real data."""
raise RuntimeError(
"FATAL: Phase 2 end-to-end training requires real data.\n"
"Download real data first: python3 scripts/download_all_data.py --stage 3\n"
"Required: JSONL files with image_path, question, answer, and conditioning fields"
)
# ---------------------------------------------------------------------------
# LR scheduler
# ---------------------------------------------------------------------------
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
"""Linear warmup then cosine decay."""
def __init__(self, optimizer, warmup_steps: int, total_steps: int,
min_lr: float = 1e-7, last_epoch: int = -1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
scale = step / max(1, self.warmup_steps)
return [base_lr * scale for base_lr in self.base_lrs]
else:
progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
return [self.min_lr + (base_lr - self.min_lr) * cosine for base_lr in self.base_lrs]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def setup_distributed():
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
return local_rank
def cleanup():
if dist.is_initialized():
dist.destroy_process_group()
def is_rank0():
return not dist.is_initialized() or dist.get_rank() == 0
def log(msg: str):
if is_rank0():
print(msg, flush=True)
def save_checkpoint(state: dict, path: str):
if not is_rank0():
return
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(state, path)
log(f" Checkpoint saved: {path}")
def push_to_hf(ckpt_path: str, repo: str):
"""Push a single checkpoint to HuggingFace Hub (rank 0 only)."""
if not is_rank0():
return
try:
from huggingface_hub import HfApi, create_repo
token = os.environ.get("HF_TOKEN")
if not token:
log(" [WARN] HF_TOKEN not set — skipping HuggingFace push")
return
api = HfApi(token=token)
create_repo(repo, repo_type="model", exist_ok=True, token=token)
api.upload_file(
path_or_fileobj=ckpt_path,
path_in_repo=ckpt_path,
repo_id=repo,
repo_type="model",
token=token,
)
log(f" Pushed to HuggingFace: {repo}/{ckpt_path}")
except ImportError:
log(" [WARN] huggingface_hub not installed — skipping push")
except Exception as e:
log(f" [WARN] HuggingFace push failed: {e}")
# ---------------------------------------------------------------------------
# Phase 1: MSE fitting to prototype LoRAs
# ---------------------------------------------------------------------------
def train_phase1(args, config, hn_config):
"""Phase 1: Train HyperNetwork to reconstruct prototype LoRA adapters."""
# ---- Distributed setup ----
local_rank = setup_distributed()
world_size = dist.get_world_size()
global_rank = dist.get_rank()
device = torch.device(f"cuda:{local_rank}")
# ---- Config ----
lora_cfg = hn_config["lora"]
hn_cfg = hn_config["hypernetwork"]
ce_cfg = hn_config["condition_encoder"]
train_cfg = hn_config["train_hypernetwork"]
lora_config = LoRAConfig(
rank=lora_cfg["rank"],
alpha=lora_cfg["alpha"],
dropout=lora_cfg["dropout"],
targets=tuple(lora_cfg["targets"]),
)
decoder_cfg = config["decoder"]
num_decoder_blocks = decoder_cfg["num_blocks"]
decoder_embed_dim = decoder_cfg["hidden_dim"]
per_gpu_batch = max(1, train_cfg["mse_batch_size"] // world_size)
grad_accum = max(1, train_cfg["mse_batch_size"] // (per_gpu_batch * world_size))
max_epochs = train_cfg["mse_epochs"]
lr = train_cfg["mse_lr"]
calibration_weight = train_cfg["calibration_weight"]
warmup_steps = max_epochs * 2 # Small warmup for MSE phase
log("=" * 70)
log("ArcisVLM — HyperNetwork Phase 1: MSE Fitting (DDP)")
log("=" * 70)
log(f" World size: {world_size}")
log(f" Global batch: {train_cfg['mse_batch_size']}")
log(f" Per-GPU batch: {per_gpu_batch}")
log(f" Gradient accumulation:{grad_accum}")
log(f" Effective batch: {per_gpu_batch * world_size * grad_accum}")
log(f" Max epochs: {max_epochs}")
log(f" Learning rate: {lr}")
log(f" Calibration weight: {calibration_weight}")
# ---- Build models ----
# ConditionEncoder
scene_input_dim = ce_cfg.get("scene_input_dim", config["predictor"]["embed_dim"])
query_input_dim = ce_cfg.get("query_input_dim", config["predictor"]["embed_dim"])
cond_encoder = ConditionEncoder(
n_cameras=ce_cfg["n_cameras"],
camera_dim=ce_cfg["camera_dim"],
scene_input_dim=scene_input_dim,
scene_dim=ce_cfg["scene_dim"],
query_input_dim=query_input_dim,
query_dim=ce_cfg["query_dim"],
out_dim=hn_cfg["cond_dim"],
).to(device)
# HyperNetwork
hypernetwork = HyperNetwork(
cond_dim=hn_cfg["cond_dim"],
hidden_dim=hn_cfg["hidden_dim"],
lora_config=lora_config,
num_decoder_blocks=num_decoder_blocks,
decoder_embed_dim=decoder_embed_dim,
).to(device)
lora_param_count = hypernetwork.num_generated_params
log(f" LoRA param count: {lora_param_count:,}")
log(f" HyperNetwork params: {hypernetwork.num_own_params:,}")
log(f" CondEncoder params: {sum(p.numel() for p in cond_encoder.parameters()):,}")
# ---- Dataset ----
cond_enc_data_cfg = {
"n_cameras": ce_cfg["n_cameras"],
"scene_input_dim": scene_input_dim,
"query_input_dim": query_input_dim,
}
dataset = PrototypeDataset(
prototype_dir=args.prototype_dir,
cond_encoder_cfg=cond_enc_data_cfg,
lora_param_count=lora_param_count,
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=global_rank, shuffle=True)
loader = DataLoader(
dataset,
batch_size=per_gpu_batch,
sampler=sampler,
num_workers=4,
pin_memory=True,
drop_last=True,
)
log(f" Dataset: {len(dataset)} prototypes, {len(loader)} batches/GPU")
if len(dataset) < 100:
log(" [WARN] Using dummy prototypes — provide real ones via --prototype_dir")
# ---- Optimizer ----
params = list(cond_encoder.parameters()) + list(hypernetwork.parameters())
optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=0.01)
# ---- Scheduler ----
steps_per_epoch = max(1, len(loader) // grad_accum)
total_steps = max_epochs * steps_per_epoch
scheduler = CosineWarmupScheduler(optimizer, warmup_steps=warmup_steps, total_steps=total_steps)
# ---- Mixed precision ----
use_bf16 = True
scaler = torch.amp.GradScaler("cuda", enabled=False)
autocast_dtype = torch.bfloat16
# ---- Resume ----
start_epoch = 0
global_step = 0
if args.resume and os.path.exists(args.resume):
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
cond_encoder.load_state_dict(ckpt["cond_encoder_state_dict"])
hypernetwork.load_state_dict(ckpt["hypernetwork_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
if "scheduler_state_dict" in ckpt:
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
start_epoch = ckpt["epoch"]
global_step = ckpt.get("global_step", start_epoch * steps_per_epoch)
log(f" Resumed from {args.resume} (epoch {start_epoch}, loss {ckpt['loss']:.4f})")
# ---- DDP wrap ----
cond_encoder = DDP(cond_encoder, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
hypernetwork = DDP(hypernetwork, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
# ---- Training loop ----
cond_encoder.train()
hypernetwork.train()
os.makedirs(args.output_dir, exist_ok=True)
for epoch in range(start_epoch, max_epochs):
sampler.set_epoch(epoch)
epoch_mse = 0.0
epoch_cal = 0.0
epoch_total = 0.0
epoch_steps = 0
epoch_start = time.time()
optimizer.zero_grad(set_to_none=True)
for batch_idx, batch in enumerate(loader):
camera_id = batch["camera_id"].to(device, non_blocking=True)
scene_desc = batch["scene_descriptor"].to(device, non_blocking=True)
query_emb = batch["query_embedding"].to(device, non_blocking=True)
target_params = batch["lora_params"].to(device, non_blocking=True)
with torch.amp.autocast("cuda", dtype=autocast_dtype):
# Condition encoding
condition = cond_encoder(camera_id, scene_desc, query_emb)
# HyperNetwork forward
pred_params, sigma = hypernetwork(condition)
# MSE loss: match prototype LoRA params
mse_loss = F.mse_loss(pred_params, target_params)
# Calibration loss (HypeLoRA): sigma should predict MSE
per_sample_mse = (pred_params - target_params).pow(2).mean(dim=-1, keepdim=True)
cal_loss = F.mse_loss(sigma, per_sample_mse.detach())
loss = (mse_loss + calibration_weight * cal_loss) / grad_accum
loss.backward()
epoch_mse += mse_loss.item()
epoch_cal += cal_loss.item()
epoch_total += (mse_loss + calibration_weight * cal_loss).item()
epoch_steps += 1
# Optimizer step every grad_accum batches
if (batch_idx + 1) % grad_accum == 0:
torch.nn.utils.clip_grad_norm_(params, 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
if global_step % 50 == 0:
current_lr = scheduler.get_last_lr()[0]
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
confidence = hypernetwork.module.compute_confidence(sigma).mean().item()
log(f" [Step {global_step}] mse={mse_loss.item():.6f} "
f"cal={cal_loss.item():.6f} "
f"conf={confidence:.3f} "
f"lr={current_lr:.2e} GPU={gpu_mem:.1f}GB")
# ---- Epoch summary ----
metrics = torch.tensor([epoch_mse, epoch_cal, epoch_total, epoch_steps],
device=device, dtype=torch.float64)
dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
avg_mse = (metrics[0] / metrics[3]).item()
avg_cal = (metrics[1] / metrics[3]).item()
avg_total = (metrics[2] / metrics[3]).item()
epoch_time = time.time() - epoch_start
current_lr = scheduler.get_last_lr()[0]
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
log(f"\nEpoch {epoch + 1}/{max_epochs}: mse={avg_mse:.6f} cal={avg_cal:.6f} "
f"total={avg_total:.6f} lr={current_lr:.2e} time={epoch_time:.0f}s GPU={gpu_mem:.1f}GB")
# ---- Checkpoint every 10 epochs ----
if (epoch + 1) % 10 == 0 or (epoch + 1) == max_epochs:
ckpt_path = os.path.join(args.output_dir, f"hypernetwork_phase1_epoch{epoch + 1}.pt")
hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict()
ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict()
save_checkpoint({
"epoch": epoch + 1,
"global_step": global_step,
"hypernetwork_state_dict": hn_state,
"cond_encoder_state_dict": ce_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss": avg_total,
"mse_loss": avg_mse,
"cal_loss": avg_cal,
"phase": 1,
"lora_config": {
"rank": lora_config.rank,
"alpha": lora_config.alpha,
"dropout": lora_config.dropout,
"targets": list(lora_config.targets),
},
}, ckpt_path)
if args.hf_push:
push_to_hf(ckpt_path, args.hf_push)
dist.barrier()
# ---- Final checkpoint ----
final_path = os.path.join(args.output_dir, "hypernetwork_phase1_final.pt")
hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict()
ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict()
save_checkpoint({
"epoch": max_epochs,
"global_step": global_step,
"hypernetwork_state_dict": hn_state,
"cond_encoder_state_dict": ce_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss": avg_total,
"mse_loss": avg_mse,
"cal_loss": avg_cal,
"phase": 1,
"lora_config": {
"rank": lora_config.rank,
"alpha": lora_config.alpha,
"dropout": lora_config.dropout,
"targets": list(lora_config.targets),
},
}, final_path)
if args.hf_push:
push_to_hf(final_path, args.hf_push)
log("\n" + "=" * 70)
log(f"Phase 1 complete. Final MSE: {avg_mse:.6f} Cal: {avg_cal:.6f}")
log("=" * 70)
cleanup()
# ---------------------------------------------------------------------------
# Phase 2: End-to-end task loss
# ---------------------------------------------------------------------------
def train_phase2(args, config, hn_config):
"""Phase 2: Fine-tune HyperNetwork end-to-end on decode loss + calibration."""
# ---- Distributed setup ----
local_rank = setup_distributed()
world_size = dist.get_world_size()
global_rank = dist.get_rank()
device = torch.device(f"cuda:{local_rank}")
# ---- Config ----
lora_cfg = hn_config["lora"]
hn_cfg = hn_config["hypernetwork"]
ce_cfg = hn_config["condition_encoder"]
train_cfg = hn_config["train_hypernetwork"]
lora_config = LoRAConfig(
rank=lora_cfg["rank"],
alpha=lora_cfg["alpha"],
dropout=lora_cfg["dropout"],
targets=tuple(lora_cfg["targets"]),
)
decoder_cfg = config["decoder"]
num_decoder_blocks = decoder_cfg["num_blocks"]
decoder_embed_dim = decoder_cfg["hidden_dim"]
per_gpu_batch = max(1, train_cfg["e2e_batch_size"] // world_size)
grad_accum = max(1, train_cfg["e2e_batch_size"] // (per_gpu_batch * world_size))
max_epochs = train_cfg["e2e_epochs"]
lr = train_cfg["e2e_lr"]
calibration_weight = train_cfg["calibration_weight"]
warmup_steps = max_epochs # Minimal warmup for fine-tuning phase
log("=" * 70)
log("ArcisVLM — HyperNetwork Phase 2: End-to-End (DDP)")
log("=" * 70)
log(f" World size: {world_size}")
log(f" Global batch: {train_cfg['e2e_batch_size']}")
log(f" Per-GPU batch: {per_gpu_batch}")
log(f" Gradient accumulation:{grad_accum}")
log(f" Effective batch: {per_gpu_batch * world_size * grad_accum}")
log(f" Max epochs: {max_epochs}")
log(f" Learning rate: {lr}")
log(f" Calibration weight: {calibration_weight}")
# ---- Build VLM backbone (frozen) ----
vlm = VLJEPAModel(config).to(device)
if os.path.exists(args.stage3_ckpt):
ckpt = torch.load(args.stage3_ckpt, map_location=device, weights_only=False)
vlm.load_state_dict(ckpt["model_state_dict"])
log(f" Loaded Stage 3 ckpt: {args.stage3_ckpt}")
else:
log(f" [WARN] Stage 3 checkpoint not found: {args.stage3_ckpt} — using random weights")
# Freeze VLM backbone — only HyperNetwork + ConditionEncoder are trainable
for param in vlm.parameters():
param.requires_grad = False
vlm.eval()
# ---- Build HyperNetwork + ConditionEncoder ----
scene_input_dim = ce_cfg.get("scene_input_dim", config["predictor"]["embed_dim"])
query_input_dim = ce_cfg.get("query_input_dim", config["predictor"]["embed_dim"])
cond_encoder = ConditionEncoder(
n_cameras=ce_cfg["n_cameras"],
camera_dim=ce_cfg["camera_dim"],
scene_input_dim=scene_input_dim,
scene_dim=ce_cfg["scene_dim"],
query_input_dim=query_input_dim,
query_dim=ce_cfg["query_dim"],
out_dim=hn_cfg["cond_dim"],
).to(device)
hypernetwork = HyperNetwork(
cond_dim=hn_cfg["cond_dim"],
hidden_dim=hn_cfg["hidden_dim"],
lora_config=lora_config,
num_decoder_blocks=num_decoder_blocks,
decoder_embed_dim=decoder_embed_dim,
).to(device)
log(f" LoRA param count: {hypernetwork.num_generated_params:,}")
log(f" HyperNetwork params: {hypernetwork.num_own_params:,}")
# ---- Load Phase 1 checkpoint ----
phase1_path = args.resume or os.path.join(args.output_dir, "hypernetwork_phase1_final.pt")
if os.path.exists(phase1_path):
ckpt = torch.load(phase1_path, map_location=device, weights_only=False)
hypernetwork.load_state_dict(ckpt["hypernetwork_state_dict"])
cond_encoder.load_state_dict(ckpt["cond_encoder_state_dict"])
log(f" Loaded Phase 1 ckpt: {phase1_path} (loss {ckpt.get('loss', -1):.4f})")
else:
log(f" [WARN] Phase 1 checkpoint not found: {phase1_path} — training from scratch")
# ---- Tokenizer ----
tokenizer = BPETokenizer(vocab_size=config["decoder"]["vocab_size"])
for tok_path in ["checkpoints/tokenizer_32k.json", "checkpoints/tokenizer.json"]:
if os.path.exists(tok_path):
tokenizer.load(tok_path)
log(f" Tokenizer: {len(tokenizer)} tokens (from {tok_path})")
break
else:
log(" [WARN] No tokenizer found — using untrained tokenizer")
# ---- Dataset ---- (real data required, no dummy fallback)
_require_e2e_dataset()
dataset = None # unreachable — _require_e2e_dataset always raises
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=global_rank, shuffle=True)
loader = DataLoader(
dataset,
batch_size=per_gpu_batch,
sampler=sampler,
num_workers=4,
pin_memory=True,
drop_last=True,
)
log(f" Dataset: {len(dataset)} samples, {len(loader)} batches/GPU")
# ---- Optimizer (only HyperNetwork + ConditionEncoder) ----
trainable_params = list(cond_encoder.parameters()) + list(hypernetwork.parameters())
optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=0.01)
# ---- Scheduler ----
steps_per_epoch = max(1, len(loader) // grad_accum)
total_steps = max_epochs * steps_per_epoch
scheduler = CosineWarmupScheduler(optimizer, warmup_steps=warmup_steps, total_steps=total_steps)
# ---- Mixed precision ----
use_bf16 = True
autocast_dtype = torch.bfloat16
# ---- Resume Phase 2 checkpoint ----
start_epoch = 0
global_step = 0
if args.resume and os.path.exists(args.resume):
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
if ckpt.get("phase") == 2:
cond_encoder.load_state_dict(ckpt["cond_encoder_state_dict"])
hypernetwork.load_state_dict(ckpt["hypernetwork_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
if "scheduler_state_dict" in ckpt:
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
start_epoch = ckpt["epoch"]
global_step = ckpt.get("global_step", start_epoch * steps_per_epoch)
log(f" Resumed Phase 2 from {args.resume} (epoch {start_epoch}, loss {ckpt['loss']:.4f})")
# ---- DDP wrap (HyperNetwork + ConditionEncoder only) ----
cond_encoder = DDP(cond_encoder, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
hypernetwork = DDP(hypernetwork, device_ids=[local_rank], output_device=local_rank,
find_unused_parameters=False)
# ---- LoRA injector ----
injector = LoRAInjector(lora_config, num_decoder_blocks, decoder_embed_dim)
# ---- Training loop ----
cond_encoder.train()
hypernetwork.train()
os.makedirs(args.output_dir, exist_ok=True)
for epoch in range(start_epoch, max_epochs):
sampler.set_epoch(epoch)
epoch_decode = 0.0
epoch_cal = 0.0
epoch_total = 0.0
epoch_steps = 0
epoch_start = time.time()
optimizer.zero_grad(set_to_none=True)
for batch_idx, batch in enumerate(loader):
images = batch["image"].to(device, non_blocking=True)
q_ids = batch["question_ids"].to(device, non_blocking=True)
q_mask = batch["question_mask"].to(device, non_blocking=True)
a_ids = batch["answer_ids"].to(device, non_blocking=True)
camera_id = batch["camera_id"].to(device, non_blocking=True)
scene_desc = batch["scene_descriptor"].to(device, non_blocking=True)
query_emb = batch["query_embedding"].to(device, non_blocking=True)
with torch.amp.autocast("cuda", dtype=autocast_dtype):
# Step 1: Generate conditioning vector
condition = cond_encoder(camera_id, scene_desc, query_emb)
# Step 2: HyperNetwork generates LoRA params
lora_params, sigma = hypernetwork(condition)
# Step 3: Inject LoRA into decoder and run forward
# Process each sample in the batch (LoRA params are per-sample)
batch_decode_loss = torch.tensor(0.0, device=device)
B = images.shape[0]
for i in range(B):
# Create LoRA layers from this sample's params
sample_params = lora_params[i]
lora_layers = injector.create_lora_layers(
sample_params.detach() if not sample_params.requires_grad else sample_params,
device=device,
)
# Inject LoRA into the (frozen-weights) decoder
vlm.decoder.apply_lora(lora_layers)
# Forward through VLM with LoRA active
with torch.no_grad():
visual_tokens = vlm.x_encoder(images[i:i+1])
pred_embeds = vlm.predictor(visual_tokens, q_ids[i:i+1], q_mask[i:i+1])
logits, decode_loss = vlm.decoder(pred_embeds, a_ids[i:i+1])
batch_decode_loss = batch_decode_loss + decode_loss
# Clean up LoRA for next sample
vlm.decoder.clear_lora()
batch_decode_loss = batch_decode_loss / B
# Calibration loss: sigma should correlate with decode loss
# Higher decode loss -> higher sigma (more uncertain)
target_sigma = batch_decode_loss.detach().unsqueeze(0).expand(B, 1)
cal_loss = F.mse_loss(sigma, target_sigma)
loss = (batch_decode_loss + calibration_weight * cal_loss) / grad_accum
loss.backward()
epoch_decode += batch_decode_loss.item()
epoch_cal += cal_loss.item()
epoch_total += (batch_decode_loss + calibration_weight * cal_loss).item()
epoch_steps += 1
# Optimizer step every grad_accum batches
if (batch_idx + 1) % grad_accum == 0:
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
if global_step % 50 == 0:
current_lr = scheduler.get_last_lr()[0]
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
confidence = hypernetwork.module.compute_confidence(sigma).mean().item()
log(f" [Step {global_step}] decode={batch_decode_loss.item():.4f} "
f"cal={cal_loss.item():.6f} "
f"conf={confidence:.3f} "
f"lr={current_lr:.2e} GPU={gpu_mem:.1f}GB")
# ---- Epoch summary ----
metrics = torch.tensor([epoch_decode, epoch_cal, epoch_total, epoch_steps],
device=device, dtype=torch.float64)
dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
avg_decode = (metrics[0] / metrics[3]).item()
avg_cal = (metrics[1] / metrics[3]).item()
avg_total = (metrics[2] / metrics[3]).item()
epoch_time = time.time() - epoch_start
current_lr = scheduler.get_last_lr()[0]
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
log(f"\nEpoch {epoch + 1}/{max_epochs}: decode={avg_decode:.4f} cal={avg_cal:.6f} "
f"total={avg_total:.4f} lr={current_lr:.2e} time={epoch_time:.0f}s GPU={gpu_mem:.1f}GB")
# ---- Checkpoint every epoch ----
ckpt_path = os.path.join(args.output_dir, f"hypernetwork_phase2_epoch{epoch + 1}.pt")
hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict()
ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict()
save_checkpoint({
"epoch": epoch + 1,
"global_step": global_step,
"hypernetwork_state_dict": hn_state,
"cond_encoder_state_dict": ce_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss": avg_total,
"decode_loss": avg_decode,
"cal_loss": avg_cal,
"phase": 2,
"lora_config": {
"rank": lora_config.rank,
"alpha": lora_config.alpha,
"dropout": lora_config.dropout,
"targets": list(lora_config.targets),
},
}, ckpt_path)
if args.hf_push:
push_to_hf(ckpt_path, args.hf_push)
dist.barrier()
# ---- Final checkpoint ----
final_path = os.path.join(args.output_dir, "hypernetwork_phase2_final.pt")
hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict()
ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict()
save_checkpoint({
"epoch": max_epochs,
"global_step": global_step,
"hypernetwork_state_dict": hn_state,
"cond_encoder_state_dict": ce_state,
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss": avg_total,
"decode_loss": avg_decode,
"cal_loss": avg_cal,
"phase": 2,
"lora_config": {
"rank": lora_config.rank,
"alpha": lora_config.alpha,
"dropout": lora_config.dropout,
"targets": list(lora_config.targets),
},
}, final_path)
if args.hf_push:
push_to_hf(final_path, args.hf_push)
log("\n" + "=" * 70)
log(f"Phase 2 complete. Final decode: {avg_decode:.4f} cal: {avg_cal:.6f}")
log("=" * 70)
cleanup()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="HyperNetwork DDP: SHINE-style Meta-Training")
parser.add_argument("--config", type=str, required=True,
help="Path to main YAML config (e.g., configs/default.yaml)")
parser.add_argument("--stage3_ckpt", type=str, default="checkpoints/stage3_final.pt",
help="Path to Stage 3 checkpoint (frozen VLM backbone for Phase 2)")
parser.add_argument("--hn_config", type=str, default="configs/hypernetwork.yaml",
help="Path to HyperNetwork YAML config")
parser.add_argument("--prototype_dir", type=str, default="checkpoints/prototypes",
help="Directory with prototype LoRA .pt files (Phase 1)")
parser.add_argument("--resume", type=str, default=None,
help="Path to checkpoint to resume from")
parser.add_argument("--output_dir", type=str, default="checkpoints",
help="Output directory for checkpoints")
parser.add_argument("--hf_push", type=str, default=None,
help="HuggingFace repo to push checkpoints (e.g., hardiksa/arcisvlm)")
parser.add_argument("--phase", type=int, required=True, choices=[1, 2],
help="Training phase: 1=MSE fitting, 2=end-to-end")
args = parser.parse_args()
# ---- Load configs ----
with open(args.config) as f:
config = yaml.safe_load(f)
with open(args.hn_config) as f:
hn_config = yaml.safe_load(f)
if args.phase == 1:
train_phase1(args, config, hn_config)
else:
train_phase2(args, config, hn_config)
if __name__ == "__main__":
main()