phanerozoic commited on
Commit
911ed47
·
verified ·
1 Parent(s): 4f2319a

Argus v1.1: add trained FCOS detection head as the fifth task

Browse files

Adds object detection to Argus via an FCOS-style anchor-free detector
built on a ViTDet-style simple feature pyramid, trained on COCO 2017
train (117,266 images, 80 classes) with the EUPE-ViT-B backbone frozen.

Detection results on COCO val2017 (5,000 images)
-------------------------------------------------
mAP@[0.5:0.95] = 41.0
mAP@0.50 = 64.8
mAP@0.75 = 43.2
mAP small/med/lg = 21.4 / 44.9 / 62.1

For context, FCOS with a fully-trained ResNet-50-FPN backbone achieves
39.1 mAP on the same benchmark. The frozen EUPE-ViT-B backbone exceeds
that baseline at 41.0 mAP while sharing its features with four other
task heads simultaneously.

Architecture
------------
The simple feature pyramid takes the backbone stride-16 spatial features
and synthesizes five levels (P3 through P7, strides 8 through 128) via
a transposed convolution for P3, identity with channel reduction for P4,
and chained stride-2 convolutions for P5-P7, each with 256 channels and
GroupNorm. Two shared four-layer conv towers (classification and
regression) with GroupNorm and GELU process each level. Three prediction
heads output 80 classification channels, 4 box regression channels
(left/top/right/bottom distances, exponentiated with learned per-level
scale), and 1 centerness channel. 16.14M trainable parameters total.

Training recipe
---------------
640px input with letterbox padding, batch 64, AdamW lr 1e-3, cosine
schedule with 3% warmup, weight decay 1e-4, gradient clipping at 10.0,
8 epochs, full FP32 throughout. Focal loss (alpha 0.25, gamma 2.0) for
classification, GIoU for boxes, BCE for centerness. ~6 hours wall clock
on a single RTX 6000 Ada at 0.7 it/s with 23 GB peak VRAM.

API
---
model.detect(image) returns a list of dicts:
[{"box": [x1,y1,x2,y2], "score": float, "label": int, "class_name": str}]

Detection uses a separate forward pass at 640px (the other tasks use
224/512/416), so it lives in its own method rather than in perceive().
Accepts single images or batches. Configurable score_thresh, nms_thresh,
and max_per_image.

Backward compatibility
----------------------
All existing methods (classify, segment, depth, perceive, correspond)
return identical results to v1.0. The detection head adds 16.14M
parameters and 62 MB to the checkpoint (334 MB to 396 MB). perceive()
does not include detection in its output.

Files changed
-------------
argus.py: +SimpleFeaturePyramid, +FCOSHead, +DetectionHead,
+detect() method, +_make_locations, +_decode_detections,
+_letterbox_to_square, +COCO_CLASSES, +FPN_STRIDES,
extended _init_weights for Conv2d/GroupNorm
model.safetensors: +79 detection_head.* tensors (334 MB to 396 MB)
config.json: +detection_num_classes, +detection_fpn_channels,
+detection_num_convs
README.md: detection in architecture diagram, mAP table,
detect() usage example, head specs, training details

Files changed (1) hide show
  1. argus.py +351 -4
argus.py CHANGED
@@ -21,6 +21,7 @@ import torch.nn.functional as F
21
  import torch.nn.init
22
  from PIL import Image
23
  from torch import Tensor, nn
 
24
  from torchvision.transforms import v2
25
  from transformers import PretrainedConfig, PreTrainedModel
26
 
@@ -874,6 +875,254 @@ class DepthHead(nn.Module):
874
  return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1)
875
 
876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
  # ===========================================================================
878
  # Argus model (transformers-compatible)
879
  # ===========================================================================
@@ -893,6 +1142,10 @@ class ArgusConfig(PretrainedConfig):
893
  num_imagenet_classes: int = 1000,
894
  class_ids: Optional[list] = None,
895
  class_names: Optional[list] = None,
 
 
 
 
896
  **kwargs,
897
  ):
898
  super().__init__(**kwargs)
@@ -905,6 +1158,10 @@ class ArgusConfig(PretrainedConfig):
905
  self.num_imagenet_classes = num_imagenet_classes
