Update modeling_falcon_perception.py
Browse files- modeling_falcon_perception.py +70 -10
modeling_falcon_perception.py
CHANGED
|
@@ -817,12 +817,64 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 817 |
all_hw.append(v)
|
| 818 |
return torch.tensor(all_xy), torch.tensor(all_hw)
|
| 819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
def _postprocess_aux(
|
| 821 |
self,
|
| 822 |
aux_list: list,
|
| 823 |
pixel_mask_hw: T,
|
| 824 |
orig_hw: tuple[int, int],
|
| 825 |
threshold: float,
|
|
|
|
| 826 |
) -> list[dict]:
|
| 827 |
"""Convert raw aux outputs into structured detections with RLE masks."""
|
| 828 |
orig_h, orig_w = orig_hw
|
|
@@ -838,8 +890,8 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 838 |
min_h = min_w = 0
|
| 839 |
act_h = act_w = None
|
| 840 |
|
| 841 |
-
# Group into triplets: coord, size, mask
|
| 842 |
-
|
| 843 |
step = 3 # coord, size, mask
|
| 844 |
for i in range(0, len(aux_list), step):
|
| 845 |
if i + 2 >= len(aux_list):
|
|
@@ -861,15 +913,23 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
|
|
| 861 |
|
| 862 |
# Threshold
|
| 863 |
binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
|
|
|
|
| 864 |
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
|
| 869 |
-
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
}
|
|
|
|
| 874 |
|
| 875 |
return detections
|
|
|
|
| 817 |
all_hw.append(v)
|
| 818 |
return torch.tensor(all_xy), torch.tensor(all_hw)
|
| 819 |
|
| 820 |
+
@staticmethod
|
| 821 |
+
def _mask_nms(
|
| 822 |
+
binary_masks: list[torch.Tensor],
|
| 823 |
+
iou_threshold: float = 0.6,
|
| 824 |
+
nms_max_side: int = 256,
|
| 825 |
+
) -> list[int]:
|
| 826 |
+
"""
|
| 827 |
+
Fast vectorised mask NMS on binary (H, W) tensors.
|
| 828 |
+
|
| 829 |
+
Returns the list of kept indices ordered by descending mask score.
|
| 830 |
+
The IoU matrix is computed via a single batched matmul; suppression
|
| 831 |
+
uses one GPU boolean op per kept mask — no .item() in the inner loop.
|
| 832 |
+
"""
|
| 833 |
+
N = len(binary_masks)
|
| 834 |
+
if N <= 1:
|
| 835 |
+
return list(range(N))
|
| 836 |
+
|
| 837 |
+
device = binary_masks[0].device
|
| 838 |
+
base_h, base_w = binary_masks[0].shape
|
| 839 |
+
scale = min(1.0, nms_max_side / max(base_h, base_w))
|
| 840 |
+
th = max(1, int(round(base_h * scale)))
|
| 841 |
+
tw = max(1, int(round(base_w * scale)))
|
| 842 |
+
|
| 843 |
+
resized = []
|
| 844 |
+
for m in binary_masks:
|
| 845 |
+
m = m.float()
|
| 846 |
+
if m.shape != (th, tw):
|
| 847 |
+
m = F.interpolate(
|
| 848 |
+
m[None, None], size=(th, tw), mode="bilinear", align_corners=False
|
| 849 |
+
).squeeze()
|
| 850 |
+
resized.append(m)
|
| 851 |
+
|
| 852 |
+
binary = torch.stack(resized) # (N, th, tw)
|
| 853 |
+
flat = binary.view(N, -1) # (N, th*tw)
|
| 854 |
+
areas = flat.sum(dim=1) # (N,)
|
| 855 |
+
scores = areas # larger mask = higher priority
|
| 856 |
+
intersection = flat @ flat.T # (N, N)
|
| 857 |
+
union = areas[:, None] + areas[None, :] - intersection
|
| 858 |
+
iou = intersection / union.clamp(min=1)
|
| 859 |
+
|
| 860 |
+
order = scores.argsort(descending=True)
|
| 861 |
+
suppressed = torch.zeros(N, dtype=torch.bool, device=device)
|
| 862 |
+
keep = []
|
| 863 |
+
for idx in order.tolist():
|
| 864 |
+
if suppressed[idx]:
|
| 865 |
+
continue
|
| 866 |
+
keep.append(idx)
|
| 867 |
+
suppressed |= iou[idx] > iou_threshold
|
| 868 |
+
|
| 869 |
+
return keep
|
| 870 |
+
|
| 871 |
def _postprocess_aux(
|
| 872 |
self,
|
| 873 |
aux_list: list,
|
| 874 |
pixel_mask_hw: T,
|
| 875 |
orig_hw: tuple[int, int],
|
| 876 |
threshold: float,
|
| 877 |
+
nms_iou_threshold: float = 0.6,
|
| 878 |
) -> list[dict]:
|
| 879 |
"""Convert raw aux outputs into structured detections with RLE masks."""
|
| 880 |
orig_h, orig_w = orig_hw
|
|
|
|
| 890 |
min_h = min_w = 0
|
| 891 |
act_h = act_w = None
|
| 892 |
|
| 893 |
+
# Group into triplets: coord, size, mask — build binary masks first
|
| 894 |
+
candidates = []
|
| 895 |
step = 3 # coord, size, mask
|
| 896 |
for i in range(0, len(aux_list), step):
|
| 897 |
if i + 2 >= len(aux_list):
|
|
|
|
| 913 |
|
| 914 |
# Threshold
|
| 915 |
binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
|
| 916 |
+
candidates.append({"xy": xy, "hw": hw, "binary_mask": binary_mask})
|
| 917 |
|
| 918 |
+
if not candidates:
|
| 919 |
+
return []
|
| 920 |
+
|
| 921 |
+
# NMS on binary masks before RLE encoding
|
| 922 |
+
keep_indices = self._mask_nms(
|
| 923 |
+
[c["binary_mask"] for c in candidates],
|
| 924 |
+
iou_threshold=nms_iou_threshold,
|
| 925 |
+
)
|
| 926 |
+
candidates = [candidates[i] for i in keep_indices]
|
| 927 |
|
| 928 |
+
# Encode survivors as COCO RLE
|
| 929 |
+
detections = []
|
| 930 |
+
for c in candidates:
|
| 931 |
+
rle_list = self._mask_to_coco_rle(c["binary_mask"].unsqueeze(0))
|
| 932 |
+
mask_rle = rle_list[0] if rle_list else {"counts": "", "size": [orig_h, orig_w]}
|
| 933 |
+
detections.append({"xy": c["xy"], "hw": c["hw"], "mask_rle": mask_rle})
|
| 934 |
|
| 935 |
return detections
|