""" HyperNetwork — generates LoRA adapter weights from conditioning context. Takes a 256-D conditioning vector (from ConditionEncoder) and outputs: 1. Flat LoRA parameter vector (~200K params for rank-16 on Q,V across 12 blocks) 2. Calibration uncertainty sigma (HypeLoRA-style) The HyperNetwork is a 3-layer MLP that learns to map camera/scene/query context to task-specific LoRA adapters. Training follows SHINE's two-phase approach: first fit to prototype LoRAs via MSE, then end-to-end on task loss. References: - HyperVLA: Mother spawns compact Child policies (arXiv: 2510.04898) - Doc-to-LoRA / Sakana AI (arXiv: 2602.15902) - SHINE: Scalable Hypernetwork Internalization (arXiv: 2602.28901) - HypeLoRA: Calibrated LoRA with uncertainty (arXiv: 2603.19278) """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from model.lora import LoRAConfig, compute_total_lora_params class HyperNetwork(nn.Module): """ Generates LoRA adapter weights from a conditioning vector. Architecture: 3-layer MLP with residual connection condition (256-D) → Linear → GELU → Linear → GELU → Linear → LoRA params + calibration head → uncertainty σ Args: cond_dim: Input conditioning dimension (from ConditionEncoder) hidden_dim: Hidden layer dimension lora_config: LoRA configuration (determines output size) num_decoder_blocks: Number of transformer blocks in the MoE decoder decoder_embed_dim: Embedding dimension of the decoder (for LoRA sizing) """ def __init__( self, cond_dim: int = 256, hidden_dim: int = 512, lora_config: LoRAConfig = None, num_decoder_blocks: int = 12, decoder_embed_dim: int = 1024, ): super().__init__() if lora_config is None: lora_config = LoRAConfig() self.lora_config = lora_config # Calculate total LoRA output size self.lora_param_count = compute_total_lora_params( num_blocks=num_decoder_blocks, embed_dim=decoder_embed_dim, rank=lora_config.rank, targets=lora_config.targets, ) # Main parameter generator: 3-layer MLP self.encoder = nn.Sequential( nn.Linear(cond_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), ) # Output head: hidden → LoRA params self.param_head = nn.Linear(hidden_dim, self.lora_param_count) # HypeLoRA-style calibration head: outputs per-param uncertainty σ # We output a SCALAR sigma (not per-param) to save memory self.calibration_head = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 4), nn.GELU(), nn.Linear(hidden_dim // 4, 1), ) # Initialize output heads with small weights for stable start nn.init.normal_(self.param_head.weight, std=0.01) nn.init.zeros_(self.param_head.bias) def forward( self, condition: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ Generate LoRA adapter parameters from conditioning vector. Args: condition: [B, cond_dim] — conditioning vector from ConditionEncoder Returns: lora_params: [B, lora_param_count] — flat LoRA parameter vector sigma: [B, 1] — calibration uncertainty (lower = more confident) """ h = self.encoder(condition) # [B, hidden_dim] lora_params = self.param_head(h) # [B, lora_param_count] sigma = F.softplus(self.calibration_head(h)) # [B, 1] — always positive return lora_params, sigma def compute_confidence(self, sigma: torch.Tensor) -> torch.Tensor: """ Convert uncertainty sigma to confidence score in [0, 1]. Uses inverse sigmoid mapping: confidence = 1 / (1 + sigma) Args: sigma: [B, 1] — uncertainty from forward() Returns: [B, 1] — confidence score (higher = more confident) """ return 1.0 / (1.0 + sigma) @property def num_generated_params(self) -> int: """Number of LoRA parameters generated per forward pass.""" return self.lora_param_count @property def num_own_params(self) -> int: """Number of trainable parameters in the HyperNetwork itself.""" return sum(p.numel() for p in self.parameters()) def summary(self) -> dict: """Return a summary of the hypernetwork configuration.""" return { "own_params": self.num_own_params, "generated_params": self.num_generated_params, "lora_config": { "rank": self.lora_config.rank, "alpha": self.lora_config.alpha, "targets": self.lora_config.targets, }, "generation_ratio": self.num_generated_params / max(self.num_own_params, 1), }