906
  self.class_ids = class_ids or []
907
  self.class_names = class_names or []
 
 
 
 
908
 
909
 
910
  class Argus(PreTrainedModel):
@@ -939,16 +1196,33 @@ class Argus(PreTrainedModel):
939
  torch.zeros(config.num_imagenet_classes),
940
  persistent=True,
941
  )
 
 
 
 
 
 
942
 
943
  for p in self.backbone.parameters():
944
  p.requires_grad = False
945
  self.backbone.eval()
946
  self.seg_head.eval()
947
  self.depth_head.eval()
 
948
 
949
  def _init_weights(self, module):
950
- # HF reallocates missing buffers with torch.empty() (uninitialized memory).
951
- # Zero any buffer that came back NaN; leave loaded buffers untouched.
 
 
 
 
 
 
 
 
 
 
952
  if module is self:
953
  for name in ("class_prototypes", "class_logit_weight", "class_logit_bias"):
954
  if hasattr(self, name):
@@ -989,9 +1263,12 @@ class Argus(PreTrainedModel):
989
  cls = F.normalize(cls, dim=-1)
990
 
991
  if method == "knn":
992
- scores_full = cls @ self.class_prototypes.T # cosine similarity in [-1, 1]
 
993
  elif method == "softmax":
994
- logits = F.linear(cls, self.class_logit_weight, self.class_logit_bias)
 
 
995
  scores_full = F.softmax(logits, dim=-1) # in [0, 1]
996
  else:
997
  raise ValueError(f"unknown classification method: {method!r} (expected 'knn' or 'softmax')")
@@ -1111,6 +1388,76 @@ class Argus(PreTrainedModel):
1111
  preds.append([px / resolution * tw, py / resolution * th])
1112
  return preds
1113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1114
  def perceive(self, image_or_images, return_confidence: bool = False):
1115
  single, images = _normalize_image_input(image_or_images)
1116
 
 
21
  import torch.nn.init
22
  from PIL import Image
23
  from torch import Tensor, nn
24
+ from torchvision.ops import nms
25
  from torchvision.transforms import v2
26
  from transformers import PretrainedConfig, PreTrainedModel
27
 
 
875
  return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1)
876
 
877
 
