shunk031 commited on
Commit
dbadb25
·
1 Parent(s): 3a2b4b6

deploy: 9adc1dc23f68a9b65ccfc5e5f8dce4b230a575ce

Browse files
Files changed (2) hide show
  1. layout_alignment.py +185 -0
  2. requirements.txt +89 -0
layout_alignment.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Union
2
+
3
+ import datasets as ds
4
+ import evaluate
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+
8
+ _DESCRIPTION = """\
9
+ Computes some alignment metrics that are different to each other in previous works.
10
+ """
11
+
12
+ _CITATION = """\
13
+ @inproceedings{lee2020neural,
14
+ title={Neural design network: Graphic layout generation with constraints},
15
+ author={Lee, Hsin-Ying and Jiang, Lu and Essa, Irfan and Le, Phuong B and Gong, Haifeng and Yang, Ming-Hsuan and Yang, Weilong},
16
+ booktitle={Computer Vision--ECCV 2020: 16th European Conference, Glasgow, UK, August 23--28, 2020, Proceedings, Part III 16},
17
+ pages={491--506},
18
+ year={2020},
19
+ organization={Springer}
20
+ }
21
+
22
+ @article{li2020attribute,
23
+ title={Attribute-conditioned layout gan for automatic graphic design},
24
+ author={Li, Jianan and Yang, Jimei and Zhang, Jianming and Liu, Chang and Wang, Christina and Xu, Tingfa},
25
+ journal={IEEE Transactions on Visualization and Computer Graphics},
26
+ volume={27},
27
+ number={10},
28
+ pages={4039--4048},
29
+ year={2020},
30
+ publisher={IEEE}
31
+ }
32
+
33
+ @inproceedings{kikuchi2021constrained,
34
+ title={Constrained graphic layout generation via latent optimization},
35
+ author={Kikuchi, Kotaro and Simo-Serra, Edgar and Otani, Mayu and Yamaguchi, Kota},
36
+ booktitle={Proceedings of the 29th ACM International Conference on Multimedia},
37
+ pages={88--96},
38
+ year={2021}
39
+ }
40
+ """
41
+
42
+
43
+ def convert_xywh_to_ltrb(
44
+ batch_bbox: npt.NDArray[np.float64],
45
+ ) -> Tuple[
46
+ npt.NDArray[np.float64],
47
+ npt.NDArray[np.float64],
48
+ npt.NDArray[np.float64],
49
+ npt.NDArray[np.float64],
50
+ ]:
51
+ xc, yc, w, h = batch_bbox
52
+ x1 = xc - w / 2
53
+ y1 = yc - h / 2
54
+ x2 = xc + w / 2
55
+ y2 = yc + h / 2
56
+ return (x1, y1, x2, y2)
57
+
58
+
59
+ class LayoutAlignment(evaluate.Metric):
60
+ def _info(self) -> evaluate.EvaluationModuleInfo:
61
+ return evaluate.MetricInfo(
62
+ description=_DESCRIPTION,
63
+ citation=_CITATION,
64
+ features=ds.Features(
65
+ {
66
+ "batch_bbox": ds.Sequence(ds.Sequence(ds.Value("float64"))),
67
+ "batch_mask": ds.Sequence(ds.Value("bool")),
68
+ }
69
+ ),
70
+ codebase_urls=[
71
+ "https://github.com/ktrk115/const_layout/blob/master/metric.py#L167-L188",
72
+ "https://github.com/CyberAgentAILab/layout-dm/blob/main/src/trainer/trainer/helpers/metric.py#L98-L147",
73
+ ],
74
+ )
75
+
76
+ def _compute_ac_layout_gan(
77
+ self,
78
+ S: int,
79
+ xl: npt.NDArray[np.float64],
80
+ xc: npt.NDArray[np.float64],
81
+ xr: npt.NDArray[np.float64],
82
+ yt: npt.NDArray[np.float64],
83
+ yc: npt.NDArray[np.float64],
84
+ yb: npt.NDArray[np.float64],
85
+ batch_mask: npt.NDArray,
86
+ ) -> npt.NDArray[np.float64]:
87
+ # shape: (B, 6, S)
88
+ X = np.stack((xl, xc, xr, yt, yc, yb), axis=1)
89
+ # shape: (B, 6, S, 1) - (B, 6, 1, S) = (B, 6 S, S)
90
+ X = X[:, :, :, None] - X[:, :, None, :]
91
+
92
+ # shape: (S,)
93
+ indices = np.arange(S)
94
+ X[:, :, indices, indices] = 1.0
95
+ # shape: (B, 6, S, S -> (B, S, 6, S)
96
+ X = np.abs(X).transpose(0, 2, 1, 3)
97
+ X[~batch_mask] = 1.0
98
+
99
+ # shape: (B, S)
100
+ X = X.min(axis=-1).min(axis=-1)
101
+ X[X == 1.0] = 0.0
102
+ X = -np.log(1 - X)
103
+
104
+ # shape: (B, S) -> (B,)
105
+ return X.sum(axis=1)
106
+
107
+ def _compute_layout_gan_pp(
108
+ self,
109
+ score_ac_layout_gan: npt.NDArray[np.float64],
110
+ batch_mask: npt.NDArray[np.bool_],
111
+ ) -> npt.NDArray[np.float64]:
112
+ # shape: (B, S) -> (B,)
113
+ batch_mask = batch_mask.sum(axis=1)
114
+
115
+ # shape: (B,)
116
+ score_normalized = score_ac_layout_gan / batch_mask
117
+ score_normalized[np.isnan(score_normalized)] = 0.0
118
+ return score_normalized
119
+
120
+ def _compute_neural_design_network(
121
+ self,
122
+ xl: npt.NDArray[np.float64],
123
+ xc: npt.NDArray[np.float64],
124
+ xr: npt.NDArray[np.float64],
125
+ batch_mask: npt.NDArray[np.bool_],
126
+ S: int,
127
+ ):
128
+ # shape: (B, 3, S)
129
+ Y = np.stack((xl, xc, xr), axis=1)
130
+ # shape: (B, 3, S, S)
131
+ Y = Y[:, :, None, :] - Y[:, :, :, None]
132
+
133
+ # shape: (B, S) -> (B, S, S)
134
+ batch_mask = ~batch_mask[:, None, :] | ~batch_mask[:, :, None]
135
+ # shape: (B,)
136
+ indices = np.arange(S)
137
+ batch_mask[:, indices, indices] = True
138
+
139
+ # shape: (B, S, S) -> (B, 1, S, S) -> (B, 3, S, S)
140
+ batch_mask = np.repeat(batch_mask[:, None, :, :], repeats=3, axis=1)
141
+ Y[batch_mask] = 1.0
142
+
143
+ # shape: (B, 3, S, S) -> (B, S, S) -> (B, S)
144
+ Y = np.abs(Y).min(axis=1).min(axis=2)
145
+ Y[Y == 1.0] = 0.0
146
+
147
+ # shape: (B, S) -> (B,)
148
+ score = Y.sum(axis=1)
149
+ return score
150
+
151
+ def _compute(
152
+ self,
153
+ *,
154
+ batch_bbox: Union[npt.NDArray[np.float64], List[List[int]]],
155
+ batch_mask: Union[npt.NDArray[np.bool_], List[List[bool]]],
156
+ ) -> Dict[str, npt.NDArray[np.float64]]:
157
+ # shape: (B, model_max_length, C)
158
+ batch_bbox = np.array(batch_bbox)
159
+ # shape: (B, model_max_length)
160
+ batch_mask = np.array(batch_mask)
161
+
162
+ # S: model_max_length
163
+ _, S, _ = batch_bbox.shape
164
+
165
+ # shape: (B, S, C) -> (C, B, S)
166
+ batch_bbox = batch_bbox.transpose(2, 0, 1)
167
+ xl, yt, xr, yb = convert_xywh_to_ltrb(batch_bbox)
168
+ xc, yc = batch_bbox[0], batch_bbox[1]
169
+
170
+ # shape: (B,)
171
+ score_ac_layout_gan = self._compute_ac_layout_gan(
172
+ S=S, xl=xl, xc=xc, xr=xr, yt=yt, yc=yc, yb=yb, batch_mask=batch_mask
173
+ )
174
+ # shape: (B,)
175
+ score_layout_gan_pp = self._compute_layout_gan_pp(
176
+ score_ac_layout_gan=score_ac_layout_gan, batch_mask=batch_mask
177
+ )
178
+ score_ndn = self._compute_neural_design_network(
179
+ xl=xl, xc=xc, xr=xr, batch_mask=batch_mask, S=S
180
+ )
181
+ return {
182
+ "alignment-ACLayoutGAN": score_ac_layout_gan,
183
+ "alignment-LayoutGAN++": score_layout_gan_pp,
184
+ "alignment-NDN": score_ndn,
185
+ }
requirements.txt ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1 ; python_version >= "3.9" and python_version < "4.0"
2
+ aiohttp==3.9.3 ; python_version >= "3.9" and python_version < "4.0"
3
+ aiosignal==1.3.1 ; python_version >= "3.9" and python_version < "4.0"
4
+ altair==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
5
+ annotated-types==0.6.0 ; python_version >= "3.9" and python_version < "4.0"
6
+ anyio==4.2.0 ; python_version >= "3.9" and python_version < "4.0"
7
+ arrow==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
8
+ async-timeout==4.0.3 ; python_version >= "3.9" and python_version < "3.11"
9
+ attrs==23.2.0 ; python_version >= "3.9" and python_version < "4.0"
10
+ binaryornot==0.4.4 ; python_version >= "3.9" and python_version < "4.0"
11
+ certifi==2024.2.2 ; python_version >= "3.9" and python_version < "4.0"
12
+ chardet==5.2.0 ; python_version >= "3.9" and python_version < "4.0"
13
+ charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "4.0"
14
+ click==8.1.7 ; python_version >= "3.9" and python_version < "4.0"
15
+ colorama==0.4.6 ; python_version >= "3.9" and python_version < "4.0"
16
+ contourpy==1.2.0 ; python_version >= "3.9" and python_version < "4.0"
17
+ cookiecutter==2.5.0 ; python_version >= "3.9" and python_version < "4.0"
18
+ cycler==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
19
+ datasets==2.17.0 ; python_version >= "3.9" and python_version < "4.0"
20
+ dill==0.3.8 ; python_version >= "3.9" and python_version < "4.0"
21
+ evaluate[template]==0.4.1 ; python_version >= "3.9" and python_version < "4.0"
22
+ exceptiongroup==1.2.0 ; python_version >= "3.9" and python_version < "3.11"
23
+ fastapi==0.109.2 ; python_version >= "3.9" and python_version < "4.0"
24
+ ffmpy==0.3.1 ; python_version >= "3.9" and python_version < "4.0"
25
+ filelock==3.13.1 ; python_version >= "3.9" and python_version < "4.0"
26
+ fonttools==4.48.1 ; python_version >= "3.9" and python_version < "4.0"
27
+ frozenlist==1.4.1 ; python_version >= "3.9" and python_version < "4.0"
28
+ fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
29
+ fsspec[http]==2023.10.0 ; python_version >= "3.9" and python_version < "4.0"
30
+ gradio-client==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
31
+ gradio==4.17.0 ; python_version >= "3.9" and python_version < "4.0"
32
+ h11==0.14.0 ; python_version >= "3.9" and python_version < "4.0"
33
+ httpcore==1.0.2 ; python_version >= "3.9" and python_version < "4.0"
34
+ httpx==0.26.0 ; python_version >= "3.9" and python_version < "4.0"
35
+ huggingface-hub==0.20.3 ; python_version >= "3.9" and python_version < "4.0"
36
+ idna==3.6 ; python_version >= "3.9" and python_version < "4.0"
37
+ importlib-resources==6.1.1 ; python_version >= "3.9" and python_version < "4.0"
38
+ jinja2==3.1.3 ; python_version >= "3.9" and python_version < "4.0"
39
+ jsonschema-specifications==2023.12.1 ; python_version >= "3.9" and python_version < "4.0"
40
+ jsonschema==4.21.1 ; python_version >= "3.9" and python_version < "4.0"
41
+ kiwisolver==1.4.5 ; python_version >= "3.9" and python_version < "4.0"
42
+ markdown-it-py==3.0.0 ; python_version >= "3.9" and python_version < "4.0"
43
+ markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "4.0"
44
+ matplotlib==3.8.2 ; python_version >= "3.9" and python_version < "4.0"
45
+ mdurl==0.1.2 ; python_version >= "3.9" and python_version < "4.0"
46
+ multidict==6.0.5 ; python_version >= "3.9" and python_version < "4.0"
47
+ multiprocess==0.70.16 ; python_version >= "3.9" and python_version < "4.0"
48
+ numpy==1.26.4 ; python_version >= "3.9" and python_version < "4.0"
49
+ orjson==3.9.13 ; python_version >= "3.9" and python_version < "4.0"
50
+ packaging==23.2 ; python_version >= "3.9" and python_version < "4.0"
51
+ pandas==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
52
+ pillow==10.2.0 ; python_version >= "3.9" and python_version < "4.0"
53
+ pyarrow-hotfix==0.6 ; python_version >= "3.9" and python_version < "4.0"
54
+ pyarrow==15.0.0 ; python_version >= "3.9" and python_version < "4.0"
55
+ pydantic-core==2.16.2 ; python_version >= "3.9" and python_version < "4.0"
56
+ pydantic==2.6.1 ; python_version >= "3.9" and python_version < "4.0"
57
+ pydub==0.25.1 ; python_version >= "3.9" and python_version < "4.0"
58
+ pygments==2.17.2 ; python_version >= "3.9" and python_version < "4.0"
59
+ pyparsing==3.1.1 ; python_version >= "3.9" and python_version < "4.0"
60
+ python-dateutil==2.8.2 ; python_version >= "3.9" and python_version < "4.0"
61
+ python-multipart==0.0.7 ; python_version >= "3.9" and python_version < "4.0"
62
+ python-slugify==8.0.4 ; python_version >= "3.9" and python_version < "4.0"
63
+ pytz==2024.1 ; python_version >= "3.9" and python_version < "4.0"
64
+ pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "4.0"
65
+ referencing==0.33.0 ; python_version >= "3.9" and python_version < "4.0"
66
+ requests==2.31.0 ; python_version >= "3.9" and python_version < "4.0"
67
+ responses==0.18.0 ; python_version >= "3.9" and python_version < "4.0"
68
+ rich==13.7.0 ; python_version >= "3.9" and python_version < "4.0"
69
+ rpds-py==0.17.1 ; python_version >= "3.9" and python_version < "4.0"
70
+ ruff==0.2.1 ; python_version >= "3.9" and python_version < "4.0"
71
+ semantic-version==2.10.0 ; python_version >= "3.9" and python_version < "4.0"
72
+ shellingham==1.5.4 ; python_version >= "3.9" and python_version < "4.0"
73
+ six==1.16.0 ; python_version >= "3.9" and python_version < "4.0"
74
+ sniffio==1.3.0 ; python_version >= "3.9" and python_version < "4.0"
75
+ starlette==0.36.3 ; python_version >= "3.9" and python_version < "4.0"
76
+ text-unidecode==1.3 ; python_version >= "3.9" and python_version < "4.0"
77
+ tomlkit==0.12.0 ; python_version >= "3.9" and python_version < "4.0"
78
+ toolz==0.12.1 ; python_version >= "3.9" and python_version < "4.0"
79
+ tqdm==4.66.1 ; python_version >= "3.9" and python_version < "4.0"
80
+ typer[all]==0.9.0 ; python_version >= "3.9" and python_version < "4.0"
81
+ types-python-dateutil==2.8.19.20240106 ; python_version >= "3.9" and python_version < "4.0"
82
+ typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "4.0"
83
+ tzdata==2023.4 ; python_version >= "3.9" and python_version < "4.0"
84
+ urllib3==2.2.0 ; python_version >= "3.9" and python_version < "4.0"
85
+ uvicorn==0.27.0.post1 ; python_version >= "3.9" and python_version < "4.0"
86
+ websockets==11.0.3 ; python_version >= "3.9" and python_version < "4.0"
87
+ xxhash==3.4.1 ; python_version >= "3.9" and python_version < "4.0"
88
+ yarl==1.9.4 ; python_version >= "3.9" and python_version < "4.0"
89
+ zipp==3.17.0 ; python_version >= "3.9" and python_version < "3.10"