shunk031 commited on
Commit
01a7a8b
·
1 Parent(s): 44ea5b0

deploy: 5dc3b259ecceeb40bc939de424152cbbf9555c53

Browse files
Files changed (3) hide show
  1. README.md +0 -1
  2. layout-unreadability.py +182 -0
  3. requirements.txt +90 -0
README.md CHANGED
@@ -8,5 +8,4 @@ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
layout-unreadability.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Literal, Optional, Union
3
+
4
+ import cv2
5
+ import datasets as ds
6
+ import evaluate
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ from PIL import Image
10
+ from PIL.Image import Image as PilImage
11
+
12
+ _DESCRIPTION = r"""\
13
+ Computes the non-flatness of regions that text elements are solely put on, referring to CGL-GAN.
14
+ """
15
+
16
+ _KWARGS_DESCRIPTION = """\
17
+ """
18
+
19
+ _CITATION = """\
20
+ @inproceedings{hsu2023posterlayout,
21
+ title={Posterlayout: A new benchmark and approach for content-aware visual-textual presentation layout},
22
+ author={Hsu, Hsiao Yuan and He, Xiangteng and Peng, Yuxin and Kong, Hao and Zhang, Qing},
23
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
24
+ pages={6018--6026},
25
+ year={2023}
26
+ }
27
+ """
28
+
29
+ ReqType = Literal["pil2cv", "cv2pil"]
30
+
31
+
32
+ class LayoutUnreadability(evaluate.Metric):
33
+ def __init__(
34
+ self,
35
+ canvas_width: int,
36
+ canvas_height: int,
37
+ text_label_index: int = 1,
38
+ decoration_label_index: int = 3,
39
+ **kwargs,
40
+ ) -> None:
41
+ super().__init__(**kwargs)
42
+ self.canvas_width = canvas_width
43
+ self.canvas_height = canvas_height
44
+
45
+ self.text_label_index = text_label_index
46
+ self.decoration_label_index = decoration_label_index
47
+
48
+ def _info(self) -> evaluate.EvaluationModuleInfo:
49
+ return evaluate.MetricInfo(
50
+ description=_DESCRIPTION,
51
+ citation=_CITATION,
52
+ inputs_description=_KWARGS_DESCRIPTION,
53
+ features=ds.Features(
54
+ {
55
+ "predictions": ds.Sequence(ds.Sequence(ds.Value("float64"))),
56
+ "gold_labels": ds.Sequence(ds.Sequence(ds.Value("int64"))),
57
+ "image_canvases": ds.Sequence(ds.Value("string")),
58
+ }
59
+ ),
60
+ codebase_urls=[
61
+ "https://github.com/PKU-ICST-MIPL/PosterLayout-CVPR2023/blob/main/eval.py#L144-L171"
62
+ ],
63
+ )
64
+
65
+ def cvt_pilcv(
66
+ self,
67
+ img: Union[PilImage, npt.NDArray[np.float64]],
68
+ req: ReqType = "pil2cv",
69
+ color_code: Optional[int] = None,
70
+ ) -> Union[PilImage, npt.NDArray[np.float64]]:
71
+ if req == "pil2cv":
72
+ assert isinstance(img, PilImage)
73
+ color_code = color_code or cv2.COLOR_RGB2BGR
74
+ return cv2.cvtColor(np.asarray(img), color_code)
75
+ elif req == "cv2pil":
76
+ assert isinstance(img, np.ndarray)
77
+ color_code = color_code or cv2.COLOR_BGR2RGB
78
+ return Image.fromarray(cv2.cvtColor(img, color_code))
79
+ else:
80
+ raise ValueError("req should be 'pil2cv' or 'cv2pil'")
81
+
82
+ def img_to_g_xy(self, img):
83
+ img_cv_gs = self.cvt_pilcv(img, req="pil2cv", color_code=cv2.COLOR_RGB2GRAY)
84
+ assert isinstance(img_cv_gs, np.ndarray)
85
+ img_cv_gs = np.uint8(img_cv_gs)
86
+
87
+ # Sobel(src, ddepth, dx, dy)
88
+ grad_x = cv2.Sobel(img_cv_gs, -1, 1, 0)
89
+ grad_y = cv2.Sobel(img_cv_gs, -1, 0, 1)
90
+ grad_xy = ((grad_x**2 + grad_y**2) / 2) ** 0.5
91
+ grad_xy = grad_xy / np.max(grad_xy) * 255
92
+ img_g_xy = Image.fromarray(grad_xy).convert("L")
93
+ return img_g_xy
94
+
95
+ def load_image_canvas(
96
+ self,
97
+ filepath: Union[os.PathLike, List[os.PathLike]],
98
+ ) -> npt.NDArray[np.float64]:
99
+ if isinstance(filepath, list):
100
+ assert len(filepath) == 1, filepath
101
+ filepath = filepath[0]
102
+
103
+ canvas_pil = Image.open(filepath) # type: ignore
104
+ canvas_pil = canvas_pil.convert("RGB")
105
+ if canvas_pil.size != (self.canvas_width, self.canvas_height):
106
+ canvas_pil = canvas_pil.resize((self.canvas_width, self.canvas_height))
107
+
108
+ canvas_pil = self.img_to_g_xy(canvas_pil)
109
+ assert isinstance(canvas_pil, PilImage)
110
+ canvas_arr = np.array(canvas_pil) / 255.0
111
+
112
+ return canvas_arr
113
+
114
+ def get_rid_of_invalid(
115
+ self, predictions: npt.NDArray[np.float64], gold_labels: npt.NDArray[np.int64]
116
+ ) -> npt.NDArray[np.int64]:
117
+ assert len(predictions) == len(gold_labels)
118
+
119
+ w = self.canvas_width / 100
120
+ h = self.canvas_height / 100
121
+
122
+ for i, prediction in enumerate(predictions):
123
+ for j, b in enumerate(prediction):
124
+ xl, yl, xr, yr = b
125
+ xl = max(0, xl)
126
+ yl = max(0, yl)
127
+ xr = min(self.canvas_width, xr)
128
+ yr = min(self.canvas_height, yr)
129
+ if abs((xr - xl) * (yr - yl)) < w * h * 10:
130
+ if gold_labels[i, j]:
131
+ gold_labels[i, j] = 0
132
+ return gold_labels
133
+
134
+ def _compute(
135
+ self,
136
+ *,
137
+ predictions: Union[npt.NDArray[np.float64], List[List[float]]],
138
+ gold_labels: Union[npt.NDArray[np.int64], List[int]],
139
+ image_canvases: List[os.PathLike],
140
+ ):
141
+ predictions = np.array(predictions)
142
+ gold_labels = np.array(gold_labels)
143
+
144
+ predictions[:, :, ::2] *= self.canvas_width
145
+ predictions[:, :, 1::2] *= self.canvas_height
146
+
147
+ gold_labels = self.get_rid_of_invalid(
148
+ predictions=predictions, gold_labels=gold_labels
149
+ )
150
+ score = 0.0
151
+
152
+ assert len(predictions) == len(gold_labels) == len(image_canvases)
153
+ num_predictions = len(predictions)
154
+ it = zip(predictions, gold_labels, image_canvases)
155
+
156
+ for prediction, gold_label, image_canvas in it:
157
+ canvas_arr = self.load_image_canvas(
158
+ image_canvas,
159
+ )
160
+ cal_mask = np.zeros_like(canvas_arr)
161
+
162
+ prediction = np.array(prediction, dtype=int)
163
+ gold_label = np.array(gold_label, dtype=int)
164
+
165
+ is_text = (gold_label == self.text_label_index).reshape(-1)
166
+ prediction_text = prediction[is_text]
167
+
168
+ is_decoration = (gold_label == self.decoration_label_index).reshape(-1)
169
+ prediction_deco = prediction[is_decoration]
170
+
171
+ for mp in prediction_text:
172
+ xl, yl, xr, yr = mp
173
+ cal_mask[yl:yr, xl:xr] = 1
174
+ for mp in prediction_deco:
175
+ xl, yl, xr, yr = mp
176
+ cal_mask[yl:yr, xl:xr] = 0
177
+
178
+ total_area = np.sum(cal_mask)
179
+ total_grad = np.sum(canvas_arr[cal_mask == 1])
180
+ if total_area and total_grad:
181
+ score += total_grad / total_area
182
+ return score / num_predictions
requirements.txt ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
2
+ aiohttp==3.9.3 ; python_version >= "3.9" and python_version < "4.0"
3
+ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
4
+ altair==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
5
+ annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
6
+ anyio==4.2.0 ; python_version >= "3.9" and python_version < "4.0"
7
+ arrow==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
8
+ async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
9
+ attrs==23.2.0 ; python_version >= "3.9" and python_version < "4.0"
10
+ binaryornot==0.4.4 ; python_version >= "3.9" and python_version < "4.0"
11
+ certifi==2024.2.2 ; python_version >= "3.9" and python_version < "4.0"
12
+ chardet==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
13
+ charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "4.0"
14
+ click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
15
+ colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0"
16
+ contourpy==1.2.0 ; python_version >= "3.9" and python_version < "4.0"
17
+ cookiecutter==2.5.0 ; python_version >= "3.9" and python_version < "4.0"
18
+ cycler==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
19
+ datasets==2.17.0 ; python_version >= "3.9" and python_version < "4.0"
20
+ dill==0.3.8 ; python_version >= "3.9" and python_version < "4.0"
21
+ evaluate[template]==0.4.1 ; python_version >= "3.9" and python_version < "4.0"
22
+ exceptiongroup==1.2.0 ; python_version >= "3.9" and python_version < "3.11"
23
+ fastapi==0.109.2 ; python_version >= "3.9" and python_version < "4.0"
24
+ ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
25
+ filelock==3.13.1 ; python_version >= "3.9" and python_version < "4.0"
26
+ fonttools==4.48.1 ; python_version >= "3.9" and python_version < "4.0"
27
+ frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "4.0"
28
+ fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
29
+ fsspec[http]==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
30
+ gradio-client==0.10.0 ; python_version >= "3.9" and python_version < "4.0"
31
+ gradio==4.18.0 ; python_version >= "3.9" and python_version < "4.0"
32
+ h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
33
+ httpcore==1.0.2 ; python_version >= "3.9" and python_version < "4.0"
34
+ httpx==0.26.0 ; python_version >= "3.9" and python_version < "4.0"
35
+ huggingface-hub==0.20.3 ; python_version >= "3.9" and python_version < "4.0"
36
+ idna==3.6 ; python_version >= "3.9" and python_version < "4.0"
37
+ importlib-resources==6.1.1 ; python_version >= "3.9" and python_version < "4.0"
38
+ jinja2==3.1.3 ; python_version >= "3.9" and python_version < "4.0"
39
+ jsonschema-specifications==2023.12.1 ; python_version >= "3.9" and python_version < "4.0"
40
+ jsonschema==4.21.1 ; python_version >= "3.9" and python_version < "4.0"
41
+ kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
42
+ markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "4.0"
43
+ markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "4.0"
44
+ matplotlib==3.8.2 ; python_version >= "3.9" and python_version < "4.0"
45
+ mdurl==0.1.2 ; python_version >= "3.9" and python_version < "4.0"
46
+ multidict==6.0.5 ; python_version >= "3.9" and python_version < "4.0"
47
+ multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "4.0"
48
+ numpy==1.26.4 ; python_version >= "3.9" and python_version < "4.0"
49
+ opencv-python==4.10.0.84 ; python_version >= "3.9" and python_version < "4.0"
50
+ orjson==3.9.13 ; python_version >= "3.9" and python_version < "4.0"
51
+ packaging==23.2 ; python_version >= "3.9" and python_version < "4.0"
52
+ pandas==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
53
+ pillow==10.2.0 ; python_version >= "3.9" and python_version < "4.0"
54
+ pyarrow-hotfix==0.6 ; python_version >= "3.9" and python_version < "4.0"
55
+ pyarrow==15.0.0 ; python_version >= "3.9" and python_version < "4.0"
56
+ pydantic-core==2.16.2 ; python_version >= "3.9" and python_version < "4.0"
57
+ pydantic==2.6.1 ; python_version >= "3.9" and python_version < "4.0"
58
+ pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
59
+ pygments==2.17.2 ; python_version >= "3.9" and python_version < "4.0"
60
+ pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
61
+ python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
62
+ python-multipart==0.0.9 ; python_version >= "3.9" and python_version < "4.0"
63
+ python-slugify==8.0.4 ; python_version >= "3.9" and python_version < "4.0"
64
+ pytz==2024.1 ; python_version >= "3.9" and python_version < "4.0"
65
+ pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
66
+ referencing==0.33.0 ; python_version >= "3.9" and python_version < "4.0"
67
+ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
68
+ responses==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
69
+ rich==13.7.0 ; python_version >= "3.9" and python_version < "4.0"
70
+ rpds-py==0.17.1 ; python_version >= "3.9" and python_version < "4.0"
71
+ ruff==0.2.1 ; python_version >= "3.9" and python_version < "4.0"
72
+ semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
73
+ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "4.0"
74
+ six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
75
+ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
76
+ starlette==0.36.3 ; python_version >= "3.9" and python_version < "4.0"
77
+ text-unidecode==1.3 ; python_version >= "3.9" and python_version < "4.0"
78
+ tomlkit==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
79
+ toolz==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
80
+ tqdm==4.66.2 ; python_version >= "3.9" and python_version < "4.0"
81
+ typer[all]==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
82
+ types-python-dateutil==2.8.19.20240106 ; python_version >= "3.9" and python_version < "4.0"
83
+ typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "4.0"
84
+ tzdata==2024.1 ; python_version >= "3.9" and python_version < "4.0"
85
+ urllib3==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
86
+ uvicorn==0.27.1 ; python_version >= "3.9" and python_version < "4.0"
87
+ websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
88
+ xxhash==3.4.1 ; python_version >= "3.9" and python_version < "4.0"
89
+ yarl==1.9.4 ; python_version >= "3.9" and python_version < "4.0"
90
+ zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"