878
+ # ===========================================================================
879
+ # Detection (FCOS with ViTDet-style simple feature pyramid)
880
+ # ===========================================================================
881
+
882
+ FPN_STRIDES = [8, 16, 32, 64, 128]
883
+
884
+ COCO_CLASSES = [
885
+ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
886
+ "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
887
+ "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
888
+ "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
889
+ "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
890
+ "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
891
+ "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
892
+ "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
893
+ "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
894
+ "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
895
+ "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
896
+ "toothbrush",
897
+ ]
898
+
899
+
900
+ class SimpleFeaturePyramid(nn.Module):
901
+ """ViTDet-style simple FPN: a single stride-16 ViT feature map -> P3..P7."""
902
+
903
+ def __init__(self, in_channels: int = 768, fpn_channels: int = 256):
904
+ super().__init__()
905
+ self.fpn_channels = fpn_channels
906
+ self.p3 = nn.Sequential(
907
+ nn.ConvTranspose2d(in_channels, in_channels, 2, stride=2),
908
+ nn.GroupNorm(32, in_channels),
909
+ nn.GELU(),
910
+ nn.Conv2d(in_channels, fpn_channels, 1),
911
+ nn.GroupNorm(32, fpn_channels),
912
+ nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
913
+ nn.GroupNorm(32, fpn_channels),
914
+ )
915
+ self.p4 = nn.Sequential(
916
+ nn.Conv2d(in_channels, fpn_channels, 1),
917
+ nn.GroupNorm(32, fpn_channels),
918
+ nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
919
+ nn.GroupNorm(32, fpn_channels),
920
+ )
921
+ self.p5 = nn.Sequential(
922
+ nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1),
923
+ nn.GroupNorm(32, in_channels),
924
+ nn.GELU(),
925
+ nn.Conv2d(in_channels, fpn_channels, 1),
926
+ nn.GroupNorm(32, fpn_channels),
927
+ nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
928
+ nn.GroupNorm(32, fpn_channels),
929
+ )
930
+ self.p6 = nn.Sequential(
931
+ nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1),
932
+ nn.GroupNorm(32, fpn_channels),
933
+ )
934
+ self.p7 = nn.Sequential(
935
+ nn.Conv2d(fpn_channels, fpn_channels, 3, stride=2, padding=1),
936
+ nn.GroupNorm(32, fpn_channels),
937
+ )
938
+
939
+ def forward(self, x: Tensor) -> List[Tensor]:
940
+ p3 = self.p3(x)
941
+ p4 = self.p4(x)
942
+ p5 = self.p5(x)
943
+ p6 = self.p6(p5)
944
+ p7 = self.p7(p6)
945
+ return [p3, p4, p5, p6, p7]
946
+
947
+
948
+ class FCOSHead(nn.Module):
949
+ """Shared classification / box regression / centerness towers across pyramid levels."""
950
+
951
+ def __init__(self, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4):
952
+ super().__init__()
953
+ self.num_classes = num_classes
954
+
955
+ cls_tower, reg_tower = [], []
956
+ for _ in range(num_convs):
957
+ cls_tower += [
958
+ nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
959
+ nn.GroupNorm(32, fpn_channels),
960
+ nn.GELU(),
961
+ ]
962
+ reg_tower += [
963
+ nn.Conv2d(fpn_channels, fpn_channels, 3, padding=1),
964
+ nn.GroupNorm(32, fpn_channels),
965
+ nn.GELU(),
966
+ ]
967
+ self.cls_tower = nn.Sequential(*cls_tower)
968
+ self.reg_tower = nn.Sequential(*reg_tower)
969
+
970
+ self.cls_pred = nn.Conv2d(fpn_channels, num_classes, 3, padding=1)
971
+ self.reg_pred = nn.Conv2d(fpn_channels, 4, 3, padding=1)
972
+ self.center_pred = nn.Conv2d(fpn_channels, 1, 3, padding=1)
973
+
974
+ self.scales = nn.Parameter(torch.ones(len(FPN_STRIDES)))
975
+
976
+ prior = 0.01
977
+ nn.init.constant_(self.cls_pred.bias, -math.log((1 - prior) / prior))
978
+ nn.init.zeros_(self.reg_pred.weight)
979
+ nn.init.zeros_(self.reg_pred.bias)
980
+ nn.init.zeros_(self.center_pred.weight)
981
+ nn.init.zeros_(self.center_pred.bias)
982
+
983
+ def forward(self, fpn_features: List[Tensor]) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
984
+ cls_logits, box_regs, centernesses = [], [], []
985
+ for level_idx, feat in enumerate(fpn_features):
986
+ cls_feat = self.cls_tower(feat)
987
+ reg_feat = self.reg_tower(feat)
988
+ cls_logits.append(self.cls_pred(cls_feat))
989
+ reg_raw = self.reg_pred(reg_feat) * self.scales[level_idx]
990
+ reg_raw = reg_raw.clamp(min=-10.0, max=10.0)
991
+ box_regs.append(torch.exp(reg_raw))
992
+ centernesses.append(self.center_pred(reg_feat))
993
+ return cls_logits, box_regs, centernesses
994
+
995
+
996
+ class DetectionHead(nn.Module):
997
+ """Combined SFP + FCOS head."""
998
+
999
+ def __init__(self, in_channels: int = 768, fpn_channels: int = 256, num_classes: int = 80, num_convs: int = 4):
1000
+ super().__init__()
1001
+ self.fpn = SimpleFeaturePyramid(in_channels=in_channels, fpn_channels=fpn_channels)
1002
+ self.head = FCOSHead(fpn_channels=fpn_channels, num_classes=num_classes, num_convs=num_convs)
1003
+ self.num_classes = num_classes
1004
+
1005
+ def forward(self, spatial_features: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
1006
+ fpn = self.fpn(spatial_features)
1007
+ return self.head(fpn)
1008
+
1009
+
1010
+ def _make_locations(feature_sizes: List[Tuple[int, int]], strides: List[int], device) -> List[Tensor]:
1011
+ """Per-level center coordinates of feature-map locations in image space."""
1012
+ all_locs = []
1013
+ for (h, w), s in zip(feature_sizes, strides):
1014
+ ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
1015
+ xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
1016
+ grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
1017
+ locs = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)
1018
+ all_locs.append(locs)
1019
+ return all_locs
1020
+
1021
+
1022
+ @torch.inference_mode()
1023
+ def _decode_detections(
1024
+ cls_logits_per_level: List[Tensor],
1025
+ box_regs_per_level: List[Tensor],
1026
+ centernesses_per_level: List[Tensor],
1027
+ locations_per_level: List[Tensor],
1028
+ image_sizes: List[Tuple[int, int]],
1029
+ score_thresh: float = 0.05,
1030
+ nms_thresh: float = 0.5,
1031
+ max_per_level: int = 1000,
1032
+ max_per_image: int = 100,
1033
+ ) -> List[Dict[str, Tensor]]:
1034
+ """Convert per-level logits/regs/centerness into per-image detections (xyxy boxes)."""
1035
+ B = cls_logits_per_level[0].shape[0]
1036
+ num_classes = cls_logits_per_level[0].shape[1]
1037
+ device = cls_logits_per_level[0].device
1038
+
1039
+ per_image_results = []
1040
+ for image_idx in range(B):
1041
+ all_boxes, all_scores, all_labels = [], [], []
1042
+ for cls_l, reg_l, ctr_l, locs_l in zip(
1043
+ cls_logits_per_level, box_regs_per_level, centernesses_per_level, locations_per_level
1044
+ ):
1045
+ cls = cls_l[image_idx].permute(1, 2, 0).reshape(-1, num_classes)
1046
+ reg = reg_l[image_idx].permute(1, 2, 0).reshape(-1, 4)
1047
+ ctr = ctr_l[image_idx].permute(1, 2, 0).reshape(-1)
1048
+
1049
+ cls_prob = torch.sigmoid(cls)
1050
+ ctr_prob = torch.sigmoid(ctr)
1051
+ scores = cls_prob * ctr_prob[:, None]
1052
+
1053
+ mask = scores > score_thresh
1054
+ if not mask.any():
1055
+ continue
1056
+ cand_loc, cand_cls = mask.nonzero(as_tuple=True)
1057
+ cand_scores = scores[cand_loc, cand_cls]
1058
+
1059
+ if cand_scores.numel() > max_per_level:
1060
+ top = cand_scores.topk(max_per_level)
1061
+ cand_scores = top.values
1062
+ idx = top.indices
1063
+ cand_loc = cand_loc[idx]
1064
+ cand_cls = cand_cls[idx]
1065
+
1066
+ cand_locs_xy = locs_l[cand_loc]
1067
+ cand_reg = reg[cand_loc]
1068
+ boxes = torch.stack([
1069
+ cand_locs_xy[:, 0] - cand_reg[:, 0],
1070
+ cand_locs_xy[:, 1] - cand_reg[:, 1],
1071
+ cand_locs_xy[:, 0] + cand_reg[:, 2],
1072
+ cand_locs_xy[:, 1] + cand_reg[:, 3],
1073
+ ], dim=-1)
1074
+ all_boxes.append(boxes)
1075
+ all_scores.append(cand_scores)
1076
+ all_labels.append(cand_cls)
1077
+
1078
+ if all_boxes:
1079
+ boxes = torch.cat(all_boxes, dim=0)
1080
+ scores = torch.cat(all_scores, dim=0)
1081
+ labels = torch.cat(all_labels, dim=0)
1082
+
1083
+ H, W = image_sizes[image_idx]
1084
+ boxes[:, 0::2] = boxes[:, 0::2].clamp(0, W)
1085
+ boxes[:, 1::2] = boxes[:, 1::2].clamp(0, H)
1086
+
1087
+ keep_all = []
1088
+ for c in labels.unique():
1089
+ cm = labels == c
1090
+ keep = nms(boxes[cm], scores[cm], nms_thresh)
1091
+ keep_idx = cm.nonzero(as_tuple=True)[0][keep]
1092
+ keep_all.append(keep_idx)
1093
+ keep_all = torch.cat(keep_all, dim=0)
1094
+
1095
+ boxes = boxes[keep_all]
1096
+ scores = scores[keep_all]
1097
+ labels = labels[keep_all]
1098
+
1099
+ if scores.numel() > max_per_image:
1100
+ top = scores.topk(max_per_image)
1101
+ boxes = boxes[top.indices]
1102
+ scores = top.values
1103
+ labels = labels[top.indices]
1104
+ else:
1105
+ boxes = torch.zeros((0, 4), device=device)
1106
+ scores = torch.zeros((0,), device=device)
1107
+ labels = torch.zeros((0,), dtype=torch.long, device=device)
1108
+
1109
+ per_image_results.append({"boxes": boxes, "scores": scores, "labels": labels})
1110
+
1111
+ return per_image_results
1112
+
1113
+
1114
+ def _letterbox_to_square(image: Image.Image, resolution: int) -> Tuple[Image.Image, float, Tuple[int, int]]:
1115
+ """Resize preserving aspect ratio and pad bottom/right with black. Matches the training transform."""
1116
+ W0, H0 = image.size
1117
+ scale = resolution / max(H0, W0)
1118
+ new_w = int(round(W0 * scale))
1119
+ new_h = int(round(H0 * scale))
1120
+ resized = image.resize((new_w, new_h), Image.BILINEAR)
1121
+ canvas = Image.new("RGB", (resolution, resolution), (0, 0, 0))
1122
+ canvas.paste(resized, (0, 0))
1123
+ return canvas, scale, (W0, H0)
1124
+
1125
+
1126
  # ===========================================================================
