denisevaldivia commited on
Commit
0ddcd53
·
1 Parent(s): 1f4fd26

initial deploy

Browse files
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from inference.animegan_inference import load_animegan, run_animegan
4
+ from inference.apdrawing_inference import load_apdrawing, run_apdrawing
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+ # Load all models once at startup
9
+ G_texture = load_animegan("models/animegan/weights/GeneratorV2_live_action_cartoon_texture.pt", device)
10
+ G_color = load_animegan("models/animegan/weights/GeneratorV2_live_action_cartoon_color_2.pt", device)
11
+ G_sketch = load_apdrawing("models/apdrawing/weights/apdrawing_200.pt", device)
12
+
13
+ # Inference functions for Gardio
14
+
15
+ def predict_live_action(image, mode, alpha):
16
+ if image is None:
17
+ return None
18
+ return run_animegan(image, G_texture, G_color, mode=mode, alpha=alpha, device=device)
19
+
20
+ def predict_sketch(image):
21
+ if image is None:
22
+ return None
23
+ return run_apdrawing(image, G_sketch, device)
24
+
25
+ # UI
26
+
27
+ with gr.Blocks(title="Disney Style Transfer") as demo:
28
+
29
+ gr.Markdown("""
30
+ # 🎨 Disney Style Transfer
31
+ Transform real faces or sketches into Disney-style images using two different GAN models.
32
+ """)
33
+
34
+ with gr.Tabs():
35
+
36
+ # Tab 1 — Live Action to Disney
37
+ with gr.Tab("🎬 Live Action → Disney"):
38
+ gr.Markdown("Upload a real face photo and convert it to Disney style.")
39
+ with gr.Row():
40
+ with gr.Column():
41
+ input_img = gr.Image(type="pil", label="Input Image", image_mode="RGB")
42
+ mode = gr.Radio(
43
+ choices=["ensemble", "texture_only", "color_only"],
44
+ value="ensemble",
45
+ label="Generation mode",
46
+ info="Ensemble combines both models. Texture only keeps stylization. Color only keeps palette."
47
+ )
48
+ alpha = gr.Slider(
49
+ minimum=0.0, maximum=1.0, value=0.8, step=0.05,
50
+ label="Color transfer strength (alpha)",
51
+ info="Only applies in ensemble mode. Higher = more color from color model.",
52
+ visible=True
53
+ )
54
+ btn1 = gr.Button("Generate", variant="primary")
55
+ with gr.Column():
56
+ output_img1 = gr.Image(type="pil", label="Generated Image")
57
+
58
+ # Show/hide alpha slider based on mode
59
+ mode.change(
60
+ fn=lambda m: gr.update(visible=(m == "ensemble")),
61
+ inputs=mode,
62
+ outputs=alpha
63
+ )
64
+
65
+ btn1.click(
66
+ fn=predict_live_action,
67
+ inputs=[input_img, mode, alpha],
68
+ outputs=output_img1
69
+ )
70
+
71
+ gr.Examples(
72
+ examples=[["examples/live1.jpg", "ensemble", 0.8]],
73
+ inputs=[input_img, mode, alpha],
74
+ outputs=output_img1,
75
+ fn=predict_live_action,
76
+ label="Examples"
77
+ )
78
+
79
+ # Tab 2 — Sketch to Disney
80
+ with gr.Tab("✏️ Sketch → Disney"):
81
+ gr.Markdown("Upload a sketch or line drawing and convert it to a colored Disney-style image.")
82
+ with gr.Row():
83
+ with gr.Column():
84
+ input_sketch = gr.Image(type="pil", label="Input Sketch", image_mode="RGB")
85
+ btn2 = gr.Button("Generate", variant="primary")
86
+ with gr.Column():
87
+ output_img2 = gr.Image(type="pil", label="Generated Image")
88
+
89
+ btn2.click(
90
+ fn=predict_sketch,
91
+ inputs=input_sketch,
92
+ outputs=output_img2
93
+ )
94
+
95
+ demo.launch()
inference/animegan_inference.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ def load_animegan(weights_path, device):
7
+ from models.animegan.generator import GeneratorV2
8
+ checkpoint = torch.load(weights_path, map_location=device)
9
+ state_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint else checkpoint
10
+ state_dict = {k[len("module."):] if k.startswith("module.") else k: v for k, v in state_dict.items()}
11
+ G = GeneratorV2().to(device)
12
+ G.load_state_dict(state_dict, strict=True)
13
+ G.eval()
14
+ return G
15
+
16
+ def preprocess(pil_img, device, size=(256, 256)):
17
+ img = np.array(pil_img.convert("RGB").resize(size))
18
+ tensor = torch.from_numpy(img).float().permute(2, 0, 1).unsqueeze(0)
19
+ tensor = (tensor / 127.5) - 1.0
20
+ return tensor.to(device)
21
+
22
+ def postprocess(tensor):
23
+ img = ((tensor.squeeze(0).permute(1, 2, 0) + 1) * 127.5).clamp(0, 255).byte().cpu().numpy()
24
+ return img # RGB numpy
25
+
26
+ def transfer_palette(source, target, alpha=1.0):
27
+ source_lab = cv2.cvtColor(source, cv2.COLOR_RGB2LAB).astype(np.float32)
28
+ target_lab = cv2.cvtColor(target, cv2.COLOR_RGB2LAB).astype(np.float32)
29
+ result_lab = target_lab.copy()
30
+ result_lab[:, :, 1] = source_lab[:, :, 1] * alpha + target_lab[:, :, 1] * (1 - alpha)
31
+ result_lab[:, :, 2] = source_lab[:, :, 2] * alpha + target_lab[:, :, 2] * (1 - alpha)
32
+ return cv2.cvtColor(result_lab.astype(np.uint8), cv2.COLOR_LAB2RGB)
33
+
34
+ def run_animegan(pil_img, G_texture, G_color, mode="ensemble", alpha=0.8, device="cpu"):
35
+ inp = preprocess(pil_img, device)
36
+ with torch.no_grad():
37
+ if mode == "texture_only":
38
+ result = postprocess(G_texture(inp))
39
+ elif mode == "color_only":
40
+ result = postprocess(G_color(inp))
41
+ else: # ensemble
42
+ out_texture = postprocess(G_texture(inp))
43
+ out_color = postprocess(G_color(inp))
44
+ result = transfer_palette(source=out_color, target=out_texture, alpha=alpha)
45
+ return Image.fromarray(result)
inference/apdrawing_inference.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+
6
+ def load_apdrawing(weights_path, device):
7
+ from models.apdrawing.generator import define_G
8
+ netG = define_G(1, 3, 64, 'unet_256', 'batch',
9
+ use_dropout=False, init_type='normal',
10
+ init_gain=0.02, gpu_ids=[])
11
+ ckpt = torch.load(weights_path, map_location='cpu')
12
+ state_dict = ckpt['G'] if isinstance(ckpt, dict) and 'G' in ckpt else ckpt
13
+ netG.load_state_dict(state_dict, strict=True)
14
+ netG.to(device).eval()
15
+ return netG
16
+
17
+ def preprocess_sketch(pil_img):
18
+ img = pil_img.convert('RGB').resize((256, 256), Image.BICUBIC)
19
+ tensor = transforms.ToTensor()(img)
20
+ tensor = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tensor)
21
+ gray = tensor[0] * 0.299 + tensor[1] * 0.587 + tensor[2] * 0.114
22
+ return gray.unsqueeze(0).unsqueeze(0) # (1, 1, 256, 256)
23
+
24
+ def run_apdrawing(pil_img, netG, device):
25
+ inp = preprocess_sketch(pil_img).to(device)
26
+ with torch.no_grad():
27
+ out = netG(inp)
28
+ out = out.squeeze(0).permute(1, 2, 0).cpu().numpy()
29
+ out = ((out + 1) * 127.5).clip(0, 255).astype(np.uint8)
30
+ return Image.fromarray(out)
models/animegan/generator.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+
6
+ # NOTE!!! This code is an adaptation form the original repo, we are not owners nor the creators behind the design, we justa dapted it to our needs
7
+ # Original: https://github.com/ptran1203/pytorch-animeGAN/tree/master
8
+
9
+ def initialize_weights(net):
10
+ for m in net.modules():
11
+ try:
12
+ if isinstance(m, nn.Conv2d):
13
+ # m.weight.data.normal_(0, 0.02)
14
+ torch.nn.init.xavier_uniform_(m.weight)
15
+ m.bias.data.zero_()
16
+ elif isinstance(m, nn.ConvTranspose2d):
17
+ # m.weight.data.normal_(0, 0.02)
18
+ torch.nn.init.xavier_uniform_(m.weight)
19
+ m.bias.data.zero_()
20
+ elif isinstance(m, nn.Linear):
21
+ # m.weight.data.normal_(0, 0.02)
22
+ torch.nn.init.xavier_uniform_(m.weight)
23
+ m.bias.data.zero_()
24
+ elif isinstance(m, nn.BatchNorm2d):
25
+ m.weight.data.fill_(1)
26
+ m.bias.data.zero_()
27
+ except Exception as e:
28
+ # print(f'SKip layer {m}, {e}')
29
+ pass
30
+
31
+ class LayerNorm2d(nn.LayerNorm):
32
+ """ LayerNorm for channels of '2D' spatial NCHW tensors """
33
+ def __init__(self, num_channels, eps=1e-6, affine=True):
34
+ super().__init__(num_channels, eps=eps, elementwise_affine=affine)
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ x = x.permute(0, 2, 3, 1)
38
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
39
+ x = x.permute(0, 3, 1, 2)
40
+ return x
41
+
42
+
43
+ def get_norm(norm_type, channels):
44
+ if norm_type == "instance":
45
+ return nn.InstanceNorm2d(channels)
46
+ elif norm_type == "layer":
47
+ # return LayerNorm2d
48
+ return nn.GroupNorm(num_groups=1, num_channels=channels, affine=True)
49
+ # return partial(nn.GroupNorm, 1, out_ch, 1e-5, True)
50
+ else:
51
+ raise ValueError(norm_type)
52
+
53
+ class ConvBlock(nn.Module):
54
+ """Stack of Conv2D + Norm + LeakyReLU"""
55
+ def __init__(
56
+ self,
57
+ channels,
58
+ out_channels,
59
+ kernel_size=3,
60
+ stride=1,
61
+ groups=1,
62
+ padding=1,
63
+ bias=False,
64
+ norm_type="instance"
65
+ ):
66
+ super(ConvBlock, self).__init__()
67
+
68
+ # if kernel_size == 3 and stride == 1:
69
+ # self.pad = nn.ReflectionPad2d((1, 1, 1, 1))
70
+ # elif kernel_size == 7 and stride == 1:
71
+ # self.pad = nn.ReflectionPad2d((3, 3, 3, 3))
72
+ # elif stride == 2:
73
+ # self.pad = nn.ReflectionPad2d((0, 1, 1, 0))
74
+ # else:
75
+ # self.pad = None
76
+
77
+ self.pad = nn.ReflectionPad2d(padding)
78
+ self.conv = nn.Conv2d(
79
+ channels,
80
+ out_channels,
81
+ kernel_size=kernel_size,
82
+ stride=stride,
83
+ groups=groups,
84
+ padding=0,
85
+ bias=bias
86
+ )
87
+ self.ins_norm = get_norm(norm_type, out_channels)
88
+ self.activation = nn.LeakyReLU(0.2, True)
89
+
90
+ # initialize_weights(self)
91
+
92
+ def forward(self, x):
93
+ if self.pad is not None:
94
+ x = self.pad(x)
95
+ out = self.conv(x)
96
+ out = self.ins_norm(out)
97
+ out = self.activation(out)
98
+ return out
99
+
100
+
101
+ class InvertedResBlock(nn.Module):
102
+ def __init__(
103
+ self,
104
+ channels=256,
105
+ out_channels=256,
106
+ expand_ratio=2,
107
+ norm_type="instance",
108
+ ):
109
+ super(InvertedResBlock, self).__init__()
110
+ bottleneck_dim = round(expand_ratio * channels)
111
+ self.conv_block = ConvBlock(
112
+ channels,
113
+ bottleneck_dim,
114
+ kernel_size=1,
115
+ padding=0,
116
+ norm_type=norm_type,
117
+ bias=False
118
+ )
119
+ self.conv_block2 = ConvBlock(
120
+ bottleneck_dim,
121
+ bottleneck_dim,
122
+ groups=bottleneck_dim,
123
+ norm_type=norm_type,
124
+ bias=True
125
+ )
126
+ self.conv = nn.Conv2d(
127
+ bottleneck_dim,
128
+ out_channels,
129
+ kernel_size=1,
130
+ padding=0,
131
+ bias=False
132
+ )
133
+ self.norm = get_norm(norm_type, out_channels)
134
+
135
+ def forward(self, x):
136
+ out = self.conv_block(x)
137
+ out = self.conv_block2(out)
138
+ # out = self.activation(out)
139
+ out = self.conv(out)
140
+ out = self.norm(out)
141
+
142
+ if out.shape[1] != x.shape[1]:
143
+ # Only concate if same shape
144
+ return out
145
+ return out + x
146
+
147
+ class GeneratorV2(nn.Module):
148
+ def __init__(self, dataset=''):
149
+ super(GeneratorV2, self).__init__()
150
+ self.name = f'{self.__class__.__name__}_{dataset}'
151
+
152
+ self.conv_block1 = nn.Sequential(
153
+ ConvBlock(3, 32, kernel_size=7, stride=1, padding=3, norm_type="layer"),
154
+ ConvBlock(32, 64, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
155
+ ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
156
+ )
157
+
158
+ self.conv_block2 = nn.Sequential(
159
+ ConvBlock(64, 128, kernel_size=3, stride=2, padding=(0, 1, 0, 1), norm_type="layer"),
160
+ ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
161
+ )
162
+
163
+ self.res_blocks = nn.Sequential(
164
+ ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
165
+ InvertedResBlock(128, 256, expand_ratio=2, norm_type="layer"),
166
+ InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
167
+ InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
168
+ InvertedResBlock(256, 256, expand_ratio=2, norm_type="layer"),
169
+ ConvBlock(256, 128, kernel_size=3, stride=1, norm_type="layer"),
170
+ )
171
+
172
+ self.conv_block3 = nn.Sequential(
173
+ # UpConvLNormLReLU(128, 128, norm_type="layer"),
174
+ ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
175
+ ConvBlock(128, 128, kernel_size=3, stride=1, norm_type="layer"),
176
+ )
177
+
178
+ self.conv_block4 = nn.Sequential(
179
+ # UpConvLNormLReLU(128, 64, norm_type="layer"),
180
+ ConvBlock(128, 64, kernel_size=3, stride=1, norm_type="layer"),
181
+ ConvBlock(64, 64, kernel_size=3, stride=1, norm_type="layer"),
182
+ ConvBlock(64, 32, kernel_size=7, padding=3, stride=1, norm_type="layer"),
183
+ )
184
+
185
+ self.decode_blocks = nn.Sequential(
186
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
187
+ nn.Tanh(),
188
+ )
189
+
190
+ initialize_weights(self)
191
+
192
+ def forward(self, x):
193
+ out = self.conv_block1(x)
194
+ out = self.conv_block2(out)
195
+ out = self.res_blocks(out)
196
+ out = F.interpolate(out, scale_factor=2, mode="bilinear")
197
+ out = self.conv_block3(out)
198
+ out = F.interpolate(out, scale_factor=2, mode="bilinear")
199
+ out = self.conv_block4(out)
200
+ img = self.decode_blocks(out)
201
+
202
+ return img
models/animegan/weights/GeneratorV2_live_action_cartoon_color.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:409e71f59fa121ac447968b79daa910ba1347de5de8510e73e9e784424a71b6e
3
+ size 25827551
models/animegan/weights/GeneratorV2_live_action_cartoon_color_2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a16e6ab6697009852dff912a9aa798fef5e1b5e936192544409be26ff871e76
3
+ size 25827551
models/animegan/weights/GeneratorV2_live_action_cartoon_texture.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67cd02cd47208c05e1209459e2830cc42e6b7c6aeba698c6c8c8afc9a13d4661
3
+ size 25827551
models/apdrawing/generator.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+
7
+ # NOTE!!! This code does not originally belong to us, it´s just an adaptation for our fine tunning
8
+ # Original repo: https://github.com/yiranran/APDrawingGAN.git
9
+
10
+ # Helper Functions
11
+
12
+ def get_norm_layer(norm_type='instance'):
13
+ if norm_type == 'batch':
14
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
15
+ elif norm_type == 'instance':
16
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
17
+ elif norm_type == 'none':
18
+ norm_layer = None
19
+ else:
20
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
21
+ return norm_layer
22
+
23
+
24
+ def get_scheduler(optimizer, opt):
25
+ if opt.lr_policy == 'lambda':
26
+ def lambda_rule(epoch):
27
+ lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
28
+ return lr_l
29
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
30
+ elif opt.lr_policy == 'step':
31
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
32
+ elif opt.lr_policy == 'plateau':
33
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
34
+ elif opt.lr_policy == 'cosine':
35
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
36
+ else:
37
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
38
+ return scheduler
39
+
40
+
41
+ def init_weights(net, init_type='normal', gain=0.02):
42
+ def init_func(m):
43
+ classname = m.__class__.__name__
44
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
45
+ if init_type == 'normal':
46
+ init.normal_(m.weight.data, 0.0, gain)
47
+ elif init_type == 'xavier':
48
+ init.xavier_normal_(m.weight.data, gain=gain)
49
+ elif init_type == 'kaiming':
50
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
51
+ elif init_type == 'orthogonal':
52
+ init.orthogonal_(m.weight.data, gain=gain)
53
+ else:
54
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
55
+ if hasattr(m, 'bias') and m.bias is not None:
56
+ init.constant_(m.bias.data, 0.0)
57
+ elif classname.find('BatchNorm2d') != -1:
58
+ init.normal_(m.weight.data, 1.0, gain)
59
+ init.constant_(m.bias.data, 0.0)
60
+
61
+ print('initialize network with %s' % init_type)
62
+ net.apply(init_func)
63
+
64
+
65
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
66
+ if len(gpu_ids) > 0:
67
+ assert(torch.cuda.is_available())
68
+ net.to(gpu_ids[0])
69
+ net = torch.nn.DataParallel(net, gpu_ids)
70
+ init_weights(net, init_type, gain=init_gain)
71
+ return net
72
+
73
+
74
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], nnG=9):
75
+ net = None
76
+ norm_layer = get_norm_layer(norm_type=norm)
77
+
78
+ if netG == 'resnet_9blocks':
79
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
80
+ elif netG == 'resnet_6blocks':
81
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
82
+ elif netG == 'resnet_nblocks':
83
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=nnG)
84
+ elif netG == 'unet_128':
85
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
86
+ elif netG == 'unet_256':#default for pix2pix
87
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
88
+ elif netG == 'unet_512':
89
+ net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
90
+ elif netG == 'unet_ndown':
91
+ net = UnetGenerator(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
92
+ elif netG == 'partunet':
93
+ net = PartUnet(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
94
+ elif netG == 'partunet2':
95
+ net = PartUnet2(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
96
+ elif netG == 'combiner':
97
+ net = Combiner(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=2)
98
+ else:
99
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
100
+ return init_net(net, init_type, init_gain, gpu_ids)
101
+
102
+
103
+ def define_D(input_nc, ndf, netD,
104
+ n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
105
+ net = None
106
+ norm_layer = get_norm_layer(norm_type=norm)
107
+
108
+ if netD == 'basic':
109
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
110
+ elif netD == 'n_layers':
111
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
112
+ elif netD == 'pixel':
113
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
114
+ else:
115
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
116
+ return init_net(net, init_type, init_gain, gpu_ids)
117
+
118
+
119
+ # Classes
120
+
121
+ # Defines the GAN loss which uses either LSGAN or the regular GAN.
122
+ # When LSGAN is used, it is basically same as MSELoss,
123
+ # but it abstracts away the need to create the target label tensor
124
+ # that has the same size as the input
125
+ class GANLoss(nn.Module):
126
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
127
+ super(GANLoss, self).__init__()
128
+ self.register_buffer('real_label', torch.tensor(target_real_label))
129
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
130
+ if use_lsgan:
131
+ self.loss = nn.MSELoss()
132
+ else:#no_lsgan
133
+ self.loss = nn.BCELoss()
134
+
135
+ def get_target_tensor(self, input, target_is_real):
136
+ if target_is_real:
137
+ target_tensor = self.real_label
138
+ else:
139
+ target_tensor = self.fake_label
140
+ return target_tensor.expand_as(input)
141
+
142
+ def __call__(self, input, target_is_real):
143
+ target_tensor = self.get_target_tensor(input, target_is_real)
144
+ return self.loss(input, target_tensor)
145
+
146
+ # Defines the generator that consists of Resnet blocks between a few
147
+ # downsampling/upsampling operations.
148
+ # Code and idea originally from Justin Johnson's architecture.
149
+ # https://github.com/jcjohnson/fast-neural-style/
150
+ class ResnetGenerator(nn.Module):
151
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
152
+ assert(n_blocks >= 0)
153
+ super(ResnetGenerator, self).__init__()
154
+ self.input_nc = input_nc
155
+ self.output_nc = output_nc
156
+ self.ngf = ngf
157
+ if type(norm_layer) == functools.partial:
158
+ use_bias = norm_layer.func == nn.InstanceNorm2d
159
+ else:
160
+ use_bias = norm_layer == nn.InstanceNorm2d
161
+
162
+ model = [nn.ReflectionPad2d(3),
163
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
164
+ bias=use_bias),
165
+ norm_layer(ngf),
166
+ nn.ReLU(True)]
167
+
168
+ n_downsampling = 2
169
+ for i in range(n_downsampling):
170
+ mult = 2**i
171
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
172
+ stride=2, padding=1, bias=use_bias),
173
+ norm_layer(ngf * mult * 2),
174
+ nn.ReLU(True)]
175
+
176
+ mult = 2**n_downsampling
177
+ for i in range(n_blocks):
178
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
179
+
180
+ for i in range(n_downsampling):
181
+ mult = 2**(n_downsampling - i)
182
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
183
+ kernel_size=3, stride=2,
184
+ padding=1, output_padding=1,
185
+ bias=use_bias),
186
+ norm_layer(int(ngf * mult / 2)),
187
+ nn.ReLU(True)]
188
+ model += [nn.ReflectionPad2d(3)]
189
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
190
+ model += [nn.Tanh()]
191
+
192
+ self.model = nn.Sequential(*model)
193
+
194
+ def forward(self, input):
195
+ return self.model(input)
196
+
197
+ class Combiner(nn.Module):
198
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
199
+ assert(n_blocks >= 0)
200
+ super(Combiner, self).__init__()
201
+ self.input_nc = input_nc
202
+ self.output_nc = output_nc
203
+ self.ngf = ngf
204
+ if type(norm_layer) == functools.partial:
205
+ use_bias = norm_layer.func == nn.InstanceNorm2d
206
+ else:
207
+ use_bias = norm_layer == nn.InstanceNorm2d
208
+
209
+ model = [nn.ReflectionPad2d(3),
210
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
211
+ bias=use_bias),
212
+ norm_layer(ngf),
213
+ nn.ReLU(True)]
214
+
215
+ for i in range(n_blocks):
216
+ model += [ResnetBlock(ngf, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
217
+
218
+ model += [nn.ReflectionPad2d(3)]
219
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
220
+ model += [nn.Tanh()]
221
+
222
+ self.model = nn.Sequential(*model)
223
+
224
+ def forward(self, input):
225
+ return self.model(input)
226
+
227
+ # Define a resnet block
228
+ class ResnetBlock(nn.Module):
229
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
230
+ super(ResnetBlock, self).__init__()
231
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
232
+
233
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
234
+ conv_block = []
235
+ p = 0
236
+ if padding_type == 'reflect':
237
+ conv_block += [nn.ReflectionPad2d(1)]
238
+ elif padding_type == 'replicate':
239
+ conv_block += [nn.ReplicationPad2d(1)]
240
+ elif padding_type == 'zero':
241
+ p = 1
242
+ else:
243
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
244
+
245
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
246
+ norm_layer(dim),
247
+ nn.ReLU(True)]
248
+ if use_dropout:
249
+ conv_block += [nn.Dropout(0.5)]
250
+
251
+ p = 0
252
+ if padding_type == 'reflect':
253
+ conv_block += [nn.ReflectionPad2d(1)]
254
+ elif padding_type == 'replicate':
255
+ conv_block += [nn.ReplicationPad2d(1)]
256
+ elif padding_type == 'zero':
257
+ p = 1
258
+ else:
259
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
260
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
261
+ norm_layer(dim)]
262
+
263
+ return nn.Sequential(*conv_block)
264
+
265
+ def forward(self, x):
266
+ out = x + self.conv_block(x)
267
+ return out
268
+
269
+
270
+ # Defines the Unet generator.
271
+ # |num_downs|: number of downsamplings in UNet. For example,
272
+ # if |num_downs| == 7, image of size 128x128 will become of size 1x1
273
+ # at the bottleneck
274
+ class UnetGenerator(nn.Module):
275
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
276
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
277
+ super(UnetGenerator, self).__init__()
278
+
279
+ # construct unet structure
280
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
281
+ for i in range(num_downs - 5):
282
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
283
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
284
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
285
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
286
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
287
+
288
+ self.model = unet_block
289
+
290
+ def forward(self, input):
291
+ return self.model(input)
292
+
293
+ class PartUnet(nn.Module):
294
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
295
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
296
+ super(PartUnet, self).__init__()
297
+
298
+ # construct unet structure
299
+ # 3 downs
300
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
301
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
302
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
303
+
304
+ self.model = unet_block
305
+
306
+ def forward(self, input):
307
+ return self.model(input)
308
+
309
+ class PartUnet2(nn.Module):
310
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
311
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
312
+ super(PartUnet2, self).__init__()
313
+
314
+ # construct unet structure
315
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
316
+ for i in range(num_downs - 3):
317
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
318
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
319
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
320
+
321
+ self.model = unet_block
322
+
323
+ def forward(self, input):
324
+ return self.model(input)
325
+
326
+
327
+ # Defines the submodule with skip connection.
328
+ # X -------------------identity---------------------- X
329
+ # |-- downsampling -- |submodule| -- upsampling --|
330
+ class UnetSkipConnectionBlock(nn.Module):
331
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
332
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
333
+ super(UnetSkipConnectionBlock, self).__init__()
334
+ self.outermost = outermost
335
+ if type(norm_layer) == functools.partial:
336
+ use_bias = norm_layer.func == nn.InstanceNorm2d
337
+ else:
338
+ use_bias = norm_layer == nn.InstanceNorm2d
339
+ if input_nc is None:
340
+ input_nc = outer_nc
341
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
342
+ stride=2, padding=1, bias=use_bias)
343
+ downrelu = nn.LeakyReLU(0.2, True)
344
+ downnorm = norm_layer(inner_nc)
345
+ uprelu = nn.ReLU(True)
346
+ upnorm = norm_layer(outer_nc)
347
+
348
+ if outermost:
349
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
350
+ kernel_size=4, stride=2,
351
+ padding=1)
352
+ down = [downconv]
353
+ up = [uprelu, upconv, nn.Tanh()]
354
+ model = down + [submodule] + up
355
+ elif innermost:
356
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
357
+ kernel_size=4, stride=2,
358
+ padding=1, bias=use_bias)
359
+ down = [downrelu, downconv]
360
+ up = [uprelu, upconv, upnorm]
361
+ model = down + up
362
+ else:
363
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
364
+ kernel_size=4, stride=2,
365
+ padding=1, bias=use_bias)
366
+ down = [downrelu, downconv, downnorm]
367
+ up = [uprelu, upconv, upnorm]
368
+
369
+ if use_dropout:
370
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
371
+ else:
372
+ model = down + [submodule] + up
373
+
374
+ self.model = nn.Sequential(*model)
375
+
376
+ def forward(self, x):
377
+ if self.outermost:
378
+ return self.model(x)
379
+ else:
380
+ return torch.cat([x, self.model(x)], 1)
381
+
382
+
383
+ # Defines the PatchGAN discriminator with the specified arguments.
384
+ class NLayerDiscriminator(nn.Module):
385
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
386
+ super(NLayerDiscriminator, self).__init__()
387
+ if type(norm_layer) == functools.partial:
388
+ use_bias = norm_layer.func == nn.InstanceNorm2d
389
+ else:
390
+ use_bias = norm_layer == nn.InstanceNorm2d
391
+
392
+ kw = 4
393
+ padw = 1
394
+ sequence = [
395
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
396
+ nn.LeakyReLU(0.2, True)
397
+ ]
398
+
399
+ nf_mult = 1
400
+ nf_mult_prev = 1
401
+ for n in range(1, n_layers):
402
+ nf_mult_prev = nf_mult
403
+ nf_mult = min(2**n, 8)
404
+ sequence += [
405
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
406
+ kernel_size=kw, stride=2, padding=padw, bias=use_bias),
407
+ norm_layer(ndf * nf_mult),
408
+ nn.LeakyReLU(0.2, True)
409
+ ]
410
+
411
+ nf_mult_prev = nf_mult
412
+ nf_mult = min(2**n_layers, 8)
413
+ sequence += [
414
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
415
+ kernel_size=kw, stride=1, padding=padw, bias=use_bias),
416
+ norm_layer(ndf * nf_mult),
417
+ nn.LeakyReLU(0.2, True)
418
+ ]
419
+
420
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
421
+
422
+ if use_sigmoid:#no_lsgan, use sigmoid before calculating bceloss(binary cross entropy)
423
+ sequence += [nn.Sigmoid()]
424
+
425
+ self.model = nn.Sequential(*sequence)
426
+
427
+ def forward(self, input):
428
+ return self.model(input)
429
+
430
+
431
+ class PixelDiscriminator(nn.Module):
432
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
433
+ super(PixelDiscriminator, self).__init__()
434
+ if type(norm_layer) == functools.partial:
435
+ use_bias = norm_layer.func == nn.InstanceNorm2d
436
+ else:
437
+ use_bias = norm_layer == nn.InstanceNorm2d
438
+
439
+ self.net = [
440
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
441
+ nn.LeakyReLU(0.2, True),
442
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
443
+ norm_layer(ndf * 2),
444
+ nn.LeakyReLU(0.2, True),
445
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
446
+
447
+ if use_sigmoid:
448
+ self.net.append(nn.Sigmoid())
449
+
450
+ self.net = nn.Sequential(*self.net)
451
+
452
+ def forward(self, input):
453
+ return self.net(input)
models/apdrawing/weights/apdrawing_200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c662146e101b06a6cf8a70de79ae86f9bfd58425ecb0a7c9d5251efb05932c4
3
+ size 217718925
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ opencv-python-headless
5
+ Pillow
6
+ numpy
7
+ scikit-image