Diffusers
Safetensors
BiliSakura's picture
Upload folder using huggingface_hub
acccad2 verified
Raw
History Blame Contribute Delete
7.91 kB
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
def extract_patch_tokens_min_windows(
images: torch.Tensor,
model: nn.Module,
processor,
window_size: int = 224,
device: str | torch.device = "cuda",
) -> torch.Tensor:
r"""
Tile each image with a minimal window set and return averaged DINO patch tokens.
Args:
images (`torch.Tensor`): Batch of RGB images `(B, C, H, W)`.
model: DINO vision transformer.
processor: Hugging Face image processor for DINO.
window_size (`int`): Sliding-window size in pixels.
device: Device for intermediate tensors.
Returns:
`torch.Tensor` of shape `(B, H//patch, W//patch, hidden_size)`.
"""
batch_size, _, height, width = images.shape
hidden_size = model.config.hidden_size
patch_size = model.config.patch_size
token_avgs = []
for batch_idx in range(batch_size):
image = images[batch_idx]
if image.max() <= 1.0:
image_np = (image.permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype("uint8")
else:
image_np = image.permute(1, 2, 0).cpu().numpy().clip(0, 255).astype("uint8")
token_sum = torch.zeros((height // patch_size, width // patch_size, hidden_size), device=device)
token_count = torch.zeros((height // patch_size, width // patch_size, 1), device=device)
num_y = (height + window_size - 1) // window_size
num_x = (width + window_size - 1) // window_size
y_positions = [index * window_size for index in range(num_y - 1)] + [height - window_size]
x_positions = [index * window_size for index in range(num_x - 1)] + [width - window_size]
for y in y_positions:
for x in x_positions:
patch = image_np[y : y + window_size, x : x + window_size, :]
inputs = processor(images=patch, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
patch_tokens = outputs.last_hidden_state[:, 1:, :]
patch_tokens = patch_tokens.reshape(
1, window_size // patch_size, window_size // patch_size, hidden_size
).squeeze(0)
y0, x0 = y // patch_size, x // patch_size
y1, x1 = y0 + window_size // patch_size, x0 + window_size // patch_size
token_sum[y0:y1, x0:x1, :] += patch_tokens
token_count[y0:y1, x0:x1, 0] += 1
token_avgs.append(token_sum / token_count)
return torch.stack(token_avgs, dim=0)
class LayerNorm2d(nn.Module):
def __init__(self, channels: int) -> None:
super().__init__()
self.norm = nn.LayerNorm([channels])
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
return x.permute(0, 3, 1, 2)
class IMAA(nn.Module):
r"""
Intrinsic Map-Aware Attention (IMAA) gating module.
Produces per-map attention biases from DINO patch tokens and learnable map embeddings.
"""
def __init__(
self,
dino_model: Optional[nn.Module] = None,
processor=None,
num_maps: int = 5,
map_embedding_dim: int = 256,
common_dim: int = 128,
conv_channels: Optional[list[int]] = None,
dino_patch_dim: int = 768,
) -> None:
super().__init__()
conv_channels = conv_channels or [128, 64]
self.dino = dino_model
self.processor = processor
if self.dino is not None:
self.dino.eval()
for param in self.dino.parameters():
param.requires_grad = False
self.num_maps = num_maps
self.map_embedding_dim = map_embedding_dim
self.common_dim = common_dim
self.dino_patch_dim = dino_patch_dim
self.map_embedding = nn.Parameter(torch.randn(num_maps, map_embedding_dim))
self.dino_proj = nn.Conv2d(dino_patch_dim, common_dim, kernel_size=1)
self.map_proj = nn.Linear(map_embedding_dim, common_dim)
self.fusion_layer = nn.Sequential(
nn.Conv2d(common_dim * 2, common_dim, 1),
LayerNorm2d(common_dim),
nn.ReLU(),
nn.Conv2d(common_dim, common_dim, 3, padding=1),
)
conv_layers: list[nn.Module] = []
in_channels = common_dim
for out_channels in conv_channels:
conv_layers.extend([nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU()])
in_channels = out_channels
conv_layers.append(nn.Conv2d(in_channels, 1, kernel_size=1))
self.conv_head = nn.Sequential(*conv_layers)
def forward(
self,
image: Optional[torch.Tensor] = None,
patch_tokens: Optional[torch.Tensor] = None,
output_size: Optional[Tuple[int, int]] = None,
map_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if patch_tokens is None:
if self.dino is None or image is None:
raise ValueError("Either `patch_tokens` or (`image` and a frozen DINO model) must be provided.")
patch_tokens = extract_patch_tokens_min_windows(
image, self.dino, self.processor, window_size=224, device=image.device
)
dino_feat_map = patch_tokens.permute(0, 3, 1, 2)
dino_proj = self.dino_proj(dino_feat_map)
map_emb = self.map_embedding[map_ids]
map_proj = self.map_proj(map_emb).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, dino_proj.size(2), dino_proj.size(3))
fused_map = self.fusion_layer(torch.cat([dino_proj, map_proj], dim=1))
raw_gating_map = self.conv_head(fused_map)
aligned_map = (
F.interpolate(raw_gating_map, size=output_size, mode="bilinear", align_corners=False)
if output_size is not None
else raw_gating_map
)
return torch.sigmoid(aligned_map)
def build_attn_mask(
w_gating: torch.Tensor,
text_len: int,
img_len: int,
lam: float,
) -> torch.Tensor:
r"""
Build an additive attention mask from IMAA gating weights.
Args:
w_gating (`torch.Tensor`): Gating map `[B, 1, H, W]` or flattened `[B, img_len]`.
text_len (`int`): Number of text tokens prepended to image tokens.
img_len (`int`): Expected number of image tokens.
lam (`float`): Mask scaling factor.
Returns:
Attention bias tensor shaped for SD3 joint attention.
"""
batch_size = w_gating.shape[0]
total_len = text_len + img_len
if w_gating.dim() == 4:
w_gating = w_gating.view(batch_size, -1)
gating = lam * w_gating
actual_img_len = gating.shape[1]
if actual_img_len != img_len:
if actual_img_len > img_len:
gating = gating[:, :img_len]
else:
padding = torch.zeros(batch_size, img_len - actual_img_len, device=gating.device, dtype=gating.dtype)
gating = torch.cat([gating, padding], dim=1)
col_bias = torch.zeros(batch_size, total_len, device=w_gating.device, dtype=w_gating.dtype)
col_bias[:, text_len:] = gating
return col_bias.view(batch_size, 1, 1, total_len)