matejpekar commited on
Commit
fbfa225
·
verified ·
1 Parent(s): 9c0349c

Upload processor

Browse files
image_processing_fast.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ from torch import Tensor
5
+ from transformers import BaseImageProcessorFast
6
+
7
+
8
+ class LSPDetrImageProcessorFast(BaseImageProcessorFast):
9
+ def post_process(self, outputs: dict[str, Tensor]) -> list[dict[str, Tensor]]:
10
+ """Converts the raw output into polygons.
11
+
12
+ Returns:
13
+ A list of dictionaries, each containing:
14
+ - "polygons": A tensor of shape (N, num_radial_distances, 2) representing the polygons.
15
+ - "labels": A tensor of shape (N,) representing the labels for each polygon.
16
+ """
17
+ radial_distances = outputs["radial_distances"].expm1()
18
+
19
+ t = torch.linspace(0, 1, radial_distances.size(-1) + 1, device=self.device)[:-1]
20
+ cos = torch.cos(2 * torch.pi * t)
21
+ sin = torch.sin(2 * torch.pi * t)
22
+
23
+ polar = radial_distances.unsqueeze(-1) * torch.stack([sin, cos], dim=-1)
24
+ polygons = outputs["absolute_points"].unsqueeze(-2) + polar
25
+
26
+ labels = outputs["logits"].argmax(dim=-1)
27
+ non_no_object_indices = labels != outputs["logits"].size(-1) - 1
28
+
29
+ return [
30
+ {"polygons": polygons[b, indices], "labels": labels[b, indices]}
31
+ for b, indices in enumerate(non_no_object_indices)
32
+ ]
33
+
34
+ def post_process_instance(
35
+ self,
36
+ results: list[dict[str, Tensor]],
37
+ height: int,
38
+ width: int,
39
+ ) -> list[dict[str, Tensor]]:
40
+ """Converts the output into actual instance segmentation predictions.
41
+
42
+ Args:
43
+ results: Results list obtained by `post_process`, to which "masks" results will be added.
44
+ height: Height of the input image.
45
+ width: Width of the input image.
46
+ """
47
+ for i, result in enumerate(results):
48
+ masks = torch.zeros(
49
+ (len(result["polygons"]), height, width),
50
+ dtype=torch.bool,
51
+ device=result["polygons"].device,
52
+ )
53
+
54
+ for i, polygon in enumerate(result["polygons"]):
55
+ img = Image.fromarray(masks[i].cpu().numpy())
56
+ canvas = ImageDraw.Draw(img)
57
+ canvas.polygon(xy=polygon.flatten().tolist(), outline=1, fill=1)
58
+ masks[i] = torch.from_numpy(np.asarray(img))
59
+
60
+ results[i]["masks"] = masks
61
+
62
+ return results
preprocessor_config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "auto_map": {
3
- "AutoImageProcessor": "image_processing.LSPDetrImageProcessor"
4
  },
5
  "crop_size": null,
6
  "data_format": "channels_first",
@@ -16,7 +16,7 @@
16
  0.456,
17
  0.406
18
  ],
19
- "image_processor_type": "LSPDetrImageProcessor",
20
  "image_std": [
21
  0.229,
22
  0.224,
 
1
  {
2
  "auto_map": {
3
+ "AutoImageProcessor": "image_processing_fast.LSPDetrImageProcessorFast"
4
  },
5
  "crop_size": null,
6
  "data_format": "channels_first",
 
16
  0.456,
17
  0.406
18
  ],
19
+ "image_processor_type": "LSPDetrImageProcessorFast",
20
  "image_std": [
21
  0.229,
22
  0.224,