hbyecoding's picture
Upload 143 files
b2c5353 verified
import torch
from typing import Tuple, Literal
def bbox_shaded(boxes: torch.Tensor,
shape: Tuple[int,int] = (128,128),
device='cuda') -> torch.Tensor:
"""
Represent a bounding box as a binary image with 1 inside the bbox and 0 outside
Args:
boxes Bx1x4 [x_min, y_min, x_max, y_max]
shape (tuple): (H,W)
device (str): 'cuda' or 'cpu'
Returns:
bbox_embed (torch.Tensor): Bx1xHxW according to shape
"""
assert len(shape)==2, "shape must be 2D"
if isinstance(boxes, torch.Tensor):
boxes = boxes.int().cpu().numpy()
batch_size = boxes.shape[0]
bbox_embed = torch.zeros((batch_size,1)+tuple(shape), device=device, dtype=torch.float32)
if boxes is not None:
for i in range(batch_size):
x_min, y_min, x_max, y_max = boxes[i,0,:]
bbox_embed[ i, 0, y_min:y_max, x_min:x_max ] = 1.0
return bbox_embed
def click_onehot(point_coords: torch.Tensor,
point_labels: torch.Tensor,
shape: Tuple[int,int] = (128,128),
indexing: Literal['xy','uv'] = 'xy') -> torch.Tensor:
"""
Represent clicks a masks of zeros with 1s at the click locations
Args:
point_coords (torch.Tensor): BxNx2 tensor of xy oordinates
point_labels (torch.Tensor): BxN tensor of click labels
shape (tuple): (H,W)
indexing (str): 'xy' or 'uv' indexing
Returns:
embed (torch.Tensor): Bx2xHxW tensor of clicks
"""
assert len(point_coords.shape) == 3, "point_coords must be BxNx2"
assert point_coords.shape[-1] == 2, "point_coords must be BxNx2"
assert point_labels.shape[-1] == point_coords.shape[1], "point_labels must be BxN"
assert len(shape)==2, f"shape must be 2D: {shape}"
device = point_coords.device
batch_size = point_coords.shape[0]
n_points = point_coords.shape[1]
embed = torch.zeros((batch_size,2)+shape, device=device)
labels = point_labels.flatten().float()
idx_coords = torch.cat(
(torch.arange(batch_size, device=device).reshape(-1,1).repeat(1,n_points)[...,None], point_coords),
axis=2).reshape(-1,3)
if indexing=='xy':
embed[ idx_coords[:,0], 0, idx_coords[:,2], idx_coords[:,1] ] = labels
embed[ idx_coords[:,0], 1, idx_coords[:,2], idx_coords[:,1] ] = 1.0-labels
else:
embed[ idx_coords[:,0], 0, idx_coords[:,1], idx_coords[:,2] ] = labels
embed[ idx_coords[:,0], 1, idx_coords[:,1], idx_coords[:,2] ] = 1.0-labels
return embed