1127
  # Argus model (transformers-compatible)
1128
  # ===========================================================================
 
1142
  num_imagenet_classes: int = 1000,
1143
  class_ids: Optional[list] = None,
1144
  class_names: Optional[list] = None,
1145
+ detection_num_classes: int = 80,
1146
+ detection_fpn_channels: int = 256,
1147
+ detection_num_convs: int = 4,
1148
+ detection_class_names: Optional[list] = None,
1149
  **kwargs,
1150
  ):
1151
  super().__init__(**kwargs)
 
1158
  self.num_imagenet_classes = num_imagenet_classes
1159
  self.class_ids = class_ids or []
1160
  self.class_names = class_names or []
1161
+ self.detection_num_classes = detection_num_classes
1162
+ self.detection_fpn_channels = detection_fpn_channels
1163
+ self.detection_num_convs = detection_num_convs
1164
+ self.detection_class_names = detection_class_names or list(COCO_CLASSES)
1165
 
1166
 
1167
  class Argus(PreTrainedModel):
 
1196
  torch.zeros(config.num_imagenet_classes),
1197
  persistent=True,
1198
  )
1199
+ self.detection_head = DetectionHead(
1200
+ in_channels=config.embed_dim,
1201
+ fpn_channels=config.detection_fpn_channels,
1202
+ num_classes=config.detection_num_classes,
1203
+ num_convs=config.detection_num_convs,
1204
+ )
1205
 
