Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from typing import Literal | |
| # ----------------------------------------------------------------------------- | |
| # Bounding Box Simulation | |
| # ----------------------------------------------------------------------------- | |
| class UniformBBox: | |
| """ | |
| Sample bounding boxes with jitter randomly sampled from a uniform | |
| distribution [0, max_jitter) | |
| """ | |
| def __init__(self, max_jitter: int = 20, train: bool = True): | |
| self.max_jitter = max_jitter | |
| self.train = train | |
| def attrs(self): | |
| return { | |
| "jitter": 'uniform', | |
| "max_jitter": self.max_jitter, | |
| } | |
| def sample_bbox(self, seg: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Sample bounding boxes for a batch of segmentations | |
| """ | |
| H,W = seg.shape[-2:] | |
| device = seg.device | |
| in_ndim = len(seg.shape) | |
| if in_ndim==3: | |
| seg = seg.unsqueeze(0) | |
| bs = seg.shape[0] | |
| x,y = torch.meshgrid( | |
| torch.arange(H, device=device), | |
| torch.arange(W, device=device), | |
| indexing='xy' | |
| ) | |
| x = x.repeat(bs,1,1) | |
| y = y.repeat(bs,1,1) | |
| if seg.sum() == 0: | |
| # If no segmentation | |
| if self.train: | |
| return torch.zeros((bs,1,4), device=device) | |
| else: | |
| return None | |
| x_idx = torch.where(seg > 0, x, 0).reshape(bs,-1) | |
| y_idx = torch.where(seg > 0, y, 0).reshape(bs,-1) | |
| x_min, _ = x_idx.min(-1) | |
| x_max, _ = x_idx.max(-1) | |
| y_min, _ = y_idx.min(-1) | |
| y_max, _ = y_idx.max(-1) | |
| if self.max_jitter == 0: | |
| x_jitter = torch.zeros((2,bs), device=device) | |
| y_jitter = torch.zeros((2,bs), device=device) | |
| else: | |
| x_jitter = torch.randint(0, self.max_jitter, size=(2,bs), device=device) | |
| y_jitter = torch.randint(0, self.max_jitter, size=(2,bs), device=device) | |
| x_min = torch.clamp(x_min - x_jitter[0], min=0, max=W) | |
| x_max = torch.clamp(x_max + x_jitter[1], min=0, max=W) | |
| y_min = torch.clamp(y_min - y_jitter[0], min=0, max=H) | |
| y_max = torch.clamp(y_max + y_jitter[1], min=0, max=H) | |
| box = torch.stack([x_min, y_min, x_max, y_max], dim=1) | |
| if in_ndim == 4: | |
| box = box.unsqueeze(1) | |
| # shape: (b,1,4) or (1,4) | |
| return box | |
| def __call__(self, seg: torch.Tensor) -> np.ndarray: | |
| """ | |
| Args: | |
| seg: (b,1,H,W) or (1,H,W) mask in [0,1] to fit a bounding box to | |
| Returns: | |
| bbox (torch.Tensor): bounding box coordinates | |
| [x_min, y_min, x_max, y_max] with shape (b,1,4) or (1,4) | |
| Note: if the given mask is empty a box [0, 0, 0, 0] is returned | |
| """ | |
| assert len(seg.shape) in [3,4], \ | |
| f"mask must be Bx1xHxW or 1xHxW. currently {seg.shape}" | |
| assert seg.shape[-3] == 1, \ | |
| f"mask must have 1 channel. currently {seg.shape[-3]}" | |
| return self.sample_bbox(seg) | |