ColabWan / preprocessing /sam3 /perflib /masks_ops.py
1ripon1's picture
Upload folder using huggingface_hub
7344bef verified
Raw
History Blame Contribute Delete
2.83 kB
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
# pyre-unsafe
import torch
def masks_to_boxes(masks: torch.Tensor, obj_ids: list[int]):
with torch.autograd.profiler.record_function("perflib: masks_to_boxes"):
# Sanity check based on callsite for replacement
assert masks.shape[0] == len(obj_ids)
assert masks.dim() == 3
# Based on torchvision masks_to_boxes
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.float)
N, H, W = masks.shape
device = masks.device
y = torch.arange(H, device=device).view(1, H)
x = torch.arange(W, device=device).view(1, W)
masks_with_obj = masks != 0 # N, H, W
masks_with_obj_x = masks_with_obj.amax(
dim=1
) # N, H (which columns have objects)
masks_with_obj_y = masks_with_obj.amax(dim=2) # N, W (which rows have objects)
masks_without_obj_x = ~masks_with_obj_x
masks_without_obj_y = ~masks_with_obj_y
bounding_boxes_0 = torch.amin(
(masks_without_obj_x * W) + (masks_with_obj_x * x), dim=1
)
bounding_boxes_1 = torch.amin(
(masks_without_obj_y * H) + (masks_with_obj_y * y), dim=1
)
bounding_boxes_2 = torch.amax(masks_with_obj_x * x, dim=1)
bounding_boxes_3 = torch.amax(masks_with_obj_y * y, dim=1)
bounding_boxes = torch.stack(
[bounding_boxes_0, bounding_boxes_1, bounding_boxes_2, bounding_boxes_3],
dim=1,
).to(dtype=torch.float)
assert bounding_boxes.shape == (N, 4)
assert bounding_boxes.device == masks.device
assert bounding_boxes.dtype == torch.float
return bounding_boxes
def mask_iou(pred_masks: torch.Tensor, gt_masks: torch.Tensor) -> torch.Tensor:
"""
Compute the IoU (Intersection over Union) between predicted masks and ground truth masks.
Uses matmul-based vectorized intersection for Tensor Core acceleration.
Args:
- pred_masks: (N, H, W) bool Tensor, containing binary predicted segmentation masks
- gt_masks: (M, H, W) bool Tensor, containing binary ground truth segmentation masks
Returns:
- ious: (N, M) float Tensor, containing IoUs for each pair of predicted and ground truth masks
"""
assert pred_masks.dtype == gt_masks.dtype == torch.bool
assert pred_masks.shape[1:] == gt_masks.shape[1:]
# Matmul-based intersection (uses Tensor Cores via float mm)
m1_flat = pred_masks.flatten(1).float()
m2_flat = gt_masks.flatten(1).float()
intersection = torch.mm(m1_flat, m2_flat.t())
area1 = m1_flat.sum(dim=1)
area2 = m2_flat.sum(dim=1)
union = area1[:, None] + area2[None, :] - intersection
return intersection / union.clamp(min=1)