1206
  for p in self.backbone.parameters():
1207
  p.requires_grad = False
1208
  self.backbone.eval()
1209
  self.seg_head.eval()
1210
  self.depth_head.eval()
1211
+ self.detection_head.eval()
1212
 
1213
  def _init_weights(self, module):
1214
+ # HF reallocates missing buffers and parameters with torch.empty()
1215
+ # (uninitialized memory) on from_pretrained. Populate sensible defaults
1216
+ # for the standard layer types used by the detection head, and zero any
1217
+ # Argus-level buffer that came back NaN.
1218
+ if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
1219
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
1220
+ if module.bias is not None:
1221
+ nn.init.zeros_(module.bias)
1222
+ elif isinstance(module, nn.GroupNorm):
1223
+ nn.init.ones_(module.weight)
1224
+ nn.init.zeros_(module.bias)
1225
+
1226
  if module is self:
1227
  for name in ("class_prototypes", "class_logit_weight", "class_logit_bias"):
1228
  if hasattr(self, name):
 
1263
  cls = F.normalize(cls, dim=-1)
1264
 
1265
  if method == "knn":
1266
+ proto = self.class_prototypes.to(cls.dtype)
1267
+ scores_full = cls @ proto.T # cosine similarity in [-1, 1]
1268
  elif method == "softmax":
