Image-Text-to-Text
Transformers
English
vision-language-model
vlm
surveillance
iot
gemma
vl-jepa
multimodal
object-detection
video-analytics
Instructions to use hardiksa/arcisvlm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use hardiksa/arcisvlm with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-text-to-text", model="hardiksa/arcisvlm")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("hardiksa/arcisvlm", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use hardiksa/arcisvlm with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "hardiksa/arcisvlm" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "hardiksa/arcisvlm", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/hardiksa/arcisvlm
- SGLang
How to use hardiksa/arcisvlm with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "hardiksa/arcisvlm" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "hardiksa/arcisvlm", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "hardiksa/arcisvlm" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "hardiksa/arcisvlm", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use hardiksa/arcisvlm with Docker Model Runner:
docker model run hf.co/hardiksa/arcisvlm
| #!/usr/bin/env python3 | |
| """ | |
| HyperNetwork Meta-Training — SHINE-style two-phase DDP training. | |
| Phase 1 (MSE fitting): | |
| Train HyperNetwork to reconstruct prototype LoRA adapters from conditioning | |
| vectors. Loss = MSE(generated_params, prototype_params) + calibration_loss. | |
| Phase 2 (End-to-end): | |
| Fine-tune HyperNetwork on task loss: condition -> HyperNetwork -> LoRA params | |
| -> inject into decoder -> decode -> task loss (decode_loss + calibration). | |
| Usage: | |
| # Phase 1: MSE fitting to prototype LoRAs | |
| torchrun --nproc_per_node=8 scripts/train_hypernetwork.py \ | |
| --config configs/default.yaml \ | |
| --hn_config configs/hypernetwork.yaml \ | |
| --stage3_ckpt checkpoints/stage3_final.pt \ | |
| --prototype_dir checkpoints/prototypes \ | |
| --phase 1 | |
| # Phase 2: End-to-end on task loss | |
| torchrun --nproc_per_node=8 scripts/train_hypernetwork.py \ | |
| --config configs/default.yaml \ | |
| --hn_config configs/hypernetwork.yaml \ | |
| --stage3_ckpt checkpoints/stage3_final.pt \ | |
| --resume checkpoints/hypernetwork_phase1_final.pt \ | |
| --phase 2 | |
| References: | |
| - SHINE: Scalable Hypernetwork Internalization (arXiv: 2602.28901) | |
| - HypeLoRA: Calibrated LoRA with uncertainty (arXiv: 2603.19278) | |
| - HyperVLA: Mother spawns compact Child policies (arXiv: 2510.04898) | |
| """ | |
| import argparse | |
| import math | |
| import os | |
| import sys | |
| import time | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.utils.data.distributed import DistributedSampler | |
| import yaml | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from model.hypernetwork import HyperNetwork | |
| from model.condition_encoder import ConditionEncoder | |
| from model.lora import LoRAConfig, LoRAInjector | |
| from model.vlm import VLJEPAModel | |
| from model.tokenizer import BPETokenizer | |
| # --------------------------------------------------------------------------- | |
| # Dataset helpers | |
| # --------------------------------------------------------------------------- | |
| class PrototypeDataset(Dataset): | |
| """ | |
| Dataset of prototype LoRA adapters for Phase 1 MSE training. | |
| Each sample contains: | |
| - A conditioning vector (camera_id, scene_descriptor, query_embedding) | |
| - The corresponding prototype LoRA parameter vector (flat .pt file) | |
| If no real prototypes exist, generates dummy data for testing. | |
| """ | |
| def __init__( | |
| self, | |
| prototype_dir: str, | |
| cond_encoder_cfg: dict, | |
| lora_param_count: int, | |
| num_dummy: int = 2000, | |
| ): | |
| self.lora_param_count = lora_param_count | |
| self.samples: list[dict] = [] | |
| scene_input_dim = cond_encoder_cfg.get("scene_input_dim", 2048) | |
| query_input_dim = cond_encoder_cfg.get("query_input_dim", 2048) | |
| n_cameras = cond_encoder_cfg.get("n_cameras", 2048) | |
| # Try loading real prototypes | |
| if prototype_dir and os.path.isdir(prototype_dir): | |
| pt_files = sorted( | |
| f for f in os.listdir(prototype_dir) if f.endswith(".pt") | |
| ) | |
| for fname in pt_files: | |
| fpath = os.path.join(prototype_dir, fname) | |
| try: | |
| proto = torch.load(fpath, map_location="cpu", weights_only=True) | |
| # Expect dict with 'lora_params' and optionally 'condition' | |
| if isinstance(proto, dict) and "lora_params" in proto: | |
| lora_params = proto["lora_params"].flatten() | |
| if lora_params.numel() == lora_param_count: | |
| sample = { | |
| "lora_params": lora_params.float(), | |
| "camera_id": proto.get("camera_id", torch.randint(0, n_cameras, (1,)).item()), | |
| "scene_descriptor": proto.get("scene_descriptor", torch.randn(scene_input_dim)), | |
| "query_embedding": proto.get("query_embedding", torch.randn(query_input_dim)), | |
| } | |
| self.samples.append(sample) | |
| elif isinstance(proto, torch.Tensor): | |
| flat = proto.flatten() | |
| if flat.numel() == lora_param_count: | |
| self.samples.append({ | |
| "lora_params": flat.float(), | |
| "camera_id": torch.randint(0, n_cameras, (1,)).item(), | |
| "scene_descriptor": torch.randn(scene_input_dim), | |
| "query_embedding": torch.randn(query_input_dim), | |
| }) | |
| except Exception: | |
| continue | |
| if len(self.samples) > 0: | |
| return | |
| raise RuntimeError( | |
| f"FATAL: No prototype LoRA files found in '{prototype_dir}'.\n" | |
| "Train prototypes first: python3 scripts/train_prototype_loras.py\n" | |
| "Required: .pt files with 'lora_params' key and matching param count" | |
| ) | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| s = self.samples[idx] | |
| return { | |
| "lora_params": s["lora_params"], | |
| "camera_id": torch.tensor(s["camera_id"], dtype=torch.long), | |
| "scene_descriptor": s["scene_descriptor"].float(), | |
| "query_embedding": s["query_embedding"].float(), | |
| } | |
| def _require_e2e_dataset(): | |
| """Phase 2 end-to-end training requires real data.""" | |
| raise RuntimeError( | |
| "FATAL: Phase 2 end-to-end training requires real data.\n" | |
| "Download real data first: python3 scripts/download_all_data.py --stage 3\n" | |
| "Required: JSONL files with image_path, question, answer, and conditioning fields" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # LR scheduler | |
| # --------------------------------------------------------------------------- | |
| class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): | |
| """Linear warmup then cosine decay.""" | |
| def __init__(self, optimizer, warmup_steps: int, total_steps: int, | |
| min_lr: float = 1e-7, last_epoch: int = -1): | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.min_lr = min_lr | |
| super().__init__(optimizer, last_epoch) | |
| def get_lr(self): | |
| step = self.last_epoch | |
| if step < self.warmup_steps: | |
| scale = step / max(1, self.warmup_steps) | |
| return [base_lr * scale for base_lr in self.base_lrs] | |
| else: | |
| progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps) | |
| cosine = 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| return [self.min_lr + (base_lr - self.min_lr) * cosine for base_lr in self.base_lrs] | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def setup_distributed(): | |
| dist.init_process_group(backend="nccl") | |
| local_rank = int(os.environ["LOCAL_RANK"]) | |
| torch.cuda.set_device(local_rank) | |
| return local_rank | |
| def cleanup(): | |
| if dist.is_initialized(): | |
| dist.destroy_process_group() | |
| def is_rank0(): | |
| return not dist.is_initialized() or dist.get_rank() == 0 | |
| def log(msg: str): | |
| if is_rank0(): | |
| print(msg, flush=True) | |
| def save_checkpoint(state: dict, path: str): | |
| if not is_rank0(): | |
| return | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| torch.save(state, path) | |
| log(f" Checkpoint saved: {path}") | |
| def push_to_hf(ckpt_path: str, repo: str): | |
| """Push a single checkpoint to HuggingFace Hub (rank 0 only).""" | |
| if not is_rank0(): | |
| return | |
| try: | |
| from huggingface_hub import HfApi, create_repo | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| log(" [WARN] HF_TOKEN not set — skipping HuggingFace push") | |
| return | |
| api = HfApi(token=token) | |
| create_repo(repo, repo_type="model", exist_ok=True, token=token) | |
| api.upload_file( | |
| path_or_fileobj=ckpt_path, | |
| path_in_repo=ckpt_path, | |
| repo_id=repo, | |
| repo_type="model", | |
| token=token, | |
| ) | |
| log(f" Pushed to HuggingFace: {repo}/{ckpt_path}") | |
| except ImportError: | |
| log(" [WARN] huggingface_hub not installed — skipping push") | |
| except Exception as e: | |
| log(f" [WARN] HuggingFace push failed: {e}") | |
| # --------------------------------------------------------------------------- | |
| # Phase 1: MSE fitting to prototype LoRAs | |
| # --------------------------------------------------------------------------- | |
| def train_phase1(args, config, hn_config): | |
| """Phase 1: Train HyperNetwork to reconstruct prototype LoRA adapters.""" | |
| # ---- Distributed setup ---- | |
| local_rank = setup_distributed() | |
| world_size = dist.get_world_size() | |
| global_rank = dist.get_rank() | |
| device = torch.device(f"cuda:{local_rank}") | |
| # ---- Config ---- | |
| lora_cfg = hn_config["lora"] | |
| hn_cfg = hn_config["hypernetwork"] | |
| ce_cfg = hn_config["condition_encoder"] | |
| train_cfg = hn_config["train_hypernetwork"] | |
| lora_config = LoRAConfig( | |
| rank=lora_cfg["rank"], | |
| alpha=lora_cfg["alpha"], | |
| dropout=lora_cfg["dropout"], | |
| targets=tuple(lora_cfg["targets"]), | |
| ) | |
| decoder_cfg = config["decoder"] | |
| num_decoder_blocks = decoder_cfg["num_blocks"] | |
| decoder_embed_dim = decoder_cfg["hidden_dim"] | |
| per_gpu_batch = max(1, train_cfg["mse_batch_size"] // world_size) | |
| grad_accum = max(1, train_cfg["mse_batch_size"] // (per_gpu_batch * world_size)) | |
| max_epochs = train_cfg["mse_epochs"] | |
| lr = train_cfg["mse_lr"] | |
| calibration_weight = train_cfg["calibration_weight"] | |
| warmup_steps = max_epochs * 2 # Small warmup for MSE phase | |
| log("=" * 70) | |
| log("ArcisVLM — HyperNetwork Phase 1: MSE Fitting (DDP)") | |
| log("=" * 70) | |
| log(f" World size: {world_size}") | |
| log(f" Global batch: {train_cfg['mse_batch_size']}") | |
| log(f" Per-GPU batch: {per_gpu_batch}") | |
| log(f" Gradient accumulation:{grad_accum}") | |
| log(f" Effective batch: {per_gpu_batch * world_size * grad_accum}") | |
| log(f" Max epochs: {max_epochs}") | |
| log(f" Learning rate: {lr}") | |
| log(f" Calibration weight: {calibration_weight}") | |
| # ---- Build models ---- | |
| # ConditionEncoder | |
| scene_input_dim = ce_cfg.get("scene_input_dim", config["predictor"]["embed_dim"]) | |
| query_input_dim = ce_cfg.get("query_input_dim", config["predictor"]["embed_dim"]) | |
| cond_encoder = ConditionEncoder( | |
| n_cameras=ce_cfg["n_cameras"], | |
| camera_dim=ce_cfg["camera_dim"], | |
| scene_input_dim=scene_input_dim, | |
| scene_dim=ce_cfg["scene_dim"], | |
| query_input_dim=query_input_dim, | |
| query_dim=ce_cfg["query_dim"], | |
| out_dim=hn_cfg["cond_dim"], | |
| ).to(device) | |
| # HyperNetwork | |
| hypernetwork = HyperNetwork( | |
| cond_dim=hn_cfg["cond_dim"], | |
| hidden_dim=hn_cfg["hidden_dim"], | |
| lora_config=lora_config, | |
| num_decoder_blocks=num_decoder_blocks, | |
| decoder_embed_dim=decoder_embed_dim, | |
| ).to(device) | |
| lora_param_count = hypernetwork.num_generated_params | |
| log(f" LoRA param count: {lora_param_count:,}") | |
| log(f" HyperNetwork params: {hypernetwork.num_own_params:,}") | |
| log(f" CondEncoder params: {sum(p.numel() for p in cond_encoder.parameters()):,}") | |
| # ---- Dataset ---- | |
| cond_enc_data_cfg = { | |
| "n_cameras": ce_cfg["n_cameras"], | |
| "scene_input_dim": scene_input_dim, | |
| "query_input_dim": query_input_dim, | |
| } | |
| dataset = PrototypeDataset( | |
| prototype_dir=args.prototype_dir, | |
| cond_encoder_cfg=cond_enc_data_cfg, | |
| lora_param_count=lora_param_count, | |
| ) | |
| sampler = DistributedSampler(dataset, num_replicas=world_size, rank=global_rank, shuffle=True) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=per_gpu_batch, | |
| sampler=sampler, | |
| num_workers=4, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| log(f" Dataset: {len(dataset)} prototypes, {len(loader)} batches/GPU") | |
| if len(dataset) < 100: | |
| log(" [WARN] Using dummy prototypes — provide real ones via --prototype_dir") | |
| # ---- Optimizer ---- | |
| params = list(cond_encoder.parameters()) + list(hypernetwork.parameters()) | |
| optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=0.01) | |
| # ---- Scheduler ---- | |
| steps_per_epoch = max(1, len(loader) // grad_accum) | |
| total_steps = max_epochs * steps_per_epoch | |
| scheduler = CosineWarmupScheduler(optimizer, warmup_steps=warmup_steps, total_steps=total_steps) | |
| # ---- Mixed precision ---- | |
| use_bf16 = True | |
| scaler = torch.amp.GradScaler("cuda", enabled=False) | |
| autocast_dtype = torch.bfloat16 | |
| # ---- Resume ---- | |
| start_epoch = 0 | |
| global_step = 0 | |
| if args.resume and os.path.exists(args.resume): | |
| ckpt = torch.load(args.resume, map_location=device, weights_only=False) | |
| cond_encoder.load_state_dict(ckpt["cond_encoder_state_dict"]) | |
| hypernetwork.load_state_dict(ckpt["hypernetwork_state_dict"]) | |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) | |
| if "scheduler_state_dict" in ckpt: | |
| scheduler.load_state_dict(ckpt["scheduler_state_dict"]) | |
| start_epoch = ckpt["epoch"] | |
| global_step = ckpt.get("global_step", start_epoch * steps_per_epoch) | |
| log(f" Resumed from {args.resume} (epoch {start_epoch}, loss {ckpt['loss']:.4f})") | |
| # ---- DDP wrap ---- | |
| cond_encoder = DDP(cond_encoder, device_ids=[local_rank], output_device=local_rank, | |
| find_unused_parameters=False) | |
| hypernetwork = DDP(hypernetwork, device_ids=[local_rank], output_device=local_rank, | |
| find_unused_parameters=False) | |
| # ---- Training loop ---- | |
| cond_encoder.train() | |
| hypernetwork.train() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| for epoch in range(start_epoch, max_epochs): | |
| sampler.set_epoch(epoch) | |
| epoch_mse = 0.0 | |
| epoch_cal = 0.0 | |
| epoch_total = 0.0 | |
| epoch_steps = 0 | |
| epoch_start = time.time() | |
| optimizer.zero_grad(set_to_none=True) | |
| for batch_idx, batch in enumerate(loader): | |
| camera_id = batch["camera_id"].to(device, non_blocking=True) | |
| scene_desc = batch["scene_descriptor"].to(device, non_blocking=True) | |
| query_emb = batch["query_embedding"].to(device, non_blocking=True) | |
| target_params = batch["lora_params"].to(device, non_blocking=True) | |
| with torch.amp.autocast("cuda", dtype=autocast_dtype): | |
| # Condition encoding | |
| condition = cond_encoder(camera_id, scene_desc, query_emb) | |
| # HyperNetwork forward | |
| pred_params, sigma = hypernetwork(condition) | |
| # MSE loss: match prototype LoRA params | |
| mse_loss = F.mse_loss(pred_params, target_params) | |
| # Calibration loss (HypeLoRA): sigma should predict MSE | |
| per_sample_mse = (pred_params - target_params).pow(2).mean(dim=-1, keepdim=True) | |
| cal_loss = F.mse_loss(sigma, per_sample_mse.detach()) | |
| loss = (mse_loss + calibration_weight * cal_loss) / grad_accum | |
| loss.backward() | |
| epoch_mse += mse_loss.item() | |
| epoch_cal += cal_loss.item() | |
| epoch_total += (mse_loss + calibration_weight * cal_loss).item() | |
| epoch_steps += 1 | |
| # Optimizer step every grad_accum batches | |
| if (batch_idx + 1) % grad_accum == 0: | |
| torch.nn.utils.clip_grad_norm_(params, 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| global_step += 1 | |
| if global_step % 50 == 0: | |
| current_lr = scheduler.get_last_lr()[0] | |
| gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 | |
| confidence = hypernetwork.module.compute_confidence(sigma).mean().item() | |
| log(f" [Step {global_step}] mse={mse_loss.item():.6f} " | |
| f"cal={cal_loss.item():.6f} " | |
| f"conf={confidence:.3f} " | |
| f"lr={current_lr:.2e} GPU={gpu_mem:.1f}GB") | |
| # ---- Epoch summary ---- | |
| metrics = torch.tensor([epoch_mse, epoch_cal, epoch_total, epoch_steps], | |
| device=device, dtype=torch.float64) | |
| dist.all_reduce(metrics, op=dist.ReduceOp.SUM) | |
| avg_mse = (metrics[0] / metrics[3]).item() | |
| avg_cal = (metrics[1] / metrics[3]).item() | |
| avg_total = (metrics[2] / metrics[3]).item() | |
| epoch_time = time.time() - epoch_start | |
| current_lr = scheduler.get_last_lr()[0] | |
| gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 | |
| log(f"\nEpoch {epoch + 1}/{max_epochs}: mse={avg_mse:.6f} cal={avg_cal:.6f} " | |
| f"total={avg_total:.6f} lr={current_lr:.2e} time={epoch_time:.0f}s GPU={gpu_mem:.1f}GB") | |
| # ---- Checkpoint every 10 epochs ---- | |
| if (epoch + 1) % 10 == 0 or (epoch + 1) == max_epochs: | |
| ckpt_path = os.path.join(args.output_dir, f"hypernetwork_phase1_epoch{epoch + 1}.pt") | |
| hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict() | |
| ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict() | |
| save_checkpoint({ | |
| "epoch": epoch + 1, | |
| "global_step": global_step, | |
| "hypernetwork_state_dict": hn_state, | |
| "cond_encoder_state_dict": ce_state, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| "loss": avg_total, | |
| "mse_loss": avg_mse, | |
| "cal_loss": avg_cal, | |
| "phase": 1, | |
| "lora_config": { | |
| "rank": lora_config.rank, | |
| "alpha": lora_config.alpha, | |
| "dropout": lora_config.dropout, | |
| "targets": list(lora_config.targets), | |
| }, | |
| }, ckpt_path) | |
| if args.hf_push: | |
| push_to_hf(ckpt_path, args.hf_push) | |
| dist.barrier() | |
| # ---- Final checkpoint ---- | |
| final_path = os.path.join(args.output_dir, "hypernetwork_phase1_final.pt") | |
| hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict() | |
| ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict() | |
| save_checkpoint({ | |
| "epoch": max_epochs, | |
| "global_step": global_step, | |
| "hypernetwork_state_dict": hn_state, | |
| "cond_encoder_state_dict": ce_state, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| "loss": avg_total, | |
| "mse_loss": avg_mse, | |
| "cal_loss": avg_cal, | |
| "phase": 1, | |
| "lora_config": { | |
| "rank": lora_config.rank, | |
| "alpha": lora_config.alpha, | |
| "dropout": lora_config.dropout, | |
| "targets": list(lora_config.targets), | |
| }, | |
| }, final_path) | |
| if args.hf_push: | |
| push_to_hf(final_path, args.hf_push) | |
| log("\n" + "=" * 70) | |
| log(f"Phase 1 complete. Final MSE: {avg_mse:.6f} Cal: {avg_cal:.6f}") | |
| log("=" * 70) | |
| cleanup() | |
| # --------------------------------------------------------------------------- | |
| # Phase 2: End-to-end task loss | |
| # --------------------------------------------------------------------------- | |
| def train_phase2(args, config, hn_config): | |
| """Phase 2: Fine-tune HyperNetwork end-to-end on decode loss + calibration.""" | |
| # ---- Distributed setup ---- | |
| local_rank = setup_distributed() | |
| world_size = dist.get_world_size() | |
| global_rank = dist.get_rank() | |
| device = torch.device(f"cuda:{local_rank}") | |
| # ---- Config ---- | |
| lora_cfg = hn_config["lora"] | |
| hn_cfg = hn_config["hypernetwork"] | |
| ce_cfg = hn_config["condition_encoder"] | |
| train_cfg = hn_config["train_hypernetwork"] | |
| lora_config = LoRAConfig( | |
| rank=lora_cfg["rank"], | |
| alpha=lora_cfg["alpha"], | |
| dropout=lora_cfg["dropout"], | |
| targets=tuple(lora_cfg["targets"]), | |
| ) | |
| decoder_cfg = config["decoder"] | |
| num_decoder_blocks = decoder_cfg["num_blocks"] | |
| decoder_embed_dim = decoder_cfg["hidden_dim"] | |
| per_gpu_batch = max(1, train_cfg["e2e_batch_size"] // world_size) | |
| grad_accum = max(1, train_cfg["e2e_batch_size"] // (per_gpu_batch * world_size)) | |
| max_epochs = train_cfg["e2e_epochs"] | |
| lr = train_cfg["e2e_lr"] | |
| calibration_weight = train_cfg["calibration_weight"] | |
| warmup_steps = max_epochs # Minimal warmup for fine-tuning phase | |
| log("=" * 70) | |
| log("ArcisVLM — HyperNetwork Phase 2: End-to-End (DDP)") | |
| log("=" * 70) | |
| log(f" World size: {world_size}") | |
| log(f" Global batch: {train_cfg['e2e_batch_size']}") | |
| log(f" Per-GPU batch: {per_gpu_batch}") | |
| log(f" Gradient accumulation:{grad_accum}") | |
| log(f" Effective batch: {per_gpu_batch * world_size * grad_accum}") | |
| log(f" Max epochs: {max_epochs}") | |
| log(f" Learning rate: {lr}") | |
| log(f" Calibration weight: {calibration_weight}") | |
| # ---- Build VLM backbone (frozen) ---- | |
| vlm = VLJEPAModel(config).to(device) | |
| if os.path.exists(args.stage3_ckpt): | |
| ckpt = torch.load(args.stage3_ckpt, map_location=device, weights_only=False) | |
| vlm.load_state_dict(ckpt["model_state_dict"]) | |
| log(f" Loaded Stage 3 ckpt: {args.stage3_ckpt}") | |
| else: | |
| log(f" [WARN] Stage 3 checkpoint not found: {args.stage3_ckpt} — using random weights") | |
| # Freeze VLM backbone — only HyperNetwork + ConditionEncoder are trainable | |
| for param in vlm.parameters(): | |
| param.requires_grad = False | |
| vlm.eval() | |
| # ---- Build HyperNetwork + ConditionEncoder ---- | |
| scene_input_dim = ce_cfg.get("scene_input_dim", config["predictor"]["embed_dim"]) | |
| query_input_dim = ce_cfg.get("query_input_dim", config["predictor"]["embed_dim"]) | |
| cond_encoder = ConditionEncoder( | |
| n_cameras=ce_cfg["n_cameras"], | |
| camera_dim=ce_cfg["camera_dim"], | |
| scene_input_dim=scene_input_dim, | |
| scene_dim=ce_cfg["scene_dim"], | |
| query_input_dim=query_input_dim, | |
| query_dim=ce_cfg["query_dim"], | |
| out_dim=hn_cfg["cond_dim"], | |
| ).to(device) | |
| hypernetwork = HyperNetwork( | |
| cond_dim=hn_cfg["cond_dim"], | |
| hidden_dim=hn_cfg["hidden_dim"], | |
| lora_config=lora_config, | |
| num_decoder_blocks=num_decoder_blocks, | |
| decoder_embed_dim=decoder_embed_dim, | |
| ).to(device) | |
| log(f" LoRA param count: {hypernetwork.num_generated_params:,}") | |
| log(f" HyperNetwork params: {hypernetwork.num_own_params:,}") | |
| # ---- Load Phase 1 checkpoint ---- | |
| phase1_path = args.resume or os.path.join(args.output_dir, "hypernetwork_phase1_final.pt") | |
| if os.path.exists(phase1_path): | |
| ckpt = torch.load(phase1_path, map_location=device, weights_only=False) | |
| hypernetwork.load_state_dict(ckpt["hypernetwork_state_dict"]) | |
| cond_encoder.load_state_dict(ckpt["cond_encoder_state_dict"]) | |
| log(f" Loaded Phase 1 ckpt: {phase1_path} (loss {ckpt.get('loss', -1):.4f})") | |
| else: | |
| log(f" [WARN] Phase 1 checkpoint not found: {phase1_path} — training from scratch") | |
| # ---- Tokenizer ---- | |
| tokenizer = BPETokenizer(vocab_size=config["decoder"]["vocab_size"]) | |
| for tok_path in ["checkpoints/tokenizer_32k.json", "checkpoints/tokenizer.json"]: | |
| if os.path.exists(tok_path): | |
| tokenizer.load(tok_path) | |
| log(f" Tokenizer: {len(tokenizer)} tokens (from {tok_path})") | |
| break | |
| else: | |
| log(" [WARN] No tokenizer found — using untrained tokenizer") | |
| # ---- Dataset ---- (real data required, no dummy fallback) | |
| _require_e2e_dataset() | |
| dataset = None # unreachable — _require_e2e_dataset always raises | |
| sampler = DistributedSampler(dataset, num_replicas=world_size, rank=global_rank, shuffle=True) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=per_gpu_batch, | |
| sampler=sampler, | |
| num_workers=4, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| log(f" Dataset: {len(dataset)} samples, {len(loader)} batches/GPU") | |
| # ---- Optimizer (only HyperNetwork + ConditionEncoder) ---- | |
| trainable_params = list(cond_encoder.parameters()) + list(hypernetwork.parameters()) | |
| optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=0.01) | |
| # ---- Scheduler ---- | |
| steps_per_epoch = max(1, len(loader) // grad_accum) | |
| total_steps = max_epochs * steps_per_epoch | |
| scheduler = CosineWarmupScheduler(optimizer, warmup_steps=warmup_steps, total_steps=total_steps) | |
| # ---- Mixed precision ---- | |
| use_bf16 = True | |
| autocast_dtype = torch.bfloat16 | |
| # ---- Resume Phase 2 checkpoint ---- | |
| start_epoch = 0 | |
| global_step = 0 | |
| if args.resume and os.path.exists(args.resume): | |
| ckpt = torch.load(args.resume, map_location=device, weights_only=False) | |
| if ckpt.get("phase") == 2: | |
| cond_encoder.load_state_dict(ckpt["cond_encoder_state_dict"]) | |
| hypernetwork.load_state_dict(ckpt["hypernetwork_state_dict"]) | |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) | |
| if "scheduler_state_dict" in ckpt: | |
| scheduler.load_state_dict(ckpt["scheduler_state_dict"]) | |
| start_epoch = ckpt["epoch"] | |
| global_step = ckpt.get("global_step", start_epoch * steps_per_epoch) | |
| log(f" Resumed Phase 2 from {args.resume} (epoch {start_epoch}, loss {ckpt['loss']:.4f})") | |
| # ---- DDP wrap (HyperNetwork + ConditionEncoder only) ---- | |
| cond_encoder = DDP(cond_encoder, device_ids=[local_rank], output_device=local_rank, | |
| find_unused_parameters=False) | |
| hypernetwork = DDP(hypernetwork, device_ids=[local_rank], output_device=local_rank, | |
| find_unused_parameters=False) | |
| # ---- LoRA injector ---- | |
| injector = LoRAInjector(lora_config, num_decoder_blocks, decoder_embed_dim) | |
| # ---- Training loop ---- | |
| cond_encoder.train() | |
| hypernetwork.train() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| for epoch in range(start_epoch, max_epochs): | |
| sampler.set_epoch(epoch) | |
| epoch_decode = 0.0 | |
| epoch_cal = 0.0 | |
| epoch_total = 0.0 | |
| epoch_steps = 0 | |
| epoch_start = time.time() | |
| optimizer.zero_grad(set_to_none=True) | |
| for batch_idx, batch in enumerate(loader): | |
| images = batch["image"].to(device, non_blocking=True) | |
| q_ids = batch["question_ids"].to(device, non_blocking=True) | |
| q_mask = batch["question_mask"].to(device, non_blocking=True) | |
| a_ids = batch["answer_ids"].to(device, non_blocking=True) | |
| camera_id = batch["camera_id"].to(device, non_blocking=True) | |
| scene_desc = batch["scene_descriptor"].to(device, non_blocking=True) | |
| query_emb = batch["query_embedding"].to(device, non_blocking=True) | |
| with torch.amp.autocast("cuda", dtype=autocast_dtype): | |
| # Step 1: Generate conditioning vector | |
| condition = cond_encoder(camera_id, scene_desc, query_emb) | |
| # Step 2: HyperNetwork generates LoRA params | |
| lora_params, sigma = hypernetwork(condition) | |
| # Step 3: Inject LoRA into decoder and run forward | |
| # Process each sample in the batch (LoRA params are per-sample) | |
| batch_decode_loss = torch.tensor(0.0, device=device) | |
| B = images.shape[0] | |
| for i in range(B): | |
| # Create LoRA layers from this sample's params | |
| sample_params = lora_params[i] | |
| lora_layers = injector.create_lora_layers( | |
| sample_params.detach() if not sample_params.requires_grad else sample_params, | |
| device=device, | |
| ) | |
| # Inject LoRA into the (frozen-weights) decoder | |
| vlm.decoder.apply_lora(lora_layers) | |
| # Forward through VLM with LoRA active | |
| with torch.no_grad(): | |
| visual_tokens = vlm.x_encoder(images[i:i+1]) | |
| pred_embeds = vlm.predictor(visual_tokens, q_ids[i:i+1], q_mask[i:i+1]) | |
| logits, decode_loss = vlm.decoder(pred_embeds, a_ids[i:i+1]) | |
| batch_decode_loss = batch_decode_loss + decode_loss | |
| # Clean up LoRA for next sample | |
| vlm.decoder.clear_lora() | |
| batch_decode_loss = batch_decode_loss / B | |
| # Calibration loss: sigma should correlate with decode loss | |
| # Higher decode loss -> higher sigma (more uncertain) | |
| target_sigma = batch_decode_loss.detach().unsqueeze(0).expand(B, 1) | |
| cal_loss = F.mse_loss(sigma, target_sigma) | |
| loss = (batch_decode_loss + calibration_weight * cal_loss) / grad_accum | |
| loss.backward() | |
| epoch_decode += batch_decode_loss.item() | |
| epoch_cal += cal_loss.item() | |
| epoch_total += (batch_decode_loss + calibration_weight * cal_loss).item() | |
| epoch_steps += 1 | |
| # Optimizer step every grad_accum batches | |
| if (batch_idx + 1) % grad_accum == 0: | |
| torch.nn.utils.clip_grad_norm_(trainable_params, 1.0) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| global_step += 1 | |
| if global_step % 50 == 0: | |
| current_lr = scheduler.get_last_lr()[0] | |
| gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 | |
| confidence = hypernetwork.module.compute_confidence(sigma).mean().item() | |
| log(f" [Step {global_step}] decode={batch_decode_loss.item():.4f} " | |
| f"cal={cal_loss.item():.6f} " | |
| f"conf={confidence:.3f} " | |
| f"lr={current_lr:.2e} GPU={gpu_mem:.1f}GB") | |
| # ---- Epoch summary ---- | |
| metrics = torch.tensor([epoch_decode, epoch_cal, epoch_total, epoch_steps], | |
| device=device, dtype=torch.float64) | |
| dist.all_reduce(metrics, op=dist.ReduceOp.SUM) | |
| avg_decode = (metrics[0] / metrics[3]).item() | |
| avg_cal = (metrics[1] / metrics[3]).item() | |
| avg_total = (metrics[2] / metrics[3]).item() | |
| epoch_time = time.time() - epoch_start | |
| current_lr = scheduler.get_last_lr()[0] | |
| gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9 | |
| log(f"\nEpoch {epoch + 1}/{max_epochs}: decode={avg_decode:.4f} cal={avg_cal:.6f} " | |
| f"total={avg_total:.4f} lr={current_lr:.2e} time={epoch_time:.0f}s GPU={gpu_mem:.1f}GB") | |
| # ---- Checkpoint every epoch ---- | |
| ckpt_path = os.path.join(args.output_dir, f"hypernetwork_phase2_epoch{epoch + 1}.pt") | |
| hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict() | |
| ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict() | |
| save_checkpoint({ | |
| "epoch": epoch + 1, | |
| "global_step": global_step, | |
| "hypernetwork_state_dict": hn_state, | |
| "cond_encoder_state_dict": ce_state, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| "loss": avg_total, | |
| "decode_loss": avg_decode, | |
| "cal_loss": avg_cal, | |
| "phase": 2, | |
| "lora_config": { | |
| "rank": lora_config.rank, | |
| "alpha": lora_config.alpha, | |
| "dropout": lora_config.dropout, | |
| "targets": list(lora_config.targets), | |
| }, | |
| }, ckpt_path) | |
| if args.hf_push: | |
| push_to_hf(ckpt_path, args.hf_push) | |
| dist.barrier() | |
| # ---- Final checkpoint ---- | |
| final_path = os.path.join(args.output_dir, "hypernetwork_phase2_final.pt") | |
| hn_state = hypernetwork.module.state_dict() if hasattr(hypernetwork, "module") else hypernetwork.state_dict() | |
| ce_state = cond_encoder.module.state_dict() if hasattr(cond_encoder, "module") else cond_encoder.state_dict() | |
| save_checkpoint({ | |
| "epoch": max_epochs, | |
| "global_step": global_step, | |
| "hypernetwork_state_dict": hn_state, | |
| "cond_encoder_state_dict": ce_state, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| "loss": avg_total, | |
| "decode_loss": avg_decode, | |
| "cal_loss": avg_cal, | |
| "phase": 2, | |
| "lora_config": { | |
| "rank": lora_config.rank, | |
| "alpha": lora_config.alpha, | |
| "dropout": lora_config.dropout, | |
| "targets": list(lora_config.targets), | |
| }, | |
| }, final_path) | |
| if args.hf_push: | |
| push_to_hf(final_path, args.hf_push) | |
| log("\n" + "=" * 70) | |
| log(f"Phase 2 complete. Final decode: {avg_decode:.4f} cal: {avg_cal:.6f}") | |
| log("=" * 70) | |
| cleanup() | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| parser = argparse.ArgumentParser(description="HyperNetwork DDP: SHINE-style Meta-Training") | |
| parser.add_argument("--config", type=str, required=True, | |
| help="Path to main YAML config (e.g., configs/default.yaml)") | |
| parser.add_argument("--stage3_ckpt", type=str, default="checkpoints/stage3_final.pt", | |
| help="Path to Stage 3 checkpoint (frozen VLM backbone for Phase 2)") | |
| parser.add_argument("--hn_config", type=str, default="configs/hypernetwork.yaml", | |
| help="Path to HyperNetwork YAML config") | |
| parser.add_argument("--prototype_dir", type=str, default="checkpoints/prototypes", | |
| help="Directory with prototype LoRA .pt files (Phase 1)") | |
| parser.add_argument("--resume", type=str, default=None, | |
| help="Path to checkpoint to resume from") | |
| parser.add_argument("--output_dir", type=str, default="checkpoints", | |
| help="Output directory for checkpoints") | |
| parser.add_argument("--hf_push", type=str, default=None, | |
| help="HuggingFace repo to push checkpoints (e.g., hardiksa/arcisvlm)") | |
| parser.add_argument("--phase", type=int, required=True, choices=[1, 2], | |
| help="Training phase: 1=MSE fitting, 2=end-to-end") | |
| args = parser.parse_args() | |
| # ---- Load configs ---- | |
| with open(args.config) as f: | |
| config = yaml.safe_load(f) | |
| with open(args.hn_config) as f: | |
| hn_config = yaml.safe_load(f) | |
| if args.phase == 1: | |
| train_phase1(args, config, hn_config) | |
| else: | |
| train_phase2(args, config, hn_config) | |
| if __name__ == "__main__": | |
| main() | |