arcisvlm / model /condition_encoder.py
Hardik Sanghvi
feat: integrate Gemma 4 E2B backbone for production-quality VLM inference
7a564e3
Raw
History Blame Contribute Delete
4.42 kB
"""
Condition Encoder β€” fuses camera identity, scene context, and query intent
into a conditioning vector for the HyperNetwork.
The conditioning vector drives adapter generation: different cameras in
different scenes processing different queries get different LoRA adapters.
Input sources:
- camera_id β†’ learned embedding (64-D)
- scene_descriptor β†’ mean of recent JEPA embeddings from MemoryManager (2048-D β†’ 128-D)
- query_embedding β†’ JEPA predictor output for current query (2048-D β†’ 128-D)
Output: 256-D conditioning vector β†’ HyperNetwork
Reference: HyperVLA (arXiv: 2510.04898), Doc-to-LoRA (arXiv: 2602.15902)
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConditionEncoder(nn.Module):
"""
Fuses camera identity, scene context, and query intent into a single
conditioning vector for the HyperNetwork.
Args:
n_cameras: Maximum number of camera IDs (embedding table size)
camera_dim: Camera embedding dimension
scene_input_dim: Dimension of scene descriptor from MemoryManager (JEPA embed_dim)
scene_dim: Projected scene dimension
query_input_dim: Dimension of query embedding from JEPA predictor
query_dim: Projected query dimension
out_dim: Output conditioning vector dimension
"""
def __init__(
self,
n_cameras: int = 2048,
camera_dim: int = 64,
scene_input_dim: int = 2048,
scene_dim: int = 128,
query_input_dim: int = 2048,
query_dim: int = 128,
out_dim: int = 256,
):
super().__init__()
self.out_dim = out_dim
# Camera identity embedding
self.camera_embed = nn.Embedding(n_cameras, camera_dim)
# Scene context projection (from mean of recent JEPA embeddings)
self.scene_proj = nn.Sequential(
nn.Linear(scene_input_dim, scene_dim),
nn.LayerNorm(scene_dim),
nn.GELU(),
)
# Query intent projection (from JEPA predictor output)
self.query_proj = nn.Sequential(
nn.Linear(query_input_dim, query_dim),
nn.LayerNorm(query_dim),
nn.GELU(),
)
# Fusion: concatenate all three β†’ project to out_dim
fusion_input_dim = camera_dim + scene_dim + query_dim
self.fuse = nn.Sequential(
nn.Linear(fusion_input_dim, out_dim),
nn.LayerNorm(out_dim),
nn.GELU(),
)
def forward(
self,
camera_id: torch.Tensor,
scene_descriptor: torch.Tensor,
query_embedding: torch.Tensor,
) -> torch.Tensor:
"""
Encode conditioning context.
Args:
camera_id: [B] β€” integer camera IDs (0 to n_cameras-1)
scene_descriptor: [B, scene_input_dim] β€” mean embedding from memory ring buffer
query_embedding: [B, query_input_dim] β€” JEPA predictor output for current query
Returns:
[B, out_dim] β€” conditioning vector for HyperNetwork
"""
cam = self.camera_embed(camera_id) # [B, camera_dim]
scene = self.scene_proj(scene_descriptor) # [B, scene_dim]
query = self.query_proj(query_embedding) # [B, query_dim]
combined = torch.cat([cam, scene, query], dim=-1) # [B, camera_dim + scene_dim + query_dim]
condition = self.fuse(combined) # [B, out_dim]
return condition
def forward_no_camera(
self,
scene_descriptor: torch.Tensor,
query_embedding: torch.Tensor,
) -> torch.Tensor:
"""
Encode conditioning without camera ID (e.g., for new/unknown cameras).
Uses a zero camera embedding as fallback.
Args:
scene_descriptor: [B, scene_input_dim]
query_embedding: [B, query_input_dim]
Returns:
[B, out_dim] β€” conditioning vector
"""
B = scene_descriptor.shape[0]
device = scene_descriptor.device
# Zero camera embedding as fallback
cam = torch.zeros(B, self.camera_embed.embedding_dim, device=device)
scene = self.scene_proj(scene_descriptor)
query = self.query_proj(query_embedding)
combined = torch.cat([cam, scene, query], dim=-1)
return self.fuse(combined)