1269
+ w = self.class_logit_weight.to(cls.dtype)
1270
+ b = self.class_logit_bias.to(cls.dtype)
1271
+ logits = F.linear(cls, w, b)
1272
  scores_full = F.softmax(logits, dim=-1) # in [0, 1]
1273
  else:
1274
  raise ValueError(f"unknown classification method: {method!r} (expected 'knn' or 'softmax')")
 
1388
  preds.append([px / resolution * tw, py / resolution * th])
1389
  return preds
1390
 
1391
+ @torch.inference_mode()
1392
+ def detect(
1393
+ self,
1394
+ image_or_images,
1395
+ resolution: int = 640,
1396
+ score_thresh: float = 0.05,
1397
+ nms_thresh: float = 0.5,
1398
+ max_per_image: int = 100,
1399
+ ):
1400
+ single, images = _normalize_image_input(image_or_images)
1401
+
1402
+ # Letterbox each image to match the training transform (resize long side
1403
+ # to `resolution`, pad bottom/right with black). Box coordinates are
1404
+ # recovered after decoding by unscaling.
1405
+ canvases, scales, orig_sizes = [], [], []
1406
+ for img in images:
1407
+ canvas, scale, orig = _letterbox_to_square(img, resolution)
1408
+ canvases.append(canvas)
1409
+ scales.append(scale)
1410
+ orig_sizes.append(orig)
1411
+
1412
+ det_normalize = v2.Compose([
1413
+ v2.ToImage(),
1414
+ v2.ToDtype(torch.float32, scale=True),
1415
+ v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
1416
+ ])
1417
+ batch = torch.stack([det_normalize(c) for c in canvases]).to(self.device)
1418
+
1419
+ _, spatial = self._extract(batch)
1420
+ with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
1421
+ cls_logits, box_regs, centernesses = self.detection_head(spatial)
1422
+ cls_logits = [c.float() for c in cls_logits]
1423
+ box_regs = [b.float() for b in box_regs]
1424
+ centernesses = [c.float() for c in centernesses]
1425
+
1426
+ feature_sizes = [(cl.shape[2], cl.shape[3]) for cl in cls_logits]
1427
+ locations = _make_locations(feature_sizes, FPN_STRIDES, spatial.device)
1428
+ image_sizes = [(resolution, resolution)] * len(images)
1429
+
1430
+ results = _decode_detections(
1431
+ cls_logits, box_regs, centernesses, locations,
1432
+ image_sizes=image_sizes,
1433
+ score_thresh=score_thresh,
1434
+ nms_thresh=nms_thresh,
1435
+ max_per_image=max_per_image,
1436
+ )
1437
+
1438
+ class_names = self.config.detection_class_names
1439
+ formatted = []
1440
+ for i, r in enumerate(results):
1441
+ scale = scales[i]
1442
+ orig_w, orig_h = orig_sizes[i]
1443
+ boxes = r["boxes"].cpu().numpy() / scale
1444
+ boxes[:, 0::2] = boxes[:, 0::2].clip(0, orig_w)
1445
+ boxes[:, 1::2] = boxes[:, 1::2].clip(0, orig_h)
1446
+
1447
+ detections = []
1448
+ for box, score, label in zip(
1449
+ boxes, r["scores"].cpu().numpy(), r["labels"].cpu().numpy()
1450
+ ):
1451
+ detections.append({
1452
+ "box": [float(v) for v in box.tolist()],
1453
+ "score": float(score),
1454
+ "label": int(label),
1455
+ "class_name": class_names[int(label)] if int(label) < len(class_names) else f"class_{int(label)}",
1456
+ })
1457
+ formatted.append(detections)
1458
+
1459
+ return formatted[0] if single else formatted
1460
+
1461
  def perceive(self, image_or_images, return_confidence: bool = False):
1462
  single, images = _normalize_image_input(image_or_images)
1463