sanps commited on
Commit
9a444f5
·
verified ·
1 Parent(s): 22ed968

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +996 -0
model.py ADDED
@@ -0,0 +1,996 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Foveated Vision-Language Model.
3
+
4
+ Architecture: DINOv2 encoder + foveated cross-attention + SmolLM2 LLM.
5
+ Each video frame is compressed to ONE visual token via query-guided attention.
6
+ The LLM controls WHERE to look by generating the query for the next frame.
7
+
8
+ Three forward modes:
9
+ 1. forward_coarse_fine -- Training (two parallel passes)
10
+ 2. forward_coarse_only -- Fast eval (single static-query pass)
11
+ 3. forward_autoregressive -- True inference (sequential, KV-cached)
12
+
13
+ Loss: text cross-entropy only (no reconstruction, no VAE).
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from transformers import AutoModelForCausalLM, AutoConfig
20
+ from typing import Dict, Optional
21
+
22
+ # Optional: Liger Kernel fused CE loss (never materializes [B, S, V] logits)
23
+ try:
24
+ from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
25
+ _HAS_LIGER = True
26
+ except ImportError:
27
+ _HAS_LIGER = False
28
+
29
+
30
+ class FoveatedVLM(nn.Module):
31
+ """
32
+ Foveated Vision-Language Model.
33
+
34
+ Parameters
35
+ ----------
36
+ llm_name : str
37
+ HuggingFace model id for SmolLM2 (e.g. "HuggingFaceTB/SmolLM2-135M-Instruct").
38
+ dino_name : str
39
+ HuggingFace model id for DINOv2 (e.g. "facebook/dinov2-small").
40
+ query_dim : int
41
+ Dimension of the foveated query vectors (matches DINO dim by default).
42
+ visual_scale : float
43
+ Multiplicative factor applied to projected visual tokens so their
44
+ magnitude matches the LLM embedding std (~0.14 for SmolLM2).
45
+ lambda_coarse : float
46
+ Weight for the optional auxiliary coarse-pass CE loss during training.
47
+ Set to 0 to disable.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ llm_name: str = "HuggingFaceTB/SmolLM2-135M-Instruct",
53
+ dino_name: str = "facebook/dinov2-small",
54
+ query_dim: int = 384,
55
+ visual_scale: float = 0.14,
56
+ lambda_coarse: float = 0.0,
57
+ deep_query: bool = True,
58
+ use_fused_ce: bool = False,
59
+ ):
60
+ super().__init__()
61
+
62
+ # ---- delayed import so encoder.py can live next to this file ----
63
+ from encoder import FoveatedEncoder
64
+
65
+ # ---- Vision encoder (DINOv2 + query cross-attention) ----
66
+ self.encoder = FoveatedEncoder(
67
+ dino_model_name=dino_name,
68
+ query_dim=query_dim,
69
+ output_dim=None, # output_dim = dino_dim by default inside encoder
70
+ )
71
+ dino_dim = self.encoder.dino_dim
72
+
73
+ # ---- Language model ----
74
+ self.llm = AutoModelForCausalLM.from_pretrained(
75
+ llm_name, attn_implementation="sdpa", torch_dtype=torch.bfloat16,
76
+ )
77
+ self.llm.config.use_cache = False # training default; overridden per-method
78
+ llm_dim = self.llm.config.hidden_size
79
+
80
+ # ---- Projections ----
81
+ self.dino_to_llm = nn.Linear(dino_dim, llm_dim)
82
+ self.llm_to_query = nn.Linear(llm_dim, query_dim)
83
+
84
+ # ---- Learnable queries ----
85
+ # BUG-001 FIX: init with std=1.0 so queries dominate over projection
86
+ # bias and produce meaningful (non-uniform) attention patterns.
87
+ self.q_static = nn.Parameter(torch.randn(1, query_dim)) # std=1.0
88
+ self.q_init = nn.Parameter(torch.randn(1, query_dim)) # std=1.0
89
+
90
+ # ---- Hyperparams stored as plain Python (not buffers) ----
91
+ self.visual_scale = visual_scale
92
+ self.lambda_coarse = lambda_coarse
93
+ self.query_dim = query_dim
94
+ self.deep_query = deep_query
95
+ self.use_fused_ce = use_fused_ce and _HAS_LIGER
96
+
97
+ # ---- Dimension bookkeeping (useful for external code) ----
98
+ self.dino_dim = dino_dim
99
+ self.llm_dim = llm_dim
100
+
101
+ # ------------------------------------------------------------------
102
+ # helpers
103
+ # ------------------------------------------------------------------
104
+
105
+ def _get_pad_token_id(self) -> int:
106
+ """Return pad_token_id from the LLM config (never hardcoded)."""
107
+ pid = getattr(self.llm.config, "pad_token_id", None)
108
+ if pid is None:
109
+ pid = getattr(self.llm.config, "eos_token_id", 0)
110
+ return pid
111
+
112
+ def _llm_dtype(self) -> torch.dtype:
113
+ """Return the dtype of the LLM parameters (e.g. bfloat16)."""
114
+ return next(self.llm.parameters()).dtype
115
+
116
+ def _embed_text(self, input_ids: torch.Tensor) -> torch.Tensor:
117
+ """[B, S] -> [B, S, llm_dim] via LLM embedding table."""
118
+ return self.llm.get_input_embeddings()(input_ids)
119
+
120
+ def _project_visual(self, z: torch.Tensor) -> torch.Tensor:
121
+ """
122
+ Project DINO features to LLM space and rescale.
123
+
124
+ z : [B, T, dino_dim] or [B, dino_dim]
125
+ Returns same shape with last dim = llm_dim.
126
+ """
127
+ h = self.dino_to_llm(z) # -> llm_dim
128
+ h = h * self.visual_scale # match LLM embedding magnitude
129
+ return h
130
+
131
+ # Maximum frames per DINO encode/query call to prevent OOM on large batches.
132
+ _MAX_ENCODE_CHUNK = 200
133
+
134
+ def _encode_all_frames(self, frames: torch.Tensor, frame_mask=None):
135
+ """
136
+ Run DINO patch encoding for every frame in the batch.
137
+
138
+ frames : [B, T, 3, 224, 224]
139
+ frame_mask : [B, T] bool — True for real frames, False for padding.
140
+
141
+ Returns (kv_cache, patch_features, mask_flat):
142
+ kv_cache : list of (K, V) per layer, each [n_real, N+1, D]
143
+ (compact — only real frames, no padding waste).
144
+ patch_features : [n_real, N+1, D] final DINO embeddings (for shallow mode).
145
+ mask_flat : [B*T] bool tensor or None. Used to scatter results back.
146
+ """
147
+ B, T, C, H, W = frames.shape
148
+ BT = B * T
149
+ frames_flat = frames.reshape(BT, C, H, W)
150
+
151
+ if frame_mask is not None:
152
+ mask_flat = frame_mask.reshape(BT)
153
+ n_real = mask_flat.sum().item()
154
+ else:
155
+ mask_flat = None
156
+ n_real = BT
157
+
158
+ if mask_flat is not None and n_real < BT:
159
+ real_frames = frames_flat[mask_flat] # [n_real, C, H, W]
160
+ else:
161
+ real_frames = frames_flat
162
+
163
+ # Chunked encoding to prevent OOM on batches with many real frames
164
+ if real_frames.shape[0] <= self._MAX_ENCODE_CHUNK:
165
+ patch_features, kv_cache = self.encoder.encode_patches(real_frames)
166
+ else:
167
+ pf_chunks, kv_chunks = [], []
168
+ for start in range(0, real_frames.shape[0], self._MAX_ENCODE_CHUNK):
169
+ pf_chunk, kv_chunk = self.encoder.encode_patches(
170
+ real_frames[start:start + self._MAX_ENCODE_CHUNK]
171
+ )
172
+ pf_chunks.append(pf_chunk)
173
+ kv_chunks.append(kv_chunk)
174
+ patch_features = torch.cat(pf_chunks, dim=0)
175
+ kv_cache = [
176
+ (torch.cat([c[li][0] for c in kv_chunks], dim=0),
177
+ torch.cat([c[li][1] for c in kv_chunks], dim=0))
178
+ for li in range(len(kv_chunks[0]))
179
+ ]
180
+
181
+ return kv_cache, patch_features, mask_flat
182
+
183
+ def _batched_query_attend(self, queries: torch.Tensor, kv_cache: list,
184
+ patch_features: torch.Tensor = None) -> torch.Tensor:
185
+ """Chunked query_attend (deep) or shallow_query_attend to prevent OOM."""
186
+ n = queries.shape[0]
187
+ if not self.deep_query:
188
+ # Shallow mode: single cross-attention on final features
189
+ if n <= self._MAX_ENCODE_CHUNK:
190
+ return self.encoder.shallow_query_attend(queries, patch_features)
191
+ chunks = []
192
+ for start in range(0, n, self._MAX_ENCODE_CHUNK):
193
+ end = min(start + self._MAX_ENCODE_CHUNK, n)
194
+ chunks.append(self.encoder.shallow_query_attend(
195
+ queries[start:end], patch_features[start:end]))
196
+ return torch.cat(chunks, dim=0)
197
+ # Deep mode: propagate through all DINO layers
198
+ if n <= self._MAX_ENCODE_CHUNK:
199
+ return self.encoder.query_attend(queries, kv_cache)
200
+ chunks = []
201
+ for start in range(0, n, self._MAX_ENCODE_CHUNK):
202
+ end = min(start + self._MAX_ENCODE_CHUNK, n)
203
+ kv_slice = [(K[start:end], V[start:end]) for K, V in kv_cache]
204
+ chunks.append(self.encoder.query_attend(queries[start:end], kv_slice))
205
+ return torch.cat(chunks, dim=0)
206
+
207
+ def _query_all_frames(
208
+ self, query: torch.Tensor, kv_cache: list,
209
+ B: int, T: int, mask_flat=None, patch_features=None,
210
+ ) -> torch.Tensor:
211
+ """
212
+ Apply a single query to every frame in ONE batched query_attend call.
213
+
214
+ query : [B, query_dim]
215
+ kv_cache : list of (K, V) per layer, each [n_real, N+1, D]
216
+ B, T : batch and temporal dimensions
217
+ mask_flat : [B*T] bool or None
218
+ patch_features : [n_real, N+1, D] (needed for shallow mode)
219
+ Returns : [B, T, dino_dim]
220
+ """
221
+ BT = B * T
222
+ dd = self.encoder.dino_dim
223
+
224
+ # Expand: same query for all T frames → [B*T, qd]
225
+ query_exp = query.unsqueeze(1).expand(B, T, -1).reshape(BT, -1)
226
+
227
+ if mask_flat is not None:
228
+ n_real = mask_flat.sum().item()
229
+ if n_real == 0:
230
+ return torch.zeros(B, T, dd, device=query.device, dtype=query.dtype)
231
+ query_real = query_exp[mask_flat] # [n_real, qd]
232
+ z_real = self._batched_query_attend(query_real, kv_cache, patch_features)
233
+ z_flat = torch.zeros(BT, dd, device=query.device, dtype=z_real.dtype)
234
+ z_flat[mask_flat] = z_real
235
+ else:
236
+ z_flat = self._batched_query_attend(query_exp, kv_cache, patch_features)
237
+
238
+ return z_flat.reshape(B, T, dd)
239
+
240
+ def _query_all_frames_batched(
241
+ self, queries: torch.Tensor, kv_cache: list,
242
+ B: int, T: int, mask_flat=None, patch_features=None,
243
+ ) -> torch.Tensor:
244
+ """
245
+ Apply per-frame queries in ONE batched query_attend call.
246
+
247
+ queries : [B, T, query_dim]
248
+ kv_cache : list of (K, V) per layer, each [n_real, N+1, D]
249
+ B, T : batch and temporal dimensions
250
+ mask_flat : [B*T] bool or None
251
+ patch_features : [n_real, N+1, D] (needed for shallow mode)
252
+ Returns : [B, T, dino_dim]
253
+ """
254
+ BT = B * T
255
+ dd = self.encoder.dino_dim
256
+ queries_flat = queries.reshape(BT, -1)
257
+
258
+ if mask_flat is not None:
259
+ n_real = mask_flat.sum().item()
260
+ if n_real == 0:
261
+ return torch.zeros(B, T, dd, device=queries.device, dtype=queries.dtype)
262
+ query_real = queries_flat[mask_flat] # [n_real, qd]
263
+ z_real = self._batched_query_attend(query_real, kv_cache, patch_features)
264
+ z_flat = torch.zeros(BT, dd, device=queries.device, dtype=z_real.dtype)
265
+ z_flat[mask_flat] = z_real
266
+ else:
267
+ z_flat = self._batched_query_attend(queries_flat, kv_cache, patch_features)
268
+
269
+ return z_flat.reshape(B, T, dd)
270
+
271
+ def _extract_frame_kv(self, kv_cache: list, mask_flat, B: int, T: int, frame_idx: int):
272
+ """
273
+ Extract single-frame KV cache from flat format (for autoregressive/eval).
274
+
275
+ Returns list of (K, V) per layer, each [B, N+1, D].
276
+ """
277
+ if mask_flat is not None:
278
+ # Scatter compact caches to full [B*T] then extract frame
279
+ N1 = kv_cache[0][0].shape[1]
280
+ D = kv_cache[0][0].shape[2]
281
+ frame_kv = []
282
+ for K_real, V_real in kv_cache:
283
+ K_full = torch.zeros(B * T, N1, D, dtype=K_real.dtype, device=K_real.device)
284
+ V_full = torch.zeros(B * T, N1, D, dtype=V_real.dtype, device=V_real.device)
285
+ K_full[mask_flat] = K_real
286
+ V_full[mask_flat] = V_real
287
+ K_t = K_full.reshape(B, T, N1, D)[:, frame_idx] # [B, N+1, D]
288
+ V_t = V_full.reshape(B, T, N1, D)[:, frame_idx]
289
+ frame_kv.append((K_t, V_t))
290
+ return frame_kv
291
+ else:
292
+ N1 = kv_cache[0][0].shape[1]
293
+ D = kv_cache[0][0].shape[2]
294
+ frame_kv = []
295
+ for K_all, V_all in kv_cache:
296
+ K_t = K_all.reshape(B, T, N1, D)[:, frame_idx]
297
+ V_t = V_all.reshape(B, T, N1, D)[:, frame_idx]
298
+ frame_kv.append((K_t, V_t))
299
+ return frame_kv
300
+
301
+ def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
302
+ """
303
+ Standard causal attention mask [1, 1, S, S] for the LLM.
304
+ True = masked (cannot attend), False = allowed.
305
+ """
306
+ mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device).triu(1)
307
+ return mask.unsqueeze(0).unsqueeze(0) # [1, 1, S, S]
308
+
309
+ def _ce_loss(
310
+ self,
311
+ logits: torch.Tensor,
312
+ labels: torch.Tensor,
313
+ loss_mask: Optional[torch.Tensor] = None,
314
+ ) -> torch.Tensor:
315
+ """
316
+ Standard autoregressive CE loss with shift-by-1.
317
+
318
+ logits : [B, S, V] (full sequence logits)
319
+ labels : [B, S] (token ids; positions without loss use pad)
320
+ loss_mask : [B, S] (1 = compute loss, 0 = ignore). Applied BEFORE
321
+ the shift so that loss_mask[i] guards label[i].
322
+
323
+ Returns scalar loss.
324
+ """
325
+ # Shift: predict position i+1 from position i
326
+ shift_logits = logits[:, :-1, :].contiguous() # [B, S-1, V]
327
+ shift_labels = labels[:, 1:].contiguous() # [B, S-1]
328
+
329
+ if loss_mask is not None:
330
+ shift_mask = loss_mask[:, 1:].contiguous() # [B, S-1]
331
+ # Replace masked positions with -100 (standard PyTorch ignore_index)
332
+ shift_labels = shift_labels.clone()
333
+ shift_labels[shift_mask == 0] = -100
334
+
335
+ V = shift_logits.shape[-1]
336
+ loss = F.cross_entropy(
337
+ shift_logits.reshape(-1, V),
338
+ shift_labels.reshape(-1),
339
+ ignore_index=-100,
340
+ reduction="mean",
341
+ )
342
+ return loss
343
+
344
+ def _fused_ce_loss(
345
+ self,
346
+ hidden_states: torch.Tensor,
347
+ labels: torch.Tensor,
348
+ loss_mask: Optional[torch.Tensor] = None,
349
+ ) -> torch.Tensor:
350
+ """
351
+ Fused lm_head + CE loss via Liger Kernel.
352
+
353
+ Never materializes the [B, S, V] logits tensor — computes CE in chunks
354
+ inside the fused kernel. Saves ~2× memory on the loss computation.
355
+
356
+ hidden_states : [B, S, ld] (LLM hidden states, NOT yet projected by lm_head)
357
+ labels : [B, S] (token ids)
358
+ loss_mask : [B, S] (1 = compute loss, 0 = ignore)
359
+
360
+ Returns scalar loss.
361
+ """
362
+ # Shift: predict position i+1 from position i
363
+ h_input = hidden_states[:, :-1, :].contiguous() # [B, S-1, ld]
364
+ shift_labels = labels[:, 1:].contiguous() # [B, S-1]
365
+
366
+ if loss_mask is not None:
367
+ shift_mask = loss_mask[:, 1:].contiguous()
368
+ # Replace masked positions with -100 (standard PyTorch ignore_index)
369
+ shift_labels = shift_labels.clone()
370
+ shift_labels[shift_mask == 0] = -100
371
+
372
+ # Flatten for Liger: [B*(S-1), ld] and [B*(S-1)]
373
+ BSminus1 = h_input.shape[0] * h_input.shape[1]
374
+ return LigerFusedLinearCrossEntropyLoss(
375
+ ignore_index=-100
376
+ )(
377
+ h_input.reshape(BSminus1, -1),
378
+ self.llm.lm_head.weight,
379
+ shift_labels.reshape(-1),
380
+ )
381
+
382
+ # ------------------------------------------------------------------
383
+ # Forward mode 1: Coarse+Fine (TRAINING)
384
+ # ------------------------------------------------------------------
385
+
386
+ def forward_coarse_fine(
387
+ self,
388
+ frames: torch.Tensor,
389
+ input_ids: torch.Tensor,
390
+ attention_mask: torch.Tensor,
391
+ loss_mask: Optional[torch.Tensor] = None,
392
+ frame_mask: Optional[torch.Tensor] = None,
393
+ ) -> Dict[str, torch.Tensor]:
394
+ """
395
+ Two-pass parallel training forward.
396
+
397
+ Pass 1 (coarse): q_static -> all frames -> z_coarse -> LLM(visual only) -> queries
398
+ Pass 2 (fine): shifted queries -> all frames -> z_fine -> LLM + text -> loss
399
+
400
+ Optimization: the coarse LLM pass processes ONLY visual tokens (not text).
401
+ Because causal attention means visual positions never see text tokens,
402
+ removing text produces mathematically identical hidden states at visual
403
+ positions while reducing sequence length from T+S to T (~10-30x shorter).
404
+
405
+ Parameters
406
+ ----------
407
+ frames : [B, T, 3, 224, 224]
408
+ input_ids : [B, S] tokenized text (prompt + answer)
409
+ attention_mask : [B, S] text attention mask
410
+ loss_mask : [B, S] which tokens contribute to loss (1=yes, 0=no).
411
+ If None, all non-pad tokens have loss.
412
+
413
+ Returns
414
+ -------
415
+ dict with keys: loss, logits, coarse_loss (optional), fine_loss
416
+ """
417
+ B, T = frames.shape[:2]
418
+ S = input_ids.shape[1]
419
+
420
+ # ---- Step 0: Encode all frames (DINO, shared across both passes) ----
421
+ # Use prefetched DINO results if available (from CUDA stream overlap)
422
+ prefetched = self._get_prefetched_dino()
423
+ if prefetched is not None:
424
+ kv_cache, patch_features, mask_flat = prefetched
425
+ else:
426
+ kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
427
+
428
+ # ---- Pass 1: Coarse (visual tokens ONLY — text is invisible to them) ----
429
+ q_static = self.q_static.expand(B, -1) # [B, qd]
430
+ z_coarse = self._query_all_frames(q_static, kv_cache, B, T, mask_flat, patch_features) # [B,T,dd]
431
+ z_coarse_llm = self._project_visual(z_coarse) # [B,T,ld]
432
+
433
+ # Coarse LLM: process ONLY visual tokens (T tokens, not T+S).
434
+ # Causal attention: visual pos i only sees visual pos 0..i, never text.
435
+ # This is ~30x faster for typical T=8, S=256 batches.
436
+ out_coarse = self.llm.model(inputs_embeds=z_coarse_llm)
437
+ h_coarse = out_coarse.last_hidden_state # [B,T,ld]
438
+
439
+ # Extract dynamic queries from visual positions
440
+ queries = self.llm_to_query(h_coarse) # [B,T,qd]
441
+
442
+ # Shift queries: frame t gets query from frame t-1; frame 0 gets q_init
443
+ q_init = self.q_init.expand(B, 1, -1) # [B,1,qd]
444
+ shifted_queries = torch.cat([q_init, queries[:, :-1]], dim=1) # [B,T,qd]
445
+
446
+ # ---- Pass 2: Fine ----
447
+ z_fine = self._query_all_frames_batched(shifted_queries, kv_cache, B, T, mask_flat, patch_features) # [B,T,dd]
448
+ z_fine_llm = self._project_visual(z_fine) # [B,T,ld]
449
+
450
+ # Build fine sequence: [visual_fine, text]
451
+ text_embeds = self._embed_text(input_ids) # [B,S,ld]
452
+ seq_fine = torch.cat([z_fine_llm, text_embeds], dim=1) # [B,T+S,ld]
453
+
454
+ out_fine = self.llm.model(inputs_embeds=seq_fine)
455
+ h_fine = out_fine.last_hidden_state # [B,T+S,ld]
456
+
457
+ # ---- Loss on text portion ----
458
+ h_text = h_fine[:, T:, :] # [B,S,ld]
459
+ if loss_mask is None:
460
+ loss_mask = attention_mask.float()
461
+
462
+ if self.use_fused_ce:
463
+ # Liger Kernel: fused lm_head + CE, never materializes [B,S,V] logits
464
+ fine_loss = self._fused_ce_loss(h_text, input_ids, loss_mask)
465
+ logits_text = None # not available with fused loss
466
+ else:
467
+ logits_text = self.llm.lm_head(h_text) # [B,S,V]
468
+ fine_loss = self._ce_loss(logits_text, input_ids, loss_mask)
469
+
470
+ # ---- Optional auxiliary coarse loss ----
471
+ coarse_loss = torch.tensor(0.0, device=frames.device)
472
+ if self.lambda_coarse > 0:
473
+ seq_coarse_full = torch.cat([z_coarse_llm, text_embeds], dim=1)
474
+ out_coarse_full = self.llm.model(inputs_embeds=seq_coarse_full)
475
+ h_coarse_text = out_coarse_full.last_hidden_state[:, T:, :]
476
+ if self.use_fused_ce:
477
+ coarse_loss = self._fused_ce_loss(h_coarse_text, input_ids, loss_mask)
478
+ else:
479
+ logits_coarse = self.llm.lm_head(h_coarse_text)
480
+ coarse_loss = self._ce_loss(logits_coarse, input_ids, loss_mask)
481
+
482
+ # ---- Combined loss ----
483
+ loss = fine_loss + self.lambda_coarse * coarse_loss
484
+
485
+ return {
486
+ "loss": loss,
487
+ "fine_loss": fine_loss,
488
+ "coarse_loss": coarse_loss,
489
+ "logits": logits_text, # [B,S,V] text positions only
490
+ }
491
+
492
+ # ------------------------------------------------------------------
493
+ # Forward mode: DPO (preference training)
494
+ # ------------------------------------------------------------------
495
+
496
+ def forward_dpo(
497
+ self,
498
+ frames: torch.Tensor,
499
+ chosen_input_ids: torch.Tensor,
500
+ chosen_attention_mask: torch.Tensor,
501
+ chosen_loss_mask: torch.Tensor,
502
+ rejected_input_ids: torch.Tensor,
503
+ rejected_attention_mask: torch.Tensor,
504
+ rejected_loss_mask: torch.Tensor,
505
+ frame_mask: Optional[torch.Tensor] = None,
506
+ ) -> Dict[str, torch.Tensor]:
507
+ """
508
+ DPO forward pass: run coarse+fine on both chosen and rejected sequences.
509
+
510
+ Shares DINO encoding across chosen and rejected (same visual input).
511
+ Returns per-sample sum of log-probabilities for both chosen and rejected,
512
+ masked by loss_mask (answer-only tokens).
513
+
514
+ Parameters
515
+ ----------
516
+ frames : [B, T, 3, 224, 224]
517
+ chosen_input_ids : [B, S_c]
518
+ chosen_attention_mask : [B, S_c]
519
+ chosen_loss_mask : [B, S_c] (1 = answer token, 0 = prompt/pad)
520
+ rejected_input_ids : [B, S_r]
521
+ rejected_attention_mask : [B, S_r]
522
+ rejected_loss_mask : [B, S_r]
523
+ frame_mask : [B, T] bool (optional)
524
+
525
+ Returns
526
+ -------
527
+ dict with keys:
528
+ chosen_logps : [B] per-sample sum of log-probs on chosen answer tokens
529
+ rejected_logps : [B] per-sample sum of log-probs on rejected answer tokens
530
+ chosen_logits : [B, T+S_c, V] full logits for chosen
531
+ rejected_logits : [B, T+S_r, V] full logits for rejected
532
+ """
533
+ B, T = frames.shape[:2]
534
+
535
+ # ---- Step 0: Encode all frames (DINO, shared across chosen & rejected) ----
536
+ kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
537
+
538
+ # ---- Coarse pass (visual tokens ONLY — text invisible in causal attn) ----
539
+ q_static = self.q_static.expand(B, -1) # [B, qd]
540
+ z_coarse = self._query_all_frames(q_static, kv_cache, B, T, mask_flat, patch_features)
541
+ z_coarse_llm = self._project_visual(z_coarse) # [B, T, ld]
542
+
543
+ # Coarse LLM: visual tokens only (T, not T+S_c). Causal attention means
544
+ # visual positions never see text, so this is mathematically identical.
545
+ out_coarse = self.llm.model(inputs_embeds=z_coarse_llm)
546
+ h_coarse = out_coarse.last_hidden_state # [B, T, ld]
547
+
548
+ # Extract dynamic queries from visual positions
549
+ queries = self.llm_to_query(h_coarse) # [B, T, qd]
550
+
551
+ q_init = self.q_init.expand(B, 1, -1)
552
+ shifted_queries = torch.cat([q_init, queries[:, :-1]], dim=1) # [B, T, qd]
553
+
554
+ # ---- Fine pass: shared visual features ----
555
+ z_fine = self._query_all_frames_batched(shifted_queries, kv_cache, B, T, mask_flat, patch_features)
556
+ z_fine_llm = self._project_visual(z_fine) # [B, T, ld]
557
+
558
+ # ---- Forward on CHOSEN (lm_head on text positions only) ----
559
+ text_embeds_chosen = self._embed_text(chosen_input_ids) # [B, S_c, ld]
560
+ seq_chosen = torch.cat([z_fine_llm, text_embeds_chosen], dim=1) # [B, T+S_c, ld]
561
+ out_chosen = self.llm.model(inputs_embeds=seq_chosen)
562
+ chosen_logits = self.llm.lm_head(out_chosen.last_hidden_state[:, T:, :]) # [B, S_c, V]
563
+
564
+ # ---- Forward on REJECTED (lm_head on text positions only) ----
565
+ text_embeds_rejected = self._embed_text(rejected_input_ids) # [B, S_r, ld]
566
+ seq_rejected = torch.cat([z_fine_llm, text_embeds_rejected], dim=1)
567
+ out_rejected = self.llm.model(inputs_embeds=seq_rejected)
568
+ rejected_logits = self.llm.lm_head(out_rejected.last_hidden_state[:, T:, :]) # [B, S_r, V]
569
+
570
+ # ---- Compute per-token log-probs ----
571
+ chosen_logps = self._sequence_logprobs(
572
+ chosen_logits, chosen_input_ids, chosen_loss_mask,
573
+ )
574
+ rejected_logps = self._sequence_logprobs(
575
+ rejected_logits, rejected_input_ids, rejected_loss_mask,
576
+ )
577
+
578
+ return {
579
+ "chosen_logps": chosen_logps, # [B]
580
+ "rejected_logps": rejected_logps, # [B]
581
+ "chosen_logits": chosen_logits, # [B, S_c, V]
582
+ "rejected_logits": rejected_logits, # [B, S_r, V]
583
+ }
584
+
585
+ def _sequence_logprobs(
586
+ self,
587
+ logits: torch.Tensor,
588
+ input_ids: torch.Tensor,
589
+ loss_mask: torch.Tensor,
590
+ ) -> torch.Tensor:
591
+ """
592
+ Compute per-sample sum of log-probabilities on answer tokens.
593
+
594
+ logits : [B, S, V] text-only logits (visual positions excluded)
595
+ input_ids : [B, S] text token ids
596
+ loss_mask : [B, S] 1.0 for answer tokens, 0.0 otherwise
597
+
598
+ Returns : [B] sum of log-probs per sample
599
+ """
600
+ B, S = input_ids.shape
601
+
602
+ # Shift for autoregressive prediction
603
+ shift_logits = logits[:, :-1, :] # [B, S-1, V]
604
+ shift_labels = input_ids[:, 1:] # [B, S-1]
605
+ shift_mask = loss_mask[:, 1:] # [B, S-1]
606
+
607
+ # Per-token log-probs: log_softmax then gather the label's prob
608
+ log_probs = F.log_softmax(shift_logits, dim=-1) # [B, S-1, V]
609
+ per_token_logps = log_probs.gather(
610
+ dim=-1, index=shift_labels.unsqueeze(-1),
611
+ ).squeeze(-1) # [B, S-1]
612
+
613
+ # Mask and sum per sample
614
+ per_token_logps = per_token_logps * shift_mask # zero out non-answer tokens
615
+ return per_token_logps.sum(dim=-1) # [B]
616
+
617
+ # ------------------------------------------------------------------
618
+ # Forward mode 2: Coarse only (FAST EVAL)
619
+ # ------------------------------------------------------------------
620
+
621
+ def forward_coarse_only(
622
+ self,
623
+ frames: torch.Tensor,
624
+ input_ids: Optional[torch.Tensor] = None,
625
+ attention_mask: Optional[torch.Tensor] = None,
626
+ loss_mask: Optional[torch.Tensor] = None,
627
+ frame_mask: Optional[torch.Tensor] = None,
628
+ ) -> Dict[str, torch.Tensor]:
629
+ """
630
+ Single-pass coarse forward (q_static only, no fine queries).
631
+
632
+ Used for:
633
+ - Training A6 ablation (coarse-only training)
634
+ - Fast eval (wrap in torch.no_grad() externally)
635
+
636
+ q_static -> all frames -> z_coarse -> LLM -> logits.
637
+
638
+ Parameters
639
+ ----------
640
+ frames : [B, T, 3, 224, 224]
641
+ input_ids : [B, S] (optional, for loss computation)
642
+ attention_mask : [B, S] (optional)
643
+ loss_mask : [B, S] (optional)
644
+
645
+ Returns
646
+ -------
647
+ dict with keys: logits, and optionally loss
648
+ """
649
+ B, T = frames.shape[:2]
650
+
651
+ kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
652
+
653
+ q_static = self.q_static.expand(B, -1)
654
+ z_coarse = self._query_all_frames(q_static, kv_cache, B, T, mask_flat, patch_features)
655
+ z_coarse_llm = self._project_visual(z_coarse)
656
+
657
+ if input_ids is not None:
658
+ text_embeds = self._embed_text(input_ids)
659
+ seq = torch.cat([z_coarse_llm, text_embeds], dim=1)
660
+ else:
661
+ seq = z_coarse_llm
662
+ # dtype handled by autocast on GPU; float32 on CPU
663
+
664
+ out = self.llm.model(inputs_embeds=seq)
665
+ h = out.last_hidden_state # [B, T+S, ld]
666
+
667
+ if input_ids is not None:
668
+ S = input_ids.shape[1]
669
+ pad_id = self._get_pad_token_id()
670
+ visual_pad = torch.full(
671
+ (B, T), pad_id, dtype=input_ids.dtype, device=input_ids.device,
672
+ )
673
+ full_labels = torch.cat([visual_pad, input_ids], dim=1)
674
+
675
+ if loss_mask is not None:
676
+ visual_no_loss = torch.zeros(
677
+ B, T, dtype=loss_mask.dtype, device=loss_mask.device,
678
+ )
679
+ full_loss_mask = torch.cat([visual_no_loss, loss_mask], dim=1)
680
+ elif attention_mask is not None:
681
+ visual_no_loss = torch.zeros(
682
+ B, T, dtype=attention_mask.dtype, device=attention_mask.device,
683
+ )
684
+ full_loss_mask = torch.cat([visual_no_loss, attention_mask], dim=1)
685
+ else:
686
+ full_loss_mask = None
687
+
688
+ if self.use_fused_ce and self.training:
689
+ # Fused CE: skip lm_head, never materializes [B, T+S, V]
690
+ loss = self._fused_ce_loss(h, full_labels, full_loss_mask)
691
+ logits = None
692
+ else:
693
+ logits = self.llm.lm_head(h)
694
+ loss = self._ce_loss(logits, full_labels, full_loss_mask)
695
+
696
+ result: Dict[str, torch.Tensor] = {"logits": logits, "loss": loss}
697
+ result["coarse_loss"] = loss
698
+ result["fine_loss"] = torch.tensor(0.0, device=frames.device)
699
+ else:
700
+ logits = self.llm.lm_head(h)
701
+ result: Dict[str, torch.Tensor] = {"logits": logits}
702
+
703
+ return result
704
+
705
+ # ------------------------------------------------------------------
706
+ # Forward mode 3: Autoregressive (TRUE INFERENCE)
707
+ # ------------------------------------------------------------------
708
+
709
+ @torch.no_grad()
710
+ def forward_autoregressive(
711
+ self,
712
+ frames: torch.Tensor,
713
+ input_ids: Optional[torch.Tensor] = None,
714
+ attention_mask: Optional[torch.Tensor] = None,
715
+ loss_mask: Optional[torch.Tensor] = None,
716
+ frame_mask: Optional[torch.Tensor] = None,
717
+ ) -> Dict[str, torch.Tensor]:
718
+ """
719
+ True autoregressive inference: sequential frame-by-frame with KV cache.
720
+
721
+ q_init -> frame_1 -> z_1 -> LLM -> q_1 -> frame_2 -> z_2 -> ...
722
+
723
+ No coarse pass. Each query is derived from the LLM hidden state after
724
+ processing the *previous* fine visual token -- exactly what happens at
725
+ real inference time.
726
+
727
+ Parameters
728
+ ----------
729
+ frames : [B, T, 3, 224, 224]
730
+ input_ids : [B, S] (optional, for loss computation)
731
+ attention_mask : [B, S] (optional)
732
+ loss_mask : [B, S] (optional)
733
+
734
+ Returns
735
+ -------
736
+ dict with keys: logits, and optionally loss
737
+ """
738
+ B, T = frames.shape[:2]
739
+ device = frames.device
740
+
741
+ # Encode all frames with DINO up front (this is OK -- DINO encoding
742
+ # does not depend on the query, only query_attend does).
743
+ kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
744
+
745
+ # Enable KV cache on the LLM for incremental decoding
746
+ orig_use_cache = self.llm.config.use_cache
747
+ self.llm.config.use_cache = True
748
+
749
+ query = self.q_init.expand(B, -1) # [B, qd]
750
+ llm_past_kv = None
751
+
752
+ for t in range(T):
753
+ # Foveated extraction with current query
754
+ frame_kv = self._extract_frame_kv(kv_cache, mask_flat, B, T, t)
755
+ z_t = self.encoder.query_attend(query, frame_kv) # [B, dd]
756
+ z_t_llm = self._project_visual(z_t.unsqueeze(1)) # [B,1,ld]
757
+ # dtype handled by autocast on GPU; float32 on CPU
758
+
759
+ # Incremental LLM forward (one visual token at a time)
760
+ out = self.llm.model(
761
+ inputs_embeds=z_t_llm,
762
+ past_key_values=llm_past_kv,
763
+ use_cache=True,
764
+ )
765
+ llm_past_kv = out.past_key_values
766
+
767
+ # Derive query for the NEXT frame from the current hidden state
768
+ if t < T - 1:
769
+ h_t = out.last_hidden_state[:, -1, :] # [B, ld]
770
+ query = self.llm_to_query(h_t) # [B, qd]
771
+
772
+ # ---- Now process text (if provided) using the accumulated KV cache ----
773
+ if input_ids is not None:
774
+ text_embeds = self._embed_text(input_ids) # [B, S, ld]
775
+
776
+ out_text = self.llm.model(
777
+ inputs_embeds=text_embeds,
778
+ past_key_values=llm_past_kv,
779
+ use_cache=False,
780
+ )
781
+ # Combine visual hidden states (already in KV cache) with text states
782
+ # for logit computation. We only need logits over the text portion
783
+ # (plus the last visual token which predicts the first text token).
784
+ #
785
+ # The KV cache holds T visual positions; out_text.last_hidden_state
786
+ # holds S text positions. We reconstruct the full logits as
787
+ # [visual_logits, text_logits] but only compute loss on text.
788
+ h_text = out_text.last_hidden_state # [B, S, ld]
789
+ logits_text = self.llm.lm_head(h_text) # [B, S, V]
790
+
791
+ # For the loss we also need the logit at the last visual position
792
+ # (it predicts the first text token). Re-derive it:
793
+ h_last_visual = out.last_hidden_state[:, -1:, :] # [B,1,ld]
794
+ logits_last_v = self.llm.lm_head(h_last_visual) # [B,1,V]
795
+
796
+ # Full logits over [last_visual, text] = [B, 1+S, V]
797
+ logits = torch.cat([logits_last_v, logits_text], dim=1)
798
+
799
+ # Labels: [pad_for_last_visual, input_ids]
800
+ pad_id = self._get_pad_token_id()
801
+ lv_pad = torch.full(
802
+ (B, 1), pad_id, dtype=input_ids.dtype, device=device,
803
+ )
804
+ full_labels = torch.cat([lv_pad, input_ids], dim=1)
805
+
806
+ # Loss mask
807
+ if loss_mask is not None:
808
+ lv_no_loss = torch.zeros(
809
+ B, 1, dtype=loss_mask.dtype, device=device,
810
+ )
811
+ full_loss_mask = torch.cat([lv_no_loss, loss_mask], dim=1)
812
+ elif attention_mask is not None:
813
+ lv_no_loss = torch.zeros(
814
+ B, 1, dtype=attention_mask.dtype, device=device,
815
+ )
816
+ full_loss_mask = torch.cat([lv_no_loss, attention_mask], dim=1)
817
+ else:
818
+ full_loss_mask = None
819
+
820
+ loss = self._ce_loss(logits, full_labels, full_loss_mask)
821
+
822
+ self.llm.config.use_cache = orig_use_cache
823
+ return {"loss": loss, "logits": logits}
824
+
825
+ else:
826
+ # No text -- just return logits at the last visual position
827
+ h_last = out.last_hidden_state # [B, 1, ld]
828
+ logits = self.llm.lm_head(h_last)
829
+ self.llm.config.use_cache = orig_use_cache
830
+ return {"logits": logits}
831
+
832
+ # ------------------------------------------------------------------
833
+ # Convenience: unified forward dispatching by name
834
+ # ------------------------------------------------------------------
835
+
836
+ def forward(
837
+ self,
838
+ frames: torch.Tensor,
839
+ input_ids: torch.Tensor,
840
+ attention_mask: torch.Tensor,
841
+ loss_mask: Optional[torch.Tensor] = None,
842
+ frame_mask: Optional[torch.Tensor] = None,
843
+ mode: str = "coarse_fine",
844
+ ) -> Dict[str, torch.Tensor]:
845
+ """
846
+ Unified forward entry point.
847
+
848
+ Parameters
849
+ ----------
850
+ frames : Tensor [B, T, 3, 224, 224]
851
+ Preprocessed video frames (DINOv2 normalization).
852
+ For **video**: T = number of sampled frames (1-64).
853
+ For **images**: replicate the single frame to T=8 to match training
854
+ distribution (``frame.unsqueeze(0).repeat(8, 1, 1, 1)``).
855
+ The model was trained with ``replicate_image_frames: 8`` in
856
+ Stages 2-3, so single-frame image input will produce degraded
857
+ results.
858
+ input_ids : Tensor [B, S]
859
+ Tokenized text (prompt + response).
860
+ attention_mask : Tensor [B, S]
861
+ 1 for real tokens, 0 for padding.
862
+ loss_mask : Tensor [B, S], optional
863
+ 1 for tokens that contribute to loss, 0 to skip.
864
+ frame_mask : Tensor [B, T] bool, optional
865
+ True for real frames, False for padding (for variable-length batches).
866
+ mode : str
867
+ "coarse_fine" — two-pass parallel forward (recommended, uses foveation)
868
+ "coarse_only" — single static-query pass (fastest, no foveation)
869
+ "autoregressive" — sequential inference with KV cache
870
+ """
871
+ if mode == "coarse_fine":
872
+ return self.forward_coarse_fine(frames, input_ids, attention_mask, loss_mask, frame_mask)
873
+ elif mode == "coarse_only":
874
+ return self.forward_coarse_only(frames, input_ids, attention_mask, loss_mask, frame_mask)
875
+ elif mode == "autoregressive":
876
+ return self.forward_autoregressive(frames, input_ids, attention_mask, loss_mask, frame_mask)
877
+ else:
878
+ raise ValueError(
879
+ f"Unknown forward mode '{mode}'. "
880
+ "Expected one of: coarse_fine, coarse_only, autoregressive"
881
+ )
882
+
883
+ # ------------------------------------------------------------------
884
+ # CUDA Stream Prefetch — overlap DINO encoding with LLM backward
885
+ # ------------------------------------------------------------------
886
+
887
+ def prefetch_dino(self, frames: torch.Tensor, frame_mask=None, stream=None):
888
+ """
889
+ Start DINO encoding on a separate CUDA stream.
890
+
891
+ Call this while the previous batch's backward pass is running.
892
+ The DINO encoder is frozen during training, so there's no gradient
893
+ dependency between the backward pass and this prefetch.
894
+
895
+ Args:
896
+ frames: [B, T, 3, 224, 224] next batch's frames
897
+ frame_mask: [B, T] bool, optional
898
+ stream: torch.cuda.Stream to run on (caller manages lifecycle)
899
+
900
+ Returns:
901
+ None — results are stored internally and retrieved via
902
+ forward_coarse_fine(..., prefetched_dino=True).
903
+ """
904
+ if stream is None:
905
+ stream = torch.cuda.Stream()
906
+ with torch.cuda.stream(stream):
907
+ with torch.no_grad():
908
+ self._prefetched_dino = self._encode_all_frames(frames, frame_mask)
909
+ self._prefetch_stream = stream
910
+
911
+ def _get_prefetched_dino(self):
912
+ """Retrieve and clear prefetched DINO results, synchronizing the stream."""
913
+ if hasattr(self, '_prefetched_dino') and self._prefetched_dino is not None:
914
+ self._prefetch_stream.synchronize()
915
+ result = self._prefetched_dino
916
+ self._prefetched_dino = None
917
+ self._prefetch_stream = None
918
+ return result
919
+ return None
920
+
921
+ # ------------------------------------------------------------------
922
+ # Utility methods for external callers (train.py, eval.py)
923
+ # ------------------------------------------------------------------
924
+
925
+ def enable_gradient_checkpointing(
926
+ self, llm_only: bool = False, use_reentrant: bool = True,
927
+ ) -> None:
928
+ """Turn on activation checkpointing for LLM (and optionally DINO).
929
+
930
+ Args:
931
+ llm_only: If True, only enable for LLM backbone. Leave DINO
932
+ un-checkpointed so it can be safely torch.compiled.
933
+ DINO is small (22M params) so checkpointing saves
934
+ little memory there.
935
+ use_reentrant: If False, use non-reentrant checkpointing which
936
+ is compatible with torch.compile (the reentrant
937
+ version causes NaN with compile). Default True
938
+ for backward compat; set False when using compile.
939
+ """
940
+ ckpt_kwargs = {"use_reentrant": use_reentrant}
941
+ self.llm.gradient_checkpointing_enable(
942
+ gradient_checkpointing_kwargs=ckpt_kwargs
943
+ )
944
+ if not llm_only and hasattr(self.encoder.dino, 'gradient_checkpointing_enable'):
945
+ self.encoder.dino.gradient_checkpointing_enable(
946
+ gradient_checkpointing_kwargs=ckpt_kwargs
947
+ )
948
+
949
+ def get_param_groups(
950
+ self,
951
+ lr_backbone: float = 1e-5,
952
+ lr_connector: float = 1e-4,
953
+ ) -> list:
954
+ """
955
+ Return parameter groups with differential learning rates.
956
+
957
+ Groups:
958
+ 1. Connector (dino_to_llm, llm_to_query, q_static, q_init) -- highest LR
959
+ 2. DINO encoder -- backbone LR
960
+ 3. LLM -- backbone LR
961
+
962
+ This is a suggestion; train.py may override.
963
+ """
964
+ connector_params = set()
965
+ for name, param in self.named_parameters():
966
+ if any(k in name for k in [
967
+ "dino_to_llm", "llm_to_query", "q_static", "q_init",
968
+ "query_input_proj", "query_output_proj",
969
+ ]):
970
+ connector_params.add(id(param))
971
+
972
+ encoder_params = set()
973
+ for name, param in self.encoder.named_parameters():
974
+ if id(param) not in connector_params:
975
+ encoder_params.add(id(param))
976
+
977
+ groups = [
978
+ {
979
+ "params": [p for p in self.parameters()
980
+ if id(p) in connector_params and p.requires_grad],
981
+ "lr": lr_connector,
982
+ "name": "connector",
983
+ },
984
+ {
985
+ "params": [p for n, p in self.encoder.named_parameters()
986
+ if id(p) in encoder_params and p.requires_grad],
987
+ "lr": lr_backbone,
988
+ "name": "dino",
989
+ },
990
+ {
991
+ "params": [p for p in self.llm.parameters() if p.requires_grad],
992
+ "lr": lr_backbone,
993
+ "name": "llm",
994
+ },
995
+ ]
996
+ return [g for g in groups if len(g["params"]) > 0]