Falcon-Perception / anyup.py
lkhphuc's picture
Add exported weights
def4178
"""
AnyUp – flattened into a single module for HuggingFace trust_remote_code compatibility.
Original package structure:
anyup/layers/convolutions.py → ResBlock
anyup/layers/feature_unification.py → LearnedFeatureUnification
anyup/layers/positional_encoding.py → RoPE (AnyUp-internal)
anyup/layers/attention/attention_masking.py → window2d, compute_attention_mask, get_attention_mask_mod
anyup/layers/attention/chunked_attention.py → FlexCrossAttention, CrossAttentionBlock
anyup/model.py → AnyUp
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops as E
from typing import Tuple
from functools import lru_cache
from torch.nn.attention.flex_attention import flex_attention
from torch.distributed.tensor import DTensor, distribute_tensor
compiled_flex_attn_prefill = torch.compile(flex_attention, dynamic=True)
# ---------------------------------------------------------------------------
# ResBlock (from layers/convolutions.py)
# ---------------------------------------------------------------------------
class ResBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=3,
num_groups=8,
pad_mode="zeros",
norm_fn=None,
activation_fn=nn.SiLU,
use_conv_shortcut=False,
):
super().__init__()
N = (lambda c: norm_fn(num_groups, c)) if norm_fn else (lambda c: nn.Identity())
p = kernel_size // 2
self.block = nn.Sequential(
N(in_channels),
activation_fn(),
nn.Conv2d(
in_channels,
out_channels,
kernel_size,
padding=p,
padding_mode=pad_mode,
bias=False,
),
N(out_channels),
activation_fn(),
nn.Conv2d(
out_channels,
out_channels,
kernel_size,
padding=p,
padding_mode=pad_mode,
bias=False,
),
)
self.shortcut = (
nn.Conv2d(in_channels, out_channels, 1, bias=False, padding_mode=pad_mode)
if use_conv_shortcut or in_channels != out_channels
else nn.Identity()
)
def forward(self, x):
return self.block(x) + self.shortcut(x)
# ---------------------------------------------------------------------------
# LearnedFeatureUnification (from layers/feature_unification.py)
# ---------------------------------------------------------------------------
class LearnedFeatureUnification(nn.Module):
def __init__(
self,
out_channels: int,
kernel_size: int = 3,
init_gaussian_derivatives: bool = False,
):
super().__init__()
self.out_channels = out_channels
self.kernel_size = kernel_size
self.basis = nn.Parameter(
torch.randn(out_channels, 1, kernel_size, kernel_size)
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
b, c, h, w = features.shape
x = self._depthwise_conv(features, self.basis, self.kernel_size).view(
b, self.out_channels, c, h, w
)
attn = F.softmax(x, dim=1)
return attn.mean(dim=2)
@staticmethod
def _depthwise_conv(feats, basis, k):
b, c, h, w = feats.shape
p = k // 2
x = F.pad(feats, (p, p, p, p), value=0)
x = F.conv2d(x, basis.repeat(c, 1, 1, 1), groups=c)
mask = torch.ones(1, 1, h, w, dtype=x.dtype, device=x.device)
denom = F.conv2d(
F.pad(mask, (p, p, p, p), value=0),
torch.ones(1, 1, k, k, device=x.device, dtype=x.dtype),
)
return x / denom
# ---------------------------------------------------------------------------
# RoPE (from layers/positional_encoding.py) – AnyUp-internal, separate from
# the main model's 3D RoPE
# ---------------------------------------------------------------------------
def _rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
class AnyUpRoPE(nn.Module):
def __init__(
self,
dim: int,
theta: int = 100,
):
super().__init__()
self.dim = dim
self.theta = theta
self.freqs = nn.Parameter(torch.empty(2, self.dim))
def _device_weight_init(self):
if isinstance(self.freqs, DTensor):
target_device = self.freqs.to_local().device
target_dtype = self.freqs.to_local().dtype
else:
target_device = self.freqs.device
target_dtype = self.freqs.dtype
freqs_1d = self.theta ** torch.linspace(
0, -1, self.dim // 4, device=target_device, dtype=target_dtype
)
freqs_1d = torch.cat([freqs_1d, freqs_1d])
freqs_2d = torch.zeros(2, self.dim, device=target_device, dtype=target_dtype)
freqs_2d[0, : self.dim // 2] = freqs_1d
freqs_2d[1, -self.dim // 2 :] = freqs_1d
freqs_2d.mul_(2 * torch.pi)
with torch.no_grad():
if isinstance(self.freqs, DTensor):
dist_freqs = distribute_tensor(
freqs_2d, self.freqs.device_mesh, placements=self.freqs.placements
)
self.freqs.to_local().copy_(dist_freqs.to_local())
else:
self.freqs.copy_(freqs_2d)
def forward(self, x: torch.Tensor, coords: torch.Tensor) -> torch.Tensor:
angle = coords @ self.freqs
return x * angle.cos() + _rotate_half(x) * angle.sin()
# ---------------------------------------------------------------------------
# Attention masking (from layers/attention/attention_masking.py)
# ---------------------------------------------------------------------------
def window2d(
low_res: int | Tuple[int, int],
high_res: int | Tuple[int, int],
ratio: float,
*,
device: str = "cpu",
) -> torch.Tensor:
"""Calculate the lower and upper bounds of row and col for each pixel/position"""
if isinstance(high_res, int):
H = W = high_res
else:
H, W = high_res
if isinstance(low_res, int):
Lh = Lw = low_res
else:
Lh, Lw = low_res
r_pos = (torch.arange(H, device=device, dtype=torch.float32) + 0.5) / H
c_pos = (torch.arange(W, device=device, dtype=torch.float32) + 0.5) / W
pos_r, pos_c = torch.meshgrid(r_pos, c_pos, indexing="ij")
r_lo = (pos_r - ratio).clamp(0.0, 1.0)
r_hi = (pos_r + ratio).clamp(0.0, 1.0)
c_lo = (pos_c - ratio).clamp(0.0, 1.0)
c_hi = (pos_c + ratio).clamp(0.0, 1.0)
r0 = (r_lo * Lh).floor().long()
r1 = (r_hi * Lh).ceil().long()
c0 = (c_lo * Lw).floor().long()
c1 = (c_hi * Lw).ceil().long()
return torch.stack([r0, r1, c0, c1], dim=2)
@lru_cache
def compute_attention_mask(
high_res_h, high_res_w, low_res_h, low_res_w, window_size_ratio, device="cpu"
):
h, w = high_res_h, high_res_w
h_, w_ = low_res_h, low_res_w
windows = window2d(
low_res=(h_, w_), high_res=(h, w), ratio=window_size_ratio, device=device
)
q = h * w
r0 = windows[..., 0].reshape(q, 1)
r1 = windows[..., 1].reshape(q, 1)
c0 = windows[..., 2].reshape(q, 1)
c1 = windows[..., 3].reshape(q, 1)
rows = torch.arange(h_, device=device)
cols = torch.arange(w_, device=device)
row_ok = (rows >= r0) & (rows < r1)
col_ok = (cols >= c0) & (cols < c1)
attention_mask = (
(row_ok.unsqueeze(2) & col_ok.unsqueeze(1))
.reshape(q, h_ * w_)
.to(dtype=torch.bool)
)
return ~attention_mask
def get_attention_mask_mod(
high_res_h, high_res_w, low_res_h, low_res_w, window_size_ratio=0.1, device="cpu"
):
"""Window Attention as above but for FlexAttention."""
h, w = high_res_h, high_res_w
h_, w_ = low_res_h, low_res_w
windows = window2d(
low_res=(h_, w_),
high_res=(h, w),
ratio=window_size_ratio,
device=device,
)
r0 = windows[..., 0]
r1 = windows[..., 1]
c0 = windows[..., 2]
c1 = windows[..., 3]
def _mask_mod(b_idx, h_idx, q_idx, kv_idx):
q_r_idx = q_idx // w
q_c_idx = q_idx % w
kv_r_idx = kv_idx // w_
kv_c_idx = kv_idx % w_
row_lower = kv_r_idx >= r0[q_r_idx, q_c_idx]
row_upper = kv_r_idx < r1[q_r_idx, q_c_idx]
col_lower = kv_c_idx >= c0[q_r_idx, q_c_idx]
col_upper = kv_c_idx < c1[q_r_idx, q_c_idx]
return row_lower & row_upper & col_lower & col_upper
return _mask_mod
# ---------------------------------------------------------------------------
# Cross-attention (from layers/attention/chunked_attention.py)
# ---------------------------------------------------------------------------
class AttentionWrapper(nn.Module):
def __init__(self, qk_dim: int):
super().__init__()
self.in_proj_weight = nn.Parameter(torch.empty([qk_dim * 3, qk_dim]))
self.in_proj_bias = nn.Parameter(torch.empty([qk_dim * 3]))
def forward(self, x_q, x_k, x_v):
w_q, w_k, w_v = self.in_proj_weight.chunk(3, dim=0)
b_q, b_k, b_v = self.in_proj_bias.chunk(3)
x_q = x_q @ w_q.T + b_q
x_k = x_k @ w_k.T + b_k
return x_q, x_k, x_v
class FlexCrossAttention(nn.Module):
def __init__(self, qk_dim: int, num_heads: int, **kwargs):
super().__init__()
self.dim = qk_dim
self.num_head = num_heads
self.norm_q = nn.RMSNorm(qk_dim)
self.norm_k = nn.RMSNorm(qk_dim)
self.attention = AttentionWrapper(qk_dim)
def forward(self, query, key, value, mask=None, **kwargs):
x_q = self.norm_q(query)
x_k = self.norm_k(key)
x_q, x_k, x_v = self.attention(x_q, x_k, value)
x_q = E.rearrange(x_q, "b HW (h d) -> b h HW d", h=self.num_head)
x_k = E.rearrange(x_k, "b hw (h d) -> b h hw d", h=self.num_head)
x_v = E.rearrange(value, "b hw (h d) -> b h hw d", h=self.num_head)
output = compiled_flex_attn_prefill(x_q, x_k, x_v, block_mask=mask)
output = E.rearrange(output, "b h hw d -> b hw (h d)")
return output
class CrossAttentionBlock(nn.Module):
def __init__(
self,
qk_dim,
num_heads,
window_ratio: float = 0.1,
**kwargs,
):
super().__init__()
self.cross_attn = FlexCrossAttention(qk_dim, num_heads)
self.window_ratio = window_ratio
self.conv2d = nn.Conv2d(
qk_dim, qk_dim, kernel_size=3, stride=1, padding=1, bias=False
)
def forward(self, q, k, v, block_mask, **kwargs):
b, _, h, w = q.shape
q = self.conv2d(q)
q = E.rearrange(q, "b c h w -> b (h w) c")
k = E.rearrange(k, "b c h w -> b (h w) c")
v = E.rearrange(v, "b c h w -> b (h w) c")
features = self.cross_attn(q, k, v, mask=block_mask)
return E.rearrange(features, "b (h w) c -> b c h w", h=h, w=w)
# ---------------------------------------------------------------------------
# AnyUp (from model.py)
# ---------------------------------------------------------------------------
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1)
def create_coordinate(h, w, start=0.0, end=1.0, device=None, dtype=None):
x = torch.linspace(start, end, h, device=device, dtype=dtype)
y = torch.linspace(start, end, w, device=device, dtype=dtype)
xx, yy = torch.meshgrid(x, y, indexing="ij")
return torch.stack((xx, yy), -1).view(1, h * w, 2)
class AnyUp(nn.Module):
def __init__(
self,
input_dim=3,
qk_dim=128,
kernel_size=1,
kernel_size_lfu=5,
window_ratio=0.1,
num_heads=4,
init_gaussian_derivatives=False,
**kwargs,
):
super().__init__()
self.qk_dim = qk_dim
self.window_ratio = window_ratio
self._rb_args = dict(
kernel_size=1,
num_groups=8,
pad_mode="reflect",
norm_fn=nn.GroupNorm,
activation_fn=nn.SiLU,
)
self.image_encoder = self._make_encoder(input_dim, kernel_size)
self.key_encoder = self._make_encoder(qk_dim, 1)
self.query_encoder = self._make_encoder(qk_dim, 1)
self.key_features_encoder = self._make_encoder(
None,
1,
first_layer_k=kernel_size_lfu,
init_gaussian_derivatives=init_gaussian_derivatives,
)
self.cross_decode = CrossAttentionBlock(
qk_dim=qk_dim, num_heads=num_heads, window_ratio=window_ratio
)
self.aggregation = self._make_encoder(2 * qk_dim, 3)
self.rope = AnyUpRoPE(qk_dim)
self.rope._device_weight_init()
self._compiled_encoders = False
def compile(self, *, mode: str | None = None, dynamic: bool = True):
if self._compiled_encoders:
return self
self.image_encoder = torch.compile(self.image_encoder, dynamic=dynamic, mode=mode)
self.key_encoder = torch.compile(self.key_encoder, dynamic=dynamic, mode=mode)
self.query_encoder = torch.compile(self.query_encoder, dynamic=dynamic, mode=mode)
self.key_features_encoder = torch.compile(
self.key_features_encoder, dynamic=dynamic, mode=mode
)
self.aggregation = torch.compile(self.aggregation, dynamic=dynamic, mode=mode)
self._compiled_encoders = True
return self
def _make_encoder(
self, in_ch, k, layers=2, first_layer_k=0, init_gaussian_derivatives=False
):
pre = (
nn.Conv2d(
in_ch,
self.qk_dim,
k,
padding=k // 2,
padding_mode="reflect",
bias=False,
)
if first_layer_k == 0
else LearnedFeatureUnification(
self.qk_dim,
first_layer_k,
init_gaussian_derivatives=init_gaussian_derivatives,
)
)
blocks = [
ResBlock(self.qk_dim, self.qk_dim, **self._rb_args) for _ in range(layers)
]
return nn.Sequential(pre, *blocks)
def upsample(
self, enc_img, feats, attn_mask, out_size, vis_attn=False, q_chunk_size=None
):
b, c, h, w = feats.shape
q = F.adaptive_avg_pool2d(self.query_encoder(enc_img), output_size=out_size)
k = F.adaptive_avg_pool2d(self.key_encoder(enc_img), output_size=(h, w))
k = torch.cat([k, self.key_features_encoder(F.normalize(feats, dim=1))], dim=1)
k = self.aggregation(k)
v = feats
result = self.cross_decode(
q, k, v, attn_mask, vis_attn=vis_attn, q_chunk_size=q_chunk_size
)
return result
def forward(
self,
images,
features,
attn_mask,
output_size=None,
vis_attn=False,
q_chunk_size=None,
):
output_size = output_size if output_size is not None else images.shape[-2:]
images = images * 0.5 + 0.5
images = (images - IMAGENET_MEAN.to(images)) / IMAGENET_STD.to(images)
images = images.to(features)
enc = self.image_encoder(images)
h = enc.shape[-2]
coords = create_coordinate(h, enc.shape[-1], device=enc.device, dtype=enc.dtype)
enc = enc.permute(0, 2, 3, 1).view(enc.shape[0], -1, enc.shape[1])
enc = self.rope(enc, coords)
enc = enc.view(enc.shape[0], h, -1, enc.shape[-1]).permute(0, 3, 1, 2)
result = self.upsample(
enc,
features,
attn_mask,
output_size,
vis_attn=vis_attn,
q_chunk_size=q_chunk_size,
)
return result