hbyecoding's picture
Upload 143 files
b2c5353 verified
import torch
import numpy as np
from typing import Literal
# -----------------------------------------------------------------------------
# Bounding Box Simulation
# -----------------------------------------------------------------------------
class UniformBBox:
"""
Sample bounding boxes with jitter randomly sampled from a uniform
distribution [0, max_jitter)
"""
def __init__(self, max_jitter: int = 20, train: bool = True):
self.max_jitter = max_jitter
self.train = train
@property
def attrs(self):
return {
"jitter": 'uniform',
"max_jitter": self.max_jitter,
}
def sample_bbox(self, seg: torch.Tensor) -> torch.Tensor:
"""
Sample bounding boxes for a batch of segmentations
"""
H,W = seg.shape[-2:]
device = seg.device
in_ndim = len(seg.shape)
if in_ndim==3:
seg = seg.unsqueeze(0)
bs = seg.shape[0]
x,y = torch.meshgrid(
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='xy'
)
x = x.repeat(bs,1,1)
y = y.repeat(bs,1,1)
if seg.sum() == 0:
# If no segmentation
if self.train:
return torch.zeros((bs,1,4), device=device)
else:
return None
x_idx = torch.where(seg > 0, x, 0).reshape(bs,-1)
y_idx = torch.where(seg > 0, y, 0).reshape(bs,-1)
x_min, _ = x_idx.min(-1)
x_max, _ = x_idx.max(-1)
y_min, _ = y_idx.min(-1)
y_max, _ = y_idx.max(-1)
if self.max_jitter == 0:
x_jitter = torch.zeros((2,bs), device=device)
y_jitter = torch.zeros((2,bs), device=device)
else:
x_jitter = torch.randint(0, self.max_jitter, size=(2,bs), device=device)
y_jitter = torch.randint(0, self.max_jitter, size=(2,bs), device=device)
x_min = torch.clamp(x_min - x_jitter[0], min=0, max=W)
x_max = torch.clamp(x_max + x_jitter[1], min=0, max=W)
y_min = torch.clamp(y_min - y_jitter[0], min=0, max=H)
y_max = torch.clamp(y_max + y_jitter[1], min=0, max=H)
box = torch.stack([x_min, y_min, x_max, y_max], dim=1)
if in_ndim == 4:
box = box.unsqueeze(1)
# shape: (b,1,4) or (1,4)
return box
def __call__(self, seg: torch.Tensor) -> np.ndarray:
"""
Args:
seg: (b,1,H,W) or (1,H,W) mask in [0,1] to fit a bounding box to
Returns:
bbox (torch.Tensor): bounding box coordinates
[x_min, y_min, x_max, y_max] with shape (b,1,4) or (1,4)
Note: if the given mask is empty a box [0, 0, 0, 0] is returned
"""
assert len(seg.shape) in [3,4], \
f"mask must be Bx1xHxW or 1xHxW. currently {seg.shape}"
assert seg.shape[-3] == 1, \
f"mask must have 1 channel. currently {seg.shape[-3]}"
return self.sample_bbox(seg)