matejpekar commited on
Commit
06e41e7
·
verified ·
1 Parent(s): 1de0261

Delete image_processing_fast.py

Browse files
Files changed (1) hide show
  1. image_processing_fast.py +0 -62
image_processing_fast.py DELETED
@@ -1,62 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from PIL import Image, ImageDraw
4
- from torch import Tensor
5
- from transformers import BaseImageProcessorFast
6
-
7
-
8
- class LSPDetrImageProcessorFast(BaseImageProcessorFast):
9
- def post_process(self, outputs: dict[str, Tensor]) -> list[dict[str, Tensor]]:
10
- """Converts the raw output into polygons.
11
-
12
- Returns:
13
- A list of dictionaries, each containing:
14
- - "polygons": A tensor of shape (N, num_radial_distances, 2) representing the polygons.
15
- - "labels": A tensor of shape (N,) representing the labels for each polygon.
16
- """
17
- radial_distances = outputs["radial_distances"].expm1()
18
-
19
- t = torch.linspace(0, 1, radial_distances.size(-1) + 1, device=self.device)[:-1]
20
- cos = torch.cos(2 * torch.pi * t)
21
- sin = torch.sin(2 * torch.pi * t)
22
-
23
- polar = radial_distances.unsqueeze(-1) * torch.stack([sin, cos], dim=-1)
24
- polygons = outputs["absolute_points"].unsqueeze(-2) + polar
25
-
26
- labels = outputs["logits"].argmax(dim=-1)
27
- non_no_object_indices = labels != outputs["logits"].size(-1) - 1
28
-
29
- return [
30
- {"polygons": polygons[b, indices], "labels": labels[b, indices]}
31
- for b, indices in enumerate(non_no_object_indices)
32
- ]
33
-
34
- def post_process_instance(
35
- self,
36
- results: list[dict[str, Tensor]],
37
- height: int,
38
- width: int,
39
- ) -> list[dict[str, Tensor]]:
40
- """Converts the output into actual instance segmentation predictions.
41
-
42
- Args:
43
- results: Results list obtained by `post_process`, to which "masks" results will be added.
44
- height: Height of the input image.
45
- width: Width of the input image.
46
- """
47
- for i, result in enumerate(results):
48
- masks = torch.zeros(
49
- (len(result["polygons"]), height, width),
50
- dtype=torch.bool,
51
- device=result["polygons"].device,
52
- )
53
-
54
- for i, polygon in enumerate(result["polygons"]):
55
- img = Image.fromarray(masks[i].cpu().numpy())
56
- canvas = ImageDraw.Draw(img)
57
- canvas.polygon(xy=polygon.flatten().tolist(), outline=1, fill=1)
58
- masks[i] = torch.from_numpy(np.asarray(img))
59
-
60
- results[i]["masks"] = masks
61
-
62
- return results