| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from torch import Tensor | |
| from transformers import BaseImageProcessorFast | |
| class LSPDetrImageProcessor(BaseImageProcessorFast): | |
| image_mean = (0.485, 0.456, 0.406) | |
| image_std = (0.229, 0.224, 0.225) | |
| do_rescale = True | |
| do_normalize = True | |
| def post_process(self, outputs: dict[str, Tensor]) -> list[dict[str, Tensor]]: | |
| """Converts the raw output into polygons. | |
| Returns: | |
| A list of dictionaries, each containing: | |
| - "polygons": A tensor of shape (N, num_radial_distances, 2) representing the polygons. | |
| - "labels": A tensor of shape (N,) representing the labels for each polygon. | |
| """ | |
| radial_distances = outputs["radial_distances"].expm1() | |
| t = torch.linspace( | |
| 0, 1, radial_distances.size(-1) + 1, device=radial_distances.device | |
| )[:-1] | |
| cos = torch.cos(2 * torch.pi * t) | |
| sin = torch.sin(2 * torch.pi * t) | |
| polar = radial_distances.unsqueeze(-1) * torch.stack([sin, cos], dim=-1) | |
| polygons = outputs["absolute_points"].unsqueeze(-2) + polar | |
| labels = outputs["logits"].argmax(dim=-1) | |
| non_no_object_indices = labels != outputs["logits"].size(-1) - 1 | |
| return [ | |
| { | |
| "polygons": polygons[b, indices], | |
| "labels": labels[b, indices], | |
| "embeddings": outputs["embeddings"][b, indices], | |
| } | |
| for b, indices in enumerate(non_no_object_indices) | |
| ] | |
| def post_process_instance( | |
| self, | |
| results: list[dict[str, Tensor]], | |
| height: int, | |
| width: int, | |
| ) -> list[dict[str, Tensor]]: | |
| """Converts the output into actual instance segmentation predictions. | |
| Args: | |
| results: Results list obtained by `post_process`, to which "masks" results will be added. | |
| height: Height of the input image. | |
| width: Width of the input image. | |
| """ | |
| for i, result in enumerate(results): | |
| masks = torch.zeros( | |
| (len(result["polygons"]), height, width), | |
| dtype=torch.bool, | |
| device=result["polygons"].device, | |
| ) | |
| for j, polygon in enumerate(result["polygons"]): | |
| img = Image.fromarray(masks[j].cpu().numpy()) | |
| canvas = ImageDraw.Draw(img) | |
| canvas.polygon(xy=polygon.flatten().tolist(), outline=1, fill=1) | |
| masks[j] = torch.tensor(np.asarray(img)) | |
| results[i]["masks"] = masks | |
| return results | |