""" Condition Encoder — fuses camera identity, scene context, and query intent into a conditioning vector for the HyperNetwork. The conditioning vector drives adapter generation: different cameras in different scenes processing different queries get different LoRA adapters. Input sources: - camera_id → learned embedding (64-D) - scene_descriptor → mean of recent JEPA embeddings from MemoryManager (2048-D → 128-D) - query_embedding → JEPA predictor output for current query (2048-D → 128-D) Output: 256-D conditioning vector → HyperNetwork Reference: HyperVLA (arXiv: 2510.04898), Doc-to-LoRA (arXiv: 2602.15902) """ from __future__ import annotations from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F class ConditionEncoder(nn.Module): """ Fuses camera identity, scene context, and query intent into a single conditioning vector for the HyperNetwork. Args: n_cameras: Maximum number of camera IDs (embedding table size) camera_dim: Camera embedding dimension scene_input_dim: Dimension of scene descriptor from MemoryManager (JEPA embed_dim) scene_dim: Projected scene dimension query_input_dim: Dimension of query embedding from JEPA predictor query_dim: Projected query dimension out_dim: Output conditioning vector dimension """ def __init__( self, n_cameras: int = 2048, camera_dim: int = 64, scene_input_dim: int = 2048, scene_dim: int = 128, query_input_dim: int = 2048, query_dim: int = 128, out_dim: int = 256, ): super().__init__() self.out_dim = out_dim # Camera identity embedding self.camera_embed = nn.Embedding(n_cameras, camera_dim) # Scene context projection (from mean of recent JEPA embeddings) self.scene_proj = nn.Sequential( nn.Linear(scene_input_dim, scene_dim), nn.LayerNorm(scene_dim), nn.GELU(), ) # Query intent projection (from JEPA predictor output) self.query_proj = nn.Sequential( nn.Linear(query_input_dim, query_dim), nn.LayerNorm(query_dim), nn.GELU(), ) # Fusion: concatenate all three → project to out_dim fusion_input_dim = camera_dim + scene_dim + query_dim self.fuse = nn.Sequential( nn.Linear(fusion_input_dim, out_dim), nn.LayerNorm(out_dim), nn.GELU(), ) def forward( self, camera_id: torch.Tensor, scene_descriptor: torch.Tensor, query_embedding: torch.Tensor, ) -> torch.Tensor: """ Encode conditioning context. Args: camera_id: [B] — integer camera IDs (0 to n_cameras-1) scene_descriptor: [B, scene_input_dim] — mean embedding from memory ring buffer query_embedding: [B, query_input_dim] — JEPA predictor output for current query Returns: [B, out_dim] — conditioning vector for HyperNetwork """ cam = self.camera_embed(camera_id) # [B, camera_dim] scene = self.scene_proj(scene_descriptor) # [B, scene_dim] query = self.query_proj(query_embedding) # [B, query_dim] combined = torch.cat([cam, scene, query], dim=-1) # [B, camera_dim + scene_dim + query_dim] condition = self.fuse(combined) # [B, out_dim] return condition def forward_no_camera( self, scene_descriptor: torch.Tensor, query_embedding: torch.Tensor, ) -> torch.Tensor: """ Encode conditioning without camera ID (e.g., for new/unknown cameras). Uses a zero camera embedding as fallback. Args: scene_descriptor: [B, scene_input_dim] query_embedding: [B, query_input_dim] Returns: [B, out_dim] — conditioning vector """ B = scene_descriptor.shape[0] device = scene_descriptor.device # Zero camera embedding as fallback cam = torch.zeros(B, self.camera_embed.embedding_dim, device=device) scene = self.scene_proj(scene_descriptor) query = self.query_proj(query_embedding) combined = torch.cat([cam, scene, query], dim=-1) return self.fuse(combined)