yasserDahou commited on
Commit
ad720cd
·
verified ·
1 Parent(s): 9abbbe0

Update modeling_falcon_perception.py

Browse files
Files changed (1) hide show
  1. modeling_falcon_perception.py +70 -10
modeling_falcon_perception.py CHANGED
@@ -817,12 +817,64 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
817
  all_hw.append(v)
818
  return torch.tensor(all_xy), torch.tensor(all_hw)
819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
820
  def _postprocess_aux(
821
  self,
822
  aux_list: list,
823
  pixel_mask_hw: T,
824
  orig_hw: tuple[int, int],
825
  threshold: float,
 
826
  ) -> list[dict]:
827
  """Convert raw aux outputs into structured detections with RLE masks."""
828
  orig_h, orig_w = orig_hw
@@ -838,8 +890,8 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
838
  min_h = min_w = 0
839
  act_h = act_w = None
840
 
841
- # Group into triplets: coord, size, mask
842
- detections = []
843
  step = 3 # coord, size, mask
844
  for i in range(0, len(aux_list), step):
845
  if i + 2 >= len(aux_list):
@@ -861,15 +913,23 @@ class FalconPerceptionForSegmentation(PreTrainedModel):
861
 
862
  # Threshold
863
  binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
 
864
 
865
- # Encode as COCO RLE
866
- rle_list = self._mask_to_coco_rle(binary_mask.unsqueeze(0))
867
- mask_rle = rle_list[0] if rle_list else {"counts": "", "size": [orig_h, orig_w]}
 
 
 
 
 
 
868
 
869
- detections.append({
870
- "xy": xy,
871
- "hw": hw,
872
- "mask_rle": mask_rle,
873
- })
 
874
 
875
  return detections
 
817
  all_hw.append(v)
818
  return torch.tensor(all_xy), torch.tensor(all_hw)
819
 
820
+ @staticmethod
821
+ def _mask_nms(
822
+ binary_masks: list[torch.Tensor],
823
+ iou_threshold: float = 0.6,
824
+ nms_max_side: int = 256,
825
+ ) -> list[int]:
826
+ """
827
+ Fast vectorised mask NMS on binary (H, W) tensors.
828
+
829
+ Returns the list of kept indices ordered by descending mask score.
830
+ The IoU matrix is computed via a single batched matmul; suppression
831
+ uses one GPU boolean op per kept mask — no .item() in the inner loop.
832
+ """
833
+ N = len(binary_masks)
834
+ if N <= 1:
835
+ return list(range(N))
836
+
837
+ device = binary_masks[0].device
838
+ base_h, base_w = binary_masks[0].shape
839
+ scale = min(1.0, nms_max_side / max(base_h, base_w))
840
+ th = max(1, int(round(base_h * scale)))
841
+ tw = max(1, int(round(base_w * scale)))
842
+
843
+ resized = []
844
+ for m in binary_masks:
845
+ m = m.float()
846
+ if m.shape != (th, tw):
847
+ m = F.interpolate(
848
+ m[None, None], size=(th, tw), mode="bilinear", align_corners=False
849
+ ).squeeze()
850
+ resized.append(m)
851
+
852
+ binary = torch.stack(resized) # (N, th, tw)
853
+ flat = binary.view(N, -1) # (N, th*tw)
854
+ areas = flat.sum(dim=1) # (N,)
855
+ scores = areas # larger mask = higher priority
856
+ intersection = flat @ flat.T # (N, N)
857
+ union = areas[:, None] + areas[None, :] - intersection
858
+ iou = intersection / union.clamp(min=1)
859
+
860
+ order = scores.argsort(descending=True)
861
+ suppressed = torch.zeros(N, dtype=torch.bool, device=device)
862
+ keep = []
863
+ for idx in order.tolist():
864
+ if suppressed[idx]:
865
+ continue
866
+ keep.append(idx)
867
+ suppressed |= iou[idx] > iou_threshold
868
+
869
+ return keep
870
+
871
  def _postprocess_aux(
872
  self,
873
  aux_list: list,
874
  pixel_mask_hw: T,
875
  orig_hw: tuple[int, int],
876
  threshold: float,
877
+ nms_iou_threshold: float = 0.6,
878
  ) -> list[dict]:
879
  """Convert raw aux outputs into structured detections with RLE masks."""
880
  orig_h, orig_w = orig_hw
 
890
  min_h = min_w = 0
891
  act_h = act_w = None
892
 
893
+ # Group into triplets: coord, size, mask — build binary masks first
894
+ candidates = []
895
  step = 3 # coord, size, mask
896
  for i in range(0, len(aux_list), step):
897
  if i + 2 >= len(aux_list):
 
913
 
914
  # Threshold
915
  binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
916
+ candidates.append({"xy": xy, "hw": hw, "binary_mask": binary_mask})
917
 
918
+ if not candidates:
919
+ return []
920
+
921
+ # NMS on binary masks before RLE encoding
922
+ keep_indices = self._mask_nms(
923
+ [c["binary_mask"] for c in candidates],
924
+ iou_threshold=nms_iou_threshold,
925
+ )
926
+ candidates = [candidates[i] for i in keep_indices]
927
 
928
+ # Encode survivors as COCO RLE
929
+ detections = []
930
+ for c in candidates:
931
+ rle_list = self._mask_to_coco_rle(c["binary_mask"].unsqueeze(0))
932
+ mask_rle = rle_list[0] if rle_list else {"counts": "", "size": [orig_h, orig_w]}
933
+ detections.append({"xy": c["xy"], "hw": c["hw"], "mask_rle": mask_rle})
934
 
935
  return detections