arcisvlm / model /hypernetwork.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
5.19 kB
"""
HyperNetwork β€” generates LoRA adapter weights from conditioning context.
Takes a 256-D conditioning vector (from ConditionEncoder) and outputs:
1. Flat LoRA parameter vector (~200K params for rank-16 on Q,V across 12 blocks)
2. Calibration uncertainty sigma (HypeLoRA-style)
The HyperNetwork is a 3-layer MLP that learns to map camera/scene/query
context to task-specific LoRA adapters. Training follows SHINE's two-phase
approach: first fit to prototype LoRAs via MSE, then end-to-end on task loss.
References:
- HyperVLA: Mother spawns compact Child policies (arXiv: 2510.04898)
- Doc-to-LoRA / Sakana AI (arXiv: 2602.15902)
- SHINE: Scalable Hypernetwork Internalization (arXiv: 2602.28901)
- HypeLoRA: Calibrated LoRA with uncertainty (arXiv: 2603.19278)
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.lora import LoRAConfig, compute_total_lora_params
class HyperNetwork(nn.Module):
"""
Generates LoRA adapter weights from a conditioning vector.
Architecture: 3-layer MLP with residual connection
condition (256-D) β†’ Linear β†’ GELU β†’ Linear β†’ GELU β†’ Linear β†’ LoRA params
+ calibration head β†’ uncertainty Οƒ
Args:
cond_dim: Input conditioning dimension (from ConditionEncoder)
hidden_dim: Hidden layer dimension
lora_config: LoRA configuration (determines output size)
num_decoder_blocks: Number of transformer blocks in the MoE decoder
decoder_embed_dim: Embedding dimension of the decoder (for LoRA sizing)
"""
def __init__(
self,
cond_dim: int = 256,
hidden_dim: int = 512,
lora_config: LoRAConfig = None,
num_decoder_blocks: int = 12,
decoder_embed_dim: int = 1024,
):
super().__init__()
if lora_config is None:
lora_config = LoRAConfig()
self.lora_config = lora_config
# Calculate total LoRA output size
self.lora_param_count = compute_total_lora_params(
num_blocks=num_decoder_blocks,
embed_dim=decoder_embed_dim,
rank=lora_config.rank,
targets=lora_config.targets,
)
# Main parameter generator: 3-layer MLP
self.encoder = nn.Sequential(
nn.Linear(cond_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
# Output head: hidden β†’ LoRA params
self.param_head = nn.Linear(hidden_dim, self.lora_param_count)
# HypeLoRA-style calibration head: outputs per-param uncertainty Οƒ
# We output a SCALAR sigma (not per-param) to save memory
self.calibration_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 4),
nn.GELU(),
nn.Linear(hidden_dim // 4, 1),
)
# Initialize output heads with small weights for stable start
nn.init.normal_(self.param_head.weight, std=0.01)
nn.init.zeros_(self.param_head.bias)
def forward(
self,
condition: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Generate LoRA adapter parameters from conditioning vector.
Args:
condition: [B, cond_dim] β€” conditioning vector from ConditionEncoder
Returns:
lora_params: [B, lora_param_count] β€” flat LoRA parameter vector
sigma: [B, 1] β€” calibration uncertainty (lower = more confident)
"""
h = self.encoder(condition) # [B, hidden_dim]
lora_params = self.param_head(h) # [B, lora_param_count]
sigma = F.softplus(self.calibration_head(h)) # [B, 1] β€” always positive
return lora_params, sigma
def compute_confidence(self, sigma: torch.Tensor) -> torch.Tensor:
"""
Convert uncertainty sigma to confidence score in [0, 1].
Uses inverse sigmoid mapping: confidence = 1 / (1 + sigma)
Args:
sigma: [B, 1] β€” uncertainty from forward()
Returns:
[B, 1] β€” confidence score (higher = more confident)
"""
return 1.0 / (1.0 + sigma)
@property
def num_generated_params(self) -> int:
"""Number of LoRA parameters generated per forward pass."""
return self.lora_param_count
@property
def num_own_params(self) -> int:
"""Number of trainable parameters in the HyperNetwork itself."""
return sum(p.numel() for p in self.parameters())
def summary(self) -> dict:
"""Return a summary of the hypernetwork configuration."""
return {
"own_params": self.num_own_params,
"generated_params": self.num_generated_params,
"lora_config": {
"rank": self.lora_config.rank,
"alpha": self.lora_config.alpha,
"targets": self.lora_config.targets,
},
"generation_ratio": self.num_generated_params / max(self.num_own_params, 1),
}