sanps commited on
Commit
6d320d6
·
verified ·
1 Parent(s): 561dc59

Upload fVLM-135M: Foveated Vision-Language Model (Stage 3 DPO)

Browse files
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - vision-language
7
+ - video-understanding
8
+ - foveated-attention
9
+ - multimodal
10
+ - smollm2
11
+ - dinov2
12
+ library_name: pytorch
13
+ pipeline_tag: image-text-to-text
14
+ ---
15
+
16
+ # fVLM-135M (Foveated Vision-Language Model)
17
+
18
+ A compact vision-language model that uses **foveated attention** to compress each video frame into a single visual token, enabling efficient processing of long videos.
19
+
20
+ ## Architecture
21
+
22
+ | Component | Details |
23
+ |-----------|---------|
24
+ | **Language Model** | SmolLM2-135M-Instruct (HuggingFaceTB/SmolLM2-135M-Instruct) |
25
+ | **Vision Encoder** | DINOv2-small (facebook/dinov2-small) |
26
+ | **Attention** | Deep query-guided foveated cross-attention |
27
+ | **Visual Tokens** | 1 token per frame (query-compressed) |
28
+ | **Total Parameters** | 157.6M |
29
+ | **Query Dimension** | 384 |
30
+ | **Visual Scale** | 0.14 |
31
+
32
+ ### How Foveated Attention Works
33
+
34
+ Unlike standard VLMs that use many visual tokens per image (e.g., 576 for LLaVA), fVLM compresses each frame to a **single visual token** using a learned query mechanism:
35
+
36
+ 1. **DINOv2** encodes each frame into patch features and caches K/V at every layer
37
+ 2. A **query vector** is propagated through all 12 DINO layers, attending to patch K/V at each layer (deep query attention)
38
+ 3. The single output token is projected to LLM dimension and prepended to the text sequence
39
+ 4. The **LLM generates the next query** from its hidden state, creating a feedback loop where the model learns *where to look*
40
+
41
+ This enables processing **64+ frames** with the same memory as a few frames in traditional VLMs.
42
+
43
+ ## Training Pipeline
44
+
45
+ The model was trained in a 3-stage pipeline:
46
+
47
+ ### Stage 1: Visual Alignment
48
+ - **Data**: OpenVid-1M (905K) + WebVid (19K) + 14% SmolTalk text retention
49
+ - **Loss**: Full-text cross-entropy (predict all tokens)
50
+ - **LR**: Converging schedule -- connector 1e-3 to 3e-5, backbone 1e-5 to 3e-5
51
+ - **Objective**: Align visual and text embedding spaces
52
+
53
+ ### Stage 2: Vision-Language SFT
54
+ - **Data**: Cauldron (2M images) + video datasets (~1.6M) + 14% SmolTalk text retention
55
+ - **Loss**: Answer-only cross-entropy (mask user/system tokens)
56
+ - **LR**: Flat 3e-5 all components with cosine decay
57
+ - **Objective**: Instruction following on visual inputs
58
+
59
+ ### Stage 3: DPO (Direct Preference Optimization)
60
+ - **Data**: RLAIF-V (83K preference pairs)
61
+ - **Loss**: DPO with beta=0.1
62
+ - **LR**: 1e-6 all components
63
+ - **Objective**: Align model outputs with human preferences
64
+
65
+ ## Model Components
66
+
67
+ The checkpoint contains the full `FoveatedVLM` model with these submodules:
68
+
69
+ - `encoder.dino.*` -- DINOv2-small vision backbone
70
+ - `encoder.query_input_proj.*` -- Query projection into DINO space (bias=False)
71
+ - `encoder.output_proj.*` -- Output projection from DINO to query dim
72
+ - `dino_to_llm.*` -- Linear projection from DINO dim (384) to LLM dim (576)
73
+ - `llm_to_query.*` -- Linear projection from LLM dim (576) to query dim (384)
74
+ - `q_static` -- Learnable static query for coarse pass
75
+ - `q_init` -- Learnable initial query for fine pass (frame 0)
76
+ - `llm.*` -- SmolLM2-135M-Instruct language model
77
+
78
+ ## Usage
79
+
80
+ ```python
81
+ import torch
82
+ from transformers import AutoModelForCausalLM, AutoTokenizer
83
+ from huggingface_hub import hf_hub_download
84
+
85
+ # Download the checkpoint
86
+ ckpt_path = hf_hub_download(
87
+ repo_id="spsanps/fVLM-135M",
88
+ filename="model.safetensors", # or model.pt
89
+ )
90
+
91
+ # Load into FoveatedVLM (requires the model code from this repo)
92
+ # See release/model/foveated_vlm.py and release/model/encoder.py
93
+ from release.model import FoveatedVLM
94
+
95
+ model = FoveatedVLM(
96
+ llm_name="HuggingFaceTB/SmolLM2-135M-Instruct",
97
+ dino_name="facebook/dinov2-small",
98
+ query_dim=384,
99
+ visual_scale=0.14,
100
+ deep_query=True,
101
+ )
102
+
103
+ # Load weights
104
+ state_dict = torch.load(ckpt_path, map_location="cpu")
105
+ model.load_state_dict(state_dict)
106
+ model.eval()
107
+ ```
108
+
109
+ ## Config Files
110
+
111
+ The training configuration YAML files for all three stages are included in this repository:
112
+ - `configs/stage1_135M.yaml` -- Visual alignment config
113
+ - `configs/stage2_135M.yaml` -- Vision-language SFT config
114
+ - `configs/stage3_135M.yaml` -- DPO config
115
+
116
+ ## License
117
+
118
+ Apache 2.0
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "foveated_vlm",
3
+ "architectures": [
4
+ "FoveatedVLM"
5
+ ],
6
+ "llm_name": "HuggingFaceTB/SmolLM2-135M-Instruct",
7
+ "dino_name": "facebook/dinov2-small",
8
+ "llm_dim": 576,
9
+ "dino_dim": 384,
10
+ "query_dim": 384,
11
+ "visual_scale": 0.14,
12
+ "lambda_coarse": 0.0,
13
+ "deep_query": true,
14
+ "total_params": 185622528,
15
+ "training_stages": [
16
+ "Stage 1: Visual Alignment (OpenVid + WebVid + text retention)",
17
+ "Stage 2: Vision-Language SFT (Cauldron + video + text retention)",
18
+ "Stage 3: DPO (RLAIF-V preference pairs)"
19
+ ]
20
+ }
configs/stage1_135M.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # FINAL Stage 1: Visual Alignment — 135M
3
+ # =============================================================================
4
+ # Model: SmolLM2-135M-Instruct + DINOv2-small (157.6M total params)
5
+ # Loss: All-text CE (predict all tokens)
6
+ # LR: Converging schedule: connector=1e-3 → 3e-5, backbone=1e-5 → 3e-5
7
+ # Data: OpenVid-1M (905K) + WebVid (19K) + 14% SmolTalk S1 text retention
8
+ # Prompt: Honest conditioning ("What would be the WebVid caption?")
9
+ # Text retention: Proper chat format (not wrapped in WebVid prompt)
10
+ # =============================================================================
11
+
12
+ stage: 1
13
+
14
+ model:
15
+ llm: /workspace/models/SmolLM2-135M-Instruct
16
+ dino: /workspace/models/dinov2-small
17
+ deep_query: true
18
+ query_dim: 384
19
+ visual_scale: 0.14
20
+ lambda_coarse: 0.0
21
+ gradient_checkpointing: false
22
+
23
+ data:
24
+ train_shards:
25
+ - "/workspace/data/openvid/*.tar"
26
+ - "/workspace/data/webvid/*.tar"
27
+ val_shards: "/workspace/data/eval/val_10k/*.tar"
28
+ text_shards: "/workspace/data/text_retention/stage1/*.tar"
29
+ text_ratio: 0.14
30
+ max_frames: 64
31
+ frame_size: 224
32
+ num_workers: 6
33
+ prefetch_factor: 4
34
+
35
+ training:
36
+ total_samples: 1_000_000
37
+ batch_size: 8
38
+ grad_accum: 4
39
+ lr_connector: 1.0e-3
40
+ lr_dino: 1.0e-5
41
+ lr_llm: 1.0e-5
42
+ target_lr: 3.0e-5
43
+ warmup_ratio: 0.03
44
+ weight_decay: 0.01
45
+ max_grad_norm: 1.0
46
+ schedule: converging
47
+ dtype: bfloat16
48
+ compile: false
49
+ seed: 42
50
+
51
+ loss:
52
+ type: text_ce_all
53
+
54
+ checkpoint:
55
+ save_dir: /workspace/checkpoints/final/stage1
56
+ save_every_steps: 1000
57
+ keep_last: 2
58
+ keep_best: 1
59
+ metric: val_loss
60
+ resume: auto
61
+
62
+ eval:
63
+ every_steps: 500
64
+ max_samples: 1000
65
+
66
+ wandb:
67
+ project: foveated-vlm-final
68
+ run_name: stage1-135M
configs/stage2_135M.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # FINAL Stage 2: Vision-Language SFT — 135M
3
+ # =============================================================================
4
+ # Model: SmolLM2-135M-Instruct + DINOv2-small
5
+ # Loss: Answer-only CE (mask user/system tokens)
6
+ # LR: Flat 3e-5 all components (1:1, SmolVLM2 style) + cosine decay
7
+ # Data: Cauldron (2M images) + all video (~1.6M) + 14% SmolTalk S2 text
8
+ # Mix: ~55% image, ~45% video (natural shard ratio), +14% text interleave
9
+ # Images: Replicated to 8 frames (A8 sweep winner)
10
+ # Init: Best Stage 1 checkpoint
11
+ # =============================================================================
12
+
13
+ stage: 2
14
+
15
+ model:
16
+ llm: /workspace/models/SmolLM2-135M-Instruct
17
+ dino: /workspace/models/dinov2-small
18
+ deep_query: true
19
+ query_dim: 384
20
+ visual_scale: 0.14
21
+ lambda_coarse: 0.0
22
+ gradient_checkpointing: false
23
+ init_from: /workspace/checkpoints/final/stage1/best.pt
24
+
25
+ data:
26
+ train_shards:
27
+ - "/workspace/data/cauldron_full/*.tar"
28
+ - "/workspace/data/openvid/*.tar"
29
+ - "/workspace/data/webvid/*.tar"
30
+ - "/workspace/data/vista_shards/*.tar"
31
+ - "/workspace/data/vista_extra_shards/*.tar"
32
+ - "/workspace/data/vript_long_shards/*.tar"
33
+ - "/workspace/data/vript_shards/*.tar"
34
+ - "/workspace/data/sharegpt4video_shards/*.tar"
35
+ - "/workspace/data/stage3_youtube/*.tar"
36
+ # No val_shards — pretraining-style, train loss only
37
+ text_shards: "/workspace/data/text_retention/stage2/*.tar"
38
+ text_ratio: 0.14
39
+ max_frames: 64
40
+ frame_size: 224
41
+ num_workers: 2
42
+ prefetch_factor: 2
43
+ replicate_image_frames: 8
44
+
45
+ training:
46
+ total_samples: 1_000_000
47
+ batch_size: 8
48
+ grad_accum: 4
49
+ lr_connector: 3.0e-5
50
+ lr_dino: 3.0e-5
51
+ lr_llm: 3.0e-5
52
+ warmup_ratio: 0.03
53
+ weight_decay: 0.01
54
+ max_grad_norm: 1.0
55
+ schedule: cosine
56
+ dtype: bfloat16
57
+ compile: false # 135M too small for torch.compile (40% regression)
58
+ seed: 42
59
+
60
+ loss:
61
+ type: text_ce_answer_only
62
+
63
+ checkpoint:
64
+ save_dir: /workspace/checkpoints/final/stage2
65
+ save_every_steps: 1000
66
+ keep_last: 2
67
+ keep_best: 1
68
+ metric: train_loss # no eval — train loss is the signal for pretraining
69
+ resume: auto
70
+
71
+ # No eval — pretraining-style, train loss only. Saves ~6min/1M samples.
72
+
73
+ wandb:
74
+ project: foveated-vlm-final
75
+ run_name: stage2-135M
configs/stage3_135M.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # FINAL Stage 3: DPO — 135M
3
+ # =============================================================================
4
+ # Model: SmolLM2-135M-Instruct + DINOv2-small
5
+ # Loss: DPO (β=0.1, reference model = frozen Stage 2 best)
6
+ # LR: 1e-6 all components (low LR typical for DPO)
7
+ # Data: RLAIF-V (83K preference pairs: chosen + rejected)
8
+ # Init: Best Stage 2 checkpoint
9
+ # Reference: Same checkpoint (frozen copy)
10
+ # =============================================================================
11
+
12
+ stage: 3
13
+
14
+ model:
15
+ llm: /workspace/models/SmolLM2-135M-Instruct
16
+ dino: /workspace/models/dinov2-small
17
+ deep_query: true
18
+ query_dim: 384
19
+ visual_scale: 0.14
20
+ lambda_coarse: 0.0
21
+ gradient_checkpointing: false
22
+ init_from: /workspace/checkpoints/final/stage2/best.pt
23
+
24
+ data:
25
+ train_shards: "/workspace/data/rlaif_v/*.tar"
26
+ # No val_shards — train loss only
27
+ max_frames: 64
28
+ frame_size: 224
29
+ num_workers: 2
30
+ prefetch_factor: 2
31
+ replicate_image_frames: 8 # RLAIF-V is image-only
32
+
33
+ training:
34
+ total_samples: 83_000 # 1 epoch of RLAIF-V
35
+ batch_size: 4 # DPO needs chosen+rejected per sample (2x memory)
36
+ grad_accum: 8 # eff batch = 32
37
+ lr_connector: 1.0e-6
38
+ lr_dino: 1.0e-6
39
+ lr_llm: 1.0e-6
40
+ warmup_ratio: 0.1
41
+ weight_decay: 0.01
42
+ max_grad_norm: 1.0
43
+ schedule: cosine
44
+ dtype: bfloat16
45
+ compile: false
46
+ seed: 42
47
+
48
+ loss:
49
+ type: dpo # requires DPO collate + loss implementation
50
+ beta: 0.1 # DPO temperature
51
+
52
+ checkpoint:
53
+ save_dir: /workspace/checkpoints/final/stage3
54
+ save_every_steps: 500
55
+ keep_last: 2
56
+ keep_best: 1
57
+ metric: train_loss
58
+ resume: auto
59
+
60
+ # No eval — DPO metric is reward accuracy (chosen > rejected), logged per step.
61
+
62
+ wandb:
63
+ project: foveated-vlm-final
64
+ run_name: stage3-dpo-135M
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62a9bc6b203dc3c83f42a1d6b1e90b6a8ac0102db43a7224d73454cfabe56d57
3
+ size 742548968
model_code/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Foveated VLM model components."""
2
+
3
+ from release.model.foveated_vlm import FoveatedVLM
4
+ from release.model.encoder import FoveatedEncoder
5
+ from release.model.multi_token_vlm import MultiTokenVLM
6
+
7
+ __all__ = ["FoveatedVLM", "FoveatedEncoder", "MultiTokenVLM"]
model_code/encoder.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FoveatedEncoder -- DINOv2 vision encoder with query-guided cross-attention.
3
+
4
+ Deep query mode only: the query token is projected into DINO dimension then
5
+ propagated through every DINO layer using cached K,V from the patch tokens.
6
+ Patches never attend to the query (asymmetric mask), so the patch forward pass
7
+ runs once and all K,V are cached. The single query-position output after the
8
+ final layer is the foveated visual token.
9
+
10
+ Key design decisions (pre-fixed bugs baked in):
11
+ * query_input_proj has bias=False (BUG-002: bias dominated small queries,
12
+ causing uniform attention regardless of query content)
13
+ * No shallow mode (BUG-004: single cross-attention on final
14
+ DINO features gives output correlation ~0.98 -- effectively uniform)
15
+ * CLS token is kept (DINO was trained with it)
16
+ * Layer norm applied after all layers (matches DINO forward)
17
+
18
+ torch.compile friendly:
19
+ * Fixed loop count (num_layers is a Python int constant per model)
20
+ * No Python-level branching in hot paths
21
+ * Attention scale stored as a float constant (not recomputed)
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import math
27
+ from typing import List, Tuple
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from transformers import Dinov2Model
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Model configs -- keeps torch.compile happy (loop counts are Python ints)
37
+ # ---------------------------------------------------------------------------
38
+ DINO_CONFIGS = {
39
+ "facebook/dinov2-small": {"dim": 384, "heads": 6, "layers": 12, "patch_size": 14},
40
+ "facebook/dinov2-base": {"dim": 768, "heads": 12, "layers": 12, "patch_size": 14},
41
+ }
42
+
43
+
44
+ class FoveatedEncoder(nn.Module):
45
+ """
46
+ Vision encoder with deep query-guided attention.
47
+
48
+ Two-phase usage:
49
+ 1. ``patches, kv_cache = encoder.encode_patches(images)``
50
+ Run DINO on all frames, cache K/V at every layer.
51
+ 2. ``z = encoder.query_attend(query, kv_cache)``
52
+ Propagate query through all layers using cached K/V.
53
+ Returns a single foveated visual token per image.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ dino_model_name: str = "facebook/dinov2-small",
59
+ query_dim: int = 384,
60
+ output_dim: int | None = None,
61
+ ) -> None:
62
+ """
63
+ Args:
64
+ dino_model_name: HuggingFace model id for DINOv2.
65
+ query_dim: Dimension of incoming query vector (from LLM).
66
+ output_dim: Dimension of the output foveated token.
67
+ """
68
+ super().__init__()
69
+
70
+ # -- Load pretrained DINOv2 -----------------------------------------
71
+ self.dino: Dinov2Model = Dinov2Model.from_pretrained(dino_model_name)
72
+
73
+ # Cache model geometry as plain Python values for torch.compile.
74
+ cfg = self.dino.config
75
+ self.dino_dim: int = cfg.hidden_size
76
+ self.num_heads: int = cfg.num_attention_heads
77
+ self.head_dim: int = self.dino_dim // self.num_heads
78
+ self.num_layers: int = cfg.num_hidden_layers
79
+ self.patch_size: int = cfg.patch_size
80
+
81
+ # Pre-compute attention scale as a constant.
82
+ self.attn_scale: float = 1.0 / math.sqrt(self.head_dim)
83
+
84
+ # -- Projections ----------------------------------------------------
85
+ if output_dim is None:
86
+ output_dim = self.dino_dim
87
+
88
+ # bias=False is CRITICAL (BUG-002). With bias, different queries
89
+ # produce near-identical embeddings at init (bias dominates the small
90
+ # query signal), so attention is uniform and fine == coarse always.
91
+ self.query_input_proj = nn.Linear(query_dim, self.dino_dim, bias=False)
92
+ self.output_proj = nn.Linear(self.dino_dim, output_dim)
93
+
94
+ # Dummy buffer for device / dtype inference.
95
+ self.register_buffer("_device_probe", torch.zeros(1), persistent=False)
96
+
97
+ # -- Convenience --------------------------------------------------------
98
+
99
+ @property
100
+ def device(self) -> torch.device:
101
+ return self._device_probe.device
102
+
103
+ def num_patches(self, image_size: int = 224) -> int:
104
+ """Number of spatial patch tokens for a square image (excludes CLS)."""
105
+ grid = image_size // self.patch_size
106
+ return grid * grid
107
+
108
+ def num_tokens(self, image_size: int = 224) -> int:
109
+ """Total sequence length from DINO (CLS + spatial patches)."""
110
+ return 1 + self.num_patches(image_size)
111
+
112
+ # ======================================================================
113
+ # Phase 1: encode patches (run once per frame set)
114
+ # ======================================================================
115
+
116
+ def encode_patches(
117
+ self, images: torch.Tensor
118
+ ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
119
+ """
120
+ Encode images through DINOv2, caching K and V at every layer.
121
+
122
+ Args:
123
+ images: ``[B*T, 3, H, W]`` input images (ImageNet-normalised).
124
+
125
+ Returns:
126
+ patch_features: ``[B*T, N+1, D]`` final embeddings (CLS + patches),
127
+ after the last layer norm.
128
+ kv_cache: List of ``(K, V)`` tuples, one per DINO layer.
129
+ Each K, V has shape ``[B*T, N+1, D]`` (full dim,
130
+ not yet reshaped to multi-head).
131
+ """
132
+ # Convert to channels_last for better conv performance on tensor cores
133
+ images = images.to(memory_format=torch.channels_last)
134
+ # Patch + position embedding (includes CLS prepend).
135
+ hidden: torch.Tensor = self.dino.embeddings(images) # [B*T, N+1, D]
136
+
137
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
138
+
139
+ # Walk every encoder layer. The loop count (self.num_layers) is a
140
+ # Python int constant, so torch.compile unrolls it -- no graph breaks.
141
+ for layer in self.dino.encoder.layer:
142
+ normed = layer.norm1(hidden)
143
+
144
+ # Grab the K, V linear projections on the *normed* input.
145
+ attn_mod = layer.attention.attention # Dinov2SelfAttention
146
+ K = attn_mod.key(normed) # [B*T, N+1, D]
147
+ V = attn_mod.value(normed) # [B*T, N+1, D]
148
+ kv_cache.append((K, V))
149
+
150
+ # Full forward for the patch tokens (self-attention + FFN).
151
+ # Patches attend to patches only -- the query is not present yet.
152
+ layer_out = layer(hidden)
153
+ hidden = layer_out[0] if isinstance(layer_out, tuple) else layer_out
154
+
155
+ # Final layer norm (matches Dinov2Model.forward).
156
+ patch_features = self.dino.layernorm(hidden) # [B*T, N+1, D]
157
+
158
+ return patch_features, kv_cache
159
+
160
+ # ======================================================================
161
+ # Phase 2: query-attend (run per query)
162
+ # ======================================================================
163
+
164
+ def query_attend(
165
+ self,
166
+ query: torch.Tensor,
167
+ kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
168
+ return_attention: bool = False,
169
+ ) -> torch.Tensor:
170
+ """
171
+ Propagate a query token through every DINO layer using cached K/V.
172
+
173
+ The query can attend to all patch tokens, but patches never see the
174
+ query (asymmetric attention -- enabled by using the cached K/V that
175
+ were computed without the query present).
176
+
177
+ Args:
178
+ query: ``[B*T, query_dim]`` query vector from the LLM.
179
+ kv_cache: Output of :meth:`encode_patches` (list of (K, V) per layer).
180
+
181
+ Returns:
182
+ z: ``[B*T, output_dim]`` -- the single foveated visual token.
183
+ """
184
+ B = query.shape[0]
185
+
186
+ # Project query into DINO space.
187
+ q_hidden = self.query_input_proj(query).unsqueeze(1) # [B, 1, D]
188
+
189
+ all_attn_weights = [] if return_attention else None
190
+
191
+ # Walk every layer, reusing cached K/V from patches.
192
+ for layer_idx, layer in enumerate(self.dino.encoder.layer):
193
+ K, V = kv_cache[layer_idx] # each [B, N+1, D]
194
+
195
+ attn_mod = layer.attention.attention # Dinov2SelfAttention
196
+
197
+ # Pre-norm for the query token.
198
+ q_normed = layer.norm1(q_hidden) # [B, 1, D]
199
+
200
+ # Q projection for the query token only.
201
+ Q = attn_mod.query(q_normed) # [B, 1, D]
202
+
203
+ # Reshape to multi-head: [B, S, D] -> [B, H, S, d]
204
+ Q = Q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
205
+ K_h = K.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
206
+ V_h = V.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
207
+
208
+ # Scaled dot-product attention (query attends to all patches).
209
+ # Q: [B, H, 1, d], K_h: [B, H, N+1, d], V_h: [B, H, N+1, d]
210
+ if return_attention:
211
+ # Manual path: need explicit weights for visualization
212
+ attn_scores = torch.matmul(Q, K_h.transpose(-2, -1)) * self.attn_scale
213
+ attn_weights = F.softmax(attn_scores, dim=-1)
214
+ all_attn_weights.append(attn_weights.detach())
215
+ attn_out = torch.matmul(attn_weights, V_h)
216
+ else:
217
+ # SDPA: fused kernel, no intermediate allocations
218
+ attn_out = F.scaled_dot_product_attention(Q, K_h, V_h)
219
+
220
+ # Merge heads: [B, H, 1, d] -> [B, 1, D]
221
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, self.dino_dim)
222
+
223
+ # Output projection + dropout (Dinov2SelfOutput.dense / .dropout).
224
+ attn_out = layer.attention.output.dense(attn_out)
225
+ attn_out = layer.attention.output.dropout(attn_out)
226
+
227
+ # Layer scale 1 + residual.
228
+ attn_out = layer.layer_scale1(attn_out)
229
+ q_hidden = q_hidden + attn_out
230
+
231
+ # FFN block: norm2 -> MLP -> layer_scale2 -> residual.
232
+ ffn_out = layer.mlp(layer.norm2(q_hidden))
233
+ ffn_out = layer.layer_scale2(ffn_out)
234
+ q_hidden = q_hidden + ffn_out
235
+
236
+ # Final layer norm (same norm used at the end of encode_patches).
237
+ q_hidden = self.dino.layernorm(q_hidden) # [B, 1, D]
238
+
239
+ # Squeeze sequence dim and project to output dimension.
240
+ z = self.output_proj(q_hidden.squeeze(1)) # [B, output_dim]
241
+
242
+ if return_attention:
243
+ return z, all_attn_weights
244
+ return z
245
+
246
+ # ======================================================================
247
+ # Phase 2b: shallow query-attend (single cross-attention on final features)
248
+ # ======================================================================
249
+
250
+ def shallow_query_attend(
251
+ self,
252
+ query: torch.Tensor,
253
+ patch_features: torch.Tensor,
254
+ ) -> torch.Tensor:
255
+ """
256
+ Single cross-attention on final DINO features (no layer propagation).
257
+
258
+ This is the "shallow" baseline: the query does ONE attention over the
259
+ already-computed final patch embeddings. Different queries produce
260
+ near-identical outputs (BUG-004 validation) because there's no deep
261
+ propagation to amplify query differences.
262
+
263
+ Args:
264
+ query: ``[B, query_dim]``
265
+ patch_features: ``[B, N+1, D]`` (output of encode_patches)
266
+
267
+ Returns:
268
+ z: ``[B, output_dim]``
269
+ """
270
+ B = query.shape[0]
271
+
272
+ # Project query into DINO space
273
+ q = self.query_input_proj(query).unsqueeze(1) # [B, 1, D]
274
+
275
+ # Single cross-attention: query attends to all patches
276
+ Q = q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)
277
+ K = patch_features.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
278
+ V = K.clone() # K=V from the same features (no separate projections)
279
+
280
+ # Use the last layer's K/V projections for proper attention
281
+ last_layer = self.dino.encoder.layer[-1]
282
+ attn_mod = last_layer.attention.attention
283
+ normed = last_layer.norm1(patch_features)
284
+ K = attn_mod.key(normed).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
285
+ V = attn_mod.value(normed).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
286
+
287
+ attn_out = F.scaled_dot_product_attention(Q, K, V) # [B, H, 1, d]
288
+
289
+ # Merge heads
290
+ attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, self.dino_dim)
291
+
292
+ # Output projection + layer norm
293
+ q_hidden = self.dino.layernorm(attn_out)
294
+ z = self.output_proj(q_hidden.squeeze(1)) # [B, output_dim]
295
+ return z
296
+
297
+ # ======================================================================
298
+ # Convenience: full forward (encode + attend in one call)
299
+ # ======================================================================
300
+
301
+ def forward(
302
+ self,
303
+ images: torch.Tensor,
304
+ query: torch.Tensor,
305
+ ) -> torch.Tensor:
306
+ """
307
+ Full forward: encode patches then attend with query.
308
+
309
+ Args:
310
+ images: ``[B, 3, H, W]``
311
+ query: ``[B, query_dim]``
312
+
313
+ Returns:
314
+ z: ``[B, output_dim]`` foveated visual token.
315
+ """
316
+ _, kv_cache = self.encode_patches(images)
317
+ return self.query_attend(query, kv_cache)
318
+
319
+
320
+ # ---------------------------------------------------------------------------
321
+ # Self-test
322
+ # ---------------------------------------------------------------------------
323
+ if __name__ == "__main__":
324
+ print("=" * 60)
325
+ print("Testing FoveatedEncoder (deep query mode)")
326
+ print("=" * 60)
327
+
328
+ device = "cuda" if torch.cuda.is_available() else "cpu"
329
+ print(f"\nDevice: {device}")
330
+
331
+ encoder = FoveatedEncoder(
332
+ dino_model_name="facebook/dinov2-small",
333
+ query_dim=384,
334
+ output_dim=384,
335
+ ).to(device)
336
+
337
+ print(f" dino_dim = {encoder.dino_dim}")
338
+ print(f" num_heads = {encoder.num_heads}")
339
+ print(f" head_dim = {encoder.head_dim}")
340
+ print(f" num_layers = {encoder.num_layers}")
341
+ print(f" patch_size = {encoder.patch_size}")
342
+
343
+ batch_size = 2
344
+ images = torch.randn(batch_size, 3, 224, 224, device=device)
345
+ query_a = torch.randn(batch_size, 384, device=device)
346
+ query_b = torch.randn(batch_size, 384, device=device)
347
+
348
+ print(f"\n num_patches(224) = {encoder.num_patches(224)}")
349
+ print(f" num_tokens(224) = {encoder.num_tokens(224)}")
350
+
351
+ # -- Phase 1 --
352
+ print("\n--- encode_patches ---")
353
+ patch_features, kv_cache = encoder.encode_patches(images)
354
+ print(f" patch_features: {patch_features.shape}")
355
+ print(f" kv_cache: {len(kv_cache)} layers, K shape = {kv_cache[0][0].shape}")
356
+
357
+ # -- Phase 2 --
358
+ print("\n--- query_attend ---")
359
+ z_a = encoder.query_attend(query_a, kv_cache)
360
+ z_b = encoder.query_attend(query_b, kv_cache)
361
+ print(f" z_a: {z_a.shape}")
362
+ print(f" z_b: {z_b.shape}")
363
+
364
+ # Check that different queries give different outputs.
365
+ cosine = F.cosine_similarity(z_a, z_b, dim=-1).mean().item()
366
+ l2_diff = (z_a - z_b).norm(dim=-1).mean().item()
367
+ print(f" cosine(z_a, z_b) = {cosine:.4f} (should be << 1.0)")
368
+ print(f" L2 diff = {l2_diff:.4f} (should be >> 0)")
369
+
370
+ # -- Backward --
371
+ print("\n--- backward ---")
372
+ z_a.sum().backward()
373
+ print(" backward: OK")
374
+
375
+ # -- Combined forward --
376
+ print("\n--- forward (combined) ---")
377
+ encoder.zero_grad()
378
+ z = encoder(images, query_a)
379
+ z.sum().backward()
380
+ print(f" z: {z.shape}")
381
+ print(" backward: OK")
382
+
383
+ print("\n" + "=" * 60)
384
+ print("All tests passed.")
385
+ print("=" * 60)
model_code/foveated_vlm.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Foveated Vision-Language Model (release implementation).
3
+
4
+ Architecture: DINOv2 encoder + foveated cross-attention + SmolLM2 LLM.
5
+ Each video frame is compressed to ONE visual token via query-guided attention.
6
+ The LLM controls WHERE to look by generating the query for the next frame.
7
+
8
+ Three forward modes:
9
+ 1. forward_coarse_fine -- Training (two parallel passes)
10
+ 2. forward_coarse_only -- Fast eval (single static-query pass)
11
+ 3. forward_autoregressive -- True inference (sequential, KV-cached)
12
+
13
+ Loss: text cross-entropy only (no reconstruction, no VAE).
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from transformers import AutoModelForCausalLM, AutoConfig
20
+ from typing import Dict, Optional
21
+
22
+
23
+ class FoveatedVLM(nn.Module):
24
+ """
25
+ Foveated Vision-Language Model.
26
+
27
+ Parameters
28
+ ----------
29
+ llm_name : str
30
+ HuggingFace model id for SmolLM2 (e.g. "HuggingFaceTB/SmolLM2-135M-Instruct").
31
+ dino_name : str
32
+ HuggingFace model id for DINOv2 (e.g. "facebook/dinov2-small").
33
+ query_dim : int
34
+ Dimension of the foveated query vectors (matches DINO dim by default).
35
+ visual_scale : float
36
+ Multiplicative factor applied to projected visual tokens so their
37
+ magnitude matches the LLM embedding std (~0.14 for SmolLM2).
38
+ lambda_coarse : float
39
+ Weight for the optional auxiliary coarse-pass CE loss during training.
40
+ Set to 0 to disable.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ llm_name: str = "HuggingFaceTB/SmolLM2-135M-Instruct",
46
+ dino_name: str = "facebook/dinov2-small",
47
+ query_dim: int = 384,
48
+ visual_scale: float = 0.14,
49
+ lambda_coarse: float = 0.0,
50
+ deep_query: bool = True,
51
+ ):
52
+ super().__init__()
53
+
54
+ # ---- delayed import so encoder.py can live next to this file ----
55
+ from release.model.encoder import FoveatedEncoder
56
+
57
+ # ---- Vision encoder (DINOv2 + query cross-attention) ----
58
+ self.encoder = FoveatedEncoder(
59
+ dino_model_name=dino_name,
60
+ query_dim=query_dim,
61
+ output_dim=None, # output_dim = dino_dim by default inside encoder
62
+ )
63
+ dino_dim = self.encoder.dino_dim
64
+
65
+ # ---- Language model ----
66
+ self.llm = AutoModelForCausalLM.from_pretrained(
67
+ llm_name, attn_implementation="sdpa", torch_dtype=torch.float32,
68
+ )
69
+ self.llm.config.use_cache = False # training default; overridden per-method
70
+ llm_dim = self.llm.config.hidden_size
71
+
72
+ # ---- Projections ----
73
+ self.dino_to_llm = nn.Linear(dino_dim, llm_dim)
74
+ self.llm_to_query = nn.Linear(llm_dim, query_dim)
75
+
76
+ # ---- Learnable queries ----
77
+ # BUG-001 FIX: init with std=1.0 so queries dominate over projection
78
+ # bias and produce meaningful (non-uniform) attention patterns.
79
+ self.q_static = nn.Parameter(torch.randn(1, query_dim)) # std=1.0
80
+ self.q_init = nn.Parameter(torch.randn(1, query_dim)) # std=1.0
81
+
82
+ # ---- Hyperparams stored as plain Python (not buffers) ----
83
+ self.visual_scale = visual_scale
84
+ self.lambda_coarse = lambda_coarse
85
+ self.query_dim = query_dim
86
+ self.deep_query = deep_query
87
+
88
+ # ---- Dimension bookkeeping (useful for external code) ----
89
+ self.dino_dim = dino_dim
90
+ self.llm_dim = llm_dim
91
+
92
+ # ------------------------------------------------------------------
93
+ # helpers
94
+ # ------------------------------------------------------------------
95
+
96
+ def _get_pad_token_id(self) -> int:
97
+ """Return pad_token_id from the LLM config (never hardcoded)."""
98
+ pid = getattr(self.llm.config, "pad_token_id", None)
99
+ if pid is None:
100
+ pid = getattr(self.llm.config, "eos_token_id", 0)
101
+ return pid
102
+
103
+ def _llm_dtype(self) -> torch.dtype:
104
+ """Return the dtype of the LLM parameters (e.g. bfloat16)."""
105
+ return next(self.llm.parameters()).dtype
106
+
107
+ def _embed_text(self, input_ids: torch.Tensor) -> torch.Tensor:
108
+ """[B, S] -> [B, S, llm_dim] via LLM embedding table."""
109
+ return self.llm.get_input_embeddings()(input_ids)
110
+
111
+ def _project_visual(self, z: torch.Tensor) -> torch.Tensor:
112
+ """
113
+ Project DINO features to LLM space and rescale.
114
+
115
+ z : [B, T, dino_dim] or [B, dino_dim]
116
+ Returns same shape with last dim = llm_dim.
117
+ """
118
+ h = self.dino_to_llm(z) # -> llm_dim
119
+ h = h * self.visual_scale # match LLM embedding magnitude
120
+ return h
121
+
122
+ # Maximum frames per DINO encode/query call to prevent OOM on large batches.
123
+ _MAX_ENCODE_CHUNK = 200
124
+
125
+ def _encode_all_frames(self, frames: torch.Tensor, frame_mask=None):
126
+ """
127
+ Run DINO patch encoding for every frame in the batch.
128
+
129
+ frames : [B, T, 3, 224, 224]
130
+ frame_mask : [B, T] bool — True for real frames, False for padding.
131
+
132
+ Returns (kv_cache, patch_features, mask_flat):
133
+ kv_cache : list of (K, V) per layer, each [n_real, N+1, D]
134
+ (compact — only real frames, no padding waste).
135
+ patch_features : [n_real, N+1, D] final DINO embeddings (for shallow mode).
136
+ mask_flat : [B*T] bool tensor or None. Used to scatter results back.
137
+ """
138
+ B, T, C, H, W = frames.shape
139
+ BT = B * T
140
+ frames_flat = frames.reshape(BT, C, H, W)
141
+
142
+ if frame_mask is not None:
143
+ mask_flat = frame_mask.reshape(BT)
144
+ n_real = mask_flat.sum().item()
145
+ else:
146
+ mask_flat = None
147
+ n_real = BT
148
+
149
+ if mask_flat is not None and n_real < BT:
150
+ real_frames = frames_flat[mask_flat] # [n_real, C, H, W]
151
+ else:
152
+ real_frames = frames_flat
153
+
154
+ # Chunked encoding to prevent OOM on batches with many real frames
155
+ if real_frames.shape[0] <= self._MAX_ENCODE_CHUNK:
156
+ patch_features, kv_cache = self.encoder.encode_patches(real_frames)
157
+ else:
158
+ pf_chunks, kv_chunks = [], []
159
+ for start in range(0, real_frames.shape[0], self._MAX_ENCODE_CHUNK):
160
+ pf_chunk, kv_chunk = self.encoder.encode_patches(
161
+ real_frames[start:start + self._MAX_ENCODE_CHUNK]
162
+ )
163
+ pf_chunks.append(pf_chunk)
164
+ kv_chunks.append(kv_chunk)
165
+ patch_features = torch.cat(pf_chunks, dim=0)
166
+ kv_cache = [
167
+ (torch.cat([c[li][0] for c in kv_chunks], dim=0),
168
+ torch.cat([c[li][1] for c in kv_chunks], dim=0))
169
+ for li in range(len(kv_chunks[0]))
170
+ ]
171
+
172
+ return kv_cache, patch_features, mask_flat
173
+
174
+ def _batched_query_attend(self, queries: torch.Tensor, kv_cache: list,
175
+ patch_features: torch.Tensor = None) -> torch.Tensor:
176
+ """Chunked query_attend (deep) or shallow_query_attend to prevent OOM."""
177
+ n = queries.shape[0]
178
+ if not self.deep_query:
179
+ # Shallow mode: single cross-attention on final features
180
+ if n <= self._MAX_ENCODE_CHUNK:
181
+ return self.encoder.shallow_query_attend(queries, patch_features)
182
+ chunks = []
183
+ for start in range(0, n, self._MAX_ENCODE_CHUNK):
184
+ end = min(start + self._MAX_ENCODE_CHUNK, n)
185
+ chunks.append(self.encoder.shallow_query_attend(
186
+ queries[start:end], patch_features[start:end]))
187
+ return torch.cat(chunks, dim=0)
188
+ # Deep mode: propagate through all DINO layers
189
+ if n <= self._MAX_ENCODE_CHUNK:
190
+ return self.encoder.query_attend(queries, kv_cache)
191
+ chunks = []
192
+ for start in range(0, n, self._MAX_ENCODE_CHUNK):
193
+ end = min(start + self._MAX_ENCODE_CHUNK, n)
194
+ kv_slice = [(K[start:end], V[start:end]) for K, V in kv_cache]
195
+ chunks.append(self.encoder.query_attend(queries[start:end], kv_slice))
196
+ return torch.cat(chunks, dim=0)
197
+
198
+ def _query_all_frames(
199
+ self, query: torch.Tensor, kv_cache: list,
200
+ B: int, T: int, mask_flat=None, patch_features=None,
201
+ ) -> torch.Tensor:
202
+ """
203
+ Apply a single query to every frame in ONE batched query_attend call.
204
+
205
+ query : [B, query_dim]
206
+ kv_cache : list of (K, V) per layer, each [n_real, N+1, D]
207
+ B, T : batch and temporal dimensions
208
+ mask_flat : [B*T] bool or None
209
+ patch_features : [n_real, N+1, D] (needed for shallow mode)
210
+ Returns : [B, T, dino_dim]
211
+ """
212
+ BT = B * T
213
+ dd = self.encoder.dino_dim
214
+
215
+ # Expand: same query for all T frames → [B*T, qd]
216
+ query_exp = query.unsqueeze(1).expand(B, T, -1).reshape(BT, -1)
217
+
218
+ if mask_flat is not None:
219
+ n_real = mask_flat.sum().item()
220
+ if n_real == 0:
221
+ return torch.zeros(B, T, dd, device=query.device, dtype=query.dtype)
222
+ query_real = query_exp[mask_flat] # [n_real, qd]
223
+ z_real = self._batched_query_attend(query_real, kv_cache, patch_features)
224
+ z_flat = torch.zeros(BT, dd, device=query.device, dtype=z_real.dtype)
225
+ z_flat[mask_flat] = z_real
226
+ else:
227
+ z_flat = self._batched_query_attend(query_exp, kv_cache, patch_features)
228
+
229
+ return z_flat.reshape(B, T, dd)
230
+
231
+ def _query_all_frames_batched(
232
+ self, queries: torch.Tensor, kv_cache: list,
233
+ B: int, T: int, mask_flat=None, patch_features=None,
234
+ ) -> torch.Tensor:
235
+ """
236
+ Apply per-frame queries in ONE batched query_attend call.
237
+
238
+ queries : [B, T, query_dim]
239
+ kv_cache : list of (K, V) per layer, each [n_real, N+1, D]
240
+ B, T : batch and temporal dimensions
241
+ mask_flat : [B*T] bool or None
242
+ patch_features : [n_real, N+1, D] (needed for shallow mode)
243
+ Returns : [B, T, dino_dim]
244
+ """
245
+ BT = B * T
246
+ dd = self.encoder.dino_dim
247
+ queries_flat = queries.reshape(BT, -1)
248
+
249
+ if mask_flat is not None:
250
+ n_real = mask_flat.sum().item()
251
+ if n_real == 0:
252
+ return torch.zeros(B, T, dd, device=queries.device, dtype=queries.dtype)
253
+ query_real = queries_flat[mask_flat] # [n_real, qd]
254
+ z_real = self._batched_query_attend(query_real, kv_cache, patch_features)
255
+ z_flat = torch.zeros(BT, dd, device=queries.device, dtype=z_real.dtype)
256
+ z_flat[mask_flat] = z_real
257
+ else:
258
+ z_flat = self._batched_query_attend(queries_flat, kv_cache, patch_features)
259
+
260
+ return z_flat.reshape(B, T, dd)
261
+
262
+ def _extract_frame_kv(self, kv_cache: list, mask_flat, B: int, T: int, frame_idx: int):
263
+ """
264
+ Extract single-frame KV cache from flat format (for autoregressive/eval).
265
+
266
+ Returns list of (K, V) per layer, each [B, N+1, D].
267
+ """
268
+ if mask_flat is not None:
269
+ # Scatter compact caches to full [B*T] then extract frame
270
+ N1 = kv_cache[0][0].shape[1]
271
+ D = kv_cache[0][0].shape[2]
272
+ frame_kv = []
273
+ for K_real, V_real in kv_cache:
274
+ K_full = torch.zeros(B * T, N1, D, dtype=K_real.dtype, device=K_real.device)
275
+ V_full = torch.zeros(B * T, N1, D, dtype=V_real.dtype, device=V_real.device)
276
+ K_full[mask_flat] = K_real
277
+ V_full[mask_flat] = V_real
278
+ K_t = K_full.reshape(B, T, N1, D)[:, frame_idx] # [B, N+1, D]
279
+ V_t = V_full.reshape(B, T, N1, D)[:, frame_idx]
280
+ frame_kv.append((K_t, V_t))
281
+ return frame_kv
282
+ else:
283
+ N1 = kv_cache[0][0].shape[1]
284
+ D = kv_cache[0][0].shape[2]
285
+ frame_kv = []
286
+ for K_all, V_all in kv_cache:
287
+ K_t = K_all.reshape(B, T, N1, D)[:, frame_idx]
288
+ V_t = V_all.reshape(B, T, N1, D)[:, frame_idx]
289
+ frame_kv.append((K_t, V_t))
290
+ return frame_kv
291
+
292
+ def _build_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
293
+ """
294
+ Standard causal attention mask [1, 1, S, S] for the LLM.
295
+ True = masked (cannot attend), False = allowed.
296
+ """
297
+ mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device).triu(1)
298
+ return mask.unsqueeze(0).unsqueeze(0) # [1, 1, S, S]
299
+
300
+ def _ce_loss(
301
+ self,
302
+ logits: torch.Tensor,
303
+ labels: torch.Tensor,
304
+ loss_mask: Optional[torch.Tensor] = None,
305
+ ) -> torch.Tensor:
306
+ """
307
+ Standard autoregressive CE loss with shift-by-1.
308
+
309
+ logits : [B, S, V] (full sequence logits)
310
+ labels : [B, S] (token ids; positions without loss use pad)
311
+ loss_mask : [B, S] (1 = compute loss, 0 = ignore). Applied BEFORE
312
+ the shift so that loss_mask[i] guards label[i].
313
+
314
+ Returns scalar loss.
315
+ """
316
+ # Shift: predict position i+1 from position i
317
+ shift_logits = logits[:, :-1, :].contiguous() # [B, S-1, V]
318
+ shift_labels = labels[:, 1:].contiguous() # [B, S-1]
319
+
320
+ if loss_mask is not None:
321
+ shift_mask = loss_mask[:, 1:].contiguous() # [B, S-1]
322
+ # Replace masked positions with ignore_index so CE ignores them
323
+ pad_id = self._get_pad_token_id()
324
+ shift_labels = shift_labels.clone()
325
+ shift_labels[shift_mask == 0] = pad_id
326
+
327
+ V = shift_logits.shape[-1]
328
+ loss = F.cross_entropy(
329
+ shift_logits.reshape(-1, V),
330
+ shift_labels.reshape(-1),
331
+ ignore_index=self._get_pad_token_id(),
332
+ reduction="mean",
333
+ )
334
+ return loss
335
+
336
+ # ------------------------------------------------------------------
337
+ # Forward mode 1: Coarse+Fine (TRAINING)
338
+ # ------------------------------------------------------------------
339
+
340
+ def forward_coarse_fine(
341
+ self,
342
+ frames: torch.Tensor,
343
+ input_ids: torch.Tensor,
344
+ attention_mask: torch.Tensor,
345
+ loss_mask: Optional[torch.Tensor] = None,
346
+ frame_mask: Optional[torch.Tensor] = None,
347
+ ) -> Dict[str, torch.Tensor]:
348
+ """
349
+ Two-pass parallel training forward.
350
+
351
+ Pass 1 (coarse): q_static -> all frames -> z_coarse -> LLM -> dynamic queries
352
+ Pass 2 (fine): shifted queries -> all frames -> z_fine -> LLM + text -> loss
353
+
354
+ Parameters
355
+ ----------
356
+ frames : [B, T, 3, 224, 224]
357
+ input_ids : [B, S] tokenized text (prompt + answer)
358
+ attention_mask : [B, S] text attention mask
359
+ loss_mask : [B, S] which tokens contribute to loss (1=yes, 0=no).
360
+ If None, all non-pad tokens have loss.
361
+
362
+ Returns
363
+ -------
364
+ dict with keys: loss, logits, coarse_loss (optional), fine_loss
365
+ """
366
+ B, T = frames.shape[:2]
367
+ S = input_ids.shape[1]
368
+
369
+ # ---- Step 0: Encode all frames (DINO, shared across both passes) ----
370
+ kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
371
+
372
+ # ---- Pass 1: Coarse ----
373
+ q_static = self.q_static.expand(B, -1) # [B, qd]
374
+ z_coarse = self._query_all_frames(q_static, kv_cache, B, T, mask_flat, patch_features) # [B,T,dd]
375
+ z_coarse_llm = self._project_visual(z_coarse) # [B,T,ld]
376
+
377
+ # Build coarse sequence: [visual_coarse, text]
378
+ text_embeds = self._embed_text(input_ids) # [B,S,ld]
379
+ seq_coarse = torch.cat([z_coarse_llm, text_embeds], dim=1) # [B,T+S,ld]
380
+ # dtype handled by autocast on GPU; float32 on CPU
381
+
382
+ # LLM forward (backbone only, no lm_head yet)
383
+ out_coarse = self.llm.model(inputs_embeds=seq_coarse)
384
+ h_coarse = out_coarse.last_hidden_state # [B,T+S,ld]
385
+
386
+ # Extract dynamic queries from visual positions
387
+ # h_coarse[:, 0..T-1] are the hidden states at visual token positions
388
+ # Each one generates a query for the corresponding frame
389
+ h_visual_coarse = h_coarse[:, :T, :] # [B,T,ld]
390
+ queries = self.llm_to_query(h_visual_coarse) # [B,T,qd]
391
+
392
+ # Shift queries: frame t gets query from frame t-1; frame 0 gets q_init
393
+ q_init = self.q_init.expand(B, 1, -1) # [B,1,qd]
394
+ shifted_queries = torch.cat([q_init, queries[:, :-1]], dim=1) # [B,T,qd]
395
+
396
+ # ---- Pass 2: Fine ----
397
+ z_fine = self._query_all_frames_batched(shifted_queries, kv_cache, B, T, mask_flat, patch_features) # [B,T,dd]
398
+ z_fine_llm = self._project_visual(z_fine) # [B,T,ld]
399
+
400
+ # Build fine sequence: [visual_fine, text]
401
+ seq_fine = torch.cat([z_fine_llm, text_embeds], dim=1) # [B,T+S,ld]
402
+ # dtype handled by autocast on GPU; float32 on CPU
403
+
404
+ out_fine = self.llm.model(inputs_embeds=seq_fine)
405
+ h_fine = out_fine.last_hidden_state # [B,T+S,ld]
406
+
407
+ # Get logits over the FULL sequence (visual + text positions)
408
+ logits_full = self.llm.lm_head(h_fine) # [B,T+S,V]
409
+
410
+ # ---- Loss on text portion only ----
411
+ # The text tokens start at position T in the sequence.
412
+ # We need labels aligned with the full sequence: visual positions get pad.
413
+ pad_id = self._get_pad_token_id()
414
+ visual_pad = torch.full(
415
+ (B, T), pad_id, dtype=input_ids.dtype, device=input_ids.device,
416
+ )
417
+ full_labels = torch.cat([visual_pad, input_ids], dim=1) # [B, T+S]
418
+
419
+ # Build full loss mask: 0 for visual positions, then the provided loss_mask
420
+ if loss_mask is not None:
421
+ visual_no_loss = torch.zeros(
422
+ B, T, dtype=loss_mask.dtype, device=loss_mask.device,
423
+ )
424
+ full_loss_mask = torch.cat([visual_no_loss, loss_mask], dim=1) # [B,T+S]
425
+ else:
426
+ # Default: compute loss on all text positions that are not padding
427
+ visual_no_loss = torch.zeros(B, T, dtype=attention_mask.dtype, device=attention_mask.device)
428
+ text_loss_mask = attention_mask # non-pad text positions
429
+ full_loss_mask = torch.cat([visual_no_loss, text_loss_mask], dim=1)
430
+
431
+ fine_loss = self._ce_loss(logits_full, full_labels, full_loss_mask)
432
+
433
+ # ---- Optional auxiliary coarse loss ----
434
+ coarse_loss = torch.tensor(0.0, device=frames.device)
435
+ if self.lambda_coarse > 0:
436
+ logits_coarse = self.llm.lm_head(h_coarse)
437
+ coarse_loss = self._ce_loss(logits_coarse, full_labels, full_loss_mask)
438
+
439
+ # ---- Combined loss ----
440
+ loss = fine_loss + self.lambda_coarse * coarse_loss
441
+
442
+ return {
443
+ "loss": loss,
444
+ "fine_loss": fine_loss,
445
+ "coarse_loss": coarse_loss,
446
+ "logits": logits_full,
447
+ }
448
+
449
+ # ------------------------------------------------------------------
450
+ # Forward mode: DPO (preference training)
451
+ # ------------------------------------------------------------------
452
+
453
+ def forward_dpo(
454
+ self,
455
+ frames: torch.Tensor,
456
+ chosen_input_ids: torch.Tensor,
457
+ chosen_attention_mask: torch.Tensor,
458
+ chosen_loss_mask: torch.Tensor,
459
+ rejected_input_ids: torch.Tensor,
460
+ rejected_attention_mask: torch.Tensor,
461
+ rejected_loss_mask: torch.Tensor,
462
+ frame_mask: Optional[torch.Tensor] = None,
463
+ ) -> Dict[str, torch.Tensor]:
464
+ """
465
+ DPO forward pass: run coarse+fine on both chosen and rejected sequences.
466
+
467
+ Shares DINO encoding across chosen and rejected (same visual input).
468
+ Returns per-sample sum of log-probabilities for both chosen and rejected,
469
+ masked by loss_mask (answer-only tokens).
470
+
471
+ Parameters
472
+ ----------
473
+ frames : [B, T, 3, 224, 224]
474
+ chosen_input_ids : [B, S_c]
475
+ chosen_attention_mask : [B, S_c]
476
+ chosen_loss_mask : [B, S_c] (1 = answer token, 0 = prompt/pad)
477
+ rejected_input_ids : [B, S_r]
478
+ rejected_attention_mask : [B, S_r]
479
+ rejected_loss_mask : [B, S_r]
480
+ frame_mask : [B, T] bool (optional)
481
+
482
+ Returns
483
+ -------
484
+ dict with keys:
485
+ chosen_logps : [B] per-sample sum of log-probs on chosen answer tokens
486
+ rejected_logps : [B] per-sample sum of log-probs on rejected answer tokens
487
+ chosen_logits : [B, T+S_c, V] full logits for chosen
488
+ rejected_logits : [B, T+S_r, V] full logits for rejected
489
+ """
490
+ B, T = frames.shape[:2]
491
+
492
+ # ---- Step 0: Encode all frames (DINO, shared across chosen & rejected) ----
493
+ kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
494
+
495
+ # ---- Coarse pass (shared, used for dynamic query generation) ----
496
+ q_static = self.q_static.expand(B, -1) # [B, qd]
497
+ z_coarse = self._query_all_frames(q_static, kv_cache, B, T, mask_flat, patch_features)
498
+ z_coarse_llm = self._project_visual(z_coarse) # [B, T, ld]
499
+
500
+ # Run coarse LLM to get dynamic queries (use chosen text for query generation)
501
+ text_embeds_chosen = self._embed_text(chosen_input_ids) # [B, S_c, ld]
502
+ seq_coarse = torch.cat([z_coarse_llm, text_embeds_chosen], dim=1)
503
+ out_coarse = self.llm.model(inputs_embeds=seq_coarse)
504
+ h_coarse = out_coarse.last_hidden_state
505
+
506
+ # Extract dynamic queries from visual positions
507
+ h_visual_coarse = h_coarse[:, :T, :] # [B, T, ld]
508
+ queries = self.llm_to_query(h_visual_coarse) # [B, T, qd]
509
+
510
+ q_init = self.q_init.expand(B, 1, -1)
511
+ shifted_queries = torch.cat([q_init, queries[:, :-1]], dim=1) # [B, T, qd]
512
+
513
+ # ---- Fine pass: shared visual features ----
514
+ z_fine = self._query_all_frames_batched(shifted_queries, kv_cache, B, T, mask_flat, patch_features)
515
+ z_fine_llm = self._project_visual(z_fine) # [B, T, ld]
516
+
517
+ # ---- Forward on CHOSEN ----
518
+ seq_chosen = torch.cat([z_fine_llm, text_embeds_chosen], dim=1) # [B, T+S_c, ld]
519
+ out_chosen = self.llm.model(inputs_embeds=seq_chosen)
520
+ chosen_logits = self.llm.lm_head(out_chosen.last_hidden_state) # [B, T+S_c, V]
521
+
522
+ # ---- Forward on REJECTED ----
523
+ text_embeds_rejected = self._embed_text(rejected_input_ids) # [B, S_r, ld]
524
+ seq_rejected = torch.cat([z_fine_llm, text_embeds_rejected], dim=1)
525
+ out_rejected = self.llm.model(inputs_embeds=seq_rejected)
526
+ rejected_logits = self.llm.lm_head(out_rejected.last_hidden_state)
527
+
528
+ # ---- Compute per-token log-probs ----
529
+ chosen_logps = self._sequence_logprobs(
530
+ chosen_logits, chosen_input_ids, chosen_loss_mask, T,
531
+ )
532
+ rejected_logps = self._sequence_logprobs(
533
+ rejected_logits, rejected_input_ids, rejected_loss_mask, T,
534
+ )
535
+
536
+ return {
537
+ "chosen_logps": chosen_logps, # [B]
538
+ "rejected_logps": rejected_logps, # [B]
539
+ "chosen_logits": chosen_logits, # [B, T+S_c, V]
540
+ "rejected_logits": rejected_logits, # [B, T+S_r, V]
541
+ }
542
+
543
+ def _sequence_logprobs(
544
+ self,
545
+ logits: torch.Tensor,
546
+ input_ids: torch.Tensor,
547
+ loss_mask: torch.Tensor,
548
+ T: int,
549
+ ) -> torch.Tensor:
550
+ """
551
+ Compute per-sample sum of log-probabilities on answer tokens.
552
+
553
+ logits : [B, T+S, V] full sequence logits (visual + text)
554
+ input_ids : [B, S] text token ids
555
+ loss_mask : [B, S] 1.0 for answer tokens, 0.0 otherwise
556
+ T : int number of visual token positions
557
+
558
+ Returns : [B] sum of log-probs per sample
559
+ """
560
+ B, S = input_ids.shape
561
+
562
+ # Extract text logits and shift for autoregressive prediction
563
+ text_logits = logits[:, T:, :] # [B, S, V]
564
+ shift_logits = text_logits[:, :-1, :] # [B, S-1, V]
565
+ shift_labels = input_ids[:, 1:] # [B, S-1]
566
+ shift_mask = loss_mask[:, 1:] # [B, S-1]
567
+
568
+ # Per-token log-probs: log_softmax then gather the label's prob
569
+ log_probs = F.log_softmax(shift_logits, dim=-1) # [B, S-1, V]
570
+ per_token_logps = log_probs.gather(
571
+ dim=-1, index=shift_labels.unsqueeze(-1),
572
+ ).squeeze(-1) # [B, S-1]
573
+
574
+ # Mask and sum per sample
575
+ per_token_logps = per_token_logps * shift_mask # zero out non-answer tokens
576
+ return per_token_logps.sum(dim=-1) # [B]
577
+
578
+ # ------------------------------------------------------------------
579
+ # Forward mode 2: Coarse only (FAST EVAL)
580
+ # ------------------------------------------------------------------
581
+
582
+ def forward_coarse_only(
583
+ self,
584
+ frames: torch.Tensor,
585
+ input_ids: Optional[torch.Tensor] = None,
586
+ attention_mask: Optional[torch.Tensor] = None,
587
+ loss_mask: Optional[torch.Tensor] = None,
588
+ frame_mask: Optional[torch.Tensor] = None,
589
+ ) -> Dict[str, torch.Tensor]:
590
+ """
591
+ Single-pass coarse forward (q_static only, no fine queries).
592
+
593
+ Used for:
594
+ - Training A6 ablation (coarse-only training)
595
+ - Fast eval (wrap in torch.no_grad() externally)
596
+
597
+ q_static -> all frames -> z_coarse -> LLM -> logits.
598
+
599
+ Parameters
600
+ ----------
601
+ frames : [B, T, 3, 224, 224]
602
+ input_ids : [B, S] (optional, for loss computation)
603
+ attention_mask : [B, S] (optional)
604
+ loss_mask : [B, S] (optional)
605
+
606
+ Returns
607
+ -------
608
+ dict with keys: logits, and optionally loss
609
+ """
610
+ B, T = frames.shape[:2]
611
+
612
+ kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
613
+
614
+ q_static = self.q_static.expand(B, -1)
615
+ z_coarse = self._query_all_frames(q_static, kv_cache, B, T, mask_flat, patch_features)
616
+ z_coarse_llm = self._project_visual(z_coarse)
617
+
618
+ if input_ids is not None:
619
+ text_embeds = self._embed_text(input_ids)
620
+ seq = torch.cat([z_coarse_llm, text_embeds], dim=1)
621
+ else:
622
+ seq = z_coarse_llm
623
+ # dtype handled by autocast on GPU; float32 on CPU
624
+
625
+ out = self.llm.model(inputs_embeds=seq)
626
+ logits = self.llm.lm_head(out.last_hidden_state)
627
+
628
+ result: Dict[str, torch.Tensor] = {"logits": logits}
629
+
630
+ if input_ids is not None:
631
+ S = input_ids.shape[1]
632
+ pad_id = self._get_pad_token_id()
633
+ visual_pad = torch.full(
634
+ (B, T), pad_id, dtype=input_ids.dtype, device=input_ids.device,
635
+ )
636
+ full_labels = torch.cat([visual_pad, input_ids], dim=1)
637
+
638
+ if loss_mask is not None:
639
+ visual_no_loss = torch.zeros(
640
+ B, T, dtype=loss_mask.dtype, device=loss_mask.device,
641
+ )
642
+ full_loss_mask = torch.cat([visual_no_loss, loss_mask], dim=1)
643
+ elif attention_mask is not None:
644
+ visual_no_loss = torch.zeros(
645
+ B, T, dtype=attention_mask.dtype, device=attention_mask.device,
646
+ )
647
+ full_loss_mask = torch.cat([visual_no_loss, attention_mask], dim=1)
648
+ else:
649
+ full_loss_mask = None
650
+
651
+ loss = self._ce_loss(logits, full_labels, full_loss_mask)
652
+ result["loss"] = loss
653
+ result["coarse_loss"] = loss
654
+ result["fine_loss"] = torch.tensor(0.0, device=frames.device)
655
+
656
+ return result
657
+
658
+ # ------------------------------------------------------------------
659
+ # Forward mode 3: Autoregressive (TRUE INFERENCE)
660
+ # ------------------------------------------------------------------
661
+
662
+ @torch.no_grad()
663
+ def forward_autoregressive(
664
+ self,
665
+ frames: torch.Tensor,
666
+ input_ids: Optional[torch.Tensor] = None,
667
+ attention_mask: Optional[torch.Tensor] = None,
668
+ loss_mask: Optional[torch.Tensor] = None,
669
+ frame_mask: Optional[torch.Tensor] = None,
670
+ ) -> Dict[str, torch.Tensor]:
671
+ """
672
+ True autoregressive inference: sequential frame-by-frame with KV cache.
673
+
674
+ q_init -> frame_1 -> z_1 -> LLM -> q_1 -> frame_2 -> z_2 -> ...
675
+
676
+ No coarse pass. Each query is derived from the LLM hidden state after
677
+ processing the *previous* fine visual token -- exactly what happens at
678
+ real inference time.
679
+
680
+ Parameters
681
+ ----------
682
+ frames : [B, T, 3, 224, 224]
683
+ input_ids : [B, S] (optional, for loss computation)
684
+ attention_mask : [B, S] (optional)
685
+ loss_mask : [B, S] (optional)
686
+
687
+ Returns
688
+ -------
689
+ dict with keys: logits, and optionally loss
690
+ """
691
+ B, T = frames.shape[:2]
692
+ device = frames.device
693
+
694
+ # Encode all frames with DINO up front (this is OK -- DINO encoding
695
+ # does not depend on the query, only query_attend does).
696
+ kv_cache, patch_features, mask_flat = self._encode_all_frames(frames, frame_mask)
697
+
698
+ # Enable KV cache on the LLM for incremental decoding
699
+ orig_use_cache = self.llm.config.use_cache
700
+ self.llm.config.use_cache = True
701
+
702
+ query = self.q_init.expand(B, -1) # [B, qd]
703
+ llm_past_kv = None
704
+
705
+ for t in range(T):
706
+ # Foveated extraction with current query
707
+ frame_kv = self._extract_frame_kv(kv_cache, mask_flat, B, T, t)
708
+ z_t = self.encoder.query_attend(query, frame_kv) # [B, dd]
709
+ z_t_llm = self._project_visual(z_t.unsqueeze(1)) # [B,1,ld]
710
+ # dtype handled by autocast on GPU; float32 on CPU
711
+
712
+ # Incremental LLM forward (one visual token at a time)
713
+ out = self.llm.model(
714
+ inputs_embeds=z_t_llm,
715
+ past_key_values=llm_past_kv,
716
+ use_cache=True,
717
+ )
718
+ llm_past_kv = out.past_key_values
719
+
720
+ # Derive query for the NEXT frame from the current hidden state
721
+ if t < T - 1:
722
+ h_t = out.last_hidden_state[:, -1, :] # [B, ld]
723
+ query = self.llm_to_query(h_t) # [B, qd]
724
+
725
+ # ---- Now process text (if provided) using the accumulated KV cache ----
726
+ if input_ids is not None:
727
+ text_embeds = self._embed_text(input_ids) # [B, S, ld]
728
+
729
+ out_text = self.llm.model(
730
+ inputs_embeds=text_embeds,
731
+ past_key_values=llm_past_kv,
732
+ use_cache=False,
733
+ )
734
+ # Combine visual hidden states (already in KV cache) with text states
735
+ # for logit computation. We only need logits over the text portion
736
+ # (plus the last visual token which predicts the first text token).
737
+ #
738
+ # The KV cache holds T visual positions; out_text.last_hidden_state
739
+ # holds S text positions. We reconstruct the full logits as
740
+ # [visual_logits, text_logits] but only compute loss on text.
741
+ h_text = out_text.last_hidden_state # [B, S, ld]
742
+ logits_text = self.llm.lm_head(h_text) # [B, S, V]
743
+
744
+ # For the loss we also need the logit at the last visual position
745
+ # (it predicts the first text token). Re-derive it:
746
+ h_last_visual = out.last_hidden_state[:, -1:, :] # [B,1,ld]
747
+ logits_last_v = self.llm.lm_head(h_last_visual) # [B,1,V]
748
+
749
+ # Full logits over [last_visual, text] = [B, 1+S, V]
750
+ logits = torch.cat([logits_last_v, logits_text], dim=1)
751
+
752
+ # Labels: [pad_for_last_visual, input_ids]
753
+ pad_id = self._get_pad_token_id()
754
+ lv_pad = torch.full(
755
+ (B, 1), pad_id, dtype=input_ids.dtype, device=device,
756
+ )
757
+ full_labels = torch.cat([lv_pad, input_ids], dim=1)
758
+
759
+ # Loss mask
760
+ if loss_mask is not None:
761
+ lv_no_loss = torch.zeros(
762
+ B, 1, dtype=loss_mask.dtype, device=device,
763
+ )
764
+ full_loss_mask = torch.cat([lv_no_loss, loss_mask], dim=1)
765
+ elif attention_mask is not None:
766
+ lv_no_loss = torch.zeros(
767
+ B, 1, dtype=attention_mask.dtype, device=device,
768
+ )
769
+ full_loss_mask = torch.cat([lv_no_loss, attention_mask], dim=1)
770
+ else:
771
+ full_loss_mask = None
772
+
773
+ loss = self._ce_loss(logits, full_labels, full_loss_mask)
774
+
775
+ self.llm.config.use_cache = orig_use_cache
776
+ return {"loss": loss, "logits": logits}
777
+
778
+ else:
779
+ # No text -- just return logits at the last visual position
780
+ h_last = out.last_hidden_state # [B, 1, ld]
781
+ logits = self.llm.lm_head(h_last)
782
+ self.llm.config.use_cache = orig_use_cache
783
+ return {"logits": logits}
784
+
785
+ # ------------------------------------------------------------------
786
+ # Convenience: unified forward dispatching by name
787
+ # ------------------------------------------------------------------
788
+
789
+ def forward(
790
+ self,
791
+ frames: torch.Tensor,
792
+ input_ids: torch.Tensor,
793
+ attention_mask: torch.Tensor,
794
+ loss_mask: Optional[torch.Tensor] = None,
795
+ frame_mask: Optional[torch.Tensor] = None,
796
+ mode: str = "coarse_fine",
797
+ ) -> Dict[str, torch.Tensor]:
798
+ """
799
+ Unified forward entry point.
800
+
801
+ mode : "coarse_fine" | "coarse_only" | "autoregressive"
802
+ frame_mask : [B, T] bool — True for real frames, False for padding.
803
+ """
804
+ if mode == "coarse_fine":
805
+ return self.forward_coarse_fine(frames, input_ids, attention_mask, loss_mask, frame_mask)
806
+ elif mode == "coarse_only":
807
+ return self.forward_coarse_only(frames, input_ids, attention_mask, loss_mask, frame_mask)
808
+ elif mode == "autoregressive":
809
+ return self.forward_autoregressive(frames, input_ids, attention_mask, loss_mask, frame_mask)
810
+ else:
811
+ raise ValueError(
812
+ f"Unknown forward mode '{mode}'. "
813
+ "Expected one of: coarse_fine, coarse_only, autoregressive"
814
+ )
815
+
816
+ # ------------------------------------------------------------------
817
+ # Utility methods for external callers (train.py, eval.py)
818
+ # ------------------------------------------------------------------
819
+
820
+ def enable_gradient_checkpointing(self) -> None:
821
+ """Turn on activation checkpointing for LLM and DINO."""
822
+ self.llm.gradient_checkpointing_enable()
823
+ if hasattr(self.encoder.dino, 'gradient_checkpointing_enable'):
824
+ self.encoder.dino.gradient_checkpointing_enable()
825
+
826
+ def get_param_groups(
827
+ self,
828
+ lr_backbone: float = 1e-5,
829
+ lr_connector: float = 1e-4,
830
+ ) -> list:
831
+ """
832
+ Return parameter groups with differential learning rates.
833
+
834
+ Groups:
835
+ 1. Connector (dino_to_llm, llm_to_query, q_static, q_init) -- highest LR
836
+ 2. DINO encoder -- backbone LR
837
+ 3. LLM -- backbone LR
838
+
839
+ This is a suggestion; train.py may override.
840
+ """
841
+ connector_params = set()
842
+ for name, param in self.named_parameters():
843
+ if any(k in name for k in [
844
+ "dino_to_llm", "llm_to_query", "q_static", "q_init",
845
+ "query_input_proj", "query_output_proj",
846
+ ]):
847
+ connector_params.add(id(param))
848
+
849
+ encoder_params = set()
850
+ for name, param in self.encoder.named_parameters():
851
+ if id(param) not in connector_params:
852
+ encoder_params.add(id(param))
853
+
854
+ groups = [
855
+ {
856
+ "params": [p for p in self.parameters()
857
+ if id(p) in connector_params and p.requires_grad],
858
+ "lr": lr_connector,
859
+ "name": "connector",
860
+ },
861
+ {
862
+ "params": [p for n, p in self.encoder.named_parameters()
863
+ if id(p) in encoder_params and p.requires_grad],
864
+ "lr": lr_backbone,
865
+ "name": "dino",
866
+ },
867
+ {
868
+ "params": [p for p in self.llm.parameters() if p.requires_grad],
869
+ "lr": lr_backbone,
870
+ "name": "llm",
871
+ },
872
+ ]
873
+ return [g for g in groups if len(g["params"]) > 0]