Spaces:
Sleeping
Sleeping
| import torch | |
| from typing import Tuple, Literal | |
| def bbox_shaded(boxes: torch.Tensor, | |
| shape: Tuple[int,int] = (128,128), | |
| device='cuda') -> torch.Tensor: | |
| """ | |
| Represent a bounding box as a binary image with 1 inside the bbox and 0 outside | |
| Args: | |
| boxes Bx1x4 [x_min, y_min, x_max, y_max] | |
| shape (tuple): (H,W) | |
| device (str): 'cuda' or 'cpu' | |
| Returns: | |
| bbox_embed (torch.Tensor): Bx1xHxW according to shape | |
| """ | |
| assert len(shape)==2, "shape must be 2D" | |
| if isinstance(boxes, torch.Tensor): | |
| boxes = boxes.int().cpu().numpy() | |
| batch_size = boxes.shape[0] | |
| bbox_embed = torch.zeros((batch_size,1)+tuple(shape), device=device, dtype=torch.float32) | |
| if boxes is not None: | |
| for i in range(batch_size): | |
| x_min, y_min, x_max, y_max = boxes[i,0,:] | |
| bbox_embed[ i, 0, y_min:y_max, x_min:x_max ] = 1.0 | |
| return bbox_embed | |
| def click_onehot(point_coords: torch.Tensor, | |
| point_labels: torch.Tensor, | |
| shape: Tuple[int,int] = (128,128), | |
| indexing: Literal['xy','uv'] = 'xy') -> torch.Tensor: | |
| """ | |
| Represent clicks a masks of zeros with 1s at the click locations | |
| Args: | |
| point_coords (torch.Tensor): BxNx2 tensor of xy oordinates | |
| point_labels (torch.Tensor): BxN tensor of click labels | |
| shape (tuple): (H,W) | |
| indexing (str): 'xy' or 'uv' indexing | |
| Returns: | |
| embed (torch.Tensor): Bx2xHxW tensor of clicks | |
| """ | |
| assert len(point_coords.shape) == 3, "point_coords must be BxNx2" | |
| assert point_coords.shape[-1] == 2, "point_coords must be BxNx2" | |
| assert point_labels.shape[-1] == point_coords.shape[1], "point_labels must be BxN" | |
| assert len(shape)==2, f"shape must be 2D: {shape}" | |
| device = point_coords.device | |
| batch_size = point_coords.shape[0] | |
| n_points = point_coords.shape[1] | |
| embed = torch.zeros((batch_size,2)+shape, device=device) | |
| labels = point_labels.flatten().float() | |
| idx_coords = torch.cat( | |
| (torch.arange(batch_size, device=device).reshape(-1,1).repeat(1,n_points)[...,None], point_coords), | |
| axis=2).reshape(-1,3) | |
| if indexing=='xy': | |
| embed[ idx_coords[:,0], 0, idx_coords[:,2], idx_coords[:,1] ] = labels | |
| embed[ idx_coords[:,0], 1, idx_coords[:,2], idx_coords[:,1] ] = 1.0-labels | |
| else: | |
| embed[ idx_coords[:,0], 0, idx_coords[:,1], idx_coords[:,2] ] = labels | |
| embed[ idx_coords[:,0], 1, idx_coords[:,1], idx_coords[:,2] ] = 1.0-labels | |
| return embed |