Argus v1.1: add trained FCOS detection head as the fifth task
Browse filesAdds object detection to Argus via an FCOS-style anchor-free detector
built on a ViTDet-style simple feature pyramid, trained on COCO 2017
train (117,266 images, 80 classes) with the EUPE-ViT-B backbone frozen.
Detection results on COCO val2017 (5,000 images)
-------------------------------------------------
mAP@[0.5:0.95] = 41.0
mAP@0.50 = 64.8
mAP@0.75 = 43.2
mAP small/med/lg = 21.4 / 44.9 / 62.1
For context, FCOS with a fully-trained ResNet-50-FPN backbone achieves
39.1 mAP on the same benchmark. The frozen EUPE-ViT-B backbone exceeds
that baseline at 41.0 mAP while sharing its features with four other
task heads simultaneously.
Architecture
------------
The simple feature pyramid takes the backbone stride-16 spatial features
and synthesizes five levels (P3 through P7, strides 8 through 128) via
a transposed convolution for P3, identity with channel reduction for P4,
and chained stride-2 convolutions for P5-P7, each with 256 channels and
GroupNorm. Two shared four-layer conv towers (classification and
regression) with GroupNorm and GELU process each level. Three prediction
heads output 80 classification channels, 4 box regression channels
(left/top/right/bottom distances, exponentiated with learned per-level
scale), and 1 centerness channel. 16.14M trainable parameters total.
Training recipe
---------------
640px input with letterbox padding, batch 64, AdamW lr 1e-3, cosine
schedule with 3% warmup, weight decay 1e-4, gradient clipping at 10.0,
8 epochs, full FP32 throughout. Focal loss (alpha 0.25, gamma 2.0) for
classification, GIoU for boxes, BCE for centerness. ~6 hours wall clock
on a single RTX 6000 Ada at 0.7 it/s with 23 GB peak VRAM.
API
---
model.detect(image) returns a list of dicts:
[{"box": [x1,y1,x2,y2], "score": float, "label": int, "class_name": str}]
Detection uses a separate forward pass at 640px (the other tasks use
224/512/416), so it lives in its own method rather than in perceive().
Accepts single images or batches. Configurable score_thresh, nms_thresh,
and max_per_image.
Backward compatibility
----------------------
All existing methods (classify, segment, depth, perceive, correspond)
return identical results to v1.0. The detection head adds 16.14M
parameters and 62 MB to the checkpoint (334 MB to 396 MB). perceive()
does not include detection in its output.
Files changed
-------------
argus.py: +SimpleFeaturePyramid, +FCOSHead, +DetectionHead,
+detect() method, +_make_locations, +_decode_detections,
+_letterbox_to_square, +COCO_CLASSES, +FPN_STRIDES,
extended _init_weights for Conv2d/GroupNorm
model.safetensors: +79 detection_head.* tensors (334 MB to 396 MB)
config.json: +detection_num_classes, +detection_fpn_channels,
+detection_num_convs
README.md: detection in architecture diagram, mAP table,
detect() usage example, head specs, training details
|
@@ -21,6 +21,7 @@ import torch.nn.functional as F
|
|
| 21 |
import torch.nn.init
|
| 22 |
from PIL import Image
|
| 23 |
from torch import Tensor, nn
|
|
|
|
| 24 |
from torchvision.transforms import v2
|
| 25 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 26 |
|
|
@@ -874,6 +875,254 @@ class DepthHead(nn.Module):
|
|
| 874 |
return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1)
|
| 875 |
|
| 876 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
# ===========================================================================
|
| 878 |
# Argus model (transformers-compatible)
|
| 879 |
# ===========================================================================
|
|
@@ -893,6 +1142,10 @@ class ArgusConfig(PretrainedConfig):
|
|
| 893 |
num_imagenet_classes: int = 1000,
|
| 894 |
class_ids: Optional[list] = None,
|
| 895 |
class_names: Optional[list] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 896 |
**kwargs,
|
| 897 |
):
|
| 898 |
super().__init__(**kwargs)
|
|
@@ -905,6 +1158,10 @@ class ArgusConfig(PretrainedConfig):
|
|
| 905 |
self.num_imagenet_classes = num_imagenet_classes
|
| 906 |
self.class_ids = class_ids or []
|
| 907 |
self.class_names = class_names or []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
|
| 909 |
|
| 910 |
class Argus(PreTrainedModel):
|
|
@@ -939,16 +1196,33 @@ class Argus(PreTrainedModel):
|
|
| 939 |
torch.zeros(config.num_imagenet_classes),
|
| 940 |
persistent=True,
|
| 941 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 942 |
|
| 943 |
for p in self.backbone.parameters():
|
| 944 |
p.requires_grad = False
|
| 945 |
self.backbone.eval()
|
| 946 |
self.seg_head.eval()
|
| 947 |
self.depth_head.eval()
|
|
|
|
| 948 |
|
| 949 |
def _init_weights(self, module):
|
| 950 |
-
# HF reallocates missing buffers with torch.empty()
|
| 951 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
if module is self:
|
| 953 |
for name in ("class_prototypes", "class_logit_weight", "class_logit_bias"):
|
| 954 |
if hasattr(self, name):
|
|
@@ -989,9 +1263,12 @@ class Argus(PreTrainedModel):
|
|
| 989 |
cls = F.normalize(cls, dim=-1)
|
| 990 |
|
| 991 |
if method == "knn":
|
| 992 |
-
|
|
|
|
| 993 |
elif method == "softmax":
|
| 994 |
-
|
|
|
|
|
|
|
| 995 |
scores_full = F.softmax(logits, dim=-1) # in [0, 1]
|
| 996 |
else:
|
| 997 |
raise ValueError(f"unknown classification method: {method!r} (expected 'knn' or 'softmax')")
|
|
@@ -1111,6 +1388,76 @@ class Argus(PreTrainedModel):
|
|
| 1111 |
preds.append([px / resolution * tw, py / resolution * th])
|
| 1112 |
return preds
|
| 1113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1114 |
def perceive(self, image_or_images, return_confidence: bool = False):
|
| 1115 |
single, images = _normalize_image_input(image_or_images)
|
| 1116 |
|
|
|
|
| 21 |
import torch.nn.init
|
| 22 |
from PIL import Image
|
| 23 |
from torch import Tensor, nn
|
| 24 |
+
from torchvision.ops import nms
|
| 25 |
from torchvision.transforms import v2
|
| 26 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 27 |
|
|
|
|
| 875 |
return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1)
|
| 876 |
|
| 877 |
|
| 878 |
+
# ===========================================================================
|
| 879 |
+
# Detection (FCOS with ViTDet-style simple feature pyramid)
|
| 880 |
+
# ===========================================================================
|
| 881 |
+
|
| 882 |
+
FPN_STRIDES = [8, 16, 32, 64, 128]
|
| 883 |
+
|
| 884 |
+
COCO_CLASSES = [
|
| 885 |
+
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
|
| 886 |
+
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
|
| 887 |
+
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
|
| 888 |
+
"giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
|
| 889 |
+
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
|
| 890 |
+
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
|
| 891 |
+
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
|
| 892 |
+
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
|
| 893 |
+
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
|
| 894 |
+
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
|
| 895 |
+
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
|
| 896 |
+
"toothbrush",
|
| 897 |
+
]
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
class SimpleFeaturePyramid(nn.Module):
|
| 901 |
+
"""ViTDet-style simple FPN: a single stride-16 ViT feature map -> P3..P7."""
|
| 902 |
+
|
| 903 |
+
def __init__(self, in_channels: int = 768, fpn_channels: int = 256):
|
| 904 |
+
super().__init__()
|
| 905 |
+
self.fpn_channels = fpn_channels
|
| 906 |
+
self.p3 = nn.Sequential(
|
| 907 |
+
nn.ConvTranspose2d(in_channels, in_channels, 2, stride=2),
|
| 908 |
+
nn.GroupNorm(32, in_channels),
|
| 909 |
+
nn.GELU(),
|
| 910 |
+
nn.Conv2d(in_channels, fpn_channels, 1),
|
| 911 |
+
nn.GroupNorm(32, fpn_channels),
|
| 912 |
+
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
|
| 913 |
+
nn.GroupNorm(32, fpn_channels),
|
| 914 |
+
)
|
| 915 |
+
self.p4 = nn.Sequential(
|
| 916 |
+
nn.Conv2d(in_channels, fpn_channels, 1),
|
| 917 |
+
nn.GroupNorm(32, fpn_channels),
|
| 918 |
+
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
|
| 919 |
+
nn.GroupNorm(32, fpn_channels),
|
| 920 |
+
)
|
| 921 |
+
self.p5 = nn.Sequential(
|
| 922 |
+
nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1),
|
| 923 |
+
nn.GroupNorm(32, in_channels),
|
| 924 |
+
nn.GELU(),
|
| 925 |
+
nn.Conv2d(in_channels, fpn_channels, 1),
|
| 926 |
+
nn.GroupNorm(32, fpn_channels),
|
| 927 |
+
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
|
| 928 |
+
nn.GroupNorm(32, fpn_channels),
|
| 929 |
+
)
|
| 930 |
+
self.p6 = nn.Sequential(
|
| 931 |
+
nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1),
|
| 932 |
+
nn.GroupNorm(32, fpn_channels),
|
| 933 |
+
)
|
| 934 |
+
self.p7 = nn.Sequential(
|
| 935 |
+
nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1),
|
| 936 |
+
nn.GroupNorm(32, fpn_channels),
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
def forward(self, x: Tensor) -> List[Tensor]:
|
| 940 |
+
p3 = self.p3(x)
|
| 941 |
+
p4 = self.p4(x)
|
| 942 |
+
p5 = self.p5(x)
|
| 943 |
+
p6 = self.p6(p5)
|
| 944 |
+
p7 = self.p7(p6)
|
| 945 |
+
return [p3, p4, p5, p6, p7]
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
class FCOSHead(nn.Module):
|
| 949 |
+
"""Shared classification / box regression / centerness towers across pyramid levels."""
|
| 950 |
+
|
| 951 |
+
def __init__(self, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4):
|
| 952 |
+
super().__init__()
|
| 953 |
+
self.num_classes = num_classes
|
| 954 |
+
|
| 955 |
+
cls_tower, reg_tower = [], []
|
| 956 |
+
for _ in range(num_convs):
|
| 957 |
+
cls_tower += [
|
| 958 |
+
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
|
| 959 |
+
nn.GroupNorm(32, fpn_channels),
|
| 960 |
+
nn.GELU(),
|
| 961 |
+
]
|
| 962 |
+
reg_tower += [
|
| 963 |
+
nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
|
| 964 |
+
nn.GroupNorm(32, fpn_channels),
|
| 965 |
+
nn.GELU(),
|
| 966 |
+
]
|
| 967 |
+
self.cls_tower = nn.Sequential(*cls_tower)
|
| 968 |
+
self.reg_tower = nn.Sequential(*reg_tower)
|
| 969 |
+
|
| 970 |
+
self.cls_pred = nn.Conv2d(fpn_channels, num_classes, 3, padding=1)
|
| 971 |
+
self.reg_pred = nn.Conv2d(fpn_channels, 4, 3, padding=1)
|
| 972 |
+
self.center_pred = nn.Conv2d(fpn_channels, 1, 3, padding=1)
|
| 973 |
+
|
| 974 |
+
self.scales = nn.Parameter(torch.ones(len(FPN_STRIDES)))
|
| 975 |
+
|
| 976 |
+
prior = 0.01
|
| 977 |
+
nn.init.constant_(self.cls_pred.bias, -math.log((1 - prior) / prior))
|
| 978 |
+
nn.init.zeros_(self.reg_pred.weight)
|
| 979 |
+
nn.init.zeros_(self.reg_pred.bias)
|
| 980 |
+
nn.init.zeros_(self.center_pred.weight)
|
| 981 |
+
nn.init.zeros_(self.center_pred.bias)
|
| 982 |
+
|
| 983 |
+
def forward(self, fpn_features: List[Tensor]) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
|
| 984 |
+
cls_logits, box_regs, centernesses = [], [], []
|
| 985 |
+
for level_idx, feat in enumerate(fpn_features):
|
| 986 |
+
cls_feat = self.cls_tower(feat)
|
| 987 |
+
reg_feat = self.reg_tower(feat)
|
| 988 |
+
cls_logits.append(self.cls_pred(cls_feat))
|
| 989 |
+
reg_raw = self.reg_pred(reg_feat) * self.scales[level_idx]
|
| 990 |
+
reg_raw = reg_raw.clamp(min=-10.0, max=10.0)
|
| 991 |
+
box_regs.append(torch.exp(reg_raw))
|
| 992 |
+
centernesses.append(self.center_pred(reg_feat))
|
| 993 |
+
return cls_logits, box_regs, centernesses
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
class DetectionHead(nn.Module):
|
| 997 |
+
"""Combined SFP + FCOS head."""
|
| 998 |
+
|
| 999 |
+
def __init__(self, in_channels: int = 768, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4):
|
| 1000 |
+
super().__init__()
|
| 1001 |
+
self.fpn = SimpleFeaturePyramid(in_channels=in_channels, fpn_channels=fpn_channels)
|
| 1002 |
+
self.head = FCOSHead(fpn_channels=fpn_channels, num_classes=num_classes, num_convs=num_convs)
|
| 1003 |
+
self.num_classes = num_classes
|
| 1004 |
+
|
| 1005 |
+
def forward(self, spatial_features: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
|
| 1006 |
+
fpn = self.fpn(spatial_features)
|
| 1007 |
+
return self.head(fpn)
|
| 1008 |
+
|
| 1009 |
+
|
| 1010 |
+
def _make_locations(feature_sizes: List[Tuple[int, int]], strides: List[int], device) -> List[Tensor]:
|
| 1011 |
+
"""Per-level center coordinates of feature-map locations in image space."""
|
| 1012 |
+
all_locs = []
|
| 1013 |
+
for (h, w), s in zip(feature_sizes, strides):
|
| 1014 |
+
ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
|
| 1015 |
+
xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
|
| 1016 |
+
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
|
| 1017 |
+
locs = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)
|
| 1018 |
+
all_locs.append(locs)
|
| 1019 |
+
return all_locs
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
@torch.inference_mode()
|
| 1023 |
+
def _decode_detections(
|
| 1024 |
+
cls_logits_per_level: List[Tensor],
|
| 1025 |
+
box_regs_per_level: List[Tensor],
|
| 1026 |
+
centernesses_per_level: List[Tensor],
|
| 1027 |
+
locations_per_level: List[Tensor],
|
| 1028 |
+
image_sizes: List[Tuple[int, int]],
|
| 1029 |
+
score_thresh: float = 0.05,
|
| 1030 |
+
nms_thresh: float = 0.5,
|
| 1031 |
+
max_per_level: int = 1000,
|
| 1032 |
+
max_per_image: int = 100,
|
| 1033 |
+
) -> List[Dict[str, Tensor]]:
|
| 1034 |
+
"""Convert per-level logits/regs/centerness into per-image detections (xyxy boxes)."""
|
| 1035 |
+
B = cls_logits_per_level[0].shape[0]
|
| 1036 |
+
num_classes = cls_logits_per_level[0].shape[1]
|
| 1037 |
+
device = cls_logits_per_level[0].device
|
| 1038 |
+
|
| 1039 |
+
per_image_results = []
|
| 1040 |
+
for image_idx in range(B):
|
| 1041 |
+
all_boxes, all_scores, all_labels = [], [], []
|
| 1042 |
+
for cls_l, reg_l, ctr_l, locs_l in zip(
|
| 1043 |
+
cls_logits_per_level, box_regs_per_level, centernesses_per_level, locations_per_level
|
| 1044 |
+
):
|
| 1045 |
+
cls = cls_l[image_idx].permute(1, 2, 0).reshape(-1, num_classes)
|
| 1046 |
+
reg = reg_l[image_idx].permute(1, 2, 0).reshape(-1, 4)
|
| 1047 |
+
ctr = ctr_l[image_idx].permute(1, 2, 0).reshape(-1)
|
| 1048 |
+
|
| 1049 |
+
cls_prob = torch.sigmoid(cls)
|
| 1050 |
+
ctr_prob = torch.sigmoid(ctr)
|
| 1051 |
+
scores = cls_prob * ctr_prob[:, None]
|
| 1052 |
+
|
| 1053 |
+
mask = scores > score_thresh
|
| 1054 |
+
if not mask.any():
|
| 1055 |
+
continue
|
| 1056 |
+
cand_loc, cand_cls = mask.nonzero(as_tuple=True)
|
| 1057 |
+
cand_scores = scores[cand_loc, cand_cls]
|
| 1058 |
+
|
| 1059 |
+
if cand_scores.numel() > max_per_level:
|
| 1060 |
+
top = cand_scores.topk(max_per_level)
|
| 1061 |
+
cand_scores = top.values
|
| 1062 |
+
idx = top.indices
|
| 1063 |
+
cand_loc = cand_loc[idx]
|
| 1064 |
+
cand_cls = cand_cls[idx]
|
| 1065 |
+
|
| 1066 |
+
cand_locs_xy = locs_l[cand_loc]
|
| 1067 |
+
cand_reg = reg[cand_loc]
|
| 1068 |
+
boxes = torch.stack([
|
| 1069 |
+
cand_locs_xy[:, 0] - cand_reg[:, 0],
|
| 1070 |
+
cand_locs_xy[:, 1] - cand_reg[:, 1],
|
| 1071 |
+
cand_locs_xy[:, 0] + cand_reg[:, 2],
|
| 1072 |
+
cand_locs_xy[:, 1] + cand_reg[:, 3],
|
| 1073 |
+
], dim=-1)
|
| 1074 |
+
all_boxes.append(boxes)
|
| 1075 |
+
all_scores.append(cand_scores)
|
| 1076 |
+
all_labels.append(cand_cls)
|
| 1077 |
+
|
| 1078 |
+
if all_boxes:
|
| 1079 |
+
boxes = torch.cat(all_boxes, dim=0)
|
| 1080 |
+
scores = torch.cat(all_scores, dim=0)
|
| 1081 |
+
labels = torch.cat(all_labels, dim=0)
|
| 1082 |
+
|
| 1083 |
+
H, W = image_sizes[image_idx]
|
| 1084 |
+
boxes[:, 0::2] = boxes[:, 0::2].clamp(0, W)
|
| 1085 |
+
boxes[:, 1::2] = boxes[:, 1::2].clamp(0, H)
|
| 1086 |
+
|
| 1087 |
+
keep_all = []
|
| 1088 |
+
for c in labels.unique():
|
| 1089 |
+
cm = labels == c
|
| 1090 |
+
keep = nms(boxes[cm], scores[cm], nms_thresh)
|
| 1091 |
+
keep_idx = cm.nonzero(as_tuple=True)[0][keep]
|
| 1092 |
+
keep_all.append(keep_idx)
|
| 1093 |
+
keep_all = torch.cat(keep_all, dim=0)
|
| 1094 |
+
|
| 1095 |
+
boxes = boxes[keep_all]
|
| 1096 |
+
scores = scores[keep_all]
|
| 1097 |
+
labels = labels[keep_all]
|
| 1098 |
+
|
| 1099 |
+
if scores.numel() > max_per_image:
|
| 1100 |
+
top = scores.topk(max_per_image)
|
| 1101 |
+
boxes = boxes[top.indices]
|
| 1102 |
+
scores = top.values
|
| 1103 |
+
labels = labels[top.indices]
|
| 1104 |
+
else:
|
| 1105 |
+
boxes = torch.zeros((0, 4), device=device)
|
| 1106 |
+
scores = torch.zeros((0,), device=device)
|
| 1107 |
+
labels = torch.zeros((0,), dtype=torch.long, device=device)
|
| 1108 |
+
|
| 1109 |
+
per_image_results.append({"boxes": boxes, "scores": scores, "labels": labels})
|
| 1110 |
+
|
| 1111 |
+
return per_image_results
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
def _letterbox_to_square(image: Image.Image, resolution: int) -> Tuple[Image.Image, float, Tuple[int, int]]:
|
| 1115 |
+
"""Resize preserving aspect ratio and pad bottom/right with black. Matches the training transform."""
|
| 1116 |
+
W0, H0 = image.size
|
| 1117 |
+
scale = resolution / max(H0, W0)
|
| 1118 |
+
new_w = int(round(W0 * scale))
|
| 1119 |
+
new_h = int(round(H0 * scale))
|
| 1120 |
+
resized = image.resize((new_w, new_h), Image.BILINEAR)
|
| 1121 |
+
canvas = Image.new("RGB", (resolution, resolution), (0, 0, 0))
|
| 1122 |
+
canvas.paste(resized, (0, 0))
|
| 1123 |
+
return canvas, scale, (W0, H0)
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
# ===========================================================================
|
| 1127 |
# Argus model (transformers-compatible)
|
| 1128 |
# ===========================================================================
|
|
|
|
| 1142 |
num_imagenet_classes: int = 1000,
|
| 1143 |
class_ids: Optional[list] = None,
|
| 1144 |
class_names: Optional[list] = None,
|
| 1145 |
+
detection_num_classes: int = 80,
|
| 1146 |
+
detection_fpn_channels: int = 256,
|
| 1147 |
+
detection_num_convs: int = 4,
|
| 1148 |
+
detection_class_names: Optional[list] = None,
|
| 1149 |
**kwargs,
|
| 1150 |
):
|
| 1151 |
super().__init__(**kwargs)
|
|
|
|
| 1158 |
self.num_imagenet_classes = num_imagenet_classes
|
| 1159 |
self.class_ids = class_ids or []
|
| 1160 |
self.class_names = class_names or []
|
| 1161 |
+
self.detection_num_classes = detection_num_classes
|
| 1162 |
+
self.detection_fpn_channels = detection_fpn_channels
|
| 1163 |
+
self.detection_num_convs = detection_num_convs
|
| 1164 |
+
self.detection_class_names = detection_class_names or list(COCO_CLASSES)
|
| 1165 |
|
| 1166 |
|
| 1167 |
class Argus(PreTrainedModel):
|
|
|
|
| 1196 |
torch.zeros(config.num_imagenet_classes),
|
| 1197 |
persistent=True,
|
| 1198 |
)
|
| 1199 |
+
self.detection_head = DetectionHead(
|
| 1200 |
+
in_channels=config.embed_dim,
|
| 1201 |
+
fpn_channels=config.detection_fpn_channels,
|
| 1202 |
+
num_classes=config.detection_num_classes,
|
| 1203 |
+
num_convs=config.detection_num_convs,
|
| 1204 |
+
)
|
| 1205 |
|
| 1206 |
for p in self.backbone.parameters():
|
| 1207 |
p.requires_grad = False
|
| 1208 |
self.backbone.eval()
|
| 1209 |
self.seg_head.eval()
|
| 1210 |
self.depth_head.eval()
|
| 1211 |
+
self.detection_head.eval()
|
| 1212 |
|
| 1213 |
def _init_weights(self, module):
|
| 1214 |
+
# HF reallocates missing buffers and parameters with torch.empty()
|
| 1215 |
+
# (uninitialized memory) on from_pretrained. Populate sensible defaults
|
| 1216 |
+
# for the standard layer types used by the detection head, and zero any
|
| 1217 |
+
# Argus-level buffer that came back NaN.
|
| 1218 |
+
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
| 1219 |
+
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
|
| 1220 |
+
if module.bias is not None:
|
| 1221 |
+
nn.init.zeros_(module.bias)
|
| 1222 |
+
elif isinstance(module, nn.GroupNorm):
|
| 1223 |
+
nn.init.ones_(module.weight)
|
| 1224 |
+
nn.init.zeros_(module.bias)
|
| 1225 |
+
|
| 1226 |
if module is self:
|
| 1227 |
for name in ("class_prototypes", "class_logit_weight", "class_logit_bias"):
|
| 1228 |
if hasattr(self, name):
|
|
|
|
| 1263 |
cls = F.normalize(cls, dim=-1)
|
| 1264 |
|
| 1265 |
if method == "knn":
|
| 1266 |
+
proto = self.class_prototypes.to(cls.dtype)
|
| 1267 |
+
scores_full = cls @ proto.T # cosine similarity in [-1, 1]
|
| 1268 |
elif method == "softmax":
|
| 1269 |
+
w = self.class_logit_weight.to(cls.dtype)
|
| 1270 |
+
b = self.class_logit_bias.to(cls.dtype)
|
| 1271 |
+
logits = F.linear(cls, w, b)
|
| 1272 |
scores_full = F.softmax(logits, dim=-1) # in [0, 1]
|
| 1273 |
else:
|
| 1274 |
raise ValueError(f"unknown classification method: {method!r} (expected 'knn' or 'softmax')")
|
|
|
|
| 1388 |
preds.append([px / resolution * tw, py / resolution * th])
|
| 1389 |
return preds
|
| 1390 |
|
| 1391 |
+
@torch.inference_mode()
|
| 1392 |
+
def detect(
|
| 1393 |
+
self,
|
| 1394 |
+
image_or_images,
|
| 1395 |
+
resolution: int = 640,
|
| 1396 |
+
score_thresh: float = 0.05,
|
| 1397 |
+
nms_thresh: float = 0.5,
|
| 1398 |
+
max_per_image: int = 100,
|
| 1399 |
+
):
|
| 1400 |
+
single, images = _normalize_image_input(image_or_images)
|
| 1401 |
+
|
| 1402 |
+
# Letterbox each image to match the training transform (resize long side
|
| 1403 |
+
# to `resolution`, pad bottom/right with black). Box coordinates are
|
| 1404 |
+
# recovered after decoding by unscaling.
|
| 1405 |
+
canvases, scales, orig_sizes = [], [], []
|
| 1406 |
+
for img in images:
|
| 1407 |
+
canvas, scale, orig = _letterbox_to_square(img, resolution)
|
| 1408 |
+
canvases.append(canvas)
|
| 1409 |
+
scales.append(scale)
|
| 1410 |
+
orig_sizes.append(orig)
|
| 1411 |
+
|
| 1412 |
+
det_normalize = v2.Compose([
|
| 1413 |
+
v2.ToImage(),
|
| 1414 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 1415 |
+
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 1416 |
+
])
|
| 1417 |
+
batch = torch.stack([det_normalize(c) for c in canvases]).to(self.device)
|
| 1418 |
+
|
| 1419 |
+
_, spatial = self._extract(batch)
|
| 1420 |
+
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
|
| 1421 |
+
cls_logits, box_regs, centernesses = self.detection_head(spatial)
|
| 1422 |
+
cls_logits = [c.float() for c in cls_logits]
|
| 1423 |
+
box_regs = [b.float() for b in box_regs]
|
| 1424 |
+
centernesses = [c.float() for c in centernesses]
|
| 1425 |
+
|
| 1426 |
+
feature_sizes = [(cl.shape[2], cl.shape[3]) for cl in cls_logits]
|
| 1427 |
+
locations = _make_locations(feature_sizes, FPN_STRIDES, spatial.device)
|
| 1428 |
+
image_sizes = [(resolution, resolution)] * len(images)
|
| 1429 |
+
|
| 1430 |
+
results = _decode_detections(
|
| 1431 |
+
cls_logits, box_regs, centernesses, locations,
|
| 1432 |
+
image_sizes=image_sizes,
|
| 1433 |
+
score_thresh=score_thresh,
|
| 1434 |
+
nms_thresh=nms_thresh,
|
| 1435 |
+
max_per_image=max_per_image,
|
| 1436 |
+
)
|
| 1437 |
+
|
| 1438 |
+
class_names = self.config.detection_class_names
|
| 1439 |
+
formatted = []
|
| 1440 |
+
for i, r in enumerate(results):
|
| 1441 |
+
scale = scales[i]
|
| 1442 |
+
orig_w, orig_h = orig_sizes[i]
|
| 1443 |
+
boxes = r["boxes"].cpu().numpy() / scale
|
| 1444 |
+
boxes[:, 0::2] = boxes[:, 0::2].clip(0, orig_w)
|
| 1445 |
+
boxes[:, 1::2] = boxes[:, 1::2].clip(0, orig_h)
|
| 1446 |
+
|
| 1447 |
+
detections = []
|
| 1448 |
+
for box, score, label in zip(
|
| 1449 |
+
boxes, r["scores"].cpu().numpy(), r["labels"].cpu().numpy()
|
| 1450 |
+
):
|
| 1451 |
+
detections.append({
|
| 1452 |
+
"box": [float(v) for v in box.tolist()],
|
| 1453 |
+
"score": float(score),
|
| 1454 |
+
"label": int(label),
|
| 1455 |
+
"class_name": class_names[int(label)] if int(label) < len(class_names) else f"class_{int(label)}",
|
| 1456 |
+
})
|
| 1457 |
+
formatted.append(detections)
|
| 1458 |
+
|
| 1459 |
+
return formatted[0] if single else formatted
|
| 1460 |
+
|
| 1461 |
def perceive(self, image_or_images, return_confidence: bool = False):
|
| 1462 |
single, images = _normalize_image_input(image_or_images)
|
| 1463 |
|