LSP-DETR / image_processing.py
matejpekar's picture
Upload processor
7e51321 verified
raw
history blame
2.51 kB
import numpy as np
import torch
from PIL import Image, ImageDraw
from torch import Tensor
from transformers import BaseImageProcessorFast
class LSPDetrImageProcessor(BaseImageProcessorFast):
image_mean = (0.485, 0.456, 0.406)
image_std = (0.229, 0.224, 0.225)
do_rescale = True
do_normalize = True
def post_process(self, outputs: dict[str, Tensor]) -> list[dict[str, Tensor]]:
"""Converts the raw output into polygons.
Returns:
A list of dictionaries, each containing:
- "polygons": A tensor of shape (N, num_radial_distances, 2) representing the polygons.
- "labels": A tensor of shape (N,) representing the labels for each polygon.
"""
radial_distances = outputs["radial_distances"].expm1()
t = torch.linspace(
0, 1, radial_distances.size(-1) + 1, device=radial_distances.device
)[:-1]
cos = torch.cos(2 * torch.pi * t)
sin = torch.sin(2 * torch.pi * t)
polar = radial_distances.unsqueeze(-1) * torch.stack([sin, cos], dim=-1)
polygons = outputs["absolute_points"].unsqueeze(-2) + polar
labels = outputs["logits"].argmax(dim=-1)
non_no_object_indices = labels != outputs["logits"].size(-1) - 1
return [
{"polygons": polygons[b, indices], "labels": labels[b, indices]}
for b, indices in enumerate(non_no_object_indices)
]
def post_process_instance(
self,
results: list[dict[str, Tensor]],
height: int,
width: int,
) -> list[dict[str, Tensor]]:
"""Converts the output into actual instance segmentation predictions.
Args:
results: Results list obtained by `post_process`, to which "masks" results will be added.
height: Height of the input image.
width: Width of the input image.
"""
for i, result in enumerate(results):
masks = torch.zeros(
(len(result["polygons"]), height, width),
dtype=torch.bool,
device=result["polygons"].device,
)
for j, polygon in enumerate(result["polygons"]):
img = Image.fromarray(masks[j].cpu().numpy())
canvas = ImageDraw.Draw(img)
canvas.polygon(xy=polygon.flatten().tolist(), outline=1, fill=1)
masks[j] = torch.tensor(np.asarray(img))
results[i]["masks"] = masks
return results