fVLM-135M / model_code /foveated_vlm.py
sanps's picture
Upload fVLM-135M: Foveated Vision-Language Model (Stage 3 DPO)
6d320d6 verified
"""
Foveated Vision-Language Model (release implementation).
Architecture: DINOv2 encoder + foveated cross-attention + SmolLM2 LLM.
Each video frame is compressed to ONE visual token via query-guided attention.
The LLM controls WHERE to look by generating the query for the next frame.
Three forward modes:
1. forward_coarse_fine -- Training (two parallel passes)
2. forward_coarse_only -- Fast eval (single static-query pass)
3. forward_autoregressive -- True inference (sequential, KV-cached)
Loss: text cross-entropy only (no reconstruction, no VAE).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoConfig
from typing import Dict, Optional
class FoveatedVLM(nn.Module):
"""
Foveated Vision-Language Model.
Parameters
----------
llm_name : str
HuggingFace model id for SmolLM2 (e.g. "HuggingFaceTB/SmolLM2-135M-Instruct").
dino_name : str
HuggingFace model id for DINOv2 (e.g. "facebook/dinov2-small").
query_dim : int
Dimension of the foveated query vectors (matches DINO dim by default).
visual_scale : float
Multiplicative factor applied to projected visual tokens so their
magnitude matches the LLM embedding std (~0.14 for SmolLM2).
lambda_coarse : float
Weight for the optional auxiliary coarse-pass CE loss during training.
Set to 0 to disable.
"""
def __init__(
self,
llm_name: str = "HuggingFaceTB/SmolLM2-135M-Instruct",
dino_name: str = "facebook/dinov2-small",
query_dim: int = 384,
visual_scale: float = 0.14,
lambda_coarse: float = 0.0,
deep_query: bool = True,
):
super().__init__()
# ---- delayed import so encoder.py can live next to this file ----
from release.model.encoder import FoveatedEncoder
# ---- Vision encoder (DINOv2 + query cross-attention) ----
self.encoder = FoveatedEncoder(
dino_model_name=dino_name,
query_dim=query_dim,
output_dim=None, # output_dim = dino_dim by default inside encoder
)
dino_dim = self.encoder.dino_dim
# ---- Language model ----
self.llm = AutoModelForCausalLM.from_pretrained(
llm_name, attn_implementation="sdpa", torch_dtype=torch.float32,
)
self.llm.config.use_cache = False # training default; overridden per-method
llm_dim = self.llm.config.hidden_size
# ---- Projections ----
self.dino_to_llm = nn.Linear(dino_dim, llm_dim)
self.llm_to_query = nn.Linear(llm_dim, query_dim)
# ---- Learnable queries ----
# BUG-001 FIX: init with std=1.0 so queries dominate over projection
# bias and produce meaningful (non-uniform) attention patterns.
self.q_static = nn.Parameter(torch.randn(1, query_dim)) # std=1.0
self.q_init = nn.Parameter(torch.randn(1, query_dim)) # std=1.0
# ---- Hyperparams stored as plain Python (not buffers) ----
self.visual_scale = visual_scale
self.lambda_coarse = lambda_coarse
self.query_dim = query_dim
self.deep_query = deep_query
# ---- Dimension bookkeeping (useful for external code) ----
self.dino_dim = dino_dim
self.llm_dim = llm_dim
# ------------------------------------------------------------------
# helpers
# ------------------------------------------------------------------
def _get_pad_token_id(self) -> int:
"""Return pad_token_id from the LLM config (never hardcoded)."""
pid = getattr(self.llm.config, "pad_token_id", None)
if pid is None:
pid = getattr(self.llm.config, "eos_token_id", 0)
return pid
def _llm_dtype(self) -> torch.dtype:
"""Return the dtype of the LLM parameters (e.g. bfloat16)."""
return next(self.llm.parameters()).dtype
def _embed_text(self, input_ids: torch.Tensor) -> torch.Tensor:
"""[B, S] -> [B, S, llm_dim] via LLM embedding table."""
return self.llm.get_input_embeddings()(input_ids)
def _project_visual(self, z: torch.Tensor) -> torch.Tensor:
"""
Project DINO features to LLM space and rescale.
z : [B, T, dino_dim] or [B, dino_dim]
Returns same shape with last dim = llm_dim.
"""
h = self.dino_to_llm(z) # -> llm_dim
h = h * self.visual_scale # match LLM embedding magnitude
return h
# Maximum frames per DINO encode/query call to prevent OOM on large batches.
_MAX_ENCODE_CHUNK = 200
def _encode_all_frames(self, frames: torch.Tensor, frame_mask=None):
"""
Run DINO patch encoding for every frame in the batch.
frames : [B, T, 3, 224, 224]
frame_mask : [B, T] bool — True for real frames, False for padding.
Returns (kv_cache, patch_features, mask_flat):
kv_cache : list of (K, V) per layer, each [n_real, N+1, D]
(compact — only real frames, no padding waste).
patch_features : [n_real, N+1, D] final DINO embeddings (for shallow mode).
mask_flat : [B*T] bool tensor or None. Used to scatter results back.
"""
B, T, C, H, W = frames.shape
BT = B * T
frames_flat = frames.reshape(BT, C, H, W)
if frame_mask is not None:
mask_flat = frame_mask.reshape(BT)
n_real = mask_flat.sum().item()
else:
mask_flat = None
n_real = BT
if mask_flat is not None and n_real < BT:
real_frames = frames_flat[mask_flat] # [n_real, C, H, W]
else:
real_frames = frames_flat
# Chunked encoding to prevent OOM on batches with many real frames
if real_frames.shape[0] <= self._MAX_ENCODE_CHUNK:
patch_features, kv_cache = self.encoder.encode_patches(real_frames)
else:
pf_chunks, kv_chunks = [], []
for start in range(0, real_frames.shape[0], self._MAX_ENCODE_CHUNK):
pf_chunk, kv_chunk = self.encoder.encode_patches(
real_frames[start:start + self._MAX_ENCODE_CHUNK]
)
pf_chunks.append(pf_chunk)
kv_chunks.append(kv_chunk)
patch_features = torch.cat(pf_chunks, dim=0)
kv_cache = [
(torch.cat([c[li][0] for c in kv_chunks], dim=0),
torch.cat([c[li][1] for c in kv_chunks], dim=0))
for li in range(len(kv_chunks[0]))
]
return kv_cache, patch_features, mask_flat
def _batched_query_attend(self, queries: torch.Tensor, kv_cache: list,
patch_features: torch.Tensor = None) -> torch.Tensor:
"""Chunked query_attend (deep) or shallow_query_attend to prevent OOM."""
n = queries.shape[0]
if not self.deep_query:
# Shallow mode: single cross-attention on final features
if n <= self._MAX_ENCODE_CHUNK:
return self.encoder.shallow_query_attend(queries, patch_features)
chunks = []
for start in range(0, n, self._MAX_ENCODE_CHUNK):
end = min(start + self._MAX_ENCODE_CHUNK, n)
chunks.append(self.encoder.shallow_query_attend(
queries[start:end], patch_features[start:end]))
return torch.cat(chunks, dim=0)
# Deep mode: propagate through all DINO layers
if n <= self._MAX_ENCODE_CHUNK:
return self.encoder.query_attend(queries, kv_cache)
chunks = []
for start in range(0, n, self._MAX_ENCODE_CHUNK):
end = min(start + self._MAX_ENCODE_CHUNK, n)
kv_slice = [(K[start:end], V[start:end]) for K, V in kv_cache]
chunks.append(self.encoder.query_attend(queries[start:end], kv_slice))
return torch.cat(chunks, dim=0)
def _query_all_frames(
self, query: torch.Tensor, kv_cache: list,
B: int, T: int, mask_flat=None, patch_features=None,
) -> torch.Tensor:
"""
Apply a single query to every frame in ONE batched query_attend call.
query : [B, query_dim]
kv_cache : list of (K, V) per layer, each [n_real, N+1, D]
B, T : batch and temporal dimensions
mask_flat : [B*T] bool or None
patch_features : [n_real, N+1, D] (needed for shallow mode)
Returns : [B, T, dino_dim]
"""
BT = B * T
dd = self.encoder.dino_dim
# Expand: same query for all T frames → [B*T, qd]
query_exp = query.unsqueeze(1).expand(B, T, -1).reshape(BT, -1)
if mask_flat is not None:
n_real = mask_flat.sum().item()
if n_real == 0:
return torch.zeros(B, T, dd, device=query.device, dtype=query.dtype)
query_real = query_exp[mask_flat] # [n_real, qd]
z_real = self._batched_query_attend(query_real, kv_cache, patch_features)
z_flat = torch.zeros(BT, dd, device=query.device, dtype=z_real.dtype)
z_flat[mask_flat] = z_real
else:
z_flat = self._batched_query_attend(query_exp, kv_cache, patch_features)
return z_flat.reshape(B, T, dd)
def _query_all_frames_batched(
self, queries: torch.Tensor, kv_cache: list,
B: int, T: int, mask_flat=None, patch_features=None,
) -> torch.Tensor:
"""
Apply per-frame queries in ONE batched query_attend call.
queries : [B, T, query_dim]
kv_cache : list of (K, V) per layer, each [n_real, N+1, D]
B, T : batch and temporal dimensions
mask_flat : [B*T] bool or None
patch_features : [n_real, N+1, D] (needed for shallow mode)
Returns : [B, T, dino_dim]
"""
BT = B * T
dd = self.encoder.dino_dim
queries_flat = queries.reshape(BT, -1)
if mask_flat is not None:
n_real = mask_flat.sum().item()
if n_real == 0:
return torch.zeros(B, T, dd, device=queries.device, dtype=queries.dtype)
query_real = queries_flat[mask_flat] # [n_real, qd]
z_real = self._batched_query_attend(query_real, kv_cache, patch_features)
z_flat = torch.zeros(BT, dd, device=queries.device, dtype=z_real.dtype)
z_flat[mask_flat] = z_real
else:
z_flat = self._batched_query_attend(queries_flat, kv_cache, patch_features)
return z_flat.reshape(B, T, dd)
def _extract_frame_kv(self, kv_cache: list, mask_flat, B: int, T: int, frame_idx: int):
"""
Extract single-frame KV cache from flat format (for autoregressive/eval).
Returns list of (K, V) per layer, each [B, N+1, D].
"""
if mask_flat is not None:
# Scatter compact caches to full [B*T] then extract frame
N1 = kv_cache[0][0].shape[1]
D = kv_cache[0][0].shape[2]
frame_kv = []
for K_real, V_real in kv_cache:
K_full = torch.zeros(B * T, N1, D, dtype=K_real.dtype, device=K_real.device)
V_full = torch.zeros(B * T, N1, D, dtype=V_real.dtype, device=V_real.device)
K_full[mask_flat] = K_real
V_full[mask_flat] = V_real
K_t = K_full.reshape(B, T, N1, D)[:, frame_idx] # [B, N+1, D]
V_t = V_full.reshape(B, T, N1, D)[:, frame_idx]
frame_kv.append((K_t, V_t))
return frame_kv
else:
N1 = kv_cache[0][0].shape[1]
D = kv_cache[0][0].shape[2]
frame_kv = []
for K_all, V_all in kv_cache:
K_t = K_all.reshape(B, T, N1, D)[:, frame_idx]
V_t = V_all.reshape(B, T, N1, D)[:, frame_idx]
frame_kv.append((K_t, V_t))
return frame_kv
def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
"""
Standard causal attention mask [1, 1, S, S] for the LLM.
True = masked (cannot attend), False = allowed.
"""
mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device).triu(1)
return mask.unsqueeze(0).unsqueeze(0) # [1, 1, S, S]
def _ce_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
loss_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Standard autoregressive CE loss with shift-by-1.
logits : [B, S, V] (full sequence logits)
labels : [B, S] (token ids; positions without loss use pad)
loss_mask : [B, S] (1 = compute loss, 0 = ignore). Applied BEFORE
the shift so that loss_mask[i] guards label[i].
Returns scalar loss.
"""
# Shift: predict position i+1 from position i
shift_logits = logits[:, :-1, :].contiguous() # [B, S-1, V]
shift_labels = labels[:, 1:].contiguous() # [B, S-1]
if loss_mask is not None:
shift_mask = loss_mask[:, 1:].contiguous() # [B, S-1]
# Replace masked positions with ignore_index so CE ignores them
pad_id = self._get_pad_token_id()
shift_labels = shift_labels.clone()
shift_labels[shift_mask == 0] = pad_id
V = shift_logits.shape[-1]
loss = F.cross_entropy(
shift_logits.reshape(-1, V),
shift_labels.reshape(-1),
ignore_index=self._get_pad_token_id(),
reduction="mean",
)
return loss
# ------------------------------------------------------------------
# Forward mode 1: Coarse+Fine (TRAINING)
# ------------------------------------------------------------------
def forward_coarse_fine(
self,
frames: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: Optional[torch.Tensor] = None,
frame_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Two-pass parallel training forward.
Pass 1 (coarse): q_static -> all frames -> z_coarse -> LLM -> dynamic queries
Pass 2 (fine): shifted queries -> all frames -> z_fine -> LLM + text -> loss
Parameters
----------
frames : [B, T, 3, 224, 224]
input_ids : [B, S] tokenized text (prompt + answer)
attention_mask : [B, S] text attention mask
loss_mask : [B, S] which tokens contribute to loss (1=yes, 0=no).
If None, all non-pad tokens have loss.
Returns
-------
dict with keys: loss, logits, coarse_loss (optional), fine_loss
"""
B, T = frames.shape[:2]
S = input_ids.shape[1]
# ---- Step 0: Encode all frames (DINO, shared across both passes) ----
kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
# ---- Pass 1: Coarse ----
q_static = self.q_static.expand(B, -1) # [B, qd]
z_coarse = self._query_all_frames(q_static, kv_cache, B, T, mask_flat, patch_features) # [B,T,dd]
z_coarse_llm = self._project_visual(z_coarse) # [B,T,ld]
# Build coarse sequence: [visual_coarse, text]
text_embeds = self._embed_text(input_ids) # [B,S,ld]
seq_coarse = torch.cat([z_coarse_llm, text_embeds], dim=1) # [B,T+S,ld]
# dtype handled by autocast on GPU; float32 on CPU
# LLM forward (backbone only, no lm_head yet)
out_coarse = self.llm.model(inputs_embeds=seq_coarse)
h_coarse = out_coarse.last_hidden_state # [B,T+S,ld]
# Extract dynamic queries from visual positions
# h_coarse[:, 0..T-1] are the hidden states at visual token positions
# Each one generates a query for the corresponding frame
h_visual_coarse = h_coarse[:, :T, :] # [B,T,ld]
queries = self.llm_to_query(h_visual_coarse) # [B,T,qd]
# Shift queries: frame t gets query from frame t-1; frame 0 gets q_init
q_init = self.q_init.expand(B, 1, -1) # [B,1,qd]
shifted_queries = torch.cat([q_init, queries[:, :-1]], dim=1) # [B,T,qd]
# ---- Pass 2: Fine ----
z_fine = self._query_all_frames_batched(shifted_queries, kv_cache, B, T, mask_flat, patch_features) # [B,T,dd]
z_fine_llm = self._project_visual(z_fine) # [B,T,ld]
# Build fine sequence: [visual_fine, text]
seq_fine = torch.cat([z_fine_llm, text_embeds], dim=1) # [B,T+S,ld]
# dtype handled by autocast on GPU; float32 on CPU
out_fine = self.llm.model(inputs_embeds=seq_fine)
h_fine = out_fine.last_hidden_state # [B,T+S,ld]
# Get logits over the FULL sequence (visual + text positions)
logits_full = self.llm.lm_head(h_fine) # [B,T+S,V]
# ---- Loss on text portion only ----
# The text tokens start at position T in the sequence.
# We need labels aligned with the full sequence: visual positions get pad.
pad_id = self._get_pad_token_id()
visual_pad = torch.full(
(B, T), pad_id, dtype=input_ids.dtype, device=input_ids.device,
)
full_labels = torch.cat([visual_pad, input_ids], dim=1) # [B, T+S]
# Build full loss mask: 0 for visual positions, then the provided loss_mask
if loss_mask is not None:
visual_no_loss = torch.zeros(
B, T, dtype=loss_mask.dtype, device=loss_mask.device,
)
full_loss_mask = torch.cat([visual_no_loss, loss_mask], dim=1) # [B,T+S]
else:
# Default: compute loss on all text positions that are not padding
visual_no_loss = torch.zeros(B, T, dtype=attention_mask.dtype, device=attention_mask.device)
text_loss_mask = attention_mask # non-pad text positions
full_loss_mask = torch.cat([visual_no_loss, text_loss_mask], dim=1)
fine_loss = self._ce_loss(logits_full, full_labels, full_loss_mask)
# ---- Optional auxiliary coarse loss ----
coarse_loss = torch.tensor(0.0, device=frames.device)
if self.lambda_coarse > 0:
logits_coarse = self.llm.lm_head(h_coarse)
coarse_loss = self._ce_loss(logits_coarse, full_labels, full_loss_mask)
# ---- Combined loss ----
loss = fine_loss + self.lambda_coarse * coarse_loss
return {
"loss": loss,
"fine_loss": fine_loss,
"coarse_loss": coarse_loss,
"logits": logits_full,
}
# ------------------------------------------------------------------
# Forward mode: DPO (preference training)
# ------------------------------------------------------------------
def forward_dpo(
self,
frames: torch.Tensor,
chosen_input_ids: torch.Tensor,
chosen_attention_mask: torch.Tensor,
chosen_loss_mask: torch.Tensor,
rejected_input_ids: torch.Tensor,
rejected_attention_mask: torch.Tensor,
rejected_loss_mask: torch.Tensor,
frame_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
DPO forward pass: run coarse+fine on both chosen and rejected sequences.
Shares DINO encoding across chosen and rejected (same visual input).
Returns per-sample sum of log-probabilities for both chosen and rejected,
masked by loss_mask (answer-only tokens).
Parameters
----------
frames : [B, T, 3, 224, 224]
chosen_input_ids : [B, S_c]
chosen_attention_mask : [B, S_c]
chosen_loss_mask : [B, S_c] (1 = answer token, 0 = prompt/pad)
rejected_input_ids : [B, S_r]
rejected_attention_mask : [B, S_r]
rejected_loss_mask : [B, S_r]
frame_mask : [B, T] bool (optional)
Returns
-------
dict with keys:
chosen_logps : [B] per-sample sum of log-probs on chosen answer tokens
rejected_logps : [B] per-sample sum of log-probs on rejected answer tokens
chosen_logits : [B, T+S_c, V] full logits for chosen
rejected_logits : [B, T+S_r, V] full logits for rejected
"""
B, T = frames.shape[:2]
# ---- Step 0: Encode all frames (DINO, shared across chosen & rejected) ----
kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
# ---- Coarse pass (shared, used for dynamic query generation) ----
q_static = self.q_static.expand(B, -1) # [B, qd]
z_coarse = self._query_all_frames(q_static, kv_cache, B, T, mask_flat, patch_features)
z_coarse_llm = self._project_visual(z_coarse) # [B, T, ld]
# Run coarse LLM to get dynamic queries (use chosen text for query generation)
text_embeds_chosen = self._embed_text(chosen_input_ids) # [B, S_c, ld]
seq_coarse = torch.cat([z_coarse_llm, text_embeds_chosen], dim=1)
out_coarse = self.llm.model(inputs_embeds=seq_coarse)
h_coarse = out_coarse.last_hidden_state
# Extract dynamic queries from visual positions
h_visual_coarse = h_coarse[:, :T, :] # [B, T, ld]
queries = self.llm_to_query(h_visual_coarse) # [B, T, qd]
q_init = self.q_init.expand(B, 1, -1)
shifted_queries = torch.cat([q_init, queries[:, :-1]], dim=1) # [B, T, qd]
# ---- Fine pass: shared visual features ----
z_fine = self._query_all_frames_batched(shifted_queries, kv_cache, B, T, mask_flat, patch_features)
z_fine_llm = self._project_visual(z_fine) # [B, T, ld]
# ---- Forward on CHOSEN ----
seq_chosen = torch.cat([z_fine_llm, text_embeds_chosen], dim=1) # [B, T+S_c, ld]
out_chosen = self.llm.model(inputs_embeds=seq_chosen)
chosen_logits = self.llm.lm_head(out_chosen.last_hidden_state) # [B, T+S_c, V]
# ---- Forward on REJECTED ----
text_embeds_rejected = self._embed_text(rejected_input_ids) # [B, S_r, ld]
seq_rejected = torch.cat([z_fine_llm, text_embeds_rejected], dim=1)
out_rejected = self.llm.model(inputs_embeds=seq_rejected)
rejected_logits = self.llm.lm_head(out_rejected.last_hidden_state)
# ---- Compute per-token log-probs ----
chosen_logps = self._sequence_logprobs(
chosen_logits, chosen_input_ids, chosen_loss_mask, T,
)
rejected_logps = self._sequence_logprobs(
rejected_logits, rejected_input_ids, rejected_loss_mask, T,
)
return {
"chosen_logps": chosen_logps, # [B]
"rejected_logps": rejected_logps, # [B]
"chosen_logits": chosen_logits, # [B, T+S_c, V]
"rejected_logits": rejected_logits, # [B, T+S_r, V]
}
def _sequence_logprobs(
self,
logits: torch.Tensor,
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
T: int,
) -> torch.Tensor:
"""
Compute per-sample sum of log-probabilities on answer tokens.
logits : [B, T+S, V] full sequence logits (visual + text)
input_ids : [B, S] text token ids
loss_mask : [B, S] 1.0 for answer tokens, 0.0 otherwise
T : int number of visual token positions
Returns : [B] sum of log-probs per sample
"""
B, S = input_ids.shape
# Extract text logits and shift for autoregressive prediction
text_logits = logits[:, T:, :] # [B, S, V]
shift_logits = text_logits[:, :-1, :] # [B, S-1, V]
shift_labels = input_ids[:, 1:] # [B, S-1]
shift_mask = loss_mask[:, 1:] # [B, S-1]
# Per-token log-probs: log_softmax then gather the label's prob
log_probs = F.log_softmax(shift_logits, dim=-1) # [B, S-1, V]
per_token_logps = log_probs.gather(
dim=-1, index=shift_labels.unsqueeze(-1),
).squeeze(-1) # [B, S-1]
# Mask and sum per sample
per_token_logps = per_token_logps * shift_mask # zero out non-answer tokens
return per_token_logps.sum(dim=-1) # [B]
# ------------------------------------------------------------------
# Forward mode 2: Coarse only (FAST EVAL)
# ------------------------------------------------------------------
def forward_coarse_only(
self,
frames: torch.Tensor,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
frame_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Single-pass coarse forward (q_static only, no fine queries).
Used for:
- Training A6 ablation (coarse-only training)
- Fast eval (wrap in torch.no_grad() externally)
q_static -> all frames -> z_coarse -> LLM -> logits.
Parameters
----------
frames : [B, T, 3, 224, 224]
input_ids : [B, S] (optional, for loss computation)
attention_mask : [B, S] (optional)
loss_mask : [B, S] (optional)
Returns
-------
dict with keys: logits, and optionally loss
"""
B, T = frames.shape[:2]
kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
q_static = self.q_static.expand(B, -1)
z_coarse = self._query_all_frames(q_static, kv_cache, B, T, mask_flat, patch_features)
z_coarse_llm = self._project_visual(z_coarse)
if input_ids is not None:
text_embeds = self._embed_text(input_ids)
seq = torch.cat([z_coarse_llm, text_embeds], dim=1)
else:
seq = z_coarse_llm
# dtype handled by autocast on GPU; float32 on CPU
out = self.llm.model(inputs_embeds=seq)
logits = self.llm.lm_head(out.last_hidden_state)
result: Dict[str, torch.Tensor] = {"logits": logits}
if input_ids is not None:
S = input_ids.shape[1]
pad_id = self._get_pad_token_id()
visual_pad = torch.full(
(B, T), pad_id, dtype=input_ids.dtype, device=input_ids.device,
)
full_labels = torch.cat([visual_pad, input_ids], dim=1)
if loss_mask is not None:
visual_no_loss = torch.zeros(
B, T, dtype=loss_mask.dtype, device=loss_mask.device,
)
full_loss_mask = torch.cat([visual_no_loss, loss_mask], dim=1)
elif attention_mask is not None:
visual_no_loss = torch.zeros(
B, T, dtype=attention_mask.dtype, device=attention_mask.device,
)
full_loss_mask = torch.cat([visual_no_loss, attention_mask], dim=1)
else:
full_loss_mask = None
loss = self._ce_loss(logits, full_labels, full_loss_mask)
result["loss"] = loss
result["coarse_loss"] = loss
result["fine_loss"] = torch.tensor(0.0, device=frames.device)
return result
# ------------------------------------------------------------------
# Forward mode 3: Autoregressive (TRUE INFERENCE)
# ------------------------------------------------------------------
@torch.no_grad()
def forward_autoregressive(
self,
frames: torch.Tensor,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
frame_mask: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
True autoregressive inference: sequential frame-by-frame with KV cache.
q_init -> frame_1 -> z_1 -> LLM -> q_1 -> frame_2 -> z_2 -> ...
No coarse pass. Each query is derived from the LLM hidden state after
processing the *previous* fine visual token -- exactly what happens at
real inference time.
Parameters
----------
frames : [B, T, 3, 224, 224]
input_ids : [B, S] (optional, for loss computation)
attention_mask : [B, S] (optional)
loss_mask : [B, S] (optional)
Returns
-------
dict with keys: logits, and optionally loss
"""
B, T = frames.shape[:2]
device = frames.device
# Encode all frames with DINO up front (this is OK -- DINO encoding
# does not depend on the query, only query_attend does).
kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
# Enable KV cache on the LLM for incremental decoding
orig_use_cache = self.llm.config.use_cache
self.llm.config.use_cache = True
query = self.q_init.expand(B, -1) # [B, qd]
llm_past_kv = None
for t in range(T):
# Foveated extraction with current query
frame_kv = self._extract_frame_kv(kv_cache, mask_flat, B, T, t)
z_t = self.encoder.query_attend(query, frame_kv) # [B, dd]
z_t_llm = self._project_visual(z_t.unsqueeze(1)) # [B,1,ld]
# dtype handled by autocast on GPU; float32 on CPU
# Incremental LLM forward (one visual token at a time)
out = self.llm.model(
inputs_embeds=z_t_llm,
past_key_values=llm_past_kv,
use_cache=True,
)
llm_past_kv = out.past_key_values
# Derive query for the NEXT frame from the current hidden state
if t < T - 1:
h_t = out.last_hidden_state[:, -1, :] # [B, ld]
query = self.llm_to_query(h_t) # [B, qd]
# ---- Now process text (if provided) using the accumulated KV cache ----
if input_ids is not None:
text_embeds = self._embed_text(input_ids) # [B, S, ld]
out_text = self.llm.model(
inputs_embeds=text_embeds,
past_key_values=llm_past_kv,
use_cache=False,
)
# Combine visual hidden states (already in KV cache) with text states
# for logit computation. We only need logits over the text portion
# (plus the last visual token which predicts the first text token).
#
# The KV cache holds T visual positions; out_text.last_hidden_state
# holds S text positions. We reconstruct the full logits as
# [visual_logits, text_logits] but only compute loss on text.
h_text = out_text.last_hidden_state # [B, S, ld]
logits_text = self.llm.lm_head(h_text) # [B, S, V]
# For the loss we also need the logit at the last visual position
# (it predicts the first text token). Re-derive it:
h_last_visual = out.last_hidden_state[:, -1:, :] # [B,1,ld]
logits_last_v = self.llm.lm_head(h_last_visual) # [B,1,V]
# Full logits over [last_visual, text] = [B, 1+S, V]
logits = torch.cat([logits_last_v, logits_text], dim=1)
# Labels: [pad_for_last_visual, input_ids]
pad_id = self._get_pad_token_id()
lv_pad = torch.full(
(B, 1), pad_id, dtype=input_ids.dtype, device=device,
)
full_labels = torch.cat([lv_pad, input_ids], dim=1)
# Loss mask
if loss_mask is not None:
lv_no_loss = torch.zeros(
B, 1, dtype=loss_mask.dtype, device=device,
)
full_loss_mask = torch.cat([lv_no_loss, loss_mask], dim=1)
elif attention_mask is not None:
lv_no_loss = torch.zeros(
B, 1, dtype=attention_mask.dtype, device=device,
)
full_loss_mask = torch.cat([lv_no_loss, attention_mask], dim=1)
else:
full_loss_mask = None
loss = self._ce_loss(logits, full_labels, full_loss_mask)
self.llm.config.use_cache = orig_use_cache
return {"loss": loss, "logits": logits}
else:
# No text -- just return logits at the last visual position
h_last = out.last_hidden_state # [B, 1, ld]
logits = self.llm.lm_head(h_last)
self.llm.config.use_cache = orig_use_cache
return {"logits": logits}
# ------------------------------------------------------------------
# Convenience: unified forward dispatching by name
# ------------------------------------------------------------------
def forward(
self,
frames: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
loss_mask: Optional[torch.Tensor] = None,
frame_mask: Optional[torch.Tensor] = None,
mode: str = "coarse_fine",
) -> Dict[str, torch.Tensor]:
"""
Unified forward entry point.
mode : "coarse_fine" | "coarse_only" | "autoregressive"
frame_mask : [B, T] bool — True for real frames, False for padding.
"""
if mode == "coarse_fine":
return self.forward_coarse_fine(frames, input_ids, attention_mask, loss_mask, frame_mask)
elif mode == "coarse_only":
return self.forward_coarse_only(frames, input_ids, attention_mask, loss_mask, frame_mask)
elif mode == "autoregressive":
return self.forward_autoregressive(frames, input_ids, attention_mask, loss_mask, frame_mask)
else:
raise ValueError(
f"Unknown forward mode '{mode}'. "
"Expected one of: coarse_fine, coarse_only, autoregressive"
)
# ------------------------------------------------------------------
# Utility methods for external callers (train.py, eval.py)
# ------------------------------------------------------------------
def enable_gradient_checkpointing(self) -> None:
"""Turn on activation checkpointing for LLM and DINO."""
self.llm.gradient_checkpointing_enable()
if hasattr(self.encoder.dino, 'gradient_checkpointing_enable'):
self.encoder.dino.gradient_checkpointing_enable()
def get_param_groups(
self,
lr_backbone: float = 1e-5,
lr_connector: float = 1e-4,
) -> list:
"""
Return parameter groups with differential learning rates.
Groups:
1. Connector (dino_to_llm, llm_to_query, q_static, q_init) -- highest LR
2. DINO encoder -- backbone LR
3. LLM -- backbone LR
This is a suggestion; train.py may override.
"""
connector_params = set()
for name, param in self.named_parameters():
if any(k in name for k in [
"dino_to_llm", "llm_to_query", "q_static", "q_init",
"query_input_proj", "query_output_proj",
]):
connector_params.add(id(param))
encoder_params = set()
for name, param in self.encoder.named_parameters():
if id(param) not in connector_params:
encoder_params.add(id(param))
groups = [
{
"params": [p for p in self.parameters()
if id(p) in connector_params and p.requires_grad],
"lr": lr_connector,
"name": "connector",
},
{
"params": [p for n, p in self.encoder.named_parameters()
if id(p) in encoder_params and p.requires_grad],
"lr": lr_backbone,
"name": "dino",
},
{
"params": [p for p in self.llm.parameters() if p.requires_grad],
"lr": lr_backbone,
"name": "llm",
},
]
return [g for g in groups if len(g["params"]) > 0]