koesan commited on
Commit
66a1d29
·
1 Parent(s): 513de3e

Initial commit: Manga Layout Generator with model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pth.tar filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image, ImageDraw
6
+ import pickle
7
+
8
+ # Model imports
9
+ from model.layoutganpp import Generator
10
+ from util import set_seed
11
+
12
+ # Configuration
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ MODEL_PATH = "model_best.pth.tar"
15
+
16
+ # Load model
17
+ def load_model():
18
+ """Load pretrained LayoutGAN++ model"""
19
+ if not os.path.exists(MODEL_PATH):
20
+ raise FileNotFoundError(f"Model file not found: {MODEL_PATH}. Please ensure model_best.pth.tar is in the same directory.")
21
+
22
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
23
+ args = checkpoint['args']
24
+
25
+ # Initialize model
26
+ num_label = 6 # For manga panels
27
+ model = Generator(args['latent_size'], num_label,
28
+ d_model=args['G_d_model'],
29
+ nhead=args['G_nhead']).to(DEVICE)
30
+
31
+ model.load_state_dict(checkpoint['model_state_dict'])
32
+ model.eval()
33
+
34
+ return model, args
35
+
36
+ # Initialize model
37
+ print("Loading model...")
38
+ model, model_args = load_model()
39
+ print("Model loaded successfully!")
40
+
41
+ def convert_layout_to_image(bbox, canvas_size=(256, 256)):
42
+ """Convert bounding boxes to visualization image"""
43
+ W, H = canvas_size
44
+ img = Image.new('RGB', (W, H), color=(255, 255, 255))
45
+ draw = ImageDraw.Draw(img, 'RGBA')
46
+
47
+ # Colors for different panels
48
+ colors = [
49
+ (255, 100, 100, 180), # Red
50
+ (100, 255, 100, 180), # Green
51
+ (100, 100, 255, 180), # Blue
52
+ (255, 255, 100, 180), # Yellow
53
+ (255, 100, 255, 180), # Magenta
54
+ (100, 255, 255, 180), # Cyan
55
+ (255, 150, 100, 180), # Orange
56
+ (150, 100, 255, 180), # Purple
57
+ (100, 255, 150, 180), # Light Green
58
+ (255, 100, 150, 180), # Pink
59
+ (150, 255, 100, 180), # Lime
60
+ (100, 150, 255, 180), # Light Blue
61
+ ]
62
+
63
+ # Sort by area (largest first)
64
+ areas = [(b[2] - b[0]) * (b[3] - b[1]) for b in bbox]
65
+ indices = sorted(range(len(areas)), key=lambda i: areas[i], reverse=True)
66
+
67
+ for idx, i in enumerate(indices):
68
+ x1, y1, x2, y2 = bbox[i]
69
+ x1 = int(x1 * W)
70
+ y1 = int(y1 * H)
71
+ x2 = int(x2 * W)
72
+ y2 = int(y2 * H)
73
+
74
+ color = colors[idx % len(colors)]
75
+ draw.rectangle([x1, y1, x2, y2], fill=color, outline=(0, 0, 0), width=2)
76
+
77
+ # Add panel number
78
+ text = f"Panel {idx + 1}"
79
+ text_bbox = draw.textbbox((0, 0), text)
80
+ text_w = text_bbox[2] - text_bbox[0]
81
+ text_h = text_bbox[3] - text_bbox[1]
82
+ text_x = x1 + (x2 - x1 - text_w) // 2
83
+ text_y = y1 + (y2 - y1 - text_h) // 2
84
+ draw.text((text_x, text_y), text, fill=(0, 0, 0))
85
+
86
+ return img
87
+
88
+ def xywh_to_ltrb(bbox):
89
+ """Convert from center format (xc, yc, w, h) to corners (x1, y1, x2, y2)"""
90
+ xc, yc, w, h = bbox
91
+ x1 = xc - w / 2
92
+ y1 = yc - h / 2
93
+ x2 = xc + w / 2
94
+ y2 = yc + h / 2
95
+ return [x1, y1, x2, y2]
96
+
97
+ def generate_manga_layout(num_panels, seed=None):
98
+ """Generate manga panel layout"""
99
+ try:
100
+ if seed is not None:
101
+ set_seed(seed)
102
+
103
+ # Clamp num_panels
104
+ num_panels = max(1, min(12, num_panels))
105
+
106
+ # Create input
107
+ z = torch.randn(1, num_panels, model_args['latent_size'], device=DEVICE)
108
+ label = torch.zeros(1, num_panels, dtype=torch.long, device=DEVICE)
109
+ padding_mask = torch.zeros(1, num_panels, dtype=torch.bool, device=DEVICE)
110
+
111
+ # Generate layout
112
+ with torch.no_grad():
113
+ bbox = model(z, label, padding_mask)
114
+
115
+ # Convert to numpy
116
+ bbox = bbox[0].cpu().numpy()
117
+
118
+ # Convert from xywh to ltrb
119
+ bbox_ltrb = [xywh_to_ltrb(b) for b in bbox]
120
+
121
+ # Clip to [0, 1]
122
+ bbox_ltrb = [[max(0, min(1, coord)) for coord in box] for box in bbox_ltrb]
123
+
124
+ # Create visualization
125
+ img = convert_layout_to_image(bbox_ltrb, canvas_size=(512, 512))
126
+
127
+ info = f"✅ Generated layout with {num_panels} panels\n"
128
+ info += f"📐 Canvas size: 512x512px\n"
129
+ info += f"🎨 Panel colors are randomly assigned"
130
+
131
+ return img, info
132
+
133
+ except Exception as e:
134
+ error_img = Image.new('RGB', (512, 512), color=(255, 200, 200))
135
+ draw = ImageDraw.Draw(error_img)
136
+ draw.text((20, 250), f"Error: {str(e)}", fill=(255, 0, 0))
137
+ return error_img, f"❌ Error: {str(e)}"
138
+
139
+ # Gradio Interface
140
+ with gr.Blocks(theme=gr.themes.Soft(), title="Manga Panel Layout Generator") as demo:
141
+ gr.Markdown("""
142
+ # 🎨 Manga Panel Layout Generator
143
+
144
+ **AI-powered manga panel position prediction using LayoutGAN++**
145
+
146
+ This tool automatically generates optimal panel layouts for manga pages. Simply select the number of panels,
147
+ and the AI will predict the best arrangement based on training from thousands of manga pages.
148
+
149
+ ### ���� Links
150
+ - 📚 [GitHub Repository](https://github.com/koesan/Manga-Panel-LayoutGAN)
151
+ - 🤗 [Hugging Face Space](https://huggingface.co/spaces/koesan/manga-layout-generator)
152
+ - 📄 [LayoutGAN++ Paper](https://arxiv.org/abs/1908.07785)
153
+
154
+ ### 📖 About
155
+ This project uses **LayoutGAN++**, a transformer-based GAN architecture, trained on manga panel data
156
+ from the [MangaZero dataset](https://huggingface.co/datasets/jianzongwu/MangaZero). It predicts
157
+ panel positions and sizes to create aesthetically pleasing manga page layouts.
158
+
159
+ **Why?** Automating manga panel layout can help manga artists, designers, and AI systems generate
160
+ structured comic pages more efficiently.
161
+ """)
162
+
163
+ with gr.Row():
164
+ with gr.Column(scale=1):
165
+ num_panels = gr.Slider(
166
+ minimum=1,
167
+ maximum=12,
168
+ value=3,
169
+ step=1,
170
+ label="Number of Panels",
171
+ info="Select how many panels you want (1-12)"
172
+ )
173
+
174
+ seed = gr.Number(
175
+ label="Random Seed (Optional)",
176
+ value=None,
177
+ precision=0,
178
+ info="Leave empty for random generation"
179
+ )
180
+
181
+ generate_btn = gr.Button("🎨 Generate Layout", variant="primary", size="lg")
182
+
183
+ gr.Markdown("""
184
+ ### 💡 Tips
185
+ - **3-4 panels**: Simple, clean layouts
186
+ - **5-6 panels**: Standard manga page
187
+ - **7-12 panels**: Complex, dynamic layouts
188
+
189
+ Click generate multiple times for different variations!
190
+ """)
191
+
192
+ with gr.Column(scale=2):
193
+ output_image = gr.Image(label="Generated Layout", type="pil", height=512)
194
+ output_info = gr.Textbox(label="Generation Info", lines=3)
195
+
196
+ # Examples
197
+ gr.Examples(
198
+ examples=[
199
+ [3, None],
200
+ [6, None],
201
+ [8, None],
202
+ [4, 42],
203
+ ],
204
+ inputs=[num_panels, seed],
205
+ outputs=[output_image, output_info],
206
+ fn=generate_manga_layout,
207
+ cache_examples=False,
208
+ label="📋 Try These Examples"
209
+ )
210
+
211
+ # Event handler
212
+ generate_btn.click(
213
+ fn=generate_manga_layout,
214
+ inputs=[num_panels, seed],
215
+ outputs=[output_image, output_info]
216
+ )
217
+
218
+ gr.Markdown("""
219
+ ---
220
+ ### 🙏 Credits
221
+ - **Dataset**: [MangaZero](https://huggingface.co/datasets/jianzongwu/MangaZero) by [jianzongwu](https://github.com/jianzongwu)
222
+ - **Model**: [LayoutGAN++](https://arxiv.org/abs/1908.07785)
223
+ - **Framework**: PyTorch, Gradio, Hugging Face Spaces
224
+
225
+ Made with ❤️ for the manga and AI community
226
+ """)
227
+
228
+ if __name__ == "__main__":
229
+ demo.launch()
model/__pycache__/layoutganpp.cpython-38.pyc ADDED
Binary file (2.72 kB). View file
 
model/__pycache__/layoutnet.cpython-38.pyc ADDED
Binary file (2.02 kB). View file
 
model/__pycache__/util.cpython-38.pyc ADDED
Binary file (1.18 kB). View file
 
model/layoutganpp.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.util import TransformerWithToken
5
+
6
+
7
+ class Generator(nn.Module):
8
+ def __init__(self, dim_latent, num_label,
9
+ d_model=512, nhead=8, num_layers=4):
10
+ super().__init__()
11
+
12
+ self.fc_z = nn.Linear(dim_latent, d_model // 2)
13
+ self.emb_label = nn.Embedding(num_label, d_model // 2)
14
+ self.fc_in = nn.Linear(d_model, d_model)
15
+
16
+ te = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
17
+ dim_feedforward=d_model // 2)
18
+ self.transformer = nn.TransformerEncoder(te, num_layers=num_layers)
19
+
20
+ self.fc_out = nn.Linear(d_model, 4)
21
+
22
+ def forward(self, z, label, padding_mask):
23
+ z = self.fc_z(z)
24
+ l = self.emb_label(label)
25
+ x = torch.cat([z, l], dim=-1)
26
+ x = torch.relu(self.fc_in(x)).permute(1, 0, 2)
27
+
28
+ x = self.transformer(x, src_key_padding_mask=padding_mask)
29
+
30
+ x = self.fc_out(x.permute(1, 0, 2))
31
+ x = torch.sigmoid(x)
32
+
33
+ return x
34
+
35
+
36
+ class Discriminator(nn.Module):
37
+ def __init__(self, num_label, d_model=512,
38
+ nhead=8, num_layers=4, max_bbox=50):
39
+ super().__init__()
40
+
41
+ # encoder
42
+ self.emb_label = nn.Embedding(num_label, d_model)
43
+ self.fc_bbox = nn.Linear(4, d_model)
44
+ self.enc_fc_in = nn.Linear(d_model * 2, d_model)
45
+
46
+ self.enc_transformer = TransformerWithToken(d_model=d_model,
47
+ dim_feedforward=d_model // 2,
48
+ nhead=nhead, num_layers=num_layers)
49
+
50
+ self.fc_out_disc = nn.Linear(d_model, 1)
51
+
52
+ # decoder
53
+ self.pos_token = nn.Parameter(torch.rand(max_bbox, 1, d_model))
54
+ self.dec_fc_in = nn.Linear(d_model * 2, d_model)
55
+
56
+ te = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
57
+ dim_feedforward=d_model // 2)
58
+ self.dec_transformer = nn.TransformerEncoder(te,
59
+ num_layers=num_layers)
60
+
61
+ self.fc_out_cls = nn.Linear(d_model, num_label)
62
+ self.fc_out_bbox = nn.Linear(d_model, 4)
63
+
64
+ def forward(self, bbox, label, padding_mask, reconst=False):
65
+ B, N, _ = bbox.size()
66
+ b = self.fc_bbox(bbox)
67
+ l = self.emb_label(label)
68
+ x = self.enc_fc_in(torch.cat([b, l], dim=-1))
69
+ x = torch.relu(x).permute(1, 0, 2)
70
+
71
+ x = self.enc_transformer(x, src_key_padding_mask=padding_mask)
72
+ x = x[0]
73
+
74
+ # logit_disc: [B,]
75
+ logit_disc = self.fc_out_disc(x).squeeze(-1)
76
+
77
+ if not reconst:
78
+ return logit_disc
79
+
80
+ else:
81
+ x = x.unsqueeze(0).expand(N, -1, -1)
82
+ t = self.pos_token[:N].expand(-1, B, -1)
83
+ x = torch.cat([x, t], dim=-1)
84
+ x = torch.relu(self.dec_fc_in(x))
85
+
86
+ x = self.dec_transformer(x, src_key_padding_mask=padding_mask)
87
+ x = x.permute(1, 0, 2)[~padding_mask]
88
+
89
+ # logit_cls: [M, L] bbox_pred: [M, 4]
90
+ logit_cls = self.fc_out_cls(x)
91
+ bbox_pred = torch.sigmoid(self.fc_out_bbox(x))
92
+
93
+ return logit_disc, logit_cls, bbox_pred
model/layoutnet.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from model.util import TransformerWithToken
5
+
6
+
7
+ class LayoutNet(nn.Module):
8
+ def __init__(self, num_label):
9
+ super().__init__()
10
+
11
+ d_model = 256
12
+ nhead = 4
13
+ num_layers = 4
14
+ max_bbox = 50
15
+
16
+ # encoder
17
+ self.emb_label = nn.Embedding(num_label, d_model)
18
+ self.fc_bbox = nn.Linear(4, d_model)
19
+ self.enc_fc_in = nn.Linear(d_model * 2, d_model)
20
+
21
+ self.enc_transformer = TransformerWithToken(d_model=d_model,
22
+ dim_feedforward=d_model // 2,
23
+ nhead=nhead, num_layers=num_layers)
24
+
25
+ self.fc_out_disc = nn.Linear(d_model, 1)
26
+
27
+ # decoder
28
+ self.pos_token = nn.Parameter(torch.rand(max_bbox, 1, d_model))
29
+ self.dec_fc_in = nn.Linear(d_model * 2, d_model)
30
+
31
+ te = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,
32
+ dim_feedforward=d_model // 2)
33
+ self.dec_transformer = nn.TransformerEncoder(te, num_layers=num_layers)
34
+
35
+ self.fc_out_cls = nn.Linear(d_model, num_label)
36
+ self.fc_out_bbox = nn.Linear(d_model, 4)
37
+
38
+ def extract_features(self, bbox, label, padding_mask):
39
+ b = self.fc_bbox(bbox)
40
+ l = self.emb_label(label)
41
+ x = self.enc_fc_in(torch.cat([b, l], dim=-1))
42
+ x = torch.relu(x).permute(1, 0, 2)
43
+ x = self.enc_transformer(x, padding_mask)
44
+ return x[0]
45
+
46
+ def forward(self, bbox, label, padding_mask):
47
+ B, N, _ = bbox.size()
48
+ x = self.extract_features(bbox, label, padding_mask)
49
+
50
+ logit_disc = self.fc_out_disc(x).squeeze(-1)
51
+
52
+ x = x.unsqueeze(0).expand(N, -1, -1)
53
+ t = self.pos_token[:N].expand(-1, B, -1)
54
+ x = torch.cat([x, t], dim=-1)
55
+ x = torch.relu(self.dec_fc_in(x))
56
+
57
+ x = self.dec_transformer(x, src_key_padding_mask=padding_mask)
58
+ x = x.permute(1, 0, 2)[~padding_mask]
59
+
60
+ # logit_cls: [M, L] bbox_pred: [M, 4]
61
+ logit_cls = self.fc_out_cls(x)
62
+ bbox_pred = torch.sigmoid(self.fc_out_bbox(x))
63
+
64
+ return logit_disc, logit_cls, bbox_pred
model/util.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class TransformerWithToken(nn.Module):
6
+ def __init__(self, d_model, nhead, dim_feedforward, num_layers):
7
+ super().__init__()
8
+
9
+ self.token = nn.Parameter(torch.randn(1, 1, d_model))
10
+ token_mask = torch.zeros(1, 1, dtype=torch.bool)
11
+ self.register_buffer('token_mask', token_mask)
12
+
13
+ self.core = nn.TransformerEncoder(
14
+ nn.TransformerEncoderLayer(
15
+ d_model=d_model, nhead=nhead,
16
+ dim_feedforward=dim_feedforward,
17
+ ), num_layers=num_layers)
18
+
19
+ def forward(self, x, src_key_padding_mask):
20
+ # x: [N, B, E]
21
+ # padding_mask: [B, N]
22
+ # `False` for valid values
23
+ # `True` for padded values
24
+
25
+ B = x.size(1)
26
+
27
+ token = self.token.expand(-1, B, -1)
28
+ x = torch.cat([token, x], dim=0)
29
+
30
+ token_mask = self.token_mask.expand(B, -1)
31
+ padding_mask = torch.cat([token_mask, src_key_padding_mask], dim=1)
32
+
33
+ x = self.core(x, src_key_padding_mask=padding_mask)
34
+
35
+ return x
model_best.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c7da885a431c1fcacedf9616eb9d43bc3ae4dee2eaf76b4ac531080c6932059
3
+ size 99615356
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.19.2
2
+ torch==1.8.1
3
+ torchvision==0.9.1
4
+ numpy==1.21.0
5
+ Pillow==9.5.0
6
+ scipy==1.7.3
util.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import shutil
4
+ import numpy as np
5
+ from pathlib import Path
6
+ from datetime import datetime
7
+ from PIL import Image, ImageDraw
8
+
9
+ import torch
10
+ import torchvision.utils as vutils
11
+ import torchvision.transforms as T
12
+
13
+
14
+ def set_seed(seed):
15
+ random.seed(seed)
16
+ np.random.seed(seed)
17
+ torch.manual_seed(seed)
18
+ print("Random Seed:", seed)
19
+
20
+
21
+ def init_experiment(args, prefix):
22
+ if args.seed is None:
23
+ args.seed = random.randint(0, 10000)
24
+
25
+ set_seed(args.seed)
26
+
27
+ if not args.name:
28
+ args.name = datetime.now().strftime('%Y%m%d%H%M%S%f')
29
+
30
+ out_dir = Path('output') / args.dataset / prefix / args.name
31
+ out_dir.mkdir(parents=True, exist_ok=True)
32
+
33
+ json_path = out_dir / 'args.json'
34
+ with json_path.open('w') as f:
35
+ json.dump(vars(args), f, indent=2)
36
+
37
+ return out_dir
38
+
39
+
40
+ def save_checkpoint(state, is_best, out_dir):
41
+ out_path = Path(out_dir) / 'checkpoint.pth.tar'
42
+ torch.save(state, out_path)
43
+
44
+ if is_best:
45
+ best_path = Path(out_dir) / 'model_best.pth.tar'
46
+ shutil.copyfile(out_path, best_path)
47
+
48
+
49
+ def convert_xywh_to_ltrb(bbox):
50
+ xc, yc, w, h = bbox
51
+ x1 = xc - w / 2
52
+ y1 = yc - h / 2
53
+ x2 = xc + w / 2
54
+ y2 = yc + h / 2
55
+ return [x1, y1, x2, y2]
56
+
57
+
58
+ def convert_layout_to_image(boxes, labels, colors, canvas_size):
59
+ H, W = canvas_size
60
+ img = Image.new('RGB', (int(W), int(H)), color=(255, 255, 255))
61
+ draw = ImageDraw.Draw(img, 'RGBA')
62
+
63
+ # draw from larger boxes
64
+ area = [b[2] * b[3] for b in boxes]
65
+ indices = sorted(range(len(area)),
66
+ key=lambda i: area[i],
67
+ reverse=True)
68
+
69
+ for i in indices:
70
+ bbox, color = boxes[i], colors[labels[i]]
71
+ c_fill = color + (100,)
72
+ x1, y1, x2, y2 = convert_xywh_to_ltrb(bbox)
73
+ x1, x2 = x1 * (W - 1), x2 * (W - 1)
74
+ y1, y2 = y1 * (H - 1), y2 * (H - 1)
75
+ draw.rectangle([x1, y1, x2, y2],
76
+ outline=color,
77
+ fill=c_fill)
78
+ return img
79
+
80
+
81
+ def save_image(batch_boxes, batch_labels, batch_mask,
82
+ dataset_colors, out_path, canvas_size=(60, 40),
83
+ nrow=None):
84
+ # batch_boxes: [B, N, 4]
85
+ # batch_labels: [B, N]
86
+ # batch_mask: [B, N]
87
+
88
+ imgs = []
89
+ B = batch_boxes.size(0)
90
+ to_tensor = T.ToTensor()
91
+ for i in range(B):
92
+ mask_i = batch_mask[i]
93
+ boxes = batch_boxes[i][mask_i]
94
+ labels = batch_labels[i][mask_i]
95
+ img = convert_layout_to_image(boxes, labels,
96
+ dataset_colors,
97
+ canvas_size)
98
+ imgs.append(to_tensor(img))
99
+ image = torch.stack(imgs)
100
+
101
+ if nrow is None:
102
+ nrow = int(np.ceil(np.sqrt(B)))
103
+
104
+ vutils.save_image(image, out_path, normalize=False, nrow=nrow)