shunk031 commited on
Commit
f3cdeee
·
1 Parent(s): c64d770

deploy: fb8481effdf5a0b23ff86fad414906046d7620bd

Browse files
Files changed (1) hide show
  1. layout-overlay.py +55 -19
layout-overlay.py CHANGED
@@ -12,14 +12,23 @@ Computes the average IoU of all pairs of elements except for underlay.
12
 
13
  _KWARGS_DESCRIPTION = """\
14
  Args:
15
- predictions (`list` of `lists` of `float`): A list of lists of floats representing normalized `ltrb`-format bounding boxes.
16
- gold_labels (`list` of `lists` of `int`): A list of lists of integers representing class labels.
17
-
18
- Ruturns:
19
- float: Average IoU except decoration (i.e., underlay) elements (used in PosterLayout).
20
-
21
- Examples::
22
- FIXME
 
 
 
 
 
 
 
 
 
23
  """
24
 
25
  _CITATION = """\
@@ -37,8 +46,8 @@ _CITATION = """\
37
  class LayoutOverlay(evaluate.Metric):
38
  def __init__(
39
  self,
40
- canvas_width: int,
41
- canvas_height: int,
42
  decoration_label_index: int = 3,
43
  **kwargs,
44
  ) -> None:
@@ -64,20 +73,24 @@ class LayoutOverlay(evaluate.Metric):
64
  )
65
 
66
  def get_rid_of_invalid(
67
- self, predictions: npt.NDArray[np.float64], gold_labels: npt.NDArray[np.int64]
 
 
 
 
68
  ) -> npt.NDArray[np.int64]:
69
  assert len(predictions) == len(gold_labels)
70
 
71
- w = self.canvas_width / 100
72
- h = self.canvas_height / 100
73
 
74
  for i, prediction in enumerate(predictions):
75
  for j, b in enumerate(prediction):
76
  xl, yl, xr, yr = b
77
  xl = max(0, xl)
78
  yl = max(0, yl)
79
- xr = min(self.canvas_width, xr)
80
- yr = min(self.canvas_height, yr)
81
  if abs((xr - xl) * (yr - yl)) < w * h * 10:
82
  if gold_labels[i, j]:
83
  gold_labels[i, j] = 0
@@ -111,15 +124,38 @@ class LayoutOverlay(evaluate.Metric):
111
  *,
112
  predictions: Union[npt.NDArray[np.float64], List[List[float]]],
113
  gold_labels: Union[npt.NDArray[np.int64], List[int]],
 
 
 
114
  ) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  predictions = np.array(predictions)
116
  gold_labels = np.array(gold_labels)
117
 
118
- predictions[:, :, ::2] *= self.canvas_width
119
- predictions[:, :, 1::2] *= self.canvas_height
120
 
121
  gold_labels = self.get_rid_of_invalid(
122
- predictions=predictions, gold_labels=gold_labels
 
 
 
123
  )
124
 
125
  score = 0.0
@@ -128,7 +164,7 @@ class LayoutOverlay(evaluate.Metric):
128
  ove = 0.0
129
 
130
  cond1 = (gold_label > 0).reshape(-1)
131
- cond2 = (gold_label != self.decoration_label_index).reshape(-1)
132
 
133
  mask = cond1 & cond2
134
  mask_box = prediction[mask]
 
12
 
13
  _KWARGS_DESCRIPTION = """\
14
  Args:
15
+ predictions (`list` of `list` of `float`): A list of lists of floats representing normalized `ltrb`-format bounding boxes.
16
+ gold_labels (`list` of `list` of `int`): A list of lists of integers representing class labels.
17
+ canvas_width (`int`, *optional*): Width of the canvas in pixels. Can be provided at initialization or during computation.
18
+ canvas_height (`int`, *optional*): Height of the canvas in pixels. Can be provided at initialization or during computation.
19
+ decoration_label_index (`int`, *optional*, defaults to 3): The label index for decoration (underlay) elements to exclude from overlay computation.
20
+
21
+ Returns:
22
+ float: Average IoU (Intersection over Union) of all pairs of elements except decoration (underlay) elements. Higher values indicate more overlap between elements.
23
+
24
+ Examples:
25
+ >>> import evaluate
26
+ >>> metric = evaluate.load("creative-graphic-design/layout-overlay")
27
+ >>> # Normalized bounding boxes (left, top, right, bottom)
28
+ >>> predictions = [[[0.1, 0.1, 0.5, 0.5], [0.3, 0.3, 0.7, 0.7]]] # Overlapping elements
29
+ >>> gold_labels = [[1, 2]] # Both are non-decoration elements
30
+ >>> result = metric.compute(predictions=predictions, gold_labels=gold_labels, canvas_width=512, canvas_height=512)
31
+ >>> print(f"Overlay score: {result:.4f}")
32
  """
33
 
34
  _CITATION = """\
 
46
  class LayoutOverlay(evaluate.Metric):
47
  def __init__(
48
  self,
49
+ canvas_width: int | None = None,
50
+ canvas_height: int | None = None,
51
  decoration_label_index: int = 3,
52
  **kwargs,
53
  ) -> None:
 
73
  )
74
 
75
  def get_rid_of_invalid(
76
+ self,
77
+ predictions: npt.NDArray[np.float64],
78
+ gold_labels: npt.NDArray[np.int64],
79
+ canvas_width: int,
80
+ canvas_height: int,
81
  ) -> npt.NDArray[np.int64]:
82
  assert len(predictions) == len(gold_labels)
83
 
84
+ w = canvas_width / 100
85
+ h = canvas_height / 100
86
 
87
  for i, prediction in enumerate(predictions):
88
  for j, b in enumerate(prediction):
89
  xl, yl, xr, yr = b
90
  xl = max(0, xl)
91
  yl = max(0, yl)
92
+ xr = min(canvas_width, xr)
93
+ yr = min(canvas_height, yr)
94
  if abs((xr - xl) * (yr - yl)) < w * h * 10:
95
  if gold_labels[i, j]:
96
  gold_labels[i, j] = 0
 
124
  *,
125
  predictions: Union[npt.NDArray[np.float64], List[List[float]]],
126
  gold_labels: Union[npt.NDArray[np.int64], List[int]],
127
+ canvas_width: int | None = None,
128
+ canvas_height: int | None = None,
129
+ decoration_label_index: int | None = None,
130
  ) -> float:
131
+ # パラメータの優先順位処理
132
+ canvas_width = canvas_width if canvas_width is not None else self.canvas_width
133
+ canvas_height = (
134
+ canvas_height if canvas_height is not None else self.canvas_height
135
+ )
136
+ decoration_label_index = (
137
+ decoration_label_index
138
+ if decoration_label_index is not None
139
+ else self.decoration_label_index
140
+ )
141
+
142
+ if canvas_width is None or canvas_height is None:
143
+ raise ValueError(
144
+ "canvas_width and canvas_height must be provided either "
145
+ "at initialization or during computation"
146
+ )
147
+
148
  predictions = np.array(predictions)
149
  gold_labels = np.array(gold_labels)
150
 
151
+ predictions[:, :, ::2] *= canvas_width
152
+ predictions[:, :, 1::2] *= canvas_height
153
 
154
  gold_labels = self.get_rid_of_invalid(
155
+ predictions=predictions,
156
+ gold_labels=gold_labels,
157
+ canvas_width=canvas_width,
158
+ canvas_height=canvas_height,
159
  )
160
 
161
  score = 0.0
 
164
  ove = 0.0
165
 
166
  cond1 = (gold_label > 0).reshape(-1)
167
+ cond2 = (gold_label != decoration_label_index).reshape(-1)
168
 
169
  mask = cond1 & cond2
170
  mask_box = prediction[mask]