File size: 2,623 Bytes
1de0261
 
 
 
 
 
 
 
7e51321
 
1de0261
 
 
 
 
 
 
 
 
 
 
 
 
0a15a33
 
 
1de0261
 
 
 
 
 
 
 
 
 
3b3a0ee
 
 
 
 
1de0261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d7601c
 
1de0261
 
4d7601c
1de0261
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
from PIL import Image, ImageDraw
from torch import Tensor
from transformers import BaseImageProcessorFast


class LSPDetrImageProcessor(BaseImageProcessorFast):
    image_mean = (0.485, 0.456, 0.406)
    image_std = (0.229, 0.224, 0.225)
    do_rescale = True
    do_normalize = True

    def post_process(self, outputs: dict[str, Tensor]) -> list[dict[str, Tensor]]:
        """Converts the raw output into polygons.

        Returns:
            A list of dictionaries, each containing:
                - "polygons": A tensor of shape (N, num_radial_distances, 2) representing the polygons.
                - "labels": A tensor of shape (N,) representing the labels for each polygon.
        """
        radial_distances = outputs["radial_distances"].expm1()

        t = torch.linspace(
            0, 1, radial_distances.size(-1) + 1, device=radial_distances.device
        )[:-1]
        cos = torch.cos(2 * torch.pi * t)
        sin = torch.sin(2 * torch.pi * t)

        polar = radial_distances.unsqueeze(-1) * torch.stack([sin, cos], dim=-1)
        polygons = outputs["absolute_points"].unsqueeze(-2) + polar

        labels = outputs["logits"].argmax(dim=-1)
        non_no_object_indices = labels != outputs["logits"].size(-1) - 1

        return [
            {
                "polygons": polygons[b, indices],
                "labels": labels[b, indices],
                "embeddings": outputs["embeddings"][b, indices],
            }
            for b, indices in enumerate(non_no_object_indices)
        ]

    def post_process_instance(
        self,
        results: list[dict[str, Tensor]],
        height: int,
        width: int,
    ) -> list[dict[str, Tensor]]:
        """Converts the output into actual instance segmentation predictions.

        Args:
            results: Results list obtained by `post_process`, to which "masks" results will be added.
            height: Height of the input image.
            width: Width of the input image.
        """
        for i, result in enumerate(results):
            masks = torch.zeros(
                (len(result["polygons"]), height, width),
                dtype=torch.bool,
                device=result["polygons"].device,
            )

            for j, polygon in enumerate(result["polygons"]):
                img = Image.fromarray(masks[j].cpu().numpy())
                canvas = ImageDraw.Draw(img)
                canvas.polygon(xy=polygon.flatten().tolist(), outline=1, fill=1)
                masks[j] = torch.tensor(np.asarray(img))

            results[i]["masks"] = masks

        return results