# Copyright 2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F def extract_patch_tokens_min_windows( images: torch.Tensor, model: nn.Module, processor, window_size: int = 224, device: str | torch.device = "cuda", ) -> torch.Tensor: r""" Tile each image with a minimal window set and return averaged DINO patch tokens. Args: images (`torch.Tensor`): Batch of RGB images `(B, C, H, W)`. model: DINO vision transformer. processor: Hugging Face image processor for DINO. window_size (`int`): Sliding-window size in pixels. device: Device for intermediate tensors. Returns: `torch.Tensor` of shape `(B, H//patch, W//patch, hidden_size)`. """ batch_size, _, height, width = images.shape hidden_size = model.config.hidden_size patch_size = model.config.patch_size token_avgs = [] for batch_idx in range(batch_size): image = images[batch_idx] if image.max() <= 1.0: image_np = (image.permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype("uint8") else: image_np = image.permute(1, 2, 0).cpu().numpy().clip(0, 255).astype("uint8") token_sum = torch.zeros((height // patch_size, width // patch_size, hidden_size), device=device) token_count = torch.zeros((height // patch_size, width // patch_size, 1), device=device) num_y = (height + window_size - 1) // window_size num_x = (width + window_size - 1) // window_size y_positions = [index * window_size for index in range(num_y - 1)] + [height - window_size] x_positions = [index * window_size for index in range(num_x - 1)] + [width - window_size] for y in y_positions: for x in x_positions: patch = image_np[y : y + window_size, x : x + window_size, :] inputs = processor(images=patch, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) patch_tokens = outputs.last_hidden_state[:, 1:, :] patch_tokens = patch_tokens.reshape( 1, window_size // patch_size, window_size // patch_size, hidden_size ).squeeze(0) y0, x0 = y // patch_size, x // patch_size y1, x1 = y0 + window_size // patch_size, x0 + window_size // patch_size token_sum[y0:y1, x0:x1, :] += patch_tokens token_count[y0:y1, x0:x1, 0] += 1 token_avgs.append(token_sum / token_count) return torch.stack(token_avgs, dim=0) class LayerNorm2d(nn.Module): def __init__(self, channels: int) -> None: super().__init__() self.norm = nn.LayerNorm([channels]) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.permute(0, 2, 3, 1) x = self.norm(x) return x.permute(0, 3, 1, 2) class IMAA(nn.Module): r""" Intrinsic Map-Aware Attention (IMAA) gating module. Produces per-map attention biases from DINO patch tokens and learnable map embeddings. """ def __init__( self, dino_model: Optional[nn.Module] = None, processor=None, num_maps: int = 5, map_embedding_dim: int = 256, common_dim: int = 128, conv_channels: Optional[list[int]] = None, dino_patch_dim: int = 768, ) -> None: super().__init__() conv_channels = conv_channels or [128, 64] self.dino = dino_model self.processor = processor if self.dino is not None: self.dino.eval() for param in self.dino.parameters(): param.requires_grad = False self.num_maps = num_maps self.map_embedding_dim = map_embedding_dim self.common_dim = common_dim self.dino_patch_dim = dino_patch_dim self.map_embedding = nn.Parameter(torch.randn(num_maps, map_embedding_dim)) self.dino_proj = nn.Conv2d(dino_patch_dim, common_dim, kernel_size=1) self.map_proj = nn.Linear(map_embedding_dim, common_dim) self.fusion_layer = nn.Sequential( nn.Conv2d(common_dim * 2, common_dim, 1), LayerNorm2d(common_dim), nn.ReLU(), nn.Conv2d(common_dim, common_dim, 3, padding=1), ) conv_layers: list[nn.Module] = [] in_channels = common_dim for out_channels in conv_channels: conv_layers.extend([nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU()]) in_channels = out_channels conv_layers.append(nn.Conv2d(in_channels, 1, kernel_size=1)) self.conv_head = nn.Sequential(*conv_layers) def forward( self, image: Optional[torch.Tensor] = None, patch_tokens: Optional[torch.Tensor] = None, output_size: Optional[Tuple[int, int]] = None, map_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: if patch_tokens is None: if self.dino is None or image is None: raise ValueError("Either `patch_tokens` or (`image` and a frozen DINO model) must be provided.") patch_tokens = extract_patch_tokens_min_windows( image, self.dino, self.processor, window_size=224, device=image.device ) dino_feat_map = patch_tokens.permute(0, 3, 1, 2) dino_proj = self.dino_proj(dino_feat_map) map_emb = self.map_embedding[map_ids] map_proj = self.map_proj(map_emb).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, dino_proj.size(2), dino_proj.size(3)) fused_map = self.fusion_layer(torch.cat([dino_proj, map_proj], dim=1)) raw_gating_map = self.conv_head(fused_map) aligned_map = ( F.interpolate(raw_gating_map, size=output_size, mode="bilinear", align_corners=False) if output_size is not None else raw_gating_map ) return torch.sigmoid(aligned_map) def build_attn_mask( w_gating: torch.Tensor, text_len: int, img_len: int, lam: float, ) -> torch.Tensor: r""" Build an additive attention mask from IMAA gating weights. Args: w_gating (`torch.Tensor`): Gating map `[B, 1, H, W]` or flattened `[B, img_len]`. text_len (`int`): Number of text tokens prepended to image tokens. img_len (`int`): Expected number of image tokens. lam (`float`): Mask scaling factor. Returns: Attention bias tensor shaped for SD3 joint attention. """ batch_size = w_gating.shape[0] total_len = text_len + img_len if w_gating.dim() == 4: w_gating = w_gating.view(batch_size, -1) gating = lam * w_gating actual_img_len = gating.shape[1] if actual_img_len != img_len: if actual_img_len > img_len: gating = gating[:, :img_len] else: padding = torch.zeros(batch_size, img_len - actual_img_len, device=gating.device, dtype=gating.dtype) gating = torch.cat([gating, padding], dim=1) col_bias = torch.zeros(batch_size, total_len, device=w_gating.device, dtype=w_gating.dtype) col_bias[:, text_len:] = gating return col_bias.view(batch_size, 1, 1, total_len)