| """ |
| Argus: multi-task perception on a single EUPE-ViT-B backbone. |
| |
| from transformers import AutoModel |
| model = AutoModel.from_pretrained("phanerozoic/argus", trust_remote_code=True) |
| result = model.perceive(image) |
| |
| The EUPE-ViT-B backbone architecture, all supporting layers, and the Argus |
| task heads are inlined below. The backbone code is reproduced from |
| facebookresearch/EUPE (Meta FAIR) under the FAIR Research License. |
| """ |
|
|
| import math |
| import time |
| from functools import partial |
| from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torch.nn.init |
| from PIL import Image |
| from torch import Tensor, nn |
| from torchvision.ops import nms |
| from torchvision.transforms import v2 |
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]: |
| shapes = [x.shape for x in x_list] |
| num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list] |
| flattened = torch.cat([x.flatten(0, -2) for x in x_list]) |
| return flattened, shapes, num_tokens |
|
|
|
|
| def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]: |
| outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0) |
| shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes] |
| outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)] |
| return outputs_reshaped |
|
|
|
|
| def named_apply( |
| fn: Callable, |
| module: nn.Module, |
| name: str = "", |
| depth_first: bool = True, |
| include_root: bool = False, |
| ) -> nn.Module: |
| if not depth_first and include_root: |
| fn(module=module, name=name) |
| for child_name, child_module in module.named_children(): |
| child_name = ".".join((name, child_name)) if name else child_name |
| named_apply( |
| fn=fn, |
| module=child_module, |
| name=child_name, |
| depth_first=depth_first, |
| include_root=True, |
| ) |
| if depth_first and include_root: |
| fn(module=module, name=name) |
| return module |
|
|
|
|
| |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-5): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def reset_parameters(self) -> None: |
| nn.init.constant_(self.weight, 1) |
|
|
| def _norm(self, x: Tensor) -> Tensor: |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| output = self._norm(x.float()).type_as(x) |
| return output * self.weight |
|
|
|
|
| |
|
|
| class LayerScale(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| init_values: Union[float, Tensor] = 1e-5, |
| inplace: bool = False, |
| device=None, |
| ) -> None: |
| super().__init__() |
| self.inplace = inplace |
| self.gamma = nn.Parameter(torch.empty(dim, device=device)) |
| self.init_values = init_values |
|
|
| def reset_parameters(self): |
| nn.init.constant_(self.gamma, self.init_values) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
| |
|
|
| def make_2tuple(x): |
| if isinstance(x, tuple): |
| assert len(x) == 2 |
| return x |
| assert isinstance(x, int) |
| return (x, x) |
|
|
|
|
| class PatchEmbed(nn.Module): |
| def __init__( |
| self, |
| img_size: Union[int, Tuple[int, int]] = 224, |
| patch_size: Union[int, Tuple[int, int]] = 16, |
| in_chans: int = 3, |
| embed_dim: int = 768, |
| norm_layer: Optional[Callable] = None, |
| flatten_embedding: bool = True, |
| ) -> None: |
| super().__init__() |
| image_HW = make_2tuple(img_size) |
| patch_HW = make_2tuple(patch_size) |
| patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) |
|
|
| self.img_size = image_HW |
| self.patch_size = patch_HW |
| self.patches_resolution = patch_grid_size |
| self.num_patches = patch_grid_size[0] * patch_grid_size[1] |
| self.in_chans = in_chans |
| self.embed_dim = embed_dim |
| self.flatten_embedding = flatten_embedding |
|
|
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| _, _, H, W = x.shape |
| x = self.proj(x) |
| H, W = x.size(2), x.size(3) |
| x = x.flatten(2).transpose(1, 2) |
| x = self.norm(x) |
| if not self.flatten_embedding: |
| x = x.reshape(-1, H, W, self.embed_dim) |
| return x |
|
|
| def reset_parameters(self): |
| k = 1 / (self.in_chans * (self.patch_size[0] ** 2)) |
| nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k)) |
| if self.proj.bias is not None: |
| nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k)) |
|
|
|
|
| |
|
|
| class RopePositionEmbedding(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| *, |
| num_heads: int, |
| base: Optional[float] = 100.0, |
| min_period: Optional[float] = None, |
| max_period: Optional[float] = None, |
| normalize_coords: Literal["min", "max", "separate"] = "separate", |
| shift_coords: Optional[float] = None, |
| jitter_coords: Optional[float] = None, |
| rescale_coords: Optional[float] = None, |
| dtype: Optional[torch.dtype] = None, |
| device: Optional[torch.device] = None, |
| ): |
| super().__init__() |
| assert embed_dim % (4 * num_heads) == 0 |
| both_periods = min_period is not None and max_period is not None |
| if (base is None and not both_periods) or (base is not None and both_periods): |
| raise ValueError("Either `base` or `min_period`+`max_period` must be provided.") |
|
|
| D_head = embed_dim // num_heads |
| self.base = base |
| self.min_period = min_period |
| self.max_period = max_period |
| self.D_head = D_head |
| self.normalize_coords = normalize_coords |
| self.shift_coords = shift_coords |
| self.jitter_coords = jitter_coords |
| self.rescale_coords = rescale_coords |
|
|
| self.dtype = dtype |
| self.register_buffer( |
| "periods", |
| torch.empty(D_head // 4, device=device, dtype=dtype), |
| persistent=True, |
| ) |
| self._init_weights() |
|
|
| def forward(self, *, H: int, W: int) -> Tuple[Tensor, Tensor]: |
| device = self.periods.device |
| dtype = self.dtype |
| dd = {"device": device, "dtype": dtype} |
|
|
| if self.normalize_coords == "max": |
| max_HW = max(H, W) |
| coords_h = torch.arange(0.5, H, **dd) / max_HW |
| coords_w = torch.arange(0.5, W, **dd) / max_HW |
| elif self.normalize_coords == "min": |
| min_HW = min(H, W) |
| coords_h = torch.arange(0.5, H, **dd) / min_HW |
| coords_w = torch.arange(0.5, W, **dd) / min_HW |
| elif self.normalize_coords == "separate": |
| coords_h = torch.arange(0.5, H, **dd) / H |
| coords_w = torch.arange(0.5, W, **dd) / W |
| else: |
| raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") |
| coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) |
| coords = coords.flatten(0, 1) |
| coords = 2.0 * coords - 1.0 |
|
|
| if self.training and self.shift_coords is not None: |
| shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords) |
| coords += shift_hw[None, :] |
|
|
| if self.training and self.jitter_coords is not None: |
| jitter_max = np.log(self.jitter_coords) |
| jitter_min = -jitter_max |
| jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() |
| coords *= jitter_hw[None, :] |
|
|
| if self.training and self.rescale_coords is not None: |
| rescale_max = np.log(self.rescale_coords) |
| rescale_min = -rescale_max |
| rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() |
| coords *= rescale_hw |
|
|
| angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] |
| angles = angles.flatten(1, 2) |
| angles = angles.tile(2) |
| cos = torch.cos(angles) |
| sin = torch.sin(angles) |
| return (sin, cos) |
|
|
| def _init_weights(self): |
| device = self.periods.device |
| dtype = self.dtype |
| if self.base is not None: |
| periods = self.base ** ( |
| 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2) |
| ) |
| else: |
| base = self.max_period / self.min_period |
| exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) |
| periods = base ** exponents |
| periods = periods / base |
| periods = periods * self.max_period |
| self.periods.data = periods |
|
|
|
|
| |
|
|
| class ListForwardMixin(object): |
| def forward(self, x: Tensor): |
| raise NotImplementedError |
|
|
| def forward_list(self, x_list: List[Tensor]) -> List[Tensor]: |
| x_flat, shapes, num_tokens = cat_keep_shapes(x_list) |
| x_flat = self.forward(x_flat) |
| return uncat_with_shapes(x_flat, shapes, num_tokens) |
|
|
|
|
| class Mlp(nn.Module, ListForwardMixin): |
| def __init__( |
| self, |
| in_features: int, |
| hidden_features: Optional[int] = None, |
| out_features: Optional[int] = None, |
| act_layer: Callable[..., nn.Module] = nn.GELU, |
| drop: float = 0.0, |
| bias: bool = True, |
| device=None, |
| ) -> None: |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class SwiGLUFFN(nn.Module, ListForwardMixin): |
| def __init__( |
| self, |
| in_features: int, |
| hidden_features: Optional[int] = None, |
| out_features: Optional[int] = None, |
| act_layer: Optional[Callable[..., nn.Module]] = None, |
| drop: float = 0.0, |
| bias: bool = True, |
| align_to: int = 8, |
| device=None, |
| ) -> None: |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| d = int(hidden_features * 2 / 3) |
| swiglu_hidden_features = d + (-d % align_to) |
| self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) |
| self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) |
| self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x1 = self.w1(x) |
| x2 = self.w2(x) |
| hidden = F.silu(x1) * x2 |
| return self.w3(hidden) |
|
|
|
|
| |
|
|
| def rope_rotate_half(x: Tensor) -> Tensor: |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
| def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: |
| return (x * cos) + (rope_rotate_half(x) * sin) |
|
|
|
|
| class LinearKMaskedBias(nn.Linear): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| o = self.out_features |
| assert o % 3 == 0 |
| if self.bias is not None: |
| self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan)) |
|
|
| def forward(self, input: Tensor) -> Tensor: |
| masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None |
| return F.linear(input, self.weight, masked_bias) |
|
|
|
|
| class SelfAttention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 8, |
| qkv_bias: bool = False, |
| proj_bias: bool = True, |
| attn_drop: float = 0.0, |
| proj_drop: float = 0.0, |
| mask_k_bias: bool = False, |
| device=None, |
| ) -> None: |
| super().__init__() |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| self.scale = head_dim ** -0.5 |
|
|
| linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear |
| self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def apply_rope(self, q: Tensor, k: Tensor, rope) -> Tuple[Tensor, Tensor]: |
| q_dtype = q.dtype |
| k_dtype = k.dtype |
| sin, cos = rope |
| rope_dtype = sin.dtype |
| q = q.to(dtype=rope_dtype) |
| k = k.to(dtype=rope_dtype) |
| N = q.shape[-2] |
| prefix = N - sin.shape[-2] |
| assert prefix >= 0 |
| q_prefix = q[:, :, :prefix, :] |
| q = rope_apply(q[:, :, prefix:, :], sin, cos) |
| q = torch.cat((q_prefix, q), dim=-2) |
| k_prefix = k[:, :, :prefix, :] |
| k = rope_apply(k[:, :, prefix:, :], sin, cos) |
| k = torch.cat((k_prefix, k), dim=-2) |
| q = q.to(dtype=q_dtype) |
| k = k.to(dtype=k_dtype) |
| return q, k |
|
|
| def forward(self, x: Tensor, attn_bias=None, rope=None) -> Tensor: |
| qkv = self.qkv(x) |
| attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope) |
| x = self.proj(attn_v) |
| x = self.proj_drop(x) |
| return x |
|
|
| def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]: |
| assert len(x_list) == len(rope_list) |
| x_flat, shapes, num_tokens = cat_keep_shapes(x_list) |
| qkv_flat = self.qkv(x_flat) |
| qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens) |
| att_out = [] |
| for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)): |
| att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope)) |
| x_flat, shapes, num_tokens = cat_keep_shapes(att_out) |
| x_flat = self.proj(x_flat) |
| return uncat_with_shapes(x_flat, shapes, num_tokens) |
|
|
| def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor: |
| assert attn_bias is None |
| B, N, _ = qkv.shape |
| C = self.qkv.in_features |
| qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) |
| q, k, v = torch.unbind(qkv, 2) |
| q, k, v = [t.transpose(1, 2) for t in [q, k, v]] |
| if rope is not None: |
| q, k = self.apply_rope(q, k, rope) |
| x = torch.nn.functional.scaled_dot_product_attention(q, k, v) |
| x = x.transpose(1, 2) |
| return x.reshape([B, N, C]) |
|
|
|
|
| |
|
|
| class SelfAttentionBlock(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| ffn_ratio: float = 4.0, |
| qkv_bias: bool = False, |
| proj_bias: bool = True, |
| ffn_bias: bool = True, |
| drop: float = 0.0, |
| attn_drop: float = 0.0, |
| init_values=None, |
| drop_path: float = 0.0, |
| act_layer: Callable[..., nn.Module] = nn.GELU, |
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm, |
| attn_class: Callable[..., nn.Module] = SelfAttention, |
| ffn_layer: Callable[..., nn.Module] = Mlp, |
| mask_k_bias: bool = False, |
| device=None, |
| ) -> None: |
| super().__init__() |
| self.norm1 = norm_layer(dim) |
| self.attn = attn_class( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| attn_drop=attn_drop, |
| proj_drop=drop, |
| mask_k_bias=mask_k_bias, |
| device=device, |
| ) |
| self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() |
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * ffn_ratio) |
| self.mlp = ffn_layer( |
| in_features=dim, |
| hidden_features=mlp_hidden_dim, |
| act_layer=act_layer, |
| drop=drop, |
| bias=ffn_bias, |
| device=device, |
| ) |
| self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() |
| self.sample_drop_ratio = drop_path |
|
|
| @staticmethod |
| def _maybe_index_rope(rope, indices: Tensor): |
| if rope is None: |
| return None |
| sin, cos = rope |
| assert sin.ndim == cos.ndim |
| if sin.ndim == 4: |
| return sin[indices], cos[indices] |
| return sin, cos |
|
|
| def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]: |
| b_list = [x.shape[0] for x in x_list] |
| sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list] |
|
|
| if self.training and self.sample_drop_ratio > 0.0: |
| residual_scale_factors = [b / s for b, s in zip(b_list, sample_subset_sizes)] |
| indices_1_list = [ |
| torch.randperm(b, device=x.device)[:s] |
| for x, b, s in zip(x_list, b_list, sample_subset_sizes) |
| ] |
| x_subset_1_list = [x[i] for x, i in zip(x_list, indices_1_list)] |
| if rope_list is not None: |
| rope_subset_list = [ |
| self._maybe_index_rope(r, i) for r, i in zip(rope_list, indices_1_list) |
| ] |
| else: |
| rope_subset_list = rope_list |
|
|
| flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list) |
| norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens) |
| residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list) |
|
|
| x_attn_list = [ |
| torch.index_add(x, dim=0, source=self.ls1(r1), index=i1, alpha=rsf) |
| for x, r1, i1, rsf in zip(x_list, residual_1_list, indices_1_list, residual_scale_factors) |
| ] |
|
|
| indices_2_list = [ |
| torch.randperm(b, device=x.device)[:s] |
| for x, b, s in zip(x_list, b_list, sample_subset_sizes) |
| ] |
| x_subset_2_list = [x[i] for x, i in zip(x_attn_list, indices_2_list)] |
| flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list) |
| norm2_list = uncat_with_shapes(self.norm2(flattened), shapes, num_tokens) |
| residual_2_list = self.mlp.forward_list(norm2_list) |
|
|
| x_ffn = [ |
| torch.index_add(xa, dim=0, source=self.ls2(r2), index=i2, alpha=rsf) |
| for xa, r2, i2, rsf in zip(x_attn_list, residual_2_list, indices_2_list, residual_scale_factors) |
| ] |
| else: |
| x_out = [] |
| for x, rope in zip(x_list, rope_list): |
| x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) |
| x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) |
| x_out.append(x_ffn) |
| x_ffn = x_out |
| return x_ffn |
|
|
| def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]: |
| if isinstance(x_or_x_list, Tensor): |
| return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0] |
| elif isinstance(x_or_x_list, list): |
| if rope_or_rope_list is None: |
| rope_or_rope_list = [None for _ in x_or_x_list] |
| return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list) |
| raise AssertionError |
|
|
|
|
| |
|
|
| ffn_layer_dict = { |
| "mlp": Mlp, |
| "swiglu": SwiGLUFFN, |
| "swiglu32": partial(SwiGLUFFN, align_to=32), |
| "swiglu64": partial(SwiGLUFFN, align_to=64), |
| "swiglu128": partial(SwiGLUFFN, align_to=128), |
| } |
|
|
| norm_layer_dict = { |
| "layernorm": partial(nn.LayerNorm, eps=1e-6), |
| "layernormbf16": partial(nn.LayerNorm, eps=1e-5), |
| "rmsnorm": RMSNorm, |
| } |
|
|
| dtype_dict = { |
| "fp32": torch.float32, |
| "fp16": torch.float16, |
| "bf16": torch.bfloat16, |
| } |
|
|
|
|
| def init_weights_vit(module: nn.Module, name: str = ""): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.trunc_normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| if hasattr(module, "bias_mask") and module.bias_mask is not None: |
| o = module.out_features |
| module.bias_mask.fill_(1) |
| module.bias_mask[o // 3 : 2 * o // 3].fill_(0) |
| if isinstance(module, nn.LayerNorm): |
| module.reset_parameters() |
| if isinstance(module, LayerScale): |
| module.reset_parameters() |
| if isinstance(module, PatchEmbed): |
| module.reset_parameters() |
| if isinstance(module, RMSNorm): |
| module.reset_parameters() |
|
|
|
|
| class DinoVisionTransformer(nn.Module): |
| def __init__( |
| self, |
| *, |
| img_size: int = 224, |
| patch_size: int = 16, |
| in_chans: int = 3, |
| pos_embed_rope_base: float = 100.0, |
| pos_embed_rope_min_period: Optional[float] = None, |
| pos_embed_rope_max_period: Optional[float] = None, |
| pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", |
| pos_embed_rope_shift_coords: Optional[float] = None, |
| pos_embed_rope_jitter_coords: Optional[float] = None, |
| pos_embed_rope_rescale_coords: Optional[float] = None, |
| pos_embed_rope_dtype: str = "bf16", |
| embed_dim: int = 768, |
| depth: int = 12, |
| num_heads: int = 12, |
| ffn_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| drop_path_rate: float = 0.0, |
| layerscale_init: Optional[float] = None, |
| norm_layer: str = "layernorm", |
| ffn_layer: str = "mlp", |
| ffn_bias: bool = True, |
| proj_bias: bool = True, |
| n_storage_tokens: int = 0, |
| mask_k_bias: bool = False, |
| untie_cls_and_patch_norms: bool = False, |
| untie_global_and_local_cls_norm: bool = False, |
| device: Any = None, |
| **ignored_kwargs, |
| ): |
| super().__init__() |
| del ignored_kwargs |
|
|
| norm_layer_cls = norm_layer_dict[norm_layer] |
|
|
| self.num_features = self.embed_dim = embed_dim |
| self.n_blocks = depth |
| self.num_heads = num_heads |
| self.patch_size = patch_size |
|
|
| self.patch_embed = PatchEmbed( |
| img_size=img_size, |
| patch_size=patch_size, |
| in_chans=in_chans, |
| embed_dim=embed_dim, |
| flatten_embedding=False, |
| ) |
|
|
| self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device)) |
| self.n_storage_tokens = n_storage_tokens |
| if self.n_storage_tokens > 0: |
| self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device)) |
|
|
| self.rope_embed = RopePositionEmbedding( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| base=pos_embed_rope_base, |
| min_period=pos_embed_rope_min_period, |
| max_period=pos_embed_rope_max_period, |
| normalize_coords=pos_embed_rope_normalize_coords, |
| shift_coords=pos_embed_rope_shift_coords, |
| jitter_coords=pos_embed_rope_jitter_coords, |
| rescale_coords=pos_embed_rope_rescale_coords, |
| dtype=dtype_dict[pos_embed_rope_dtype], |
| device=device, |
| ) |
|
|
| ffn_layer_cls = ffn_layer_dict[ffn_layer] |
| ffn_ratio_sequence = [ffn_ratio] * depth |
| blocks_list = [ |
| SelfAttentionBlock( |
| dim=embed_dim, |
| num_heads=num_heads, |
| ffn_ratio=ffn_ratio_sequence[i], |
| qkv_bias=qkv_bias, |
| proj_bias=proj_bias, |
| ffn_bias=ffn_bias, |
| drop_path=drop_path_rate, |
| norm_layer=norm_layer_cls, |
| act_layer=nn.GELU, |
| ffn_layer=ffn_layer_cls, |
| init_values=layerscale_init, |
| mask_k_bias=mask_k_bias, |
| device=device, |
| ) |
| for i in range(depth) |
| ] |
|
|
| self.chunked_blocks = False |
| self.blocks = nn.ModuleList(blocks_list) |
| self.norm = norm_layer_cls(embed_dim) |
|
|
| self.untie_cls_and_patch_norms = untie_cls_and_patch_norms |
| self.cls_norm = norm_layer_cls(embed_dim) if untie_cls_and_patch_norms else None |
|
|
| self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm |
| self.local_cls_norm = norm_layer_cls(embed_dim) if untie_global_and_local_cls_norm else None |
|
|
| self.head = nn.Identity() |
| self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device)) |
|
|
| def init_weights(self): |
| self.rope_embed._init_weights() |
| nn.init.normal_(self.cls_token, std=0.02) |
| if self.n_storage_tokens > 0: |
| nn.init.normal_(self.storage_tokens, std=0.02) |
| nn.init.zeros_(self.mask_token) |
| named_apply(init_weights_vit, self) |
|
|
| def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int, int]]: |
| x = self.patch_embed(x) |
| B, H, W, _ = x.shape |
| x = x.flatten(1, 2) |
|
|
| if masks is not None: |
| x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) |
| cls_token = self.cls_token |
| else: |
| cls_token = self.cls_token + 0 * self.mask_token |
|
|
| if self.n_storage_tokens > 0: |
| storage_tokens = self.storage_tokens |
| else: |
| storage_tokens = torch.empty( |
| 1, 0, cls_token.shape[-1], |
| dtype=cls_token.dtype, device=cls_token.device, |
| ) |
|
|
| x = torch.cat( |
| [cls_token.expand(B, -1, -1), storage_tokens.expand(B, -1, -1), x], |
| dim=1, |
| ) |
| return x, (H, W) |
|
|
| def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]: |
| x = [] |
| rope = [] |
| for t_x, t_masks in zip(x_list, masks_list): |
| t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks) |
| x.append(t2_x) |
| rope.append(hw_tuple) |
| for blk in self.blocks: |
| if self.rope_embed is not None: |
| rope_sincos = [self.rope_embed(H=H, W=W) for H, W in rope] |
| else: |
| rope_sincos = [None for _ in rope] |
| x = blk(x, rope_sincos) |
| all_x = x |
| output = [] |
| for idx, (x, masks) in enumerate(zip(all_x, masks_list)): |
| if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm: |
| if self.untie_global_and_local_cls_norm and self.training and idx == 1: |
| x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1]) |
| elif self.untie_cls_and_patch_norms: |
| x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1]) |
| else: |
| x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1]) |
| x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :]) |
| else: |
| x_norm = self.norm(x) |
| x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1] |
| x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :] |
| output.append({ |
| "x_norm_clstoken": x_norm_cls_reg[:, 0], |
| "x_storage_tokens": x_norm_cls_reg[:, 1:], |
| "x_norm_patchtokens": x_norm_patch, |
| "x_prenorm": x, |
| "masks": masks, |
| }) |
| return output |
|
|
| def forward_features(self, x, masks: Optional[Tensor] = None): |
| if isinstance(x, torch.Tensor): |
| return self.forward_features_list([x], [masks])[0] |
| return self.forward_features_list(x, masks) |
|
|
| def forward(self, *args, is_training: bool = False, **kwargs): |
| ret = self.forward_features(*args, **kwargs) |
| if is_training: |
| return ret |
| return self.head(ret["x_norm_clstoken"]) |
|
|
|
|
| def build_eupe_vitb16() -> DinoVisionTransformer: |
| return DinoVisionTransformer( |
| img_size=224, |
| patch_size=16, |
| in_chans=3, |
| pos_embed_rope_base=100, |
| pos_embed_rope_normalize_coords="separate", |
| pos_embed_rope_rescale_coords=2, |
| pos_embed_rope_dtype="fp32", |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| ffn_ratio=4, |
| qkv_bias=True, |
| drop_path_rate=0.0, |
| layerscale_init=1.0e-05, |
| norm_layer="layernormbf16", |
| ffn_layer="mlp", |
| ffn_bias=True, |
| proj_bias=True, |
| n_storage_tokens=4, |
| mask_k_bias=True, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def make_eupe_transform(resize_size: int): |
| return v2.Compose([ |
| v2.ToImage(), |
| v2.Resize((resize_size, resize_size), antialias=True), |
| v2.ToDtype(torch.float32, scale=True), |
| v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ]) |
|
|
|
|
| def _normalize_image_input(image_or_images) -> Tuple[bool, list]: |
| """Returns (was_single, [images]). Accepts a PIL.Image or an iterable of them.""" |
| if isinstance(image_or_images, Image.Image): |
| return True, [image_or_images] |
| images = list(image_or_images) |
| if not images: |
| raise ValueError("empty image list") |
| for i, img in enumerate(images): |
| if not isinstance(img, Image.Image): |
| raise TypeError(f"images[{i}] is {type(img).__name__}, expected PIL.Image") |
| return False, images |
|
|
|
|
| class _BackboneExportWrapper(nn.Module): |
| """ONNX-friendly wrapper: returns (cls, spatial) instead of a dict.""" |
|
|
| def __init__(self, backbone: nn.Module): |
| super().__init__() |
| self.backbone = backbone |
|
|
| def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: |
| out = self.backbone.forward_features(x) |
| cls = out["x_norm_clstoken"] |
| patches = out["x_norm_patchtokens"] |
| B, N, D = patches.shape |
| h = w = int(N ** 0.5) |
| spatial = patches.permute(0, 2, 1).reshape(B, D, h, w) |
| return cls, spatial |
|
|
|
|
| class SegmentationHead(nn.Module): |
| def __init__(self, in_dim: int = 768, num_classes: int = 150): |
| super().__init__() |
| self.batchnorm_layer = nn.BatchNorm2d(in_dim) |
| self.conv = nn.Conv2d(in_dim, num_classes, kernel_size=1) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.conv(self.batchnorm_layer(x)) |
|
|
|
|
| class DepthHead(nn.Module): |
| def __init__(self, in_dim: int = 768, n_bins: int = 256, |
| min_depth: float = 0.001, max_depth: float = 10.0): |
| super().__init__() |
| self.batchnorm_layer = nn.BatchNorm2d(in_dim) |
| self.conv_depth = nn.Conv2d(in_dim, n_bins, kernel_size=1) |
| self.min_depth = min_depth |
| self.max_depth = max_depth |
| self.n_bins = n_bins |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| logits = self.conv_depth(self.batchnorm_layer(x)) |
| logit = torch.relu(logits) + 0.1 |
| logit = logit / logit.sum(dim=1, keepdim=True) |
| bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device) |
| return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1) |
|
|
|
|
| |
| |
| |
|
|
| FPN_STRIDES = [8, 16, 32, 64, 128] |
|
|
| COCO_CLASSES = [ |
| "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", |
| "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", |
| "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", |
| "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", |
| "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", |
| "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", |
| "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", |
| "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", |
| "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", |
| "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", |
| "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", |
| "toothbrush", |
| ] |
|
|
|
|
| class SimpleFeaturePyramid(nn.Module): |
| """ViTDet-style simple FPN: a single stride-16 ViT feature map -> P3..P7.""" |
|
|
| def __init__(self, in_channels: int = 768, fpn_channels: int = 256): |
| super().__init__() |
| self.fpn_channels = fpn_channels |
| self.p3 = nn.Sequential( |
| nn.ConvTranspose2d(in_channels, in_channels, 2, stride=2), |
| nn.GroupNorm(32, in_channels), |
| nn.GELU(), |
| nn.Conv2d(in_channels, fpn_channels, 1), |
| nn.GroupNorm(32, fpn_channels), |
| nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), |
| nn.GroupNorm(32, fpn_channels), |
| ) |
| self.p4 = nn.Sequential( |
| nn.Conv2d(in_channels, fpn_channels, 1), |
| nn.GroupNorm(32, fpn_channels), |
| nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), |
| nn.GroupNorm(32, fpn_channels), |
| ) |
| self.p5 = nn.Sequential( |
| nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1), |
| nn.GroupNorm(32, in_channels), |
| nn.GELU(), |
| nn.Conv2d(in_channels, fpn_channels, 1), |
| nn.GroupNorm(32, fpn_channels), |
| nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), |
| nn.GroupNorm(32, fpn_channels), |
| ) |
| self.p6 = nn.Sequential( |
| nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1), |
| nn.GroupNorm(32, fpn_channels), |
| ) |
| self.p7 = nn.Sequential( |
| nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1), |
| nn.GroupNorm(32, fpn_channels), |
| ) |
|
|
| def forward(self, x: Tensor) -> List[Tensor]: |
| p3 = self.p3(x) |
| p4 = self.p4(x) |
| p5 = self.p5(x) |
| p6 = self.p6(p5) |
| p7 = self.p7(p6) |
| return [p3, p4, p5, p6, p7] |
|
|
|
|
| class FCOSHead(nn.Module): |
| """Shared classification / box regression / centerness towers across pyramid levels.""" |
|
|
| def __init__(self, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4): |
| super().__init__() |
| self.num_classes = num_classes |
|
|
| cls_tower, reg_tower = [], [] |
| for _ in range(num_convs): |
| cls_tower += [ |
| nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), |
| nn.GroupNorm(32, fpn_channels), |
| nn.GELU(), |
| ] |
| reg_tower += [ |
| nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1), |
| nn.GroupNorm(32, fpn_channels), |
| nn.GELU(), |
| ] |
| self.cls_tower = nn.Sequential(*cls_tower) |
| self.reg_tower = nn.Sequential(*reg_tower) |
|
|
| self.cls_pred = nn.Conv2d(fpn_channels, num_classes, 3, padding=1) |
| self.reg_pred = nn.Conv2d(fpn_channels, 4, 3, padding=1) |
| self.center_pred = nn.Conv2d(fpn_channels, 1, 3, padding=1) |
|
|
| self.scales = nn.Parameter(torch.ones(len(FPN_STRIDES))) |
|
|
| prior = 0.01 |
| nn.init.constant_(self.cls_pred.bias, -math.log((1 - prior) / prior)) |
| nn.init.zeros_(self.reg_pred.weight) |
| nn.init.zeros_(self.reg_pred.bias) |
| nn.init.zeros_(self.center_pred.weight) |
| nn.init.zeros_(self.center_pred.bias) |
|
|
| def forward(self, fpn_features: List[Tensor]) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: |
| cls_logits, box_regs, centernesses = [], [], [] |
| for level_idx, feat in enumerate(fpn_features): |
| cls_feat = self.cls_tower(feat) |
| reg_feat = self.reg_tower(feat) |
| cls_logits.append(self.cls_pred(cls_feat)) |
| reg_raw = self.reg_pred(reg_feat) * self.scales[level_idx] |
| reg_raw = reg_raw.clamp(min=-10.0, max=10.0) |
| box_regs.append(torch.exp(reg_raw)) |
| centernesses.append(self.center_pred(reg_feat)) |
| return cls_logits, box_regs, centernesses |
|
|
|
|
| class DetectionHead(nn.Module): |
| """Combined SFP + FCOS head.""" |
|
|
| def __init__(self, in_channels: int = 768, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4): |
| super().__init__() |
| self.fpn = SimpleFeaturePyramid(in_channels=in_channels, fpn_channels=fpn_channels) |
| self.head = FCOSHead(fpn_channels=fpn_channels, num_classes=num_classes, num_convs=num_convs) |
| self.num_classes = num_classes |
|
|
| def forward(self, spatial_features: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: |
| fpn = self.fpn(spatial_features) |
| return self.head(fpn) |
|
|
|
|
| def _make_locations(feature_sizes: List[Tuple[int, int]], strides: List[int], device) -> List[Tensor]: |
| """Per-level center coordinates of feature-map locations in image space.""" |
| all_locs = [] |
| for (h, w), s in zip(feature_sizes, strides): |
| ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s |
| xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s |
| grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij") |
| locs = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1) |
| all_locs.append(locs) |
| return all_locs |
|
|
|
|
| @torch.inference_mode() |
| def _decode_detections( |
| cls_logits_per_level: List[Tensor], |
| box_regs_per_level: List[Tensor], |
| centernesses_per_level: List[Tensor], |
| locations_per_level: List[Tensor], |
| image_sizes: List[Tuple[int, int]], |
| score_thresh: float = 0.05, |
| nms_thresh: float = 0.5, |
| max_per_level: int = 1000, |
| max_per_image: int = 100, |
| ) -> List[Dict[str, Tensor]]: |
| """Convert per-level logits/regs/centerness into per-image detections (xyxy boxes).""" |
| B = cls_logits_per_level[0].shape[0] |
| num_classes = cls_logits_per_level[0].shape[1] |
| device = cls_logits_per_level[0].device |
|
|
| per_image_results = [] |
| for image_idx in range(B): |
| all_boxes, all_scores, all_labels = [], [], [] |
| for cls_l, reg_l, ctr_l, locs_l in zip( |
| cls_logits_per_level, box_regs_per_level, centernesses_per_level, locations_per_level |
| ): |
| cls = cls_l[image_idx].permute(1, 2, 0).reshape(-1, num_classes) |
| reg = reg_l[image_idx].permute(1, 2, 0).reshape(-1, 4) |
| ctr = ctr_l[image_idx].permute(1, 2, 0).reshape(-1) |
|
|
| cls_prob = torch.sigmoid(cls) |
| ctr_prob = torch.sigmoid(ctr) |
| scores = cls_prob * ctr_prob[:, None] |
|
|
| mask = scores > score_thresh |
| if not mask.any(): |
| continue |
| cand_loc, cand_cls = mask.nonzero(as_tuple=True) |
| cand_scores = scores[cand_loc, cand_cls] |
|
|
| if cand_scores.numel() > max_per_level: |
| top = cand_scores.topk(max_per_level) |
| cand_scores = top.values |
| idx = top.indices |
| cand_loc = cand_loc[idx] |
| cand_cls = cand_cls[idx] |
|
|
| cand_locs_xy = locs_l[cand_loc] |
| cand_reg = reg[cand_loc] |
| boxes = torch.stack([ |
| cand_locs_xy[:, 0] - cand_reg[:, 0], |
| cand_locs_xy[:, 1] - cand_reg[:, 1], |
| cand_locs_xy[:, 0] + cand_reg[:, 2], |
| cand_locs_xy[:, 1] + cand_reg[:, 3], |
| ], dim=-1) |
| all_boxes.append(boxes) |
| all_scores.append(cand_scores) |
| all_labels.append(cand_cls) |
|
|
| if all_boxes: |
| boxes = torch.cat(all_boxes, dim=0) |
| scores = torch.cat(all_scores, dim=0) |
| labels = torch.cat(all_labels, dim=0) |
|
|
| H, W = image_sizes[image_idx] |
| boxes[:, 0::2] = boxes[:, 0::2].clamp(0, W) |
| boxes[:, 1::2] = boxes[:, 1::2].clamp(0, H) |
|
|
| keep_all = [] |
| for c in labels.unique(): |
| cm = labels == c |
| keep = nms(boxes[cm], scores[cm], nms_thresh) |
| keep_idx = cm.nonzero(as_tuple=True)[0][keep] |
| keep_all.append(keep_idx) |
| keep_all = torch.cat(keep_all, dim=0) |
|
|
| boxes = boxes[keep_all] |
| scores = scores[keep_all] |
| labels = labels[keep_all] |
|
|
| if scores.numel() > max_per_image: |
| top = scores.topk(max_per_image) |
| boxes = boxes[top.indices] |
| scores = top.values |
| labels = labels[top.indices] |
| else: |
| boxes = torch.zeros((0, 4), device=device) |
| scores = torch.zeros((0,), device=device) |
| labels = torch.zeros((0,), dtype=torch.long, device=device) |
|
|
| per_image_results.append({"boxes": boxes, "scores": scores, "labels": labels}) |
|
|
| return per_image_results |
|
|
|
|
| def _letterbox_to_square(image: Image.Image, resolution: int) -> Tuple[Image.Image, float, Tuple[int, int]]: |
| """Resize preserving aspect ratio and pad bottom/right with black. Matches the training transform.""" |
| W0, H0 = image.size |
| scale = resolution / max(H0, W0) |
| new_w = int(round(W0 * scale)) |
| new_h = int(round(H0 * scale)) |
| resized = image.resize((new_w, new_h), Image.BILINEAR) |
| canvas = Image.new("RGB", (resolution, resolution), (0, 0, 0)) |
| canvas.paste(resized, (0, 0)) |
| return canvas, scale, (W0, H0) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ArgusConfig(PretrainedConfig): |
| model_type = "argus" |
|
|
| def __init__( |
| self, |
| embed_dim: int = 768, |
| patch_size: int = 16, |
| num_seg_classes: int = 150, |
| depth_n_bins: int = 256, |
| depth_min_depth: float = 0.001, |
| depth_max_depth: float = 10.0, |
| num_imagenet_classes: int = 1000, |
| class_ids: Optional[list] = None, |
| class_names: Optional[list] = None, |
| detection_num_classes: int = 80, |
| detection_fpn_channels: int = 256, |
| detection_num_convs: int = 4, |
| detection_class_names: Optional[list] = None, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.embed_dim = embed_dim |
| self.patch_size = patch_size |
| self.num_seg_classes = num_seg_classes |
| self.depth_n_bins = depth_n_bins |
| self.depth_min_depth = depth_min_depth |
| self.depth_max_depth = depth_max_depth |
| self.num_imagenet_classes = num_imagenet_classes |
| self.class_ids = class_ids or [] |
| self.class_names = class_names or [] |
| self.detection_num_classes = detection_num_classes |
| self.detection_fpn_channels = detection_fpn_channels |
| self.detection_num_convs = detection_num_convs |
| self.detection_class_names = detection_class_names or list(COCO_CLASSES) |
|
|
|
|
| class Argus(PreTrainedModel): |
| config_class = ArgusConfig |
| base_model_prefix = "argus" |
| supports_gradient_checkpointing = False |
| _tied_weights_keys: list = [] |
| all_tied_weights_keys: dict = {} |
|
|
| def __init__(self, config: ArgusConfig): |
| super().__init__(config) |
| self.backbone = build_eupe_vitb16() |
| self.seg_head = SegmentationHead(config.embed_dim, config.num_seg_classes) |
| self.depth_head = DepthHead( |
| in_dim=config.embed_dim, |
| n_bins=config.depth_n_bins, |
| min_depth=config.depth_min_depth, |
| max_depth=config.depth_max_depth, |
| ) |
| self.register_buffer( |
| "class_prototypes", |
| torch.zeros(config.num_imagenet_classes, config.embed_dim), |
| persistent=True, |
| ) |
| self.register_buffer( |
| "class_logit_weight", |
| torch.zeros(config.num_imagenet_classes, config.embed_dim), |
| persistent=True, |
| ) |
| self.register_buffer( |
| "class_logit_bias", |
| torch.zeros(config.num_imagenet_classes), |
| persistent=True, |
| ) |
| self.detection_head = DetectionHead( |
| in_channels=config.embed_dim, |
| fpn_channels=config.detection_fpn_channels, |
| num_classes=config.detection_num_classes, |
| num_convs=config.detection_num_convs, |
| ) |
|
|
| for p in self.backbone.parameters(): |
| p.requires_grad = False |
| self.backbone.eval() |
| self.seg_head.eval() |
| self.depth_head.eval() |
| self.detection_head.eval() |
|
|
| def _init_weights(self, module): |
| |
| |
| |
| |
| if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): |
| nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.GroupNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
|
|
| if module is self: |
| for name in ("class_prototypes", "class_logit_weight", "class_logit_bias"): |
| if hasattr(self, name): |
| buf = getattr(self, name) |
| if torch.isnan(buf).any() or torch.isinf(buf).any(): |
| buf.data.zero_() |
|
|
| @property |
| def class_ids(self): |
| return self.config.class_ids |
|
|
| @property |
| def class_names(self): |
| return self.config.class_names |
|
|
| @torch.inference_mode() |
| def _extract(self, image_tensor: Tensor) -> Tuple[Tensor, Tensor]: |
| with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): |
| out = self.backbone.forward_features(image_tensor) |
| cls = out["x_norm_clstoken"].float() |
| patches = out["x_norm_patchtokens"].float() |
| B, N, D = patches.shape |
| h = w = int(N ** 0.5) |
| spatial = patches.permute(0, 2, 1).reshape(B, D, h, w) |
| return cls, spatial |
|
|
| @torch.inference_mode() |
| def classify( |
| self, |
| image_or_images, |
| top_k: int = 5, |
| method: Literal["knn", "softmax"] = "knn", |
| ): |
| single, images = _normalize_image_input(image_or_images) |
| transform = make_eupe_transform(224) |
| batch = torch.stack([transform(img) for img in images]).to(self.device) |
| cls, _ = self._extract(batch) |
| cls = F.normalize(cls, dim=-1) |
|
|
| if method == "knn": |
| proto = self.class_prototypes.to(cls.dtype) |
| scores_full = cls @ proto.T |
| elif method == "softmax": |
| w = self.class_logit_weight.to(cls.dtype) |
| b = self.class_logit_bias.to(cls.dtype) |
| logits = F.linear(cls, w, b) |
| scores_full = F.softmax(logits, dim=-1) |
| else: |
| raise ValueError(f"unknown classification method: {method!r} (expected 'knn' or 'softmax')") |
|
|
| topk = scores_full.topk(top_k, dim=-1) |
| top2 = scores_full.topk(2, dim=-1) |
| margins = (top2.values[:, 0] - top2.values[:, 1]).tolist() |
|
|
| results = [] |
| for b in range(len(images)): |
| entries = [] |
| for score, idx in zip(topk.values[b].tolist(), topk.indices[b].tolist()): |
| entries.append({ |
| "class_id": self.class_ids[idx], |
| "class_name": self.class_names[idx], |
| "score": float(score), |
| }) |
| entries[0]["margin"] = float(margins[b]) |
| results.append(entries) |
| return results[0] if single else results |
|
|
| @torch.inference_mode() |
| def segment(self, image_or_images, resolution: int = 512, return_confidence: bool = False): |
| single, images = _normalize_image_input(image_or_images) |
| transform = make_eupe_transform(resolution) |
| batch = torch.stack([transform(img) for img in images]).to(self.device) |
| _, spatial = self._extract(batch) |
| with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): |
| logits = self.seg_head(spatial) |
| logits = F.interpolate(logits, size=(resolution, resolution), mode="bilinear", align_corners=False) |
| seg_maps = logits.argmax(dim=1) |
|
|
| if return_confidence: |
| probs = F.softmax(logits.float(), dim=1) |
| conf_maps = probs.max(dim=1).values |
| if single: |
| return seg_maps[0], conf_maps[0] |
| return [(seg_maps[i], conf_maps[i]) for i in range(len(images))] |
|
|
| if single: |
| return seg_maps[0] |
| return [seg_maps[i] for i in range(len(images))] |
|
|
| @torch.inference_mode() |
| def depth(self, image_or_images, resolution: int = 416, return_confidence: bool = False): |
| single, images = _normalize_image_input(image_or_images) |
| transform = make_eupe_transform(resolution) |
| batch = torch.stack([transform(img) for img in images]).to(self.device) |
| _, spatial = self._extract(batch) |
|
|
| with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): |
| if return_confidence: |
| |
| normed = self.depth_head.batchnorm_layer(spatial) |
| bin_logits = self.depth_head.conv_depth(normed) |
| distribution = torch.relu(bin_logits) + 0.1 |
| distribution = distribution / distribution.sum(dim=1, keepdim=True) |
| bins = torch.linspace( |
| self.depth_head.min_depth, self.depth_head.max_depth, |
| self.depth_head.n_bins, device=spatial.device, |
| ) |
| depth_b = torch.einsum("bkhw,k->bhw", distribution, bins).unsqueeze(1) |
| mean_sq = torch.einsum("bkhw,k->bhw", distribution, bins ** 2) |
| variance = (mean_sq - depth_b.squeeze(1) ** 2).clamp(min=0) |
| std_b = torch.sqrt(variance).unsqueeze(1) |
| else: |
| depth_b = self.depth_head(spatial) |
| std_b = None |
|
|
| depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False) |
| if std_b is not None: |
| std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False) |
|
|
| depth_squeezed = depth_b[:, 0].float() |
|
|
| if return_confidence: |
| std_squeezed = std_b[:, 0].float() |
| if single: |
| return depth_squeezed[0], std_squeezed[0] |
| return [(depth_squeezed[i], std_squeezed[i]) for i in range(len(images))] |
|
|
| if single: |
| return depth_squeezed[0] |
| return [depth_squeezed[i] for i in range(len(images))] |
|
|
| @torch.inference_mode() |
| def correspond( |
| self, |
| src_image: Image.Image, |
| tgt_image: Image.Image, |
| src_keypoints: list, |
| resolution: int = 512, |
| ): |
| sw, sh = src_image.size |
| tw, th = tgt_image.size |
| transform = make_eupe_transform(resolution) |
| src_t = transform(src_image).unsqueeze(0).to(self.device) |
| tgt_t = transform(tgt_image).unsqueeze(0).to(self.device) |
|
|
| _, src_feats = self._extract(src_t) |
| _, tgt_feats = self._extract(tgt_t) |
|
|
| src_feats = F.interpolate(src_feats, size=(resolution, resolution), mode="bilinear", align_corners=False) |
| tgt_feats = F.interpolate(tgt_feats, size=(resolution, resolution), mode="bilinear", align_corners=False) |
|
|
| src_feats = F.normalize(src_feats[0].permute(1, 2, 0), dim=-1) |
| tgt_feats = F.normalize(tgt_feats[0].permute(1, 2, 0), dim=-1) |
|
|
| preds = [] |
| for kp in src_keypoints: |
| sx = min(max(int(kp[0] / sw * resolution), 0), resolution - 1) |
| sy = min(max(int(kp[1] / sh * resolution), 0), resolution - 1) |
| src_vec = src_feats[sy, sx] |
| sim_map = torch.einsum("d,hwd->hw", src_vec, tgt_feats) |
| flat = sim_map.argmax().item() |
| py, px = flat // resolution, flat % resolution |
| preds.append([px / resolution * tw, py / resolution * th]) |
| return preds |
|
|
| @torch.inference_mode() |
| def detect( |
| self, |
| image_or_images, |
| resolution: int = 640, |
| score_thresh: float = 0.05, |
| nms_thresh: float = 0.5, |
| max_per_image: int = 100, |
| ): |
| single, images = _normalize_image_input(image_or_images) |
|
|
| |
| |
| |
| canvases, scales, orig_sizes = [], [], [] |
| for img in images: |
| canvas, scale, orig = _letterbox_to_square(img, resolution) |
| canvases.append(canvas) |
| scales.append(scale) |
| orig_sizes.append(orig) |
|
|
| det_normalize = v2.Compose([ |
| v2.ToImage(), |
| v2.ToDtype(torch.float32, scale=True), |
| v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ]) |
| batch = torch.stack([det_normalize(c) for c in canvases]).to(self.device) |
|
|
| _, spatial = self._extract(batch) |
| with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): |
| cls_logits, box_regs, centernesses = self.detection_head(spatial) |
| cls_logits = [c.float() for c in cls_logits] |
| box_regs = [b.float() for b in box_regs] |
| centernesses = [c.float() for c in centernesses] |
|
|
| feature_sizes = [(cl.shape[2], cl.shape[3]) for cl in cls_logits] |
| locations = _make_locations(feature_sizes, FPN_STRIDES, spatial.device) |
| image_sizes = [(resolution, resolution)] * len(images) |
|
|
| results = _decode_detections( |
| cls_logits, box_regs, centernesses, locations, |
| image_sizes=image_sizes, |
| score_thresh=score_thresh, |
| nms_thresh=nms_thresh, |
| max_per_image=max_per_image, |
| ) |
|
|
| class_names = self.config.detection_class_names |
| formatted = [] |
| for i, r in enumerate(results): |
| scale = scales[i] |
| orig_w, orig_h = orig_sizes[i] |
| boxes = r["boxes"].cpu().numpy() / scale |
| boxes[:, 0::2] = boxes[:, 0::2].clip(0, orig_w) |
| boxes[:, 1::2] = boxes[:, 1::2].clip(0, orig_h) |
|
|
| detections = [] |
| for box, score, label in zip( |
| boxes, r["scores"].cpu().numpy(), r["labels"].cpu().numpy() |
| ): |
| detections.append({ |
| "box": [float(v) for v in box.tolist()], |
| "score": float(score), |
| "label": int(label), |
| "class_name": class_names[int(label)] if int(label) < len(class_names) else f"class_{int(label)}", |
| }) |
| formatted.append(detections) |
|
|
| return formatted[0] if single else formatted |
|
|
| def perceive(self, image_or_images, return_confidence: bool = False): |
| single, images = _normalize_image_input(image_or_images) |
|
|
| t0 = time.time() |
| classif = self.classify(images, top_k=5) |
| t1 = time.time() |
| seg_out = self.segment(images, resolution=512, return_confidence=return_confidence) |
| t2 = time.time() |
| depth_out = self.depth(images, resolution=416, return_confidence=return_confidence) |
| t3 = time.time() |
|
|
| if return_confidence: |
| seg_maps = [s for s, _ in seg_out] |
| seg_confs = [c for _, c in seg_out] |
| depth_maps = [d for d, _ in depth_out] |
| depth_uncerts = [u for _, u in depth_out] |
| else: |
| seg_maps = seg_out |
| depth_maps = depth_out |
| seg_confs = depth_uncerts = None |
|
|
| timings = { |
| "classify": (t1 - t0) * 1000, |
| "segment": (t2 - t1) * 1000, |
| "depth": (t3 - t2) * 1000, |
| "total": (t3 - t0) * 1000, |
| } |
|
|
| results = [] |
| for i in range(len(images)): |
| entry = { |
| "classification": classif[i], |
| "segmentation": seg_maps[i].cpu().numpy(), |
| "depth": depth_maps[i].cpu().numpy(), |
| "timings_ms": timings, |
| } |
| if return_confidence: |
| entry["segmentation_confidence"] = seg_confs[i].cpu().numpy() |
| entry["depth_uncertainty"] = depth_uncerts[i].cpu().numpy() |
| results.append(entry) |
| return results[0] if single else results |
|
|
| def export_onnx( |
| self, |
| out_dir: str, |
| backbone_resolution: int = 224, |
| dynamic_batch: bool = True, |
| verify: bool = True, |
| tolerance: float = 5e-2, |
| opset_version: int = 17, |
| ) -> dict: |
| """Export backbone, seg head, and depth head to ONNX. kNN classification and correspondence run on top of the backbone output and need no separate graph.""" |
| import os |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| if backbone_resolution % self.config.patch_size != 0: |
| raise ValueError( |
| f"backbone_resolution ({backbone_resolution}) must be a multiple of patch_size ({self.config.patch_size})" |
| ) |
| spatial_resolution = backbone_resolution // self.config.patch_size |
|
|
| wrapper = _BackboneExportWrapper(self.backbone).to(self.device).eval() |
|
|
| dummy_image = torch.randn( |
| 1, 3, backbone_resolution, backbone_resolution, |
| device=self.device, dtype=torch.float32, |
| ) |
| dummy_spatial = torch.randn( |
| 1, self.config.embed_dim, spatial_resolution, spatial_resolution, |
| device=self.device, dtype=torch.float32, |
| ) |
|
|
| backbone_path = os.path.join(out_dir, "argus_backbone.onnx") |
| seg_path = os.path.join(out_dir, "argus_seg_head.onnx") |
| depth_path = os.path.join(out_dir, "argus_depth_head.onnx") |
|
|
| backbone_axes = None |
| head_axes = None |
| if dynamic_batch: |
| backbone_axes = { |
| "image": {0: "batch"}, |
| "cls_token": {0: "batch"}, |
| "spatial_features": {0: "batch"}, |
| } |
| head_axes = { |
| "spatial_features": {0: "batch"}, |
| "seg_logits": {0: "batch"}, |
| "depth_map": {0: "batch"}, |
| } |
|
|
| |
| with torch.inference_mode(): |
| torch.onnx.export( |
| wrapper, dummy_image, backbone_path, |
| input_names=["image"], |
| output_names=["cls_token", "spatial_features"], |
| dynamic_axes=backbone_axes, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| dynamo=False, |
| ) |
| torch.onnx.export( |
| self.seg_head, dummy_spatial, seg_path, |
| input_names=["spatial_features"], |
| output_names=["seg_logits"], |
| dynamic_axes={"spatial_features": head_axes["spatial_features"], "seg_logits": head_axes["seg_logits"]} if head_axes else None, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| dynamo=False, |
| ) |
| torch.onnx.export( |
| self.depth_head, dummy_spatial, depth_path, |
| input_names=["spatial_features"], |
| output_names=["depth_map"], |
| dynamic_axes={"spatial_features": head_axes["spatial_features"], "depth_map": head_axes["depth_map"]} if head_axes else None, |
| opset_version=opset_version, |
| do_constant_folding=True, |
| dynamo=False, |
| ) |
|
|
| result = { |
| "backbone": backbone_path, |
| "seg_head": seg_path, |
| "depth_head": depth_path, |
| } |
|
|
| if verify: |
| try: |
| import onnxruntime as ort |
| except ImportError as e: |
| raise ImportError("onnxruntime not installed; pip install onnxruntime") from e |
|
|
| providers = ["CPUExecutionProvider"] |
| verify_image = torch.randn(2, 3, backbone_resolution, backbone_resolution, dtype=torch.float32) |
| verify_spatial = torch.randn(2, self.config.embed_dim, spatial_resolution, spatial_resolution, dtype=torch.float32) |
|
|
| with torch.inference_mode(): |
| ref_cls, ref_spatial = wrapper(verify_image.to(self.device)) |
| ref_seg = self.seg_head(verify_spatial.to(self.device)) |
| ref_depth = self.depth_head(verify_spatial.to(self.device)) |
|
|
| sess = ort.InferenceSession(backbone_path, providers=providers) |
| ort_cls, ort_spatial = sess.run(None, {"image": verify_image.numpy()}) |
| cls_diff = float(np.abs(ort_cls - ref_cls.cpu().numpy()).max()) |
| spatial_diff = float(np.abs(ort_spatial - ref_spatial.cpu().numpy()).max()) |
|
|
| sess = ort.InferenceSession(seg_path, providers=providers) |
| ort_seg = sess.run(None, {"spatial_features": verify_spatial.numpy()})[0] |
| seg_diff = float(np.abs(ort_seg - ref_seg.cpu().numpy()).max()) |
|
|
| sess = ort.InferenceSession(depth_path, providers=providers) |
| ort_depth = sess.run(None, {"spatial_features": verify_spatial.numpy()})[0] |
| depth_diff = float(np.abs(ort_depth - ref_depth.cpu().numpy()).max()) |
|
|
| verification = { |
| "backbone_cls_max_diff": cls_diff, |
| "backbone_spatial_max_diff": spatial_diff, |
| "seg_head_max_diff": seg_diff, |
| "depth_head_max_diff": depth_diff, |
| "tolerance": tolerance, |
| "verified_batch_size": 2, |
| } |
| for key, val in verification.items(): |
| if key.endswith("_max_diff") and val > tolerance: |
| raise RuntimeError( |
| f"ONNX/PyTorch divergence in {key}: {val:.2e} > tolerance {tolerance:.2e}" |
| ) |
| result["verification"] = verification |
|
|
| return result |
|
|