| | """ |
| | FoveatedEncoder -- DINOv2 vision encoder with query-guided cross-attention. |
| | |
| | Deep query mode only: the query token is projected into DINO dimension then |
| | propagated through every DINO layer using cached K,V from the patch tokens. |
| | Patches never attend to the query (asymmetric mask), so the patch forward pass |
| | runs once and all K,V are cached. The single query-position output after the |
| | final layer is the foveated visual token. |
| | |
| | Key design decisions (pre-fixed bugs baked in): |
| | * query_input_proj has bias=False (BUG-002: bias dominated small queries, |
| | causing uniform attention regardless of query content) |
| | * No shallow mode (BUG-004: single cross-attention on final |
| | DINO features gives output correlation ~0.98 -- effectively uniform) |
| | * CLS token is kept (DINO was trained with it) |
| | * Layer norm applied after all layers (matches DINO forward) |
| | |
| | torch.compile friendly: |
| | * Fixed loop count (num_layers is a Python int constant per model) |
| | * No Python-level branching in hot paths |
| | * Attention scale stored as a float constant (not recomputed) |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import math |
| | from typing import List, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import Dinov2Model |
| |
|
| |
|
| | |
| | |
| | |
| | DINO_CONFIGS = { |
| | "facebook/dinov2-small": {"dim": 384, "heads": 6, "layers": 12, "patch_size": 14}, |
| | "facebook/dinov2-base": {"dim": 768, "heads": 12, "layers": 12, "patch_size": 14}, |
| | } |
| |
|
| |
|
| | class FoveatedEncoder(nn.Module): |
| | """ |
| | Vision encoder with deep query-guided attention. |
| | |
| | Two-phase usage: |
| | 1. ``patches, kv_cache = encoder.encode_patches(images)`` |
| | Run DINO on all frames, cache K/V at every layer. |
| | 2. ``z = encoder.query_attend(query, kv_cache)`` |
| | Propagate query through all layers using cached K/V. |
| | Returns a single foveated visual token per image. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dino_model_name: str = "facebook/dinov2-small", |
| | query_dim: int = 384, |
| | output_dim: int | None = None, |
| | ) -> None: |
| | """ |
| | Args: |
| | dino_model_name: HuggingFace model id for DINOv2. |
| | query_dim: Dimension of incoming query vector (from LLM). |
| | output_dim: Dimension of the output foveated token. |
| | """ |
| | super().__init__() |
| |
|
| | |
| | self.dino: Dinov2Model = Dinov2Model.from_pretrained(dino_model_name) |
| |
|
| | |
| | cfg = self.dino.config |
| | self.dino_dim: int = cfg.hidden_size |
| | self.num_heads: int = cfg.num_attention_heads |
| | self.head_dim: int = self.dino_dim // self.num_heads |
| | self.num_layers: int = cfg.num_hidden_layers |
| | self.patch_size: int = cfg.patch_size |
| |
|
| | |
| | self.attn_scale: float = 1.0 / math.sqrt(self.head_dim) |
| |
|
| | |
| | if output_dim is None: |
| | output_dim = self.dino_dim |
| |
|
| | |
| | |
| | |
| | self.query_input_proj = nn.Linear(query_dim, self.dino_dim, bias=False) |
| | self.output_proj = nn.Linear(self.dino_dim, output_dim) |
| |
|
| | |
| | self.register_buffer("_device_probe", torch.zeros(1), persistent=False) |
| |
|
| | |
| |
|
| | @property |
| | def device(self) -> torch.device: |
| | return self._device_probe.device |
| |
|
| | def num_patches(self, image_size: int = 224) -> int: |
| | """Number of spatial patch tokens for a square image (excludes CLS).""" |
| | grid = image_size // self.patch_size |
| | return grid * grid |
| |
|
| | def num_tokens(self, image_size: int = 224) -> int: |
| | """Total sequence length from DINO (CLS + spatial patches).""" |
| | return 1 + self.num_patches(image_size) |
| |
|
| | |
| | |
| | |
| |
|
| | def encode_patches( |
| | self, images: torch.Tensor |
| | ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: |
| | """ |
| | Encode images through DINOv2, caching K and V at every layer. |
| | |
| | Args: |
| | images: ``[B*T, 3, H, W]`` input images (ImageNet-normalised). |
| | |
| | Returns: |
| | patch_features: ``[B*T, N+1, D]`` final embeddings (CLS + patches), |
| | after the last layer norm. |
| | kv_cache: List of ``(K, V)`` tuples, one per DINO layer. |
| | Each K, V has shape ``[B*T, N+1, D]`` (full dim, |
| | not yet reshaped to multi-head). |
| | """ |
| | |
| | images = images.to(memory_format=torch.channels_last) |
| | |
| | hidden: torch.Tensor = self.dino.embeddings(images) |
| |
|
| | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] |
| |
|
| | |
| | |
| | for layer in self.dino.encoder.layer: |
| | normed = layer.norm1(hidden) |
| |
|
| | |
| | attn_mod = layer.attention.attention |
| | K = attn_mod.key(normed) |
| | V = attn_mod.value(normed) |
| | kv_cache.append((K, V)) |
| |
|
| | |
| | |
| | layer_out = layer(hidden) |
| | hidden = layer_out[0] if isinstance(layer_out, tuple) else layer_out |
| |
|
| | |
| | patch_features = self.dino.layernorm(hidden) |
| |
|
| | return patch_features, kv_cache |
| |
|
| | |
| | |
| | |
| |
|
| | def query_attend( |
| | self, |
| | query: torch.Tensor, |
| | kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], |
| | return_attention: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | Propagate a query token through every DINO layer using cached K/V. |
| | |
| | The query can attend to all patch tokens, but patches never see the |
| | query (asymmetric attention -- enabled by using the cached K/V that |
| | were computed without the query present). |
| | |
| | Args: |
| | query: ``[B*T, query_dim]`` query vector from the LLM. |
| | kv_cache: Output of :meth:`encode_patches` (list of (K, V) per layer). |
| | |
| | Returns: |
| | z: ``[B*T, output_dim]`` -- the single foveated visual token. |
| | """ |
| | B = query.shape[0] |
| |
|
| | |
| | q_hidden = self.query_input_proj(query).unsqueeze(1) |
| |
|
| | all_attn_weights = [] if return_attention else None |
| |
|
| | |
| | for layer_idx, layer in enumerate(self.dino.encoder.layer): |
| | K, V = kv_cache[layer_idx] |
| |
|
| | attn_mod = layer.attention.attention |
| |
|
| | |
| | q_normed = layer.norm1(q_hidden) |
| |
|
| | |
| | Q = attn_mod.query(q_normed) |
| |
|
| | |
| | Q = Q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2) |
| | K_h = K.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| | V_h = V.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| |
|
| | |
| | |
| | if return_attention: |
| | |
| | attn_scores = torch.matmul(Q, K_h.transpose(-2, -1)) * self.attn_scale |
| | attn_weights = F.softmax(attn_scores, dim=-1) |
| | all_attn_weights.append(attn_weights.detach()) |
| | attn_out = torch.matmul(attn_weights, V_h) |
| | else: |
| | |
| | attn_out = F.scaled_dot_product_attention(Q, K_h, V_h) |
| |
|
| | |
| | attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, self.dino_dim) |
| |
|
| | |
| | attn_out = layer.attention.output.dense(attn_out) |
| | attn_out = layer.attention.output.dropout(attn_out) |
| |
|
| | |
| | attn_out = layer.layer_scale1(attn_out) |
| | q_hidden = q_hidden + attn_out |
| |
|
| | |
| | ffn_out = layer.mlp(layer.norm2(q_hidden)) |
| | ffn_out = layer.layer_scale2(ffn_out) |
| | q_hidden = q_hidden + ffn_out |
| |
|
| | |
| | q_hidden = self.dino.layernorm(q_hidden) |
| |
|
| | |
| | z = self.output_proj(q_hidden.squeeze(1)) |
| |
|
| | if return_attention: |
| | return z, all_attn_weights |
| | return z |
| |
|
| | |
| | |
| | |
| |
|
| | def shallow_query_attend( |
| | self, |
| | query: torch.Tensor, |
| | patch_features: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Single cross-attention on final DINO features (no layer propagation). |
| | |
| | This is the "shallow" baseline: the query does ONE attention over the |
| | already-computed final patch embeddings. Different queries produce |
| | near-identical outputs (BUG-004 validation) because there's no deep |
| | propagation to amplify query differences. |
| | |
| | Args: |
| | query: ``[B, query_dim]`` |
| | patch_features: ``[B, N+1, D]`` (output of encode_patches) |
| | |
| | Returns: |
| | z: ``[B, output_dim]`` |
| | """ |
| | B = query.shape[0] |
| |
|
| | |
| | q = self.query_input_proj(query).unsqueeze(1) |
| |
|
| | |
| | Q = q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2) |
| | K = patch_features.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| | V = K.clone() |
| |
|
| | |
| | last_layer = self.dino.encoder.layer[-1] |
| | attn_mod = last_layer.attention.attention |
| | normed = last_layer.norm1(patch_features) |
| | K = attn_mod.key(normed).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| | V = attn_mod.value(normed).view(B, -1, self.num_heads, self.head_dim).transpose(1, 2) |
| |
|
| | attn_out = F.scaled_dot_product_attention(Q, K, V) |
| |
|
| | |
| | attn_out = attn_out.transpose(1, 2).contiguous().view(B, 1, self.dino_dim) |
| |
|
| | |
| | q_hidden = self.dino.layernorm(attn_out) |
| | z = self.output_proj(q_hidden.squeeze(1)) |
| | return z |
| |
|
| | |
| | |
| | |
| |
|
| | def forward( |
| | self, |
| | images: torch.Tensor, |
| | query: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Full forward: encode patches then attend with query. |
| | |
| | Args: |
| | images: ``[B, 3, H, W]`` |
| | query: ``[B, query_dim]`` |
| | |
| | Returns: |
| | z: ``[B, output_dim]`` foveated visual token. |
| | """ |
| | _, kv_cache = self.encode_patches(images) |
| | return self.query_attend(query, kv_cache) |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | print("=" * 60) |
| | print("Testing FoveatedEncoder (deep query mode)") |
| | print("=" * 60) |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"\nDevice: {device}") |
| |
|
| | encoder = FoveatedEncoder( |
| | dino_model_name="facebook/dinov2-small", |
| | query_dim=384, |
| | output_dim=384, |
| | ).to(device) |
| |
|
| | print(f" dino_dim = {encoder.dino_dim}") |
| | print(f" num_heads = {encoder.num_heads}") |
| | print(f" head_dim = {encoder.head_dim}") |
| | print(f" num_layers = {encoder.num_layers}") |
| | print(f" patch_size = {encoder.patch_size}") |
| |
|
| | batch_size = 2 |
| | images = torch.randn(batch_size, 3, 224, 224, device=device) |
| | query_a = torch.randn(batch_size, 384, device=device) |
| | query_b = torch.randn(batch_size, 384, device=device) |
| |
|
| | print(f"\n num_patches(224) = {encoder.num_patches(224)}") |
| | print(f" num_tokens(224) = {encoder.num_tokens(224)}") |
| |
|
| | |
| | print("\n--- encode_patches ---") |
| | patch_features, kv_cache = encoder.encode_patches(images) |
| | print(f" patch_features: {patch_features.shape}") |
| | print(f" kv_cache: {len(kv_cache)} layers, K shape = {kv_cache[0][0].shape}") |
| |
|
| | |
| | print("\n--- query_attend ---") |
| | z_a = encoder.query_attend(query_a, kv_cache) |
| | z_b = encoder.query_attend(query_b, kv_cache) |
| | print(f" z_a: {z_a.shape}") |
| | print(f" z_b: {z_b.shape}") |
| |
|
| | |
| | cosine = F.cosine_similarity(z_a, z_b, dim=-1).mean().item() |
| | l2_diff = (z_a - z_b).norm(dim=-1).mean().item() |
| | print(f" cosine(z_a, z_b) = {cosine:.4f} (should be << 1.0)") |
| | print(f" L2 diff = {l2_diff:.4f} (should be >> 0)") |
| |
|
| | |
| | print("\n--- backward ---") |
| | z_a.sum().backward() |
| | print(" backward: OK") |
| |
|
| | |
| | print("\n--- forward (combined) ---") |
| | encoder.zero_grad() |
| | z = encoder(images, query_a) |
| | z.sum().backward() |
| | print(f" z: {z.shape}") |
| | print(" backward: OK") |
| |
|
| | print("\n" + "=" * 60) |
| | print("All tests passed.") |
| | print("=" * 60) |
| |
|