|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Prompt encoder.""" |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class PromptEncoder(nn.Module): |
|
|
"""Module to encode geometric prompts.""" |
|
|
|
|
|
def __init__(self, embed_dim, image_size): |
|
|
super(PromptEncoder, self).__init__() |
|
|
self.point_embed = nn.Embedding(5, embed_dim) |
|
|
self.corner_labels = torch.tensor([[2, 3]], dtype=torch.int64) |
|
|
self.register_buffer("coord_matrix", torch.randn((2, embed_dim // 2))) |
|
|
self.img_pos, self.img_size = None, [image_size] * 2 |
|
|
|
|
|
def as_tensor(self, input): |
|
|
"""Convert input into a tensor.""" |
|
|
return torch.as_tensor(input, device=self.coord_matrix.device) |
|
|
|
|
|
def to_points(self, points=None, boxes=None): |
|
|
"""Convert points or boxes to point prompts.""" |
|
|
if points is not None: |
|
|
if isinstance(points, (tuple, list)): |
|
|
coords, labels = points |
|
|
else: |
|
|
coords, labels = points[:, :, :2], points[:, :, 2] |
|
|
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1]) |
|
|
coords = self.as_tensor(coords.clip(0, 1).astype("float32")) |
|
|
labels = self.as_tensor(labels.astype("int64")) |
|
|
return coords, labels |
|
|
if boxes is not None: |
|
|
coords = boxes.reshape((-1, 2, 2)) |
|
|
coords = coords.__add__(0.5).__itruediv__(self.img_size[::-1]) |
|
|
coords = self.as_tensor(coords.clip(0, 1).astype("float32")) |
|
|
labels = self.as_tensor(self.corner_labels) |
|
|
return coords, labels |
|
|
return None |
|
|
|
|
|
def encode_coords(self, coords): |
|
|
"""Return the embedding for given coords.""" |
|
|
pi4, pi2 = 4 * 3.1415926, 2 * 3.1415926 |
|
|
if self.coord_matrix.dtype != torch.float32: |
|
|
self.coord_matrix = self.coord_matrix.float() |
|
|
rad = coords.mul(pi4).sub_(pi2) @ self.coord_matrix |
|
|
dtype = self.point_embed.weight.dtype |
|
|
return torch.cat([rad.sin(), rad.cos()], dim=-1).to(dtype=dtype) |
|
|
|
|
|
def encode_points(self, coords, labels): |
|
|
"""Return the embedding for given points.""" |
|
|
embed = self.encode_coords(coords) |
|
|
embed.mul_(labels.ne(4).unsqueeze_(-1).float().to(dtype=embed.dtype)) |
|
|
return embed.add_(self.point_embed(labels)) |
|
|
|
|
|
def encode_grid(self, grid_size): |
|
|
"""Return the embedding for a grid of specified size.""" |
|
|
grid = torch.ones(*grid_size, dtype=torch.float32) |
|
|
y = grid.cumsum(dim=0).sub_(0.5).div_(grid_size[0]) |
|
|
x = grid.cumsum(dim=1).sub_(0.5).div_(grid_size[1]) |
|
|
coords = self.as_tensor(torch.stack([x, y], dim=-1)) |
|
|
return self.encode_coords(coords) |
|
|
|
|
|
def forward(self, inputs): |
|
|
sparse_embeds = [] |
|
|
if inputs.get("boxes", None) is not None: |
|
|
coords, labels = self.to_points(boxes=inputs["boxes"]) |
|
|
sparse_embeds.append(self.encode_points(coords, labels)) |
|
|
if inputs.get("points", None) is not None: |
|
|
coords, labels = self.to_points(points=inputs["points"]) |
|
|
sparse_embeds.append(self.encode_points(coords, labels)) |
|
|
if len(sparse_embeds) > 1: |
|
|
sparse_embeds = [torch.cat(sparse_embeds, dim=1)] |
|
|
elif len(sparse_embeds) == 0: |
|
|
raise ValueError("Excepted ``points`` or ``boxes`` prompts.") |
|
|
img_embed_size = torch.Size(inputs["img_embeds"].shape[2:-1]) |
|
|
if self.img_pos is None or self.img_pos.shape[0] != img_embed_size.numel(): |
|
|
self.img_pos = self.encode_grid(img_embed_size).flatten(0, 1) |
|
|
return {"sparse_embeds": sparse_embeds[0], "img_pos": self.img_pos} |
|
|
|