argus / argus.py
phanerozoic's picture
Argus v1.1: add trained FCOS detection head as the fifth task
911ed47 verified
"""
Argus: multi-task perception on a single EUPE-ViT-B backbone.
from transformers import AutoModel
model = AutoModel.from_pretrained("phanerozoic/argus", trust_remote_code=True)
result = model.perceive(image)
The EUPE-ViT-B backbone architecture, all supporting layers, and the Argus
task heads are inlined below. The backbone code is reproduced from
facebookresearch/EUPE (Meta FAIR) under the FAIR Research License.
"""
import math
import time
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.init
from PIL import Image
from torch import Tensor, nn
from torchvision.ops import nms
from torchvision.transforms import v2
from transformers import PretrainedConfig, PreTrainedModel
# ===========================================================================
# EUPE backbone — vendored verbatim from facebookresearch/EUPE
# ===========================================================================
# ---------- utility helpers (from eupe/utils/utils.py) ---------------------
def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]:
shapes = [x.shape for x in x_list]
num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
flattened = torch.cat([x.flatten(0, -2) for x in x_list])
return flattened, shapes, num_tokens
def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]:
outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0)
shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes]
outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)]
return outputs_reshaped
def named_apply(
fn: Callable,
module: nn.Module,
name: str = "",
depth_first: bool = True,
include_root: bool = False,
) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(
fn=fn,
module=child_module,
name=child_name,
depth_first=depth_first,
include_root=True,
)
if depth_first and include_root:
fn(module=module, name=name)
return module
# ---------- RMSNorm (from eupe/layers/rms_norm.py) -------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def reset_parameters(self) -> None:
nn.init.constant_(self.weight, 1)
def _norm(self, x: Tensor) -> Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
# ---------- LayerScale (from eupe/layers/layer_scale.py) -------------------
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
device=None,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(torch.empty(dim, device=device))
self.init_values = init_values
def reset_parameters(self):
nn.init.constant_(self.gamma, self.init_values)
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
# ---------- PatchEmbed (from eupe/layers/patch_embed.py) -------------------
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
x = self.proj(x)
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim)
return x
def reset_parameters(self):
k = 1 / (self.in_chans * (self.patch_size[0] ** 2))
nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k))
if self.proj.bias is not None:
nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k))
# ---------- RoPE (from eupe/layers/rope_position_encoding.py) --------------
class RopePositionEmbedding(nn.Module):
def __init__(
self,
embed_dim: int,
*,
num_heads: int,
base: Optional[float] = 100.0,
min_period: Optional[float] = None,
max_period: Optional[float] = None,
normalize_coords: Literal["min", "max", "separate"] = "separate",
shift_coords: Optional[float] = None,
jitter_coords: Optional[float] = None,
rescale_coords: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
assert embed_dim % (4 * num_heads) == 0
both_periods = min_period is not None and max_period is not None
if (base is None and not both_periods) or (base is not None and both_periods):
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
D_head = embed_dim // num_heads
self.base = base
self.min_period = min_period
self.max_period = max_period
self.D_head = D_head
self.normalize_coords = normalize_coords
self.shift_coords = shift_coords
self.jitter_coords = jitter_coords
self.rescale_coords = rescale_coords
self.dtype = dtype
self.register_buffer(
"periods",
torch.empty(D_head // 4, device=device, dtype=dtype),
persistent=True,
)
self._init_weights()
def forward(self, *, H: int, W: int) -> Tuple[Tensor, Tensor]:
device = self.periods.device
dtype = self.dtype
dd = {"device": device, "dtype": dtype}
if self.normalize_coords == "max":
max_HW = max(H, W)
coords_h = torch.arange(0.5, H, **dd) / max_HW
coords_w = torch.arange(0.5, W, **dd) / max_HW
elif self.normalize_coords == "min":
min_HW = min(H, W)
coords_h = torch.arange(0.5, H, **dd) / min_HW
coords_w = torch.arange(0.5, W, **dd) / min_HW
elif self.normalize_coords == "separate":
coords_h = torch.arange(0.5, H, **dd) / H
coords_w = torch.arange(0.5, W, **dd) / W
else:
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
coords = coords.flatten(0, 1)
coords = 2.0 * coords - 1.0
if self.training and self.shift_coords is not None:
shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
coords += shift_hw[None, :]
if self.training and self.jitter_coords is not None:
jitter_max = np.log(self.jitter_coords)
jitter_min = -jitter_max
jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
coords *= jitter_hw[None, :]
if self.training and self.rescale_coords is not None:
rescale_max = np.log(self.rescale_coords)
rescale_min = -rescale_max
rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
coords *= rescale_hw
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
angles = angles.flatten(1, 2)
angles = angles.tile(2)
cos = torch.cos(angles)
sin = torch.sin(angles)
return (sin, cos)
def _init_weights(self):
device = self.periods.device
dtype = self.dtype
if self.base is not None:
periods = self.base ** (
2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
)
else:
base = self.max_period / self.min_period
exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype)
periods = base ** exponents
periods = periods / base
periods = periods * self.max_period
self.periods.data = periods
# ---------- FFN layers (from eupe/layers/ffn_layers.py) --------------------
class ListForwardMixin(object):
def forward(self, x: Tensor):
raise NotImplementedError
def forward_list(self, x_list: List[Tensor]) -> List[Tensor]:
x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
x_flat = self.forward(x_flat)
return uncat_with_shapes(x_flat, shapes, num_tokens)
class Mlp(nn.Module, ListForwardMixin):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
device=None,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLUFFN(nn.Module, ListForwardMixin):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Optional[Callable[..., nn.Module]] = None,
drop: float = 0.0,
bias: bool = True,
align_to: int = 8,
device=None,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
d = int(hidden_features * 2 / 3)
swiglu_hidden_features = d + (-d % align_to)
self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device)
def forward(self, x: Tensor) -> Tensor:
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.w3(hidden)
# ---------- Attention (from eupe/layers/attention.py) ----------------------
def rope_rotate_half(x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
return (x * cos) + (rope_rotate_half(x) * sin)
class LinearKMaskedBias(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
o = self.out_features
assert o % 3 == 0
if self.bias is not None:
self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan))
def forward(self, input: Tensor) -> Tensor:
masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None
return F.linear(input, self.weight, masked_bias)
class SelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
mask_k_bias: bool = False,
device=None,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear
self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def apply_rope(self, q: Tensor, k: Tensor, rope) -> Tuple[Tensor, Tensor]:
q_dtype = q.dtype
k_dtype = k.dtype
sin, cos = rope
rope_dtype = sin.dtype
q = q.to(dtype=rope_dtype)
k = k.to(dtype=rope_dtype)
N = q.shape[-2]
prefix = N - sin.shape[-2]
assert prefix >= 0
q_prefix = q[:, :, :prefix, :]
q = rope_apply(q[:, :, prefix:, :], sin, cos)
q = torch.cat((q_prefix, q), dim=-2)
k_prefix = k[:, :, :prefix, :]
k = rope_apply(k[:, :, prefix:, :], sin, cos)
k = torch.cat((k_prefix, k), dim=-2)
q = q.to(dtype=q_dtype)
k = k.to(dtype=k_dtype)
return q, k
def forward(self, x: Tensor, attn_bias=None, rope=None) -> Tensor:
qkv = self.qkv(x)
attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope)
x = self.proj(attn_v)
x = self.proj_drop(x)
return x
def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]:
assert len(x_list) == len(rope_list)
x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
qkv_flat = self.qkv(x_flat)
qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens)
att_out = []
for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)):
att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope))
x_flat, shapes, num_tokens = cat_keep_shapes(att_out)
x_flat = self.proj(x_flat)
return uncat_with_shapes(x_flat, shapes, num_tokens)
def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor:
assert attn_bias is None
B, N, _ = qkv.shape
C = self.qkv.in_features
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = torch.unbind(qkv, 2)
q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
if rope is not None:
q, k = self.apply_rope(q, k, rope)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.transpose(1, 2)
return x.reshape([B, N, C])
# ---------- Block (from eupe/layers/block.py) ------------------------------
class SelfAttentionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
ffn_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = SelfAttention,
ffn_layer: Callable[..., nn.Module] = Mlp,
mask_k_bias: bool = False,
device=None,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
mask_k_bias=mask_k_bias,
device=device,
)
self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * ffn_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
device=device,
)
self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
self.sample_drop_ratio = drop_path
@staticmethod
def _maybe_index_rope(rope, indices: Tensor):
if rope is None:
return None
sin, cos = rope
assert sin.ndim == cos.ndim
if sin.ndim == 4:
return sin[indices], cos[indices]
return sin, cos
def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]:
b_list = [x.shape[0] for x in x_list]
sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list]
if self.training and self.sample_drop_ratio > 0.0:
residual_scale_factors = [b / s for b, s in zip(b_list, sample_subset_sizes)]
indices_1_list = [
torch.randperm(b, device=x.device)[:s]
for x, b, s in zip(x_list, b_list, sample_subset_sizes)
]
x_subset_1_list = [x[i] for x, i in zip(x_list, indices_1_list)]
if rope_list is not None:
rope_subset_list = [
self._maybe_index_rope(r, i) for r, i in zip(rope_list, indices_1_list)
]
else:
rope_subset_list = rope_list
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list)
norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens)
residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list)
x_attn_list = [
torch.index_add(x, dim=0, source=self.ls1(r1), index=i1, alpha=rsf)
for x, r1, i1, rsf in zip(x_list, residual_1_list, indices_1_list, residual_scale_factors)
]
indices_2_list = [
torch.randperm(b, device=x.device)[:s]
for x, b, s in zip(x_list, b_list, sample_subset_sizes)
]
x_subset_2_list = [x[i] for x, i in zip(x_attn_list, indices_2_list)]
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list)
norm2_list = uncat_with_shapes(self.norm2(flattened), shapes, num_tokens)
residual_2_list = self.mlp.forward_list(norm2_list)
x_ffn = [
torch.index_add(xa, dim=0, source=self.ls2(r2), index=i2, alpha=rsf)
for xa, r2, i2, rsf in zip(x_attn_list, residual_2_list, indices_2_list, residual_scale_factors)
]
else:
x_out = []
for x, rope in zip(x_list, rope_list):
x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
x_out.append(x_ffn)
x_ffn = x_out
return x_ffn
def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]:
if isinstance(x_or_x_list, Tensor):
return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0]
elif isinstance(x_or_x_list, list):
if rope_or_rope_list is None:
rope_or_rope_list = [None for _ in x_or_x_list]
return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list)
raise AssertionError
# ---------- DinoVisionTransformer (from eupe/models/vision_transformer.py)
ffn_layer_dict = {
"mlp": Mlp,
"swiglu": SwiGLUFFN,
"swiglu32": partial(SwiGLUFFN, align_to=32),
"swiglu64": partial(SwiGLUFFN, align_to=64),
"swiglu128": partial(SwiGLUFFN, align_to=128),
}
norm_layer_dict = {
"layernorm": partial(nn.LayerNorm, eps=1e-6),
"layernormbf16": partial(nn.LayerNorm, eps=1e-5),
"rmsnorm": RMSNorm,
}
dtype_dict = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
def init_weights_vit(module: nn.Module, name: str = ""):
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if hasattr(module, "bias_mask") and module.bias_mask is not None:
o = module.out_features
module.bias_mask.fill_(1)
module.bias_mask[o // 3 : 2 * o // 3].fill_(0)
if isinstance(module, nn.LayerNorm):
module.reset_parameters()
if isinstance(module, LayerScale):
module.reset_parameters()
if isinstance(module, PatchEmbed):
module.reset_parameters()
if isinstance(module, RMSNorm):
module.reset_parameters()
class DinoVisionTransformer(nn.Module):
def __init__(
self,
*,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
pos_embed_rope_base: float = 100.0,
pos_embed_rope_min_period: Optional[float] = None,
pos_embed_rope_max_period: Optional[float] = None,
pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate",
pos_embed_rope_shift_coords: Optional[float] = None,
pos_embed_rope_jitter_coords: Optional[float] = None,
pos_embed_rope_rescale_coords: Optional[float] = None,
pos_embed_rope_dtype: str = "bf16",
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
ffn_ratio: float = 4.0,
qkv_bias: bool = True,
drop_path_rate: float = 0.0,
layerscale_init: Optional[float] = None,
norm_layer: str = "layernorm",
ffn_layer: str = "mlp",
ffn_bias: bool = True,
proj_bias: bool = True,
n_storage_tokens: int = 0,
mask_k_bias: bool = False,
untie_cls_and_patch_norms: bool = False,
untie_global_and_local_cls_norm: bool = False,
device: Any = None,
**ignored_kwargs,
):
super().__init__()
del ignored_kwargs
norm_layer_cls = norm_layer_dict[norm_layer]
self.num_features = self.embed_dim = embed_dim
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
flatten_embedding=False,
)
self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device))
self.n_storage_tokens = n_storage_tokens
if self.n_storage_tokens > 0:
self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device))
self.rope_embed = RopePositionEmbedding(
embed_dim=embed_dim,
num_heads=num_heads,
base=pos_embed_rope_base,
min_period=pos_embed_rope_min_period,
max_period=pos_embed_rope_max_period,
normalize_coords=pos_embed_rope_normalize_coords,
shift_coords=pos_embed_rope_shift_coords,
jitter_coords=pos_embed_rope_jitter_coords,
rescale_coords=pos_embed_rope_rescale_coords,
dtype=dtype_dict[pos_embed_rope_dtype],
device=device,
)
ffn_layer_cls = ffn_layer_dict[ffn_layer]
ffn_ratio_sequence = [ffn_ratio] * depth
blocks_list = [
SelfAttentionBlock(
dim=embed_dim,
num_heads=num_heads,
ffn_ratio=ffn_ratio_sequence[i],
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=drop_path_rate,
norm_layer=norm_layer_cls,
act_layer=nn.GELU,
ffn_layer=ffn_layer_cls,
init_values=layerscale_init,
mask_k_bias=mask_k_bias,
device=device,
)
for i in range(depth)
]
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer_cls(embed_dim)
self.untie_cls_and_patch_norms = untie_cls_and_patch_norms
self.cls_norm = norm_layer_cls(embed_dim) if untie_cls_and_patch_norms else None
self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm
self.local_cls_norm = norm_layer_cls(embed_dim) if untie_global_and_local_cls_norm else None
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device))
def init_weights(self):
self.rope_embed._init_weights()
nn.init.normal_(self.cls_token, std=0.02)
if self.n_storage_tokens > 0:
nn.init.normal_(self.storage_tokens, std=0.02)
nn.init.zeros_(self.mask_token)
named_apply(init_weights_vit, self)
def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int, int]]:
x = self.patch_embed(x)
B, H, W, _ = x.shape
x = x.flatten(1, 2)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
cls_token = self.cls_token
else:
cls_token = self.cls_token + 0 * self.mask_token
if self.n_storage_tokens > 0:
storage_tokens = self.storage_tokens
else:
storage_tokens = torch.empty(
1, 0, cls_token.shape[-1],
dtype=cls_token.dtype, device=cls_token.device,
)
x = torch.cat(
[cls_token.expand(B, -1, -1), storage_tokens.expand(B, -1, -1), x],
dim=1,
)
return x, (H, W)
def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]:
x = []
rope = []
for t_x, t_masks in zip(x_list, masks_list):
t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks)
x.append(t2_x)
rope.append(hw_tuple)
for blk in self.blocks:
if self.rope_embed is not None:
rope_sincos = [self.rope_embed(H=H, W=W) for H, W in rope]
else:
rope_sincos = [None for _ in rope]
x = blk(x, rope_sincos)
all_x = x
output = []
for idx, (x, masks) in enumerate(zip(all_x, masks_list)):
if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm:
if self.untie_global_and_local_cls_norm and self.training and idx == 1:
x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1])
elif self.untie_cls_and_patch_norms:
x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1])
else:
x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1])
x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :])
else:
x_norm = self.norm(x)
x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1]
x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :]
output.append({
"x_norm_clstoken": x_norm_cls_reg[:, 0],
"x_storage_tokens": x_norm_cls_reg[:, 1:],
"x_norm_patchtokens": x_norm_patch,
"x_prenorm": x,
"masks": masks,
})
return output
def forward_features(self, x, masks: Optional[Tensor] = None):
if isinstance(x, torch.Tensor):
return self.forward_features_list([x], [masks])[0]
return self.forward_features_list(x, masks)
def forward(self, *args, is_training: bool = False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
return self.head(ret["x_norm_clstoken"])
def build_eupe_vitb16() -> DinoVisionTransformer:
return DinoVisionTransformer(
img_size=224,
patch_size=16,
in_chans=3,
pos_embed_rope_base=100,
pos_embed_rope_normalize_coords="separate",
pos_embed_rope_rescale_coords=2,
pos_embed_rope_dtype="fp32",
embed_dim=768,
depth=12,
num_heads=12,
ffn_ratio=4,
qkv_bias=True,
drop_path_rate=0.0,
layerscale_init=1.0e-05,
norm_layer="layernormbf16",
ffn_layer="mlp",
ffn_bias=True,
proj_bias=True,
n_storage_tokens=4,
mask_k_bias=True,
)
# ===========================================================================
# Argus task heads
# ===========================================================================
def make_eupe_transform(resize_size: int):
return v2.Compose([
v2.ToImage(),
v2.Resize((resize_size, resize_size), antialias=True),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
def _normalize_image_input(image_or_images) -> Tuple[bool, list]:
"""Returns (was_single, [images]). Accepts a PIL.Image or an iterable of them."""
if isinstance(image_or_images, Image.Image):
return True, [image_or_images]
images = list(image_or_images)
if not images:
raise ValueError("empty image list")
for i, img in enumerate(images):
if not isinstance(img, Image.Image):
raise TypeError(f"images[{i}] is {type(img).__name__}, expected PIL.Image")
return False, images
class _BackboneExportWrapper(nn.Module):
"""ONNX-friendly wrapper: returns (cls, spatial) instead of a dict."""
def __init__(self, backbone: nn.Module):
super().__init__()
self.backbone = backbone
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
out = self.backbone.forward_features(x)
cls = out["x_norm_clstoken"]
patches = out["x_norm_patchtokens"]
B, N, D = patches.shape
h = w = int(N ** 0.5)
spatial = patches.permute(0, 2, 1).reshape(B, D, h, w)
return cls, spatial
class SegmentationHead(nn.Module):
def __init__(self, in_dim: int = 768, num_classes: int = 150):
super().__init__()
self.batchnorm_layer = nn.BatchNorm2d(in_dim)
self.conv = nn.Conv2d(in_dim, num_classes, kernel_size=1)
def forward(self, x: Tensor) -> Tensor:
return self.conv(self.batchnorm_layer(x))
class DepthHead(nn.Module):
def __init__(self, in_dim: int = 768, n_bins: int = 256,
min_depth: float = 0.001, max_depth: float = 10.0):
super().__init__()
self.batchnorm_layer = nn.BatchNorm2d(in_dim)
self.conv_depth = nn.Conv2d(in_dim, n_bins, kernel_size=1)
self.min_depth = min_depth
self.max_depth = max_depth
self.n_bins = n_bins
def forward(self, x: Tensor) -> Tensor:
logits = self.conv_depth(self.batchnorm_layer(x))
logit = torch.relu(logits) + 0.1
logit = logit / logit.sum(dim=1, keepdim=True)
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device)
return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1)
# ===========================================================================
# Detection (FCOS with ViTDet-style simple feature pyramid)
# ===========================================================================
FPN_STRIDES = [8, 16, 32, 64, 128]
COCO_CLASSES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
"giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
"toothbrush",
]
class SimpleFeaturePyramid(nn.Module):
"""ViTDet-style simple FPN: a single stride-16 ViT feature map -> P3..P7."""
def __init__(self, in_channels: int = 768, fpn_channels: int = 256):
super().__init__()
self.fpn_channels = fpn_channels
self.p3 = nn.Sequential(
nn.ConvTranspose2d(in_channels, in_channels, 2, stride=2),
nn.GroupNorm(32, in_channels),
nn.GELU(),
nn.Conv2d(in_channels, fpn_channels, 1),
nn.GroupNorm(32, fpn_channels),
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
nn.GroupNorm(32, fpn_channels),
)
self.p4 = nn.Sequential(
nn.Conv2d(in_channels, fpn_channels, 1),
nn.GroupNorm(32, fpn_channels),
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
nn.GroupNorm(32, fpn_channels),
)
self.p5 = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1),
nn.GroupNorm(32, in_channels),
nn.GELU(),
nn.Conv2d(in_channels, fpn_channels, 1),
nn.GroupNorm(32, fpn_channels),
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
nn.GroupNorm(32, fpn_channels),
)
self.p6 = nn.Sequential(
nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1),
nn.GroupNorm(32, fpn_channels),
)
self.p7 = nn.Sequential(
nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1),
nn.GroupNorm(32, fpn_channels),
)
def forward(self, x: Tensor) -> List[Tensor]:
p3 = self.p3(x)
p4 = self.p4(x)
p5 = self.p5(x)
p6 = self.p6(p5)
p7 = self.p7(p6)
return [p3, p4, p5, p6, p7]
class FCOSHead(nn.Module):
"""Shared classification / box regression / centerness towers across pyramid levels."""
def __init__(self, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4):
super().__init__()
self.num_classes = num_classes
cls_tower, reg_tower = [], []
for _ in range(num_convs):
cls_tower += [
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
nn.GroupNorm(32, fpn_channels),
nn.GELU(),
]
reg_tower += [
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
nn.GroupNorm(32, fpn_channels),
nn.GELU(),
]
self.cls_tower = nn.Sequential(*cls_tower)
self.reg_tower = nn.Sequential(*reg_tower)
self.cls_pred = nn.Conv2d(fpn_channels, num_classes, 3, padding=1)
self.reg_pred = nn.Conv2d(fpn_channels, 4, 3, padding=1)
self.center_pred = nn.Conv2d(fpn_channels, 1, 3, padding=1)
self.scales = nn.Parameter(torch.ones(len(FPN_STRIDES)))
prior = 0.01
nn.init.constant_(self.cls_pred.bias, -math.log((1 - prior) / prior))
nn.init.zeros_(self.reg_pred.weight)
nn.init.zeros_(self.reg_pred.bias)
nn.init.zeros_(self.center_pred.weight)
nn.init.zeros_(self.center_pred.bias)
def forward(self, fpn_features: List[Tensor]) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
cls_logits, box_regs, centernesses = [], [], []
for level_idx, feat in enumerate(fpn_features):
cls_feat = self.cls_tower(feat)
reg_feat = self.reg_tower(feat)
cls_logits.append(self.cls_pred(cls_feat))
reg_raw = self.reg_pred(reg_feat) * self.scales[level_idx]
reg_raw = reg_raw.clamp(min=-10.0, max=10.0)
box_regs.append(torch.exp(reg_raw))
centernesses.append(self.center_pred(reg_feat))
return cls_logits, box_regs, centernesses
class DetectionHead(nn.Module):
"""Combined SFP + FCOS head."""
def __init__(self, in_channels: int = 768, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4):
super().__init__()
self.fpn = SimpleFeaturePyramid(in_channels=in_channels, fpn_channels=fpn_channels)
self.head = FCOSHead(fpn_channels=fpn_channels, num_classes=num_classes, num_convs=num_convs)
self.num_classes = num_classes
def forward(self, spatial_features: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
fpn = self.fpn(spatial_features)
return self.head(fpn)
def _make_locations(feature_sizes: List[Tuple[int, int]], strides: List[int], device) -> List[Tensor]:
"""Per-level center coordinates of feature-map locations in image space."""
all_locs = []
for (h, w), s in zip(feature_sizes, strides):
ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
locs = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)
all_locs.append(locs)
return all_locs
@torch.inference_mode()
def _decode_detections(
cls_logits_per_level: List[Tensor],
box_regs_per_level: List[Tensor],
centernesses_per_level: List[Tensor],
locations_per_level: List[Tensor],
image_sizes: List[Tuple[int, int]],
score_thresh: float = 0.05,
nms_thresh: float = 0.5,
max_per_level: int = 1000,
max_per_image: int = 100,
) -> List[Dict[str, Tensor]]:
"""Convert per-level logits/regs/centerness into per-image detections (xyxy boxes)."""
B = cls_logits_per_level[0].shape[0]
num_classes = cls_logits_per_level[0].shape[1]
device = cls_logits_per_level[0].device
per_image_results = []
for image_idx in range(B):
all_boxes, all_scores, all_labels = [], [], []
for cls_l, reg_l, ctr_l, locs_l in zip(
cls_logits_per_level, box_regs_per_level, centernesses_per_level, locations_per_level
):
cls = cls_l[image_idx].permute(1, 2, 0).reshape(-1, num_classes)
reg = reg_l[image_idx].permute(1, 2, 0).reshape(-1, 4)
ctr = ctr_l[image_idx].permute(1, 2, 0).reshape(-1)
cls_prob = torch.sigmoid(cls)
ctr_prob = torch.sigmoid(ctr)
scores = cls_prob * ctr_prob[:, None]
mask = scores > score_thresh
if not mask.any():
continue
cand_loc, cand_cls = mask.nonzero(as_tuple=True)
cand_scores = scores[cand_loc, cand_cls]
if cand_scores.numel() > max_per_level:
top = cand_scores.topk(max_per_level)
cand_scores = top.values
idx = top.indices
cand_loc = cand_loc[idx]
cand_cls = cand_cls[idx]
cand_locs_xy = locs_l[cand_loc]
cand_reg = reg[cand_loc]
boxes = torch.stack([
cand_locs_xy[:, 0] - cand_reg[:, 0],
cand_locs_xy[:, 1] - cand_reg[:, 1],
cand_locs_xy[:, 0] + cand_reg[:, 2],
cand_locs_xy[:, 1] + cand_reg[:, 3],
], dim=-1)
all_boxes.append(boxes)
all_scores.append(cand_scores)
all_labels.append(cand_cls)
if all_boxes:
boxes = torch.cat(all_boxes, dim=0)
scores = torch.cat(all_scores, dim=0)
labels = torch.cat(all_labels, dim=0)
H, W = image_sizes[image_idx]
boxes[:, 0::2] = boxes[:, 0::2].clamp(0, W)
boxes[:, 1::2] = boxes[:, 1::2].clamp(0, H)
keep_all = []
for c in labels.unique():
cm = labels == c
keep = nms(boxes[cm], scores[cm], nms_thresh)
keep_idx = cm.nonzero(as_tuple=True)[0][keep]
keep_all.append(keep_idx)
keep_all = torch.cat(keep_all, dim=0)
boxes = boxes[keep_all]
scores = scores[keep_all]
labels = labels[keep_all]
if scores.numel() > max_per_image:
top = scores.topk(max_per_image)
boxes = boxes[top.indices]
scores = top.values
labels = labels[top.indices]
else:
boxes = torch.zeros((0, 4), device=device)
scores = torch.zeros((0,), device=device)
labels = torch.zeros((0,), dtype=torch.long, device=device)
per_image_results.append({"boxes": boxes, "scores": scores, "labels": labels})
return per_image_results
def _letterbox_to_square(image: Image.Image, resolution: int) -> Tuple[Image.Image, float, Tuple[int, int]]:
"""Resize preserving aspect ratio and pad bottom/right with black. Matches the training transform."""
W0, H0 = image.size
scale = resolution / max(H0, W0)
new_w = int(round(W0 * scale))
new_h = int(round(H0 * scale))
resized = image.resize((new_w, new_h), Image.BILINEAR)
canvas = Image.new("RGB", (resolution, resolution), (0, 0, 0))
canvas.paste(resized, (0, 0))
return canvas, scale, (W0, H0)
# ===========================================================================
# Argus model (transformers-compatible)
# ===========================================================================
class ArgusConfig(PretrainedConfig):
model_type = "argus"
def __init__(
self,
embed_dim: int = 768,
patch_size: int = 16,
num_seg_classes: int = 150,
depth_n_bins: int = 256,
depth_min_depth: float = 0.001,
depth_max_depth: float = 10.0,
num_imagenet_classes: int = 1000,
class_ids: Optional[list] = None,
class_names: Optional[list] = None,
detection_num_classes: int = 80,
detection_fpn_channels: int = 256,
detection_num_convs: int = 4,
detection_class_names: Optional[list] = None,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.patch_size = patch_size
self.num_seg_classes = num_seg_classes
self.depth_n_bins = depth_n_bins
self.depth_min_depth = depth_min_depth
self.depth_max_depth = depth_max_depth
self.num_imagenet_classes = num_imagenet_classes
self.class_ids = class_ids or []
self.class_names = class_names or []
self.detection_num_classes = detection_num_classes
self.detection_fpn_channels = detection_fpn_channels
self.detection_num_convs = detection_num_convs
self.detection_class_names = detection_class_names or list(COCO_CLASSES)
class Argus(PreTrainedModel):
config_class = ArgusConfig
base_model_prefix = "argus"
supports_gradient_checkpointing = False
_tied_weights_keys: list = []
all_tied_weights_keys: dict = {}
def __init__(self, config: ArgusConfig):
super().__init__(config)
self.backbone = build_eupe_vitb16()
self.seg_head = SegmentationHead(config.embed_dim, config.num_seg_classes)
self.depth_head = DepthHead(
in_dim=config.embed_dim,
n_bins=config.depth_n_bins,
min_depth=config.depth_min_depth,
max_depth=config.depth_max_depth,
)
self.register_buffer(
"class_prototypes",
torch.zeros(config.num_imagenet_classes, config.embed_dim),
persistent=True,
)
self.register_buffer(
"class_logit_weight",
torch.zeros(config.num_imagenet_classes, config.embed_dim),
persistent=True,
)
self.register_buffer(
"class_logit_bias",
torch.zeros(config.num_imagenet_classes),
persistent=True,
)
self.detection_head = DetectionHead(
in_channels=config.embed_dim,
fpn_channels=config.detection_fpn_channels,
num_classes=config.detection_num_classes,
num_convs=config.detection_num_convs,
)
for p in self.backbone.parameters():
p.requires_grad = False
self.backbone.eval()
self.seg_head.eval()
self.depth_head.eval()
self.detection_head.eval()
def _init_weights(self, module):
# HF reallocates missing buffers and parameters with torch.empty()
# (uninitialized memory) on from_pretrained. Populate sensible defaults
# for the standard layer types used by the detection head, and zero any
# Argus-level buffer that came back NaN.
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.GroupNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
if module is self:
for name in ("class_prototypes", "class_logit_weight", "class_logit_bias"):
if hasattr(self, name):
buf = getattr(self, name)
if torch.isnan(buf).any() or torch.isinf(buf).any():
buf.data.zero_()
@property
def class_ids(self):
return self.config.class_ids
@property
def class_names(self):
return self.config.class_names
@torch.inference_mode()
def _extract(self, image_tensor: Tensor) -> Tuple[Tensor, Tensor]:
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
out = self.backbone.forward_features(image_tensor)
cls = out["x_norm_clstoken"].float()
patches = out["x_norm_patchtokens"].float()
B, N, D = patches.shape
h = w = int(N ** 0.5)
spatial = patches.permute(0, 2, 1).reshape(B, D, h, w)
return cls, spatial
@torch.inference_mode()
def classify(
self,
image_or_images,
top_k: int = 5,
method: Literal["knn", "softmax"] = "knn",
):
single, images = _normalize_image_input(image_or_images)
transform = make_eupe_transform(224)
batch = torch.stack([transform(img) for img in images]).to(self.device)
cls, _ = self._extract(batch)
cls = F.normalize(cls, dim=-1)
if method == "knn":
proto = self.class_prototypes.to(cls.dtype)
scores_full = cls @ proto.T # cosine similarity in [-1, 1]
elif method == "softmax":
w = self.class_logit_weight.to(cls.dtype)
b = self.class_logit_bias.to(cls.dtype)
logits = F.linear(cls, w, b)
scores_full = F.softmax(logits, dim=-1) # in [0, 1]
else:
raise ValueError(f"unknown classification method: {method!r} (expected 'knn' or 'softmax')")
topk = scores_full.topk(top_k, dim=-1)
top2 = scores_full.topk(2, dim=-1)
margins = (top2.values[:, 0] - top2.values[:, 1]).tolist()
results = []
for b in range(len(images)):
entries = []
for score, idx in zip(topk.values[b].tolist(), topk.indices[b].tolist()):
entries.append({
"class_id": self.class_ids[idx],
"class_name": self.class_names[idx],
"score": float(score),
})
entries[0]["margin"] = float(margins[b])
results.append(entries)
return results[0] if single else results
@torch.inference_mode()
def segment(self, image_or_images, resolution: int = 512, return_confidence: bool = False):
single, images = _normalize_image_input(image_or_images)
transform = make_eupe_transform(resolution)
batch = torch.stack([transform(img) for img in images]).to(self.device)
_, spatial = self._extract(batch)
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
logits = self.seg_head(spatial)
logits = F.interpolate(logits, size=(resolution, resolution), mode="bilinear", align_corners=False)
seg_maps = logits.argmax(dim=1) # [B, H, W]
if return_confidence:
probs = F.softmax(logits.float(), dim=1)
conf_maps = probs.max(dim=1).values # [B, H, W] in [0, 1]
if single:
return seg_maps[0], conf_maps[0]
return [(seg_maps[i], conf_maps[i]) for i in range(len(images))]
if single:
return seg_maps[0]
return [seg_maps[i] for i in range(len(images))]
@torch.inference_mode()
def depth(self, image_or_images, resolution: int = 416, return_confidence: bool = False):
single, images = _normalize_image_input(image_or_images)
transform = make_eupe_transform(resolution)
batch = torch.stack([transform(img) for img in images]).to(self.device)
_, spatial = self._extract(batch)
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
if return_confidence:
# inline depth head to expose the bin distribution
normed = self.depth_head.batchnorm_layer(spatial)
bin_logits = self.depth_head.conv_depth(normed)
distribution = torch.relu(bin_logits) + 0.1
distribution = distribution / distribution.sum(dim=1, keepdim=True)
bins = torch.linspace(
self.depth_head.min_depth, self.depth_head.max_depth,
self.depth_head.n_bins, device=spatial.device,
)
depth_b = torch.einsum("bkhw,k->bhw", distribution, bins).unsqueeze(1)
mean_sq = torch.einsum("bkhw,k->bhw", distribution, bins ** 2)
variance = (mean_sq - depth_b.squeeze(1) ** 2).clamp(min=0)
std_b = torch.sqrt(variance).unsqueeze(1)
else:
depth_b = self.depth_head(spatial)
std_b = None
depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
if std_b is not None:
std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
depth_squeezed = depth_b[:, 0].float() # [B, H, W]
if return_confidence:
std_squeezed = std_b[:, 0].float() # [B, H, W]
if single:
return depth_squeezed[0], std_squeezed[0]
return [(depth_squeezed[i], std_squeezed[i]) for i in range(len(images))]
if single:
return depth_squeezed[0]
return [depth_squeezed[i] for i in range(len(images))]
@torch.inference_mode()
def correspond(
self,
src_image: Image.Image,
tgt_image: Image.Image,
src_keypoints: list,
resolution: int = 512,
):
sw, sh = src_image.size
tw, th = tgt_image.size
transform = make_eupe_transform(resolution)
src_t = transform(src_image).unsqueeze(0).to(self.device)
tgt_t = transform(tgt_image).unsqueeze(0).to(self.device)
_, src_feats = self._extract(src_t)
_, tgt_feats = self._extract(tgt_t)
src_feats = F.interpolate(src_feats, size=(resolution, resolution), mode="bilinear", align_corners=False)
tgt_feats = F.interpolate(tgt_feats, size=(resolution, resolution), mode="bilinear", align_corners=False)
src_feats = F.normalize(src_feats[0].permute(1, 2, 0), dim=-1)
tgt_feats = F.normalize(tgt_feats[0].permute(1, 2, 0), dim=-1)
preds = []
for kp in src_keypoints:
sx = min(max(int(kp[0] / sw * resolution), 0), resolution - 1)
sy = min(max(int(kp[1] / sh * resolution), 0), resolution - 1)
src_vec = src_feats[sy, sx]
sim_map = torch.einsum("d,hwd->hw", src_vec, tgt_feats)
flat = sim_map.argmax().item()
py, px = flat // resolution, flat % resolution
preds.append([px / resolution * tw, py / resolution * th])
return preds
@torch.inference_mode()
def detect(
self,
image_or_images,
resolution: int = 640,
score_thresh: float = 0.05,
nms_thresh: float = 0.5,
max_per_image: int = 100,
):
single, images = _normalize_image_input(image_or_images)
# Letterbox each image to match the training transform (resize long side
# to `resolution`, pad bottom/right with black). Box coordinates are
# recovered after decoding by unscaling.
canvases, scales, orig_sizes = [], [], []
for img in images:
canvas, scale, orig = _letterbox_to_square(img, resolution)
canvases.append(canvas)
scales.append(scale)
orig_sizes.append(orig)
det_normalize = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
batch = torch.stack([det_normalize(c) for c in canvases]).to(self.device)
_, spatial = self._extract(batch)
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
cls_logits, box_regs, centernesses = self.detection_head(spatial)
cls_logits = [c.float() for c in cls_logits]
box_regs = [b.float() for b in box_regs]
centernesses = [c.float() for c in centernesses]
feature_sizes = [(cl.shape[2], cl.shape[3]) for cl in cls_logits]
locations = _make_locations(feature_sizes, FPN_STRIDES, spatial.device)
image_sizes = [(resolution, resolution)] * len(images)
results = _decode_detections(
cls_logits, box_regs, centernesses, locations,
image_sizes=image_sizes,
score_thresh=score_thresh,
nms_thresh=nms_thresh,
max_per_image=max_per_image,
)
class_names = self.config.detection_class_names
formatted = []
for i, r in enumerate(results):
scale = scales[i]
orig_w, orig_h = orig_sizes[i]
boxes = r["boxes"].cpu().numpy() / scale
boxes[:, 0::2] = boxes[:, 0::2].clip(0, orig_w)
boxes[:, 1::2] = boxes[:, 1::2].clip(0, orig_h)
detections = []
for box, score, label in zip(
boxes, r["scores"].cpu().numpy(), r["labels"].cpu().numpy()
):
detections.append({
"box": [float(v) for v in box.tolist()],
"score": float(score),
"label": int(label),
"class_name": class_names[int(label)] if int(label) < len(class_names) else f"class_{int(label)}",
})
formatted.append(detections)
return formatted[0] if single else formatted
def perceive(self, image_or_images, return_confidence: bool = False):
single, images = _normalize_image_input(image_or_images)
t0 = time.time()
classif = self.classify(images, top_k=5)
t1 = time.time()
seg_out = self.segment(images, resolution=512, return_confidence=return_confidence)
t2 = time.time()
depth_out = self.depth(images, resolution=416, return_confidence=return_confidence)
t3 = time.time()
if return_confidence:
seg_maps = [s for s, _ in seg_out]
seg_confs = [c for _, c in seg_out]
depth_maps = [d for d, _ in depth_out]
depth_uncerts = [u for _, u in depth_out]
else:
seg_maps = seg_out
depth_maps = depth_out
seg_confs = depth_uncerts = None
timings = {
"classify": (t1 - t0) * 1000,
"segment": (t2 - t1) * 1000,
"depth": (t3 - t2) * 1000,
"total": (t3 - t0) * 1000,
}
results = []
for i in range(len(images)):
entry = {
"classification": classif[i],
"segmentation": seg_maps[i].cpu().numpy(),
"depth": depth_maps[i].cpu().numpy(),
"timings_ms": timings,
}
if return_confidence:
entry["segmentation_confidence"] = seg_confs[i].cpu().numpy()
entry["depth_uncertainty"] = depth_uncerts[i].cpu().numpy()
results.append(entry)
return results[0] if single else results
def export_onnx(
self,
out_dir: str,
backbone_resolution: int = 224,
dynamic_batch: bool = True,
verify: bool = True,
tolerance: float = 5e-2,
opset_version: int = 17,
) -> dict:
"""Export backbone, seg head, and depth head to ONNX. kNN classification and correspondence run on top of the backbone output and need no separate graph."""
import os
os.makedirs(out_dir, exist_ok=True)
if backbone_resolution % self.config.patch_size != 0:
raise ValueError(
f"backbone_resolution ({backbone_resolution}) must be a multiple of patch_size ({self.config.patch_size})"
)
spatial_resolution = backbone_resolution // self.config.patch_size
wrapper = _BackboneExportWrapper(self.backbone).to(self.device).eval()
dummy_image = torch.randn(
1, 3, backbone_resolution, backbone_resolution,
device=self.device, dtype=torch.float32,
)
dummy_spatial = torch.randn(
1, self.config.embed_dim, spatial_resolution, spatial_resolution,
device=self.device, dtype=torch.float32,
)
backbone_path = os.path.join(out_dir, "argus_backbone.onnx")
seg_path = os.path.join(out_dir, "argus_seg_head.onnx")
depth_path = os.path.join(out_dir, "argus_depth_head.onnx")
backbone_axes = None
head_axes = None
if dynamic_batch:
backbone_axes = {
"image": {0: "batch"},
"cls_token": {0: "batch"},
"spatial_features": {0: "batch"},
}
head_axes = {
"spatial_features": {0: "batch"},
"seg_logits": {0: "batch"},
"depth_map": {0: "batch"},
}
# dynamo path crashes on EUPE's list-based forward; use legacy.
with torch.inference_mode():
torch.onnx.export(
wrapper, dummy_image, backbone_path,
input_names=["image"],
output_names=["cls_token", "spatial_features"],
dynamic_axes=backbone_axes,
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
torch.onnx.export(
self.seg_head, dummy_spatial, seg_path,
input_names=["spatial_features"],
output_names=["seg_logits"],
dynamic_axes={"spatial_features": head_axes["spatial_features"], "seg_logits": head_axes["seg_logits"]} if head_axes else None,
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
torch.onnx.export(
self.depth_head, dummy_spatial, depth_path,
input_names=["spatial_features"],
output_names=["depth_map"],
dynamic_axes={"spatial_features": head_axes["spatial_features"], "depth_map": head_axes["depth_map"]} if head_axes else None,
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
result = {
"backbone": backbone_path,
"seg_head": seg_path,
"depth_head": depth_path,
}
if verify:
try:
import onnxruntime as ort
except ImportError as e:
raise ImportError("onnxruntime not installed; pip install onnxruntime") from e
providers = ["CPUExecutionProvider"]
verify_image = torch.randn(2, 3, backbone_resolution, backbone_resolution, dtype=torch.float32)
verify_spatial = torch.randn(2, self.config.embed_dim, spatial_resolution, spatial_resolution, dtype=torch.float32)
with torch.inference_mode():
ref_cls, ref_spatial = wrapper(verify_image.to(self.device))
ref_seg = self.seg_head(verify_spatial.to(self.device))
ref_depth = self.depth_head(verify_spatial.to(self.device))
sess = ort.InferenceSession(backbone_path, providers=providers)
ort_cls, ort_spatial = sess.run(None, {"image": verify_image.numpy()})
cls_diff = float(np.abs(ort_cls - ref_cls.cpu().numpy()).max())
spatial_diff = float(np.abs(ort_spatial - ref_spatial.cpu().numpy()).max())
sess = ort.InferenceSession(seg_path, providers=providers)
ort_seg = sess.run(None, {"spatial_features": verify_spatial.numpy()})[0]
seg_diff = float(np.abs(ort_seg - ref_seg.cpu().numpy()).max())
sess = ort.InferenceSession(depth_path, providers=providers)
ort_depth = sess.run(None, {"spatial_features": verify_spatial.numpy()})[0]
depth_diff = float(np.abs(ort_depth - ref_depth.cpu().numpy()).max())
verification = {
"backbone_cls_max_diff": cls_diff,
"backbone_spatial_max_diff": spatial_diff,
"seg_head_max_diff": seg_diff,
"depth_head_max_diff": depth_diff,
"tolerance": tolerance,
"verified_batch_size": 2,
}
for key, val in verification.items():
if key.endswith("_max_diff") and val > tolerance:
raise RuntimeError(
f"ONNX/PyTorch divergence in {key}: {val:.2e} > tolerance {tolerance:.2e}"
)
result["verification"] = verification
return result