RAIL-KNUST commited on
Commit
9054e98
·
verified ·
1 Parent(s): 8f5fcc5

adding files for the application

Browse files

Adding files for the model deployment

Files changed (5) hide show
  1. README.md +8 -5
  2. app.py +135 -0
  3. gitattributes +35 -0
  4. networks.py +601 -0
  5. requirements.txt +7 -0
README.md CHANGED
@@ -1,10 +1,13 @@
1
  ---
2
- title: README
3
- emoji: 😻
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
 
 
7
  pinned: false
 
8
  ---
9
 
10
- Edit this `README.md` markdown file to author your organization card.
 
1
  ---
2
+ title: Decgan Demo
3
+ emoji: 📈
4
+ colorFrom: yellow
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.41.1
8
+ app_file: app.py
9
  pinned: false
10
+ short_description: Diversity enhanced cycle gan
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ import numpy as np
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
+
10
+ # Import your networks (you'll need to upload networks.py to your Space)
11
+ from networks import ResnetGenerator # Adjust this import based on your networks.py structure
12
+
13
+ class CycleGANInference:
14
+ def __init__(self, model_repo_id, checkpoint_filename_AtoB, checkpoint_filename_BtoA=None):
15
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ # Download model checkpoints from Hugging Face Hub
18
+ checkpoint_path_AtoB = hf_hub_download(
19
+ repo_id=model_repo_id,
20
+ filename=checkpoint_filename_AtoB
21
+ )
22
+
23
+ # Initialize generators
24
+ # Adjust these parameters based on your model architecture
25
+ self.netG_A2B = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # A to B
26
+
27
+ if checkpoint_filename_BtoA:
28
+ checkpoint_path_BtoA = hf_hub_download(
29
+ repo_id=model_repo_id,
30
+ filename=checkpoint_filename_BtoA
31
+ )
32
+ self.netG_B2A = ResnetGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # B to A
33
+ else:
34
+ self.netG_B2A = None
35
+
36
+ # Load model weights
37
+ self.netG_A2B.load_state_dict(torch.load(checkpoint_path_AtoB, map_location=self.device))
38
+ if self.netG_B2A and checkpoint_filename_BtoA:
39
+ self.netG_B2A.load_state_dict(torch.load(checkpoint_path_BtoA, map_location=self.device))
40
+
41
+ # Set to evaluation mode
42
+ self.netG_A2B.eval()
43
+ if self.netG_B2A:
44
+ self.netG_B2A.eval()
45
+
46
+ # Move to device
47
+ self.netG_A2B.to(self.device)
48
+ if self.netG_B2A:
49
+ self.netG_B2A.to(self.device)
50
+
51
+ # Define transforms
52
+ self.transform = transforms.Compose([
53
+ transforms.Resize((256, 256)), # Adjust size based on your model
54
+ transforms.ToTensor(),
55
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
56
+ ])
57
+
58
+ self.inverse_transform = transforms.Compose([
59
+ transforms.Normalize((-1, -1, -1), (2, 2, 2)), # Denormalize
60
+ transforms.ToPILImage()
61
+ ])
62
+
63
+ def transform_image(self, image, direction="A_to_B"):
64
+ # Preprocess
65
+ input_tensor = self.transform(image).unsqueeze(0).to(self.device)
66
+
67
+ with torch.no_grad():
68
+ if direction == "A_to_B":
69
+ output_tensor = self.netG_A2B(input_tensor)
70
+ elif direction == "B_to_A" and self.netG_B2A:
71
+ output_tensor = self.netG_B2A(input_tensor)
72
+ else:
73
+ raise ValueError("Invalid direction or model not available")
74
+
75
+ # Postprocess
76
+ output_image = self.inverse_transform(output_tensor.squeeze(0).cpu())
77
+ return output_image
78
+
79
+ # Initialize your model
80
+ # Replace these with your actual Hugging Face repo ID and checkpoint filenames
81
+ MODEL_REPO_ID = "profmatthew/decgan" # Replace with your repo
82
+ CHECKPOINT_A2B = "200_net_G_A.pth" # Replace with your checkpoint filename
83
+ CHECKPOINT_B2A = "200_net_G_B.pth" # Replace with your checkpoint filename (optional)
84
+
85
+ cyclegan_model = CycleGANInference(
86
+ model_repo_id=MODEL_REPO_ID,
87
+ checkpoint_filename_AtoB=CHECKPOINT_A2B,
88
+ checkpoint_filename_BtoA=CHECKPOINT_B2A # Set to None if you only have one direction
89
+ )
90
+
91
+ def generate_image(input_image, direction):
92
+ try:
93
+ output_image = cyclegan_model.transform_image(input_image, direction)
94
+ return output_image
95
+ except Exception as e:
96
+ return f"Error: {str(e)}"
97
+
98
+ # Create Gradio interface
99
+ with gr.Blocks(title="CycleGAN Image Translation") as demo:
100
+ gr.Markdown("# CycleGAN Image Translation")
101
+ gr.Markdown("Upload an image and select the transformation direction.")
102
+
103
+ with gr.Row():
104
+ with gr.Column():
105
+ input_image = gr.Image(type="pil", label="Input Image")
106
+ direction = gr.Dropdown(
107
+ choices=["A_to_B", "B_to_A"],
108
+ value="A_to_B",
109
+ label="Translation Direction"
110
+ )
111
+ generate_btn = gr.Button("Generate", variant="primary")
112
+
113
+ with gr.Column():
114
+ output_image = gr.Image(type="pil", label="Generated Image")
115
+
116
+ generate_btn.click(
117
+ fn=generate_image,
118
+ inputs=[input_image, direction],
119
+ outputs=output_image
120
+ )
121
+
122
+ # Add some examples if you have them
123
+ # gr.Examples(
124
+ # examples=[
125
+ # # Add paths to example images here
126
+ # # ["example1.jpg", "A_to_B"],
127
+ # # ["example2.jpg", "B_to_A"],
128
+ # ],
129
+ # inputs=[input_image, direction],
130
+ # outputs=output_image,
131
+ # fn=generate_image,
132
+ # )
133
+
134
+ if __name__ == "__main__":
135
+ demo.launch()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
networks.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+ import torch.nn.functional as F
7
+ from torch import nn, einsum
8
+ from einops import rearrange, reduce, repeat
9
+
10
+ ###############################################################################
11
+ # Helper Functions
12
+ ###############################################################################
13
+
14
+ class SelfAttention(nn.Module):
15
+ """ Self attention Layer"""
16
+
17
+ def __init__(self, input_channel, activation="relu"):
18
+ super(SelfAttention, self).__init__()
19
+ self.chanel_in = input_channel
20
+ self.activation = activation
21
+
22
+ self.query_conv = nn.Conv2d(input_channel, input_channel // 8, 1)
23
+ self.key_conv = nn.Conv2d(input_channel, input_channel // 8, 1)
24
+ self.value_conv = nn.Conv2d(input_channel, input_channel, 1)
25
+ self.gamma = nn.Parameter(torch.zeros(1))
26
+ self.softmax = nn.Softmax(dim=-1)
27
+
28
+ def forward(self, x):
29
+ print("Attention Mechanism!")
30
+ m_batchsize, C, width, height = x.size()
31
+ attention_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N) # Q
32
+ attention_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H) # K
33
+ energy = torch.bmm(attention_query, attention_key) # transpose check
34
+ attention = self.softmax(energy) # BX (N) X (N)
35
+ attention_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
36
+
37
+ out = torch.bmm(attention_value, attention.permute(0, 2, 1))
38
+ out = out.view(m_batchsize, C, width, height)
39
+
40
+ out = self.gamma * out + x
41
+
42
+ return out
43
+
44
+ def get_norm_layer(norm_type='instance'):
45
+ if norm_type == 'batch':
46
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
47
+ elif norm_type == 'instance':
48
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
49
+ elif norm_type == 'none':
50
+ norm_layer = None
51
+ else:
52
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
53
+ return norm_layer
54
+
55
+
56
+ def get_scheduler(optimizer, opt):
57
+ if opt.lr_policy == 'lambda':
58
+ def lambda_rule(epoch):
59
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
60
+ return lr_l
61
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
62
+ elif opt.lr_policy == 'step':
63
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
64
+ elif opt.lr_policy == 'plateau':
65
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
66
+ elif opt.lr_policy == 'cosine':
67
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
68
+ else:
69
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
70
+ return scheduler
71
+
72
+
73
+ def init_weights(net, init_type='normal', gain=0.02):
74
+ def init_func(m):
75
+ classname = m.__class__.__name__
76
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
77
+ if init_type == 'normal':
78
+ init.normal_(m.weight.data, 0.0, gain)
79
+ elif init_type == 'xavier':
80
+ init.xavier_normal_(m.weight.data, gain=gain)
81
+ elif init_type == 'kaiming':
82
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
83
+ elif init_type == 'orthogonal':
84
+ init.orthogonal_(m.weight.data, gain=gain)
85
+ else:
86
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
87
+ if hasattr(m, 'bias') and m.bias is not None:
88
+ init.constant_(m.bias.data, 0.0)
89
+ elif classname.find('BatchNorm2d') != -1:
90
+ init.normal_(m.weight.data, 1.0, gain)
91
+ init.constant_(m.bias.data, 0.0)
92
+
93
+ print('initialize network with %s' % init_type)
94
+ net.apply(init_func)
95
+
96
+
97
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
98
+ if len(gpu_ids) > 0:
99
+ assert(torch.cuda.is_available())
100
+ net.to(gpu_ids[0])
101
+ net = torch.nn.DataParallel(net, gpu_ids)
102
+ init_weights(net, init_type, gain=init_gain)
103
+ return net
104
+
105
+
106
+
107
+ ##############################################################################
108
+ # Classes
109
+ ##############################################################################
110
+
111
+
112
+ # Defines the GAN loss which uses either LSGAN or the regular GAN.
113
+ # When LSGAN is used, it is basically same as MSELoss,
114
+ # but it abstracts away the need to create the target label tensor
115
+ # that has the same size as the input
116
+ class GANLoss(nn.Module):
117
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
118
+ super(GANLoss, self).__init__()
119
+ self.register_buffer('real_label', torch.tensor(target_real_label))
120
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
121
+ if use_lsgan:
122
+ self.loss = nn.MSELoss()
123
+ else:
124
+ self.loss = nn.BCELoss()
125
+ def get_target_tensor(self, input, target_is_real):
126
+ if target_is_real:
127
+ target_tensor = self.real_label
128
+ else:
129
+ target_tensor = self.fake_label
130
+ return target_tensor.expand_as(input)
131
+
132
+ def __call__(self, input, target_is_real):
133
+ target_tensor = self.get_target_tensor(input, target_is_real)
134
+ return self.loss(input, target_tensor)
135
+
136
+ #################################################################################
137
+ # Critic Loss for Wassertein Gan GP #
138
+ #################################################################################
139
+ class GradPenalty(nn.Module):
140
+ def __init__(self, use_cuda):
141
+ super(GradPenalty, self).__init__()
142
+ self.use_cuda = use_cuda
143
+ def forward(self, critic, real_data, fake_data):
144
+ alpha = torch.rand_like(real_data)
145
+
146
+ assignGPU = lambda x: x.cuda() if self.use_cuda else x
147
+ alpha = assignGPU(alpha)
148
+
149
+ interpolates = alpha*real_data + (1-alpha)*fake_data.detach()
150
+ interpolates = assignGPU(interpolates)
151
+ interpolates = torch.autograd.Variable(interpolates, requires_grad = True)
152
+
153
+ critic_interpolates = critic(interpolates)
154
+
155
+ gradients = torch.autograd.grad(
156
+ outputs=critic_interpolates,
157
+ inputs=interpolates,
158
+ grad_outputs=assignGPU(torch.ones(critic_interpolates.size())),
159
+ create_graph=True, retain_graph=True, only_inputs=True
160
+ )[0]
161
+ gradients = gradients.view(gradients.size(0), -1)
162
+ gradient_penalty = ((gradients.norm(2, dim=1)-1)**2).mean()
163
+ return gradient_penalty
164
+
165
+ #####
166
+ #####
167
+
168
+ #################################################################################
169
+ # Hybrid Perception Block and DPSA LAyer #
170
+ #################################################################################
171
+
172
+
173
+ # helper functions
174
+
175
+ def exists(val):
176
+ return val is not None
177
+
178
+ def default(val, d):
179
+ return val if exists(val) else d
180
+
181
+ def l2norm(t):
182
+ return F.normalize(t, dim = -1)
183
+
184
+ # helper classes
185
+
186
+ class Residual(nn.Module):
187
+ def __init__(self, fn):
188
+ super().__init__()
189
+ self.fn = fn
190
+
191
+ def forward(self, x, **kwargs):
192
+ return self.fn(x, **kwargs) + x
193
+
194
+ class ChanLayerNorm(nn.Module):
195
+ def __init__(self, dim, eps = 1e-5):
196
+ super().__init__()
197
+ self.eps = eps
198
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
199
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
200
+
201
+ def forward(self, x):
202
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
203
+ mean = torch.mean(x, dim = 1, keepdim = True)
204
+ return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
205
+
206
+ # classes
207
+
208
+
209
+ # Defines the generator that consists of Resnet blocks between a few
210
+ # downsampling/upsampling operations.
211
+ # Code and idea from Justin Johnson's architecture.
212
+ # https://github.com/jcjohnson/fast-neural-style/
213
+
214
+ class ResnetGenerator(nn.Module):
215
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.InstanceNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', use_attention=False):
216
+ assert(n_blocks >= 0)
217
+ super(ResnetGenerator, self).__init__()
218
+ self.input_nc = input_nc
219
+ self.output_nc = output_nc
220
+ self.ngf = ngf
221
+ if type(norm_layer) == functools.partial:
222
+ use_bias = norm_layer.func == nn.InstanceNorm2d
223
+ else:
224
+ use_bias = norm_layer == nn.InstanceNorm2d
225
+
226
+ model = [
227
+ nn.ReflectionPad2d(3),
228
+ nn.Conv2d(
229
+ input_nc, ngf,
230
+ kernel_size=7,
231
+ padding=0,
232
+ bias=use_bias
233
+ ),
234
+ norm_layer(ngf),
235
+ nn.ReLU(True)
236
+ ]
237
+
238
+ n_downsampling = 2
239
+ for i in range(n_downsampling):
240
+ mult = 2**i
241
+ model += [
242
+ nn.Conv2d(
243
+ ngf * mult, ngf * mult * 2, kernel_size=3,
244
+ stride=2, padding=1, bias=use_bias
245
+ ),
246
+ norm_layer(ngf * mult * 2),
247
+ nn.ReLU(True)
248
+ ]
249
+
250
+ mult = 2**n_downsampling
251
+ for i in range(n_blocks):
252
+ model += [
253
+ ResnetBlock(
254
+ ngf * mult,
255
+ padding_type=padding_type,
256
+ norm_layer=norm_layer,
257
+ use_dropout=use_dropout,
258
+ use_bias=use_bias
259
+ )
260
+ ]
261
+
262
+ for i in range(n_downsampling):
263
+ mult = 2**(n_downsampling - i)
264
+ model += [
265
+ nn.ConvTranspose2d(
266
+ ngf * mult, int(ngf * mult / 2),
267
+ kernel_size=3, stride=2,
268
+ padding=1, output_padding=1,
269
+ bias=use_bias
270
+ ),
271
+ norm_layer(int(ngf * mult / 2)),
272
+ nn.ReLU(True)
273
+ ]
274
+
275
+ if use_attention and i==0:
276
+ model += [SelfAttention(128, 'relu')]
277
+
278
+ model += [nn.ReflectionPad2d(3)]
279
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
280
+ model += [nn.Tanh()]
281
+
282
+ self.model = nn.Sequential(*model)
283
+
284
+ def forward(self, input):
285
+ return self.model(input)
286
+
287
+
288
+ # Define a resnet block
289
+ class ResnetBlock(nn.Module):
290
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
291
+ super(ResnetBlock, self).__init__()
292
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
293
+
294
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
295
+ conv_block = []
296
+ p = 0
297
+ if padding_type == 'reflect':
298
+ conv_block += [nn.ReflectionPad2d(1)]
299
+ elif padding_type == 'replicate':
300
+ conv_block += [nn.ReplicationPad2d(1)]
301
+ elif padding_type == 'zero':
302
+ p = 1
303
+ else:
304
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
305
+
306
+ conv_block += [
307
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
308
+ norm_layer(dim),
309
+ nn.ReLU(True)
310
+ ]
311
+ if use_dropout:
312
+ conv_block += [nn.Dropout(0.5)]
313
+
314
+ p = 0
315
+ if padding_type == 'reflect':
316
+ conv_block += [nn.ReflectionPad2d(1)]
317
+ elif padding_type == 'replicate':
318
+ conv_block += [nn.ReplicationPad2d(1)]
319
+ elif padding_type == 'zero':
320
+ p = 1
321
+ else:
322
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
323
+ conv_block += [
324
+ nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
325
+ norm_layer(dim)
326
+ ]
327
+
328
+ return nn.Sequential(*conv_block)
329
+
330
+ def forward(self, x):
331
+ out = x + self.conv_block(x)
332
+ return out
333
+
334
+
335
+ # Defines the Unet generator.
336
+ # |num_downs|: number of downsamplings in UNet. For example,
337
+ # if |num_downs| == 7, image of size 128x128 will become of size 1x1
338
+ # at the bottleneck
339
+ class UnetGenerator(nn.Module):
340
+ def __init__(
341
+ self,
342
+ input_nc,
343
+ output_nc,
344
+ num_downs, ngf=64,
345
+ norm_layer=nn.BatchNorm2d,
346
+ use_dropout=False
347
+ ):
348
+ super(UnetGenerator, self).__init__()
349
+
350
+ # construct unet structure
351
+ unet_block = UnetSkipConnectionBlock(
352
+ ngf * 8,
353
+ ngf * 8,
354
+ input_nc=None,
355
+ submodule=None,
356
+ norm_layer=norm_layer,
357
+ innermost=True
358
+ )
359
+ for i in range(num_downs - 5):
360
+ unet_block = UnetSkipConnectionBlock(
361
+ ngf * 8, ngf * 8,
362
+ input_nc=None,
363
+ submodule=unet_block,
364
+ norm_layer=norm_layer,
365
+ use_dropout=use_dropout
366
+ )
367
+ unet_block = UnetSkipConnectionBlock(
368
+ ngf * 4, ngf * 8,
369
+ input_nc=None,
370
+ submodule=unet_block,
371
+ norm_layer=norm_layer
372
+ )
373
+ unet_block = UnetSkipConnectionBlock(
374
+ ngf * 2, ngf * 4,
375
+ input_nc=None,
376
+ submodule=unet_block,
377
+ norm_layer=norm_layer
378
+ )
379
+ unet_block = UnetSkipConnectionBlock(
380
+ ngf, ngf * 2,
381
+ input_nc=None,
382
+ submodule=unet_block,
383
+ norm_layer=norm_layer
384
+ )
385
+ unet_block = UnetSkipConnectionBlock(
386
+ output_nc, ngf,
387
+ input_nc=input_nc,
388
+ submodule=unet_block,
389
+ outermost=True,
390
+ norm_layer=norm_layer
391
+ )
392
+
393
+ self.model = unet_block
394
+
395
+ def forward(self, input):
396
+ return self.model(input)
397
+
398
+
399
+ # Defines the submodule with skip connection.
400
+ # X -------------------identity---------------------- X
401
+ # |-- downsampling -- |submodule| -- upsampling --|
402
+ class UnetSkipConnectionBlock(nn.Module):
403
+ def __init__(
404
+ self,
405
+ outer_nc,
406
+ inner_nc,
407
+ input_nc=None,
408
+ submodule=None,
409
+ outermost=False,
410
+ innermost=False,
411
+ norm_layer=nn.BatchNorm2d,
412
+ use_dropout=False
413
+ ):
414
+ super(UnetSkipConnectionBlock, self).__init__()
415
+ self.outermost = outermost
416
+ if type(norm_layer) == functools.partial:
417
+ use_bias = norm_layer.func == nn.InstanceNorm2d
418
+ else:
419
+ use_bias = norm_layer == nn.InstanceNorm2d
420
+ if input_nc is None:
421
+ input_nc = outer_nc
422
+ downconv = nn.Conv2d(
423
+ input_nc, inner_nc, kernel_size=4,
424
+ stride=2, padding=1, bias=use_bias
425
+ )
426
+ downrelu = nn.LeakyReLU(0.2, True)
427
+ downnorm = norm_layer(inner_nc)
428
+ uprelu = nn.ReLU(True)
429
+ upnorm = norm_layer(outer_nc)
430
+
431
+ if outermost:
432
+ upconv = nn.ConvTranspose2d(
433
+ inner_nc * 2, outer_nc,
434
+ kernel_size=4, stride=2,
435
+ padding=1
436
+ )
437
+ down = [downconv]
438
+ up = [uprelu, upconv, nn.Tanh()]
439
+ model = down + [submodule] + up
440
+ elif innermost:
441
+ upconv = nn.ConvTranspose2d(
442
+ inner_nc, outer_nc,
443
+ kernel_size=4, stride=2,
444
+ padding=1, bias=use_bias
445
+ )
446
+ down = [downrelu, downconv]
447
+ up = [uprelu, upconv, upnorm]
448
+ model = down + up
449
+ else:
450
+ upconv = nn.ConvTranspose2d(
451
+ inner_nc * 2, outer_nc,
452
+ kernel_size=4, stride=2,
453
+ padding=1, bias=use_bias
454
+ )
455
+ down = [downrelu, downconv, downnorm]
456
+ up = [uprelu, upconv, upnorm]
457
+
458
+ if use_dropout:
459
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
460
+ else:
461
+ model = down + [submodule] + up
462
+
463
+ self.model = nn.Sequential(*model)
464
+
465
+ def forward(self, x):
466
+ if self.outermost:
467
+ return self.model(x)
468
+ else:
469
+ return torch.cat([x, self.model(x)], 1)
470
+
471
+
472
+ # Defines the PatchGAN discriminator with the specified arguments.
473
+ class NLayerDiscriminator(nn.Module):
474
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_attention=False):
475
+ super(NLayerDiscriminator, self).__init__()
476
+ if type(norm_layer) == functools.partial:
477
+ use_bias = norm_layer.func == nn.InstanceNorm2d
478
+ else:
479
+ use_bias = norm_layer == nn.InstanceNorm2d
480
+
481
+ kw = 4
482
+ padw = 1
483
+ sequence = [
484
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
485
+ nn.LeakyReLU(0.2, True)
486
+ ]
487
+
488
+ nf_mult = 1
489
+ nf_mult_prev = 1
490
+ for n in range(1, n_layers):
491
+ nf_mult_prev = nf_mult
492
+ nf_mult = min(2**n, 8)
493
+ sequence += [
494
+ nn.Conv2d(
495
+ ndf * nf_mult_prev, ndf * nf_mult,
496
+ kernel_size=kw, stride=2, padding=padw, bias=use_bias
497
+ ),
498
+ norm_layer(ndf * nf_mult),
499
+ nn.LeakyReLU(0.2, True)
500
+ ]
501
+
502
+ nf_mult_prev = nf_mult
503
+ nf_mult = min(2**n_layers, 8)
504
+ sequence += [
505
+ nn.Conv2d(
506
+ ndf * nf_mult_prev, ndf * nf_mult,
507
+ kernel_size=kw, stride=1,
508
+ padding=padw, bias=use_bias
509
+ ),
510
+ norm_layer(ndf * nf_mult),
511
+ nn.LeakyReLU(0.2, True)
512
+ ]
513
+ if use_attention:
514
+ sequence += [SelfAttention(512, 'relu')]
515
+ sequence += [
516
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
517
+ ]
518
+
519
+ if use_sigmoid:
520
+ sequence += [nn.Sigmoid()]
521
+
522
+ self.model = nn.Sequential(*sequence)
523
+
524
+ def forward(self, input):
525
+ return self.model(input)
526
+
527
+ class NLayerDiscriminatorSN(nn.Module):
528
+ def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, use_attention=False):
529
+ super(NLayerDiscriminatorSN, self).__init__()
530
+ use_bias = False
531
+
532
+ kw = 4
533
+ padw = 1
534
+ sequence = [
535
+ SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
536
+ nn.LeakyReLU(0.2, True)
537
+ ]
538
+
539
+ nf_mult = 1
540
+ nf_mult_prev = 1
541
+ for n in range(1, n_layers):
542
+ nf_mult_prev = nf_mult
543
+ nf_mult = min(2**n, 8)
544
+ sequence += [
545
+ SpectralNorm(
546
+ nn.Conv2d(
547
+ ndf * nf_mult_prev,
548
+ ndf * nf_mult,
549
+ kernel_size=kw, stride=2,
550
+ padding=padw, bias=use_bias
551
+ )
552
+ ),
553
+ nn.LeakyReLU(0.2, True)
554
+ ]
555
+
556
+ nf_mult_prev = nf_mult
557
+ nf_mult = min(2**n_layers, 8)
558
+ sequence += [
559
+ SpectralNorm(
560
+ nn.Conv2d(
561
+ ndf * nf_mult_prev, ndf * nf_mult,
562
+ kernel_size=kw, stride=1, padding=padw, bias=use_bias
563
+ )
564
+ ),
565
+ nn.LeakyReLU(0.2, True)
566
+ ]
567
+ if use_attention:
568
+ sequence += [SelfAttention(512, 'relu')]
569
+ sequence += [SpectralNorm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))]
570
+
571
+ if use_sigmoid:
572
+ sequence += [nn.Sigmoid()]
573
+
574
+ self.model = nn.Sequential(*sequence)
575
+
576
+ def forward(self, input):
577
+ return self.model(input)
578
+
579
+ class PixelDiscriminator(nn.Module):
580
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
581
+ super(PixelDiscriminator, self).__init__()
582
+ if type(norm_layer) == functools.partial:
583
+ use_bias = norm_layer.func == nn.InstanceNorm2d
584
+ else:
585
+ use_bias = norm_layer == nn.InstanceNorm2d
586
+
587
+ self.net = [
588
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
589
+ nn.LeakyReLU(0.2, True),
590
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
591
+ norm_layer(ndf * 2),
592
+ nn.LeakyReLU(0.2, True),
593
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
594
+
595
+ if use_sigmoid:
596
+ self.net.append(nn.Sigmoid())
597
+
598
+ self.net = nn.Sequential(*self.net)
599
+
600
+ def forward(self, input):
601
+ return self.net(input)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ numpy
6
+ huggingface_hub
7
+ einops