""" Argus: multi-task perception on a single EUPE-ViT-B backbone. from transformers import AutoModel model = AutoModel.from_pretrained("phanerozoic/argus", trust_remote_code=True) result = model.perceive(image) The EUPE-ViT-B backbone architecture, all supporting layers, and the Argus task heads are inlined below. The backbone code is reproduced from facebookresearch/EUPE (Meta FAIR) under the FAIR Research License. """ import math import time from functools import partial from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F import torch.nn.init from PIL import Image from torch import Tensor, nn from torchvision.ops import nms from torchvision.transforms import v2 from transformers import PretrainedConfig, PreTrainedModel # =========================================================================== # EUPE backbone — vendored verbatim from facebookresearch/EUPE # =========================================================================== # ---------- utility helpers (from eupe/utils/utils.py) --------------------- def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]: shapes = [x.shape for x in x_list] num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list] flattened = torch.cat([x.flatten(0, -2) for x in x_list]) return flattened, shapes, num_tokens def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]: outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0) shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes] outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)] return outputs_reshaped def named_apply( fn: Callable, module: nn.Module, name: str = "", depth_first: bool = True, include_root: bool = False, ) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply( fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True, ) if depth_first and include_root: fn(module=module, name=name) return module # ---------- RMSNorm (from eupe/layers/rms_norm.py) ------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def reset_parameters(self) -> None: nn.init.constant_(self.weight, 1) def _norm(self, x: Tensor) -> Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: Tensor) -> Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight # ---------- LayerScale (from eupe/layers/layer_scale.py) ------------------- class LayerScale(nn.Module): def __init__( self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False, device=None, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(torch.empty(dim, device=device)) self.init_values = init_values def reset_parameters(self): nn.init.constant_(self.gamma, self.init_values) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma # ---------- PatchEmbed (from eupe/layers/patch_embed.py) ------------------- def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) class PatchEmbed(nn.Module): def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten_embedding: bool = True, ) -> None: super().__init__() image_HW = make_2tuple(img_size) patch_HW = make_2tuple(patch_size) patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) self.img_size = image_HW self.patch_size = patch_HW self.patches_resolution = patch_grid_size self.num_patches = patch_grid_size[0] * patch_grid_size[1] self.in_chans = in_chans self.embed_dim = embed_dim self.flatten_embedding = flatten_embedding self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape x = self.proj(x) H, W = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) x = self.norm(x) if not self.flatten_embedding: x = x.reshape(-1, H, W, self.embed_dim) return x def reset_parameters(self): k = 1 / (self.in_chans * (self.patch_size[0] ** 2)) nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k)) if self.proj.bias is not None: nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k)) # ---------- RoPE (from eupe/layers/rope_position_encoding.py) -------------- class RopePositionEmbedding(nn.Module): def __init__( self, embed_dim: int, *, num_heads: int, base: Optional[float] = 100.0, min_period: Optional[float] = None, max_period: Optional[float] = None, normalize_coords: Literal["min", "max", "separate"] = "separate", shift_coords: Optional[float] = None, jitter_coords: Optional[float] = None, rescale_coords: Optional[float] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): super().__init__() assert embed_dim % (4 * num_heads) == 0 both_periods = min_period is not None and max_period is not None if (base is None and not both_periods) or (base is not None and both_periods): raise ValueError("Either `base` or `min_period`+`max_period` must be provided.") D_head = embed_dim // num_heads self.base = base self.min_period = min_period self.max_period = max_period self.D_head = D_head self.normalize_coords = normalize_coords self.shift_coords = shift_coords self.jitter_coords = jitter_coords self.rescale_coords = rescale_coords self.dtype = dtype self.register_buffer( "periods", torch.empty(D_head // 4, device=device, dtype=dtype), persistent=True, ) self._init_weights() def forward(self, *, H: int, W: int) -> Tuple[Tensor, Tensor]: device = self.periods.device dtype = self.dtype dd = {"device": device, "dtype": dtype} if self.normalize_coords == "max": max_HW = max(H, W) coords_h = torch.arange(0.5, H, **dd) / max_HW coords_w = torch.arange(0.5, W, **dd) / max_HW elif self.normalize_coords == "min": min_HW = min(H, W) coords_h = torch.arange(0.5, H, **dd) / min_HW coords_w = torch.arange(0.5, W, **dd) / min_HW elif self.normalize_coords == "separate": coords_h = torch.arange(0.5, H, **dd) / H coords_w = torch.arange(0.5, W, **dd) / W else: raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) coords = coords.flatten(0, 1) coords = 2.0 * coords - 1.0 if self.training and self.shift_coords is not None: shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords) coords += shift_hw[None, :] if self.training and self.jitter_coords is not None: jitter_max = np.log(self.jitter_coords) jitter_min = -jitter_max jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() coords *= jitter_hw[None, :] if self.training and self.rescale_coords is not None: rescale_max = np.log(self.rescale_coords) rescale_min = -rescale_max rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() coords *= rescale_hw angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] angles = angles.flatten(1, 2) angles = angles.tile(2) cos = torch.cos(angles) sin = torch.sin(angles) return (sin, cos) def _init_weights(self): device = self.periods.device dtype = self.dtype if self.base is not None: periods = self.base ** ( 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2) ) else: base = self.max_period / self.min_period exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) periods = base ** exponents periods = periods / base periods = periods * self.max_period self.periods.data = periods # ---------- FFN layers (from eupe/layers/ffn_layers.py) -------------------- class ListForwardMixin(object): def forward(self, x: Tensor): raise NotImplementedError def forward_list(self, x_list: List[Tensor]) -> List[Tensor]: x_flat, shapes, num_tokens = cat_keep_shapes(x_list) x_flat = self.forward(x_flat) return uncat_with_shapes(x_flat, shapes, num_tokens) class Mlp(nn.Module, ListForwardMixin): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, device=None, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class SwiGLUFFN(nn.Module, ListForwardMixin): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Optional[Callable[..., nn.Module]] = None, drop: float = 0.0, bias: bool = True, align_to: int = 8, device=None, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features d = int(hidden_features * 2 / 3) swiglu_hidden_features = d + (-d % align_to) self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device) def forward(self, x: Tensor) -> Tensor: x1 = self.w1(x) x2 = self.w2(x) hidden = F.silu(x1) * x2 return self.w3(hidden) # ---------- Attention (from eupe/layers/attention.py) ---------------------- def rope_rotate_half(x: Tensor) -> Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: return (x * cos) + (rope_rotate_half(x) * sin) class LinearKMaskedBias(nn.Linear): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) o = self.out_features assert o % 3 == 0 if self.bias is not None: self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan)) def forward(self, input: Tensor) -> Tensor: masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None return F.linear(input, self.weight, masked_bias) class SelfAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, mask_k_bias: bool = False, device=None, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device) self.proj_drop = nn.Dropout(proj_drop) def apply_rope(self, q: Tensor, k: Tensor, rope) -> Tuple[Tensor, Tensor]: q_dtype = q.dtype k_dtype = k.dtype sin, cos = rope rope_dtype = sin.dtype q = q.to(dtype=rope_dtype) k = k.to(dtype=rope_dtype) N = q.shape[-2] prefix = N - sin.shape[-2] assert prefix >= 0 q_prefix = q[:, :, :prefix, :] q = rope_apply(q[:, :, prefix:, :], sin, cos) q = torch.cat((q_prefix, q), dim=-2) k_prefix = k[:, :, :prefix, :] k = rope_apply(k[:, :, prefix:, :], sin, cos) k = torch.cat((k_prefix, k), dim=-2) q = q.to(dtype=q_dtype) k = k.to(dtype=k_dtype) return q, k def forward(self, x: Tensor, attn_bias=None, rope=None) -> Tensor: qkv = self.qkv(x) attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope) x = self.proj(attn_v) x = self.proj_drop(x) return x def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]: assert len(x_list) == len(rope_list) x_flat, shapes, num_tokens = cat_keep_shapes(x_list) qkv_flat = self.qkv(x_flat) qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens) att_out = [] for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)): att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope)) x_flat, shapes, num_tokens = cat_keep_shapes(att_out) x_flat = self.proj(x_flat) return uncat_with_shapes(x_flat, shapes, num_tokens) def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor: assert attn_bias is None B, N, _ = qkv.shape C = self.qkv.in_features qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = torch.unbind(qkv, 2) q, k, v = [t.transpose(1, 2) for t in [q, k, v]] if rope is not None: q, k = self.apply_rope(q, k, rope) x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = x.transpose(1, 2) return x.reshape([B, N, C]) # ---------- Block (from eupe/layers/block.py) ------------------------------ class SelfAttentionBlock(nn.Module): def __init__( self, dim: int, num_heads: int, ffn_ratio: float = 4.0, qkv_bias: bool = False, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = SelfAttention, ffn_layer: Callable[..., nn.Module] = Mlp, mask_k_bias: bool = False, device=None, ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, mask_k_bias=mask_k_bias, device=device, ) self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * ffn_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias, device=device, ) self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() self.sample_drop_ratio = drop_path @staticmethod def _maybe_index_rope(rope, indices: Tensor): if rope is None: return None sin, cos = rope assert sin.ndim == cos.ndim if sin.ndim == 4: return sin[indices], cos[indices] return sin, cos def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]: b_list = [x.shape[0] for x in x_list] sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list] if self.training and self.sample_drop_ratio > 0.0: residual_scale_factors = [b / s for b, s in zip(b_list, sample_subset_sizes)] indices_1_list = [ torch.randperm(b, device=x.device)[:s] for x, b, s in zip(x_list, b_list, sample_subset_sizes) ] x_subset_1_list = [x[i] for x, i in zip(x_list, indices_1_list)] if rope_list is not None: rope_subset_list = [ self._maybe_index_rope(r, i) for r, i in zip(rope_list, indices_1_list) ] else: rope_subset_list = rope_list flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list) norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens) residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list) x_attn_list = [ torch.index_add(x, dim=0, source=self.ls1(r1), index=i1, alpha=rsf) for x, r1, i1, rsf in zip(x_list, residual_1_list, indices_1_list, residual_scale_factors) ] indices_2_list = [ torch.randperm(b, device=x.device)[:s] for x, b, s in zip(x_list, b_list, sample_subset_sizes) ] x_subset_2_list = [x[i] for x, i in zip(x_attn_list, indices_2_list)] flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list) norm2_list = uncat_with_shapes(self.norm2(flattened), shapes, num_tokens) residual_2_list = self.mlp.forward_list(norm2_list) x_ffn = [ torch.index_add(xa, dim=0, source=self.ls2(r2), index=i2, alpha=rsf) for xa, r2, i2, rsf in zip(x_attn_list, residual_2_list, indices_2_list, residual_scale_factors) ] else: x_out = [] for x, rope in zip(x_list, rope_list): x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) x_out.append(x_ffn) x_ffn = x_out return x_ffn def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]: if isinstance(x_or_x_list, Tensor): return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0] elif isinstance(x_or_x_list, list): if rope_or_rope_list is None: rope_or_rope_list = [None for _ in x_or_x_list] return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list) raise AssertionError # ---------- DinoVisionTransformer (from eupe/models/vision_transformer.py) ffn_layer_dict = { "mlp": Mlp, "swiglu": SwiGLUFFN, "swiglu32": partial(SwiGLUFFN, align_to=32), "swiglu64": partial(SwiGLUFFN, align_to=64), "swiglu128": partial(SwiGLUFFN, align_to=128), } norm_layer_dict = { "layernorm": partial(nn.LayerNorm, eps=1e-6), "layernormbf16": partial(nn.LayerNorm, eps=1e-5), "rmsnorm": RMSNorm, } dtype_dict = { "fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16, } def init_weights_vit(module: nn.Module, name: str = ""): if isinstance(module, nn.Linear): torch.nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) if hasattr(module, "bias_mask") and module.bias_mask is not None: o = module.out_features module.bias_mask.fill_(1) module.bias_mask[o // 3 : 2 * o // 3].fill_(0) if isinstance(module, nn.LayerNorm): module.reset_parameters() if isinstance(module, LayerScale): module.reset_parameters() if isinstance(module, PatchEmbed): module.reset_parameters() if isinstance(module, RMSNorm): module.reset_parameters() class DinoVisionTransformer(nn.Module): def __init__( self, *, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, pos_embed_rope_base: float = 100.0, pos_embed_rope_min_period: Optional[float] = None, pos_embed_rope_max_period: Optional[float] = None, pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", pos_embed_rope_shift_coords: Optional[float] = None, pos_embed_rope_jitter_coords: Optional[float] = None, pos_embed_rope_rescale_coords: Optional[float] = None, pos_embed_rope_dtype: str = "bf16", embed_dim: int = 768, depth: int = 12, num_heads: int = 12, ffn_ratio: float = 4.0, qkv_bias: bool = True, drop_path_rate: float = 0.0, layerscale_init: Optional[float] = None, norm_layer: str = "layernorm", ffn_layer: str = "mlp", ffn_bias: bool = True, proj_bias: bool = True, n_storage_tokens: int = 0, mask_k_bias: bool = False, untie_cls_and_patch_norms: bool = False, untie_global_and_local_cls_norm: bool = False, device: Any = None, **ignored_kwargs, ): super().__init__() del ignored_kwargs norm_layer_cls = norm_layer_dict[norm_layer] self.num_features = self.embed_dim = embed_dim self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, flatten_embedding=False, ) self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device)) self.n_storage_tokens = n_storage_tokens if self.n_storage_tokens > 0: self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device)) self.rope_embed = RopePositionEmbedding( embed_dim=embed_dim, num_heads=num_heads, base=pos_embed_rope_base, min_period=pos_embed_rope_min_period, max_period=pos_embed_rope_max_period, normalize_coords=pos_embed_rope_normalize_coords, shift_coords=pos_embed_rope_shift_coords, jitter_coords=pos_embed_rope_jitter_coords, rescale_coords=pos_embed_rope_rescale_coords, dtype=dtype_dict[pos_embed_rope_dtype], device=device, ) ffn_layer_cls = ffn_layer_dict[ffn_layer] ffn_ratio_sequence = [ffn_ratio] * depth blocks_list = [ SelfAttentionBlock( dim=embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio_sequence[i], qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=drop_path_rate, norm_layer=norm_layer_cls, act_layer=nn.GELU, ffn_layer=ffn_layer_cls, init_values=layerscale_init, mask_k_bias=mask_k_bias, device=device, ) for i in range(depth) ] self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = norm_layer_cls(embed_dim) self.untie_cls_and_patch_norms = untie_cls_and_patch_norms self.cls_norm = norm_layer_cls(embed_dim) if untie_cls_and_patch_norms else None self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm self.local_cls_norm = norm_layer_cls(embed_dim) if untie_global_and_local_cls_norm else None self.head = nn.Identity() self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device)) def init_weights(self): self.rope_embed._init_weights() nn.init.normal_(self.cls_token, std=0.02) if self.n_storage_tokens > 0: nn.init.normal_(self.storage_tokens, std=0.02) nn.init.zeros_(self.mask_token) named_apply(init_weights_vit, self) def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int, int]]: x = self.patch_embed(x) B, H, W, _ = x.shape x = x.flatten(1, 2) if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) cls_token = self.cls_token else: cls_token = self.cls_token + 0 * self.mask_token if self.n_storage_tokens > 0: storage_tokens = self.storage_tokens else: storage_tokens = torch.empty( 1, 0, cls_token.shape[-1], dtype=cls_token.dtype, device=cls_token.device, ) x = torch.cat( [cls_token.expand(B, -1, -1), storage_tokens.expand(B, -1, -1), x], dim=1, ) return x, (H, W) def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]: x = [] rope = [] for t_x, t_masks in zip(x_list, masks_list): t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks) x.append(t2_x) rope.append(hw_tuple) for blk in self.blocks: if self.rope_embed is not None: rope_sincos = [self.rope_embed(H=H, W=W) for H, W in rope] else: rope_sincos = [None for _ in rope] x = blk(x, rope_sincos) all_x = x output = [] for idx, (x, masks) in enumerate(zip(all_x, masks_list)): if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm: if self.untie_global_and_local_cls_norm and self.training and idx == 1: x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1]) elif self.untie_cls_and_patch_norms: x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1]) else: x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1]) x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :]) else: x_norm = self.norm(x) x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1] x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :] output.append({ "x_norm_clstoken": x_norm_cls_reg[:, 0], "x_storage_tokens": x_norm_cls_reg[:, 1:], "x_norm_patchtokens": x_norm_patch, "x_prenorm": x, "masks": masks, }) return output def forward_features(self, x, masks: Optional[Tensor] = None): if isinstance(x, torch.Tensor): return self.forward_features_list([x], [masks])[0] return self.forward_features_list(x, masks) def forward(self, *args, is_training: bool = False, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret return self.head(ret["x_norm_clstoken"]) def build_eupe_vitb16() -> DinoVisionTransformer: return DinoVisionTransformer( img_size=224, patch_size=16, in_chans=3, pos_embed_rope_base=100, pos_embed_rope_normalize_coords="separate", pos_embed_rope_rescale_coords=2, pos_embed_rope_dtype="fp32", embed_dim=768, depth=12, num_heads=12, ffn_ratio=4, qkv_bias=True, drop_path_rate=0.0, layerscale_init=1.0e-05, norm_layer="layernormbf16", ffn_layer="mlp", ffn_bias=True, proj_bias=True, n_storage_tokens=4, mask_k_bias=True, ) # =========================================================================== # Argus task heads # =========================================================================== def make_eupe_transform(resize_size: int): return v2.Compose([ v2.ToImage(), v2.Resize((resize_size, resize_size), antialias=True), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) def _normalize_image_input(image_or_images) -> Tuple[bool, list]: """Returns (was_single, [images]). Accepts a PIL.Image or an iterable of them.""" if isinstance(image_or_images, Image.Image): return True, [image_or_images] images = list(image_or_images) if not images: raise ValueError("empty image list") for i, img in enumerate(images): if not isinstance(img, Image.Image): raise TypeError(f"images[{i}] is {type(img).__name__}, expected PIL.Image") return False, images class _BackboneExportWrapper(nn.Module): """ONNX-friendly wrapper: returns (cls, spatial) instead of a dict.""" def __init__(self, backbone: nn.Module): super().__init__() self.backbone = backbone def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: out = self.backbone.forward_features(x) cls = out["x_norm_clstoken"] patches = out["x_norm_patchtokens"] B, N, D = patches.shape h = w = int(N ** 0.5) spatial = patches.permute(0, 2, 1).reshape(B, D, h, w) return cls, spatial class SegmentationHead(nn.Module): def __init__(self, in_dim: int = 768, num_classes: int = 150): super().__init__() self.batchnorm_layer = nn.BatchNorm2d(in_dim) self.conv = nn.Conv2d(in_dim, num_classes, kernel_size=1) def forward(self, x: Tensor) -> Tensor: return self.conv(self.batchnorm_layer(x)) class DepthHead(nn.Module): def __init__(self, in_dim: int = 768, n_bins: int = 256, min_depth: float = 0.001, max_depth: float = 10.0): super().__init__() self.batchnorm_layer = nn.BatchNorm2d(in_dim) self.conv_depth = nn.Conv2d(in_dim, n_bins, kernel_size=1) self.min_depth = min_depth self.max_depth = max_depth self.n_bins = n_bins def forward(self, x: Tensor) -> Tensor: logits = self.conv_depth(self.batchnorm_layer(x)) logit = torch.relu(logits) + 0.1 logit = logit / logit.sum(dim=1, keepdim=True) bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device) return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1) # =========================================================================== # Detection (FCOS with ViTDet-style simple feature pyramid) # =========================================================================== FPN_STRIDES = [8, 16, 32, 64, 128] COCO_CLASSES = [ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", ] class SimpleFeaturePyramid(nn.Module): """ViTDet-style simple FPN: a single stride-16 ViT feature map -> P3..P7.""" def __init__(self, in_channels: int = 768, fpn_channels: int = 256): super().__init__() self.fpn_channels = fpn_channels self.p3 = nn.Sequential( nn.ConvTranspose2d(in_channels, in_channels, 2, stride=2), nn.GroupNorm(32, in_channels), nn.GELU(), nn.Conv2d(in_channels, fpn_channels, 1), nn.GroupNorm(32, fpn_channels), nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), nn.GroupNorm(32, fpn_channels), ) self.p4 = nn.Sequential( nn.Conv2d(in_channels, fpn_channels, 1), nn.GroupNorm(32, fpn_channels), nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), nn.GroupNorm(32, fpn_channels), ) self.p5 = nn.Sequential( nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1), nn.GroupNorm(32, in_channels), nn.GELU(), nn.Conv2d(in_channels, fpn_channels, 1), nn.GroupNorm(32, fpn_channels), nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), nn.GroupNorm(32, fpn_channels), ) self.p6 = nn.Sequential( nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1), nn.GroupNorm(32, fpn_channels), ) self.p7 = nn.Sequential( nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1), nn.GroupNorm(32, fpn_channels), ) def forward(self, x: Tensor) -> List[Tensor]: p3 = self.p3(x) p4 = self.p4(x) p5 = self.p5(x) p6 = self.p6(p5) p7 = self.p7(p6) return [p3, p4, p5, p6, p7] class FCOSHead(nn.Module): """Shared classification / box regression / centerness towers across pyramid levels.""" def __init__(self, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4): super().__init__() self.num_classes = num_classes cls_tower, reg_tower = [], [] for _ in range(num_convs): cls_tower += [ nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), nn.GroupNorm(32, fpn_channels), nn.GELU(), ] reg_tower += [ nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), nn.GroupNorm(32, fpn_channels), nn.GELU(), ] self.cls_tower = nn.Sequential(*cls_tower) self.reg_tower = nn.Sequential(*reg_tower) self.cls_pred = nn.Conv2d(fpn_channels, num_classes, 3, padding=1) self.reg_pred = nn.Conv2d(fpn_channels, 4, 3, padding=1) self.center_pred = nn.Conv2d(fpn_channels, 1, 3, padding=1) self.scales = nn.Parameter(torch.ones(len(FPN_STRIDES))) prior = 0.01 nn.init.constant_(self.cls_pred.bias, -math.log((1 - prior) / prior)) nn.init.zeros_(self.reg_pred.weight) nn.init.zeros_(self.reg_pred.bias) nn.init.zeros_(self.center_pred.weight) nn.init.zeros_(self.center_pred.bias) def forward(self, fpn_features: List[Tensor]) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: cls_logits, box_regs, centernesses = [], [], [] for level_idx, feat in enumerate(fpn_features): cls_feat = self.cls_tower(feat) reg_feat = self.reg_tower(feat) cls_logits.append(self.cls_pred(cls_feat)) reg_raw = self.reg_pred(reg_feat) * self.scales[level_idx] reg_raw = reg_raw.clamp(min=-10.0, max=10.0) box_regs.append(torch.exp(reg_raw)) centernesses.append(self.center_pred(reg_feat)) return cls_logits, box_regs, centernesses class DetectionHead(nn.Module): """Combined SFP + FCOS head.""" def __init__(self, in_channels: int = 768, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4): super().__init__() self.fpn = SimpleFeaturePyramid(in_channels=in_channels, fpn_channels=fpn_channels) self.head = FCOSHead(fpn_channels=fpn_channels, num_classes=num_classes, num_convs=num_convs) self.num_classes = num_classes def forward(self, spatial_features: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: fpn = self.fpn(spatial_features) return self.head(fpn) def _make_locations(feature_sizes: List[Tuple[int, int]], strides: List[int], device) -> List[Tensor]: """Per-level center coordinates of feature-map locations in image space.""" all_locs = [] for (h, w), s in zip(feature_sizes, strides): ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij") locs = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1) all_locs.append(locs) return all_locs @torch.inference_mode() def _decode_detections( cls_logits_per_level: List[Tensor], box_regs_per_level: List[Tensor], centernesses_per_level: List[Tensor], locations_per_level: List[Tensor], image_sizes: List[Tuple[int, int]], score_thresh: float = 0.05, nms_thresh: float = 0.5, max_per_level: int = 1000, max_per_image: int = 100, ) -> List[Dict[str, Tensor]]: """Convert per-level logits/regs/centerness into per-image detections (xyxy boxes).""" B = cls_logits_per_level[0].shape[0] num_classes = cls_logits_per_level[0].shape[1] device = cls_logits_per_level[0].device per_image_results = [] for image_idx in range(B): all_boxes, all_scores, all_labels = [], [], [] for cls_l, reg_l, ctr_l, locs_l in zip( cls_logits_per_level, box_regs_per_level, centernesses_per_level, locations_per_level ): cls = cls_l[image_idx].permute(1, 2, 0).reshape(-1, num_classes) reg = reg_l[image_idx].permute(1, 2, 0).reshape(-1, 4) ctr = ctr_l[image_idx].permute(1, 2, 0).reshape(-1) cls_prob = torch.sigmoid(cls) ctr_prob = torch.sigmoid(ctr) scores = cls_prob * ctr_prob[:, None] mask = scores > score_thresh if not mask.any(): continue cand_loc, cand_cls = mask.nonzero(as_tuple=True) cand_scores = scores[cand_loc, cand_cls] if cand_scores.numel() > max_per_level: top = cand_scores.topk(max_per_level) cand_scores = top.values idx = top.indices cand_loc = cand_loc[idx] cand_cls = cand_cls[idx] cand_locs_xy = locs_l[cand_loc] cand_reg = reg[cand_loc] boxes = torch.stack([ cand_locs_xy[:, 0] - cand_reg[:, 0], cand_locs_xy[:, 1] - cand_reg[:, 1], cand_locs_xy[:, 0] + cand_reg[:, 2], cand_locs_xy[:, 1] + cand_reg[:, 3], ], dim=-1) all_boxes.append(boxes) all_scores.append(cand_scores) all_labels.append(cand_cls) if all_boxes: boxes = torch.cat(all_boxes, dim=0) scores = torch.cat(all_scores, dim=0) labels = torch.cat(all_labels, dim=0) H, W = image_sizes[image_idx] boxes[:, 0::2] = boxes[:, 0::2].clamp(0, W) boxes[:, 1::2] = boxes[:, 1::2].clamp(0, H) keep_all = [] for c in labels.unique(): cm = labels == c keep = nms(boxes[cm], scores[cm], nms_thresh) keep_idx = cm.nonzero(as_tuple=True)[0][keep] keep_all.append(keep_idx) keep_all = torch.cat(keep_all, dim=0) boxes = boxes[keep_all] scores = scores[keep_all] labels = labels[keep_all] if scores.numel() > max_per_image: top = scores.topk(max_per_image) boxes = boxes[top.indices] scores = top.values labels = labels[top.indices] else: boxes = torch.zeros((0, 4), device=device) scores = torch.zeros((0,), device=device) labels = torch.zeros((0,), dtype=torch.long, device=device) per_image_results.append({"boxes": boxes, "scores": scores, "labels": labels}) return per_image_results def _letterbox_to_square(image: Image.Image, resolution: int) -> Tuple[Image.Image, float, Tuple[int, int]]: """Resize preserving aspect ratio and pad bottom/right with black. Matches the training transform.""" W0, H0 = image.size scale = resolution / max(H0, W0) new_w = int(round(W0 * scale)) new_h = int(round(H0 * scale)) resized = image.resize((new_w, new_h), Image.BILINEAR) canvas = Image.new("RGB", (resolution, resolution), (0, 0, 0)) canvas.paste(resized, (0, 0)) return canvas, scale, (W0, H0) # =========================================================================== # Argus model (transformers-compatible) # =========================================================================== class ArgusConfig(PretrainedConfig): model_type = "argus" def __init__( self, embed_dim: int = 768, patch_size: int = 16, num_seg_classes: int = 150, depth_n_bins: int = 256, depth_min_depth: float = 0.001, depth_max_depth: float = 10.0, num_imagenet_classes: int = 1000, class_ids: Optional[list] = None, class_names: Optional[list] = None, detection_num_classes: int = 80, detection_fpn_channels: int = 256, detection_num_convs: int = 4, detection_class_names: Optional[list] = None, **kwargs, ): super().__init__(**kwargs) self.embed_dim = embed_dim self.patch_size = patch_size self.num_seg_classes = num_seg_classes self.depth_n_bins = depth_n_bins self.depth_min_depth = depth_min_depth self.depth_max_depth = depth_max_depth self.num_imagenet_classes = num_imagenet_classes self.class_ids = class_ids or [] self.class_names = class_names or [] self.detection_num_classes = detection_num_classes self.detection_fpn_channels = detection_fpn_channels self.detection_num_convs = detection_num_convs self.detection_class_names = detection_class_names or list(COCO_CLASSES) class Argus(PreTrainedModel): config_class = ArgusConfig base_model_prefix = "argus" supports_gradient_checkpointing = False _tied_weights_keys: list = [] all_tied_weights_keys: dict = {} def __init__(self, config: ArgusConfig): super().__init__(config) self.backbone = build_eupe_vitb16() self.seg_head = SegmentationHead(config.embed_dim, config.num_seg_classes) self.depth_head = DepthHead( in_dim=config.embed_dim, n_bins=config.depth_n_bins, min_depth=config.depth_min_depth, max_depth=config.depth_max_depth, ) self.register_buffer( "class_prototypes", torch.zeros(config.num_imagenet_classes, config.embed_dim), persistent=True, ) self.register_buffer( "class_logit_weight", torch.zeros(config.num_imagenet_classes, config.embed_dim), persistent=True, ) self.register_buffer( "class_logit_bias", torch.zeros(config.num_imagenet_classes), persistent=True, ) self.detection_head = DetectionHead( in_channels=config.embed_dim, fpn_channels=config.detection_fpn_channels, num_classes=config.detection_num_classes, num_convs=config.detection_num_convs, ) for p in self.backbone.parameters(): p.requires_grad = False self.backbone.eval() self.seg_head.eval() self.depth_head.eval() self.detection_head.eval() def _init_weights(self, module): # HF reallocates missing buffers and parameters with torch.empty() # (uninitialized memory) on from_pretrained. Populate sensible defaults # for the standard layer types used by the detection head, and zero any # Argus-level buffer that came back NaN. if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.GroupNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) if module is self: for name in ("class_prototypes", "class_logit_weight", "class_logit_bias"): if hasattr(self, name): buf = getattr(self, name) if torch.isnan(buf).any() or torch.isinf(buf).any(): buf.data.zero_() @property def class_ids(self): return self.config.class_ids @property def class_names(self): return self.config.class_names @torch.inference_mode() def _extract(self, image_tensor: Tensor) -> Tuple[Tensor, Tensor]: with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): out = self.backbone.forward_features(image_tensor) cls = out["x_norm_clstoken"].float() patches = out["x_norm_patchtokens"].float() B, N, D = patches.shape h = w = int(N ** 0.5) spatial = patches.permute(0, 2, 1).reshape(B, D, h, w) return cls, spatial @torch.inference_mode() def classify( self, image_or_images, top_k: int = 5, method: Literal["knn", "softmax"] = "knn", ): single, images = _normalize_image_input(image_or_images) transform = make_eupe_transform(224) batch = torch.stack([transform(img) for img in images]).to(self.device) cls, _ = self._extract(batch) cls = F.normalize(cls, dim=-1) if method == "knn": proto = self.class_prototypes.to(cls.dtype) scores_full = cls @ proto.T # cosine similarity in [-1, 1] elif method == "softmax": w = self.class_logit_weight.to(cls.dtype) b = self.class_logit_bias.to(cls.dtype) logits = F.linear(cls, w, b) scores_full = F.softmax(logits, dim=-1) # in [0, 1] else: raise ValueError(f"unknown classification method: {method!r} (expected 'knn' or 'softmax')") topk = scores_full.topk(top_k, dim=-1) top2 = scores_full.topk(2, dim=-1) margins = (top2.values[:, 0] - top2.values[:, 1]).tolist() results = [] for b in range(len(images)): entries = [] for score, idx in zip(topk.values[b].tolist(), topk.indices[b].tolist()): entries.append({ "class_id": self.class_ids[idx], "class_name": self.class_names[idx], "score": float(score), }) entries[0]["margin"] = float(margins[b]) results.append(entries) return results[0] if single else results @torch.inference_mode() def segment(self, image_or_images, resolution: int = 512, return_confidence: bool = False): single, images = _normalize_image_input(image_or_images) transform = make_eupe_transform(resolution) batch = torch.stack([transform(img) for img in images]).to(self.device) _, spatial = self._extract(batch) with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): logits = self.seg_head(spatial) logits = F.interpolate(logits, size=(resolution, resolution), mode="bilinear", align_corners=False) seg_maps = logits.argmax(dim=1) # [B, H, W] if return_confidence: probs = F.softmax(logits.float(), dim=1) conf_maps = probs.max(dim=1).values # [B, H, W] in [0, 1] if single: return seg_maps[0], conf_maps[0] return [(seg_maps[i], conf_maps[i]) for i in range(len(images))] if single: return seg_maps[0] return [seg_maps[i] for i in range(len(images))] @torch.inference_mode() def depth(self, image_or_images, resolution: int = 416, return_confidence: bool = False): single, images = _normalize_image_input(image_or_images) transform = make_eupe_transform(resolution) batch = torch.stack([transform(img) for img in images]).to(self.device) _, spatial = self._extract(batch) with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): if return_confidence: # inline depth head to expose the bin distribution normed = self.depth_head.batchnorm_layer(spatial) bin_logits = self.depth_head.conv_depth(normed) distribution = torch.relu(bin_logits) + 0.1 distribution = distribution / distribution.sum(dim=1, keepdim=True) bins = torch.linspace( self.depth_head.min_depth, self.depth_head.max_depth, self.depth_head.n_bins, device=spatial.device, ) depth_b = torch.einsum("bkhw,k->bhw", distribution, bins).unsqueeze(1) mean_sq = torch.einsum("bkhw,k->bhw", distribution, bins ** 2) variance = (mean_sq - depth_b.squeeze(1) ** 2).clamp(min=0) std_b = torch.sqrt(variance).unsqueeze(1) else: depth_b = self.depth_head(spatial) std_b = None depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False) if std_b is not None: std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False) depth_squeezed = depth_b[:, 0].float() # [B, H, W] if return_confidence: std_squeezed = std_b[:, 0].float() # [B, H, W] if single: return depth_squeezed[0], std_squeezed[0] return [(depth_squeezed[i], std_squeezed[i]) for i in range(len(images))] if single: return depth_squeezed[0] return [depth_squeezed[i] for i in range(len(images))] @torch.inference_mode() def correspond( self, src_image: Image.Image, tgt_image: Image.Image, src_keypoints: list, resolution: int = 512, ): sw, sh = src_image.size tw, th = tgt_image.size transform = make_eupe_transform(resolution) src_t = transform(src_image).unsqueeze(0).to(self.device) tgt_t = transform(tgt_image).unsqueeze(0).to(self.device) _, src_feats = self._extract(src_t) _, tgt_feats = self._extract(tgt_t) src_feats = F.interpolate(src_feats, size=(resolution, resolution), mode="bilinear", align_corners=False) tgt_feats = F.interpolate(tgt_feats, size=(resolution, resolution), mode="bilinear", align_corners=False) src_feats = F.normalize(src_feats[0].permute(1, 2, 0), dim=-1) tgt_feats = F.normalize(tgt_feats[0].permute(1, 2, 0), dim=-1) preds = [] for kp in src_keypoints: sx = min(max(int(kp[0] / sw * resolution), 0), resolution - 1) sy = min(max(int(kp[1] / sh * resolution), 0), resolution - 1) src_vec = src_feats[sy, sx] sim_map = torch.einsum("d,hwd->hw", src_vec, tgt_feats) flat = sim_map.argmax().item() py, px = flat // resolution, flat % resolution preds.append([px / resolution * tw, py / resolution * th]) return preds @torch.inference_mode() def detect( self, image_or_images, resolution: int = 640, score_thresh: float = 0.05, nms_thresh: float = 0.5, max_per_image: int = 100, ): single, images = _normalize_image_input(image_or_images) # Letterbox each image to match the training transform (resize long side # to `resolution`, pad bottom/right with black). Box coordinates are # recovered after decoding by unscaling. canvases, scales, orig_sizes = [], [], [] for img in images: canvas, scale, orig = _letterbox_to_square(img, resolution) canvases.append(canvas) scales.append(scale) orig_sizes.append(orig) det_normalize = v2.Compose([ v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ]) batch = torch.stack([det_normalize(c) for c in canvases]).to(self.device) _, spatial = self._extract(batch) with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): cls_logits, box_regs, centernesses = self.detection_head(spatial) cls_logits = [c.float() for c in cls_logits] box_regs = [b.float() for b in box_regs] centernesses = [c.float() for c in centernesses] feature_sizes = [(cl.shape[2], cl.shape[3]) for cl in cls_logits] locations = _make_locations(feature_sizes, FPN_STRIDES, spatial.device) image_sizes = [(resolution, resolution)] * len(images) results = _decode_detections( cls_logits, box_regs, centernesses, locations, image_sizes=image_sizes, score_thresh=score_thresh, nms_thresh=nms_thresh, max_per_image=max_per_image, ) class_names = self.config.detection_class_names formatted = [] for i, r in enumerate(results): scale = scales[i] orig_w, orig_h = orig_sizes[i] boxes = r["boxes"].cpu().numpy() / scale boxes[:, 0::2] = boxes[:, 0::2].clip(0, orig_w) boxes[:, 1::2] = boxes[:, 1::2].clip(0, orig_h) detections = [] for box, score, label in zip( boxes, r["scores"].cpu().numpy(), r["labels"].cpu().numpy() ): detections.append({ "box": [float(v) for v in box.tolist()], "score": float(score), "label": int(label), "class_name": class_names[int(label)] if int(label) < len(class_names) else f"class_{int(label)}", }) formatted.append(detections) return formatted[0] if single else formatted def perceive(self, image_or_images, return_confidence: bool = False): single, images = _normalize_image_input(image_or_images) t0 = time.time() classif = self.classify(images, top_k=5) t1 = time.time() seg_out = self.segment(images, resolution=512, return_confidence=return_confidence) t2 = time.time() depth_out = self.depth(images, resolution=416, return_confidence=return_confidence) t3 = time.time() if return_confidence: seg_maps = [s for s, _ in seg_out] seg_confs = [c for _, c in seg_out] depth_maps = [d for d, _ in depth_out] depth_uncerts = [u for _, u in depth_out] else: seg_maps = seg_out depth_maps = depth_out seg_confs = depth_uncerts = None timings = { "classify": (t1 - t0) * 1000, "segment": (t2 - t1) * 1000, "depth": (t3 - t2) * 1000, "total": (t3 - t0) * 1000, } results = [] for i in range(len(images)): entry = { "classification": classif[i], "segmentation": seg_maps[i].cpu().numpy(), "depth": depth_maps[i].cpu().numpy(), "timings_ms": timings, } if return_confidence: entry["segmentation_confidence"] = seg_confs[i].cpu().numpy() entry["depth_uncertainty"] = depth_uncerts[i].cpu().numpy() results.append(entry) return results[0] if single else results def export_onnx( self, out_dir: str, backbone_resolution: int = 224, dynamic_batch: bool = True, verify: bool = True, tolerance: float = 5e-2, opset_version: int = 17, ) -> dict: """Export backbone, seg head, and depth head to ONNX. kNN classification and correspondence run on top of the backbone output and need no separate graph.""" import os os.makedirs(out_dir, exist_ok=True) if backbone_resolution % self.config.patch_size != 0: raise ValueError( f"backbone_resolution ({backbone_resolution}) must be a multiple of patch_size ({self.config.patch_size})" ) spatial_resolution = backbone_resolution // self.config.patch_size wrapper = _BackboneExportWrapper(self.backbone).to(self.device).eval() dummy_image = torch.randn( 1, 3, backbone_resolution, backbone_resolution, device=self.device, dtype=torch.float32, ) dummy_spatial = torch.randn( 1, self.config.embed_dim, spatial_resolution, spatial_resolution, device=self.device, dtype=torch.float32, ) backbone_path = os.path.join(out_dir, "argus_backbone.onnx") seg_path = os.path.join(out_dir, "argus_seg_head.onnx") depth_path = os.path.join(out_dir, "argus_depth_head.onnx") backbone_axes = None head_axes = None if dynamic_batch: backbone_axes = { "image": {0: "batch"}, "cls_token": {0: "batch"}, "spatial_features": {0: "batch"}, } head_axes = { "spatial_features": {0: "batch"}, "seg_logits": {0: "batch"}, "depth_map": {0: "batch"}, } # dynamo path crashes on EUPE's list-based forward; use legacy. with torch.inference_mode(): torch.onnx.export( wrapper, dummy_image, backbone_path, input_names=["image"], output_names=["cls_token", "spatial_features"], dynamic_axes=backbone_axes, opset_version=opset_version, do_constant_folding=True, dynamo=False, ) torch.onnx.export( self.seg_head, dummy_spatial, seg_path, input_names=["spatial_features"], output_names=["seg_logits"], dynamic_axes={"spatial_features": head_axes["spatial_features"], "seg_logits": head_axes["seg_logits"]} if head_axes else None, opset_version=opset_version, do_constant_folding=True, dynamo=False, ) torch.onnx.export( self.depth_head, dummy_spatial, depth_path, input_names=["spatial_features"], output_names=["depth_map"], dynamic_axes={"spatial_features": head_axes["spatial_features"], "depth_map": head_axes["depth_map"]} if head_axes else None, opset_version=opset_version, do_constant_folding=True, dynamo=False, ) result = { "backbone": backbone_path, "seg_head": seg_path, "depth_head": depth_path, } if verify: try: import onnxruntime as ort except ImportError as e: raise ImportError("onnxruntime not installed; pip install onnxruntime") from e providers = ["CPUExecutionProvider"] verify_image = torch.randn(2, 3, backbone_resolution, backbone_resolution, dtype=torch.float32) verify_spatial = torch.randn(2, self.config.embed_dim, spatial_resolution, spatial_resolution, dtype=torch.float32) with torch.inference_mode(): ref_cls, ref_spatial = wrapper(verify_image.to(self.device)) ref_seg = self.seg_head(verify_spatial.to(self.device)) ref_depth = self.depth_head(verify_spatial.to(self.device)) sess = ort.InferenceSession(backbone_path, providers=providers) ort_cls, ort_spatial = sess.run(None, {"image": verify_image.numpy()}) cls_diff = float(np.abs(ort_cls - ref_cls.cpu().numpy()).max()) spatial_diff = float(np.abs(ort_spatial - ref_spatial.cpu().numpy()).max()) sess = ort.InferenceSession(seg_path, providers=providers) ort_seg = sess.run(None, {"spatial_features": verify_spatial.numpy()})[0] seg_diff = float(np.abs(ort_seg - ref_seg.cpu().numpy()).max()) sess = ort.InferenceSession(depth_path, providers=providers) ort_depth = sess.run(None, {"spatial_features": verify_spatial.numpy()})[0] depth_diff = float(np.abs(ort_depth - ref_depth.cpu().numpy()).max()) verification = { "backbone_cls_max_diff": cls_diff, "backbone_spatial_max_diff": spatial_diff, "seg_head_max_diff": seg_diff, "depth_head_max_diff": depth_diff, "tolerance": tolerance, "verified_batch_size": 2, } for key, val in verification.items(): if key.endswith("_max_diff") and val > tolerance: raise RuntimeError( f"ONNX/PyTorch divergence in {key}: {val:.2e} > tolerance {tolerance:.2e}" ) result["verification"] = verification return result