RAIL-KNUST commited on
Commit
fc931ab
·
verified ·
1 Parent(s): c2714c7

Uploading logic files

Browse files
Files changed (5) hide show
  1. README.md +5 -6
  2. app.py +135 -0
  3. gitattributes +35 -0
  4. networks.py +394 -0
  5. requirements.txt +7 -0
README.md CHANGED
@@ -1,14 +1,13 @@
1
  ---
2
- title: Attndecgan
3
- emoji: 😻
4
  colorFrom: gray
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.5.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- short_description: Att-DeCGAN demo
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Attndecgan Demo
3
+ emoji: 🔥
4
  colorFrom: gray
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.41.1
8
  app_file: app.py
9
  pinned: false
10
+ short_description: A demo for the attention diversity enhanced GAN
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ import numpy as np
7
+ from huggingface_hub import hf_hub_download
8
+ import os
9
+
10
+ # Import your networks (you'll need to upload networks.py to your Space)
11
+ from networks import HPBGenerator # Adjust this import based on your networks.py structure
12
+
13
+ class CycleGANInference:
14
+ def __init__(self, model_repo_id, checkpoint_filename_AtoB, checkpoint_filename_BtoA=None):
15
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+
17
+ # Download model checkpoints from Hugging Face Hub
18
+ checkpoint_path_AtoB = hf_hub_download(
19
+ repo_id=model_repo_id,
20
+ filename=checkpoint_filename_AtoB
21
+ )
22
+
23
+ # Initialize generators
24
+ # Adjust these parameters based on your model architecture
25
+ self.netG_A2B = HPBGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # A to B
26
+
27
+ if checkpoint_filename_BtoA:
28
+ checkpoint_path_BtoA = hf_hub_download(
29
+ repo_id=model_repo_id,
30
+ filename=checkpoint_filename_BtoA
31
+ )
32
+ self.netG_B2A = HPBGenerator(input_nc=3, output_nc=3, ngf=64, n_blocks=9) # B to A
33
+ else:
34
+ self.netG_B2A = None
35
+
36
+ # Load model weights
37
+ self.netG_A2B.load_state_dict(torch.load(checkpoint_path_AtoB, map_location=self.device))
38
+ if self.netG_B2A and checkpoint_filename_BtoA:
39
+ self.netG_B2A.load_state_dict(torch.load(checkpoint_path_BtoA, map_location=self.device))
40
+
41
+ # Set to evaluation mode
42
+ self.netG_A2B.eval()
43
+ if self.netG_B2A:
44
+ self.netG_B2A.eval()
45
+
46
+ # Move to device
47
+ self.netG_A2B.to(self.device)
48
+ if self.netG_B2A:
49
+ self.netG_B2A.to(self.device)
50
+
51
+ # Define transforms
52
+ self.transform = transforms.Compose([
53
+ transforms.Resize((256, 256)), # Adjust size based on your model
54
+ transforms.ToTensor(),
55
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
56
+ ])
57
+
58
+ self.inverse_transform = transforms.Compose([
59
+ transforms.Normalize((-1, -1, -1), (2, 2, 2)), # Denormalize
60
+ transforms.ToPILImage()
61
+ ])
62
+
63
+ def transform_image(self, image, direction="A_to_B"):
64
+ # Preprocess
65
+ input_tensor = self.transform(image).unsqueeze(0).to(self.device)
66
+
67
+ with torch.no_grad():
68
+ if direction == "A_to_B":
69
+ output_tensor = self.netG_A2B(input_tensor)
70
+ elif direction == "B_to_A" and self.netG_B2A:
71
+ output_tensor = self.netG_B2A(input_tensor)
72
+ else:
73
+ raise ValueError("Invalid direction or model not available")
74
+
75
+ # Postprocess
76
+ output_image = self.inverse_transform(output_tensor.squeeze(0).cpu())
77
+ return output_image
78
+
79
+ # Initialize your model
80
+ # Replace these with your actual Hugging Face repo ID and checkpoint filenames
81
+ MODEL_REPO_ID = "profmatthew/Attn-DeCGAN" # Replace with your repo
82
+ CHECKPOINT_A2B = "latest_net_G_A.pth" # Replace with your checkpoint filename
83
+ CHECKPOINT_B2A = "latest_net_G_B.pth" # Replace with your checkpoint filename (optional)
84
+
85
+ cyclegan_model = CycleGANInference(
86
+ model_repo_id=MODEL_REPO_ID,
87
+ checkpoint_filename_AtoB=CHECKPOINT_A2B,
88
+ checkpoint_filename_BtoA=CHECKPOINT_B2A # Set to None if you only have one direction
89
+ )
90
+
91
+ def generate_image(input_image, direction):
92
+ try:
93
+ output_image = cyclegan_model.transform_image(input_image, direction)
94
+ return output_image
95
+ except Exception as e:
96
+ return f"Error: {str(e)}"
97
+
98
+ # Create Gradio interface
99
+ with gr.Blocks(title="CycleGAN Image Translation") as demo:
100
+ gr.Markdown("# CycleGAN Image Translation")
101
+ gr.Markdown("Upload an image and select the transformation direction.")
102
+
103
+ with gr.Row():
104
+ with gr.Column():
105
+ input_image = gr.Image(type="pil", label="Input Image")
106
+ direction = gr.Dropdown(
107
+ choices=["A_to_B", "B_to_A"],
108
+ value="A_to_B",
109
+ label="Translation Direction"
110
+ )
111
+ generate_btn = gr.Button("Generate", variant="primary")
112
+
113
+ with gr.Column():
114
+ output_image = gr.Image(type="pil", label="Generated Image")
115
+
116
+ generate_btn.click(
117
+ fn=generate_image,
118
+ inputs=[input_image, direction],
119
+ outputs=output_image
120
+ )
121
+
122
+ # Add some examples if you have them
123
+ # gr.Examples(
124
+ # examples=[
125
+ # # Add paths to example images here
126
+ # # ["example1.jpg", "A_to_B"],
127
+ # # ["example2.jpg", "B_to_A"],
128
+ # ],
129
+ # inputs=[input_image, direction],
130
+ # outputs=output_image,
131
+ # fn=generate_image,
132
+ # )
133
+
134
+ if __name__ == "__main__":
135
+ demo.launch()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
networks.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+ import torch.nn.functional as F
7
+ from torch import nn, einsum
8
+ from einops import rearrange, reduce, repeat
9
+
10
+ ###############################################################################
11
+ # Helper Functions
12
+ ###############################################################################
13
+
14
+
15
+ def get_norm_layer(norm_type='instance'):
16
+ if norm_type == 'batch':
17
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
18
+ elif norm_type == 'instance':
19
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
20
+ elif norm_type == 'none':
21
+ norm_layer = None
22
+ else:
23
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
24
+ return norm_layer
25
+
26
+
27
+ def get_scheduler(optimizer, opt):
28
+ if opt.lr_policy == 'lambda':
29
+ def lambda_rule(epoch):
30
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
31
+ return lr_l
32
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
33
+ elif opt.lr_policy == 'step':
34
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
35
+ elif opt.lr_policy == 'plateau':
36
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
37
+ elif opt.lr_policy == 'cosine':
38
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
39
+ else:
40
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
41
+ return scheduler
42
+
43
+
44
+ def init_weights(net, init_type='normal', gain=0.02):
45
+ def init_func(m):
46
+ classname = m.__class__.__name__
47
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
48
+ if init_type == 'normal':
49
+ init.normal_(m.weight.data, 0.0, gain)
50
+ elif init_type == 'xavier':
51
+ init.xavier_normal_(m.weight.data, gain=gain)
52
+ elif init_type == 'kaiming':
53
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
54
+ elif init_type == 'orthogonal':
55
+ init.orthogonal_(m.weight.data, gain=gain)
56
+ else:
57
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
58
+ if hasattr(m, 'bias') and m.bias is not None:
59
+ init.constant_(m.bias.data, 0.0)
60
+ elif classname.find('BatchNorm2d') != -1:
61
+ init.normal_(m.weight.data, 1.0, gain)
62
+ init.constant_(m.bias.data, 0.0)
63
+
64
+ print('initialize network with %s' % init_type)
65
+ net.apply(init_func)
66
+
67
+
68
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
69
+ if len(gpu_ids) > 0:
70
+ assert(torch.cuda.is_available())
71
+ net.to(gpu_ids[0])
72
+ net = torch.nn.DataParallel(net, gpu_ids)
73
+ init_weights(net, init_type, gain=init_gain)
74
+ return net
75
+
76
+
77
+
78
+
79
+ ##############################################################################
80
+ # Classes
81
+ ##############################################################################
82
+
83
+
84
+ # Defines the GAN loss which uses either LSGAN or the regular GAN.
85
+ # When LSGAN is used, it is basically same as MSELoss,
86
+ # but it abstracts away the need to create the target label tensor
87
+ # that has the same size as the input
88
+ class GANLoss(nn.Module):
89
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
90
+ super(GANLoss, self).__init__()
91
+ self.register_buffer('real_label', torch.tensor(target_real_label))
92
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
93
+ if use_lsgan:
94
+ self.loss = nn.MSELoss()
95
+ else:
96
+ self.loss = nn.BCELoss()
97
+ def get_target_tensor(self, input, target_is_real):
98
+ if target_is_real:
99
+ target_tensor = self.real_label
100
+ else:
101
+ target_tensor = self.fake_label
102
+ return target_tensor.expand_as(input)
103
+
104
+ def __call__(self, input, target_is_real):
105
+ target_tensor = self.get_target_tensor(input, target_is_real)
106
+ return self.loss(input, target_tensor)
107
+
108
+ #################################################################################
109
+ # Critic Loss for Wassertein Gan GP #
110
+ #################################################################################
111
+ class GradPenalty(nn.Module):
112
+ def __init__(self, use_cuda):
113
+ super(GradPenalty, self).__init__()
114
+ self.use_cuda = use_cuda
115
+ def forward(self, critic, real_data, fake_data):
116
+ alpha = torch.rand_like(real_data)
117
+
118
+ assignGPU = lambda x: x.cuda() if self.use_cuda else x
119
+ alpha = assignGPU(alpha)
120
+
121
+ interpolates = alpha*real_data + (1-alpha)*fake_data.detach()
122
+ interpolates = assignGPU(interpolates)
123
+ interpolates = torch.autograd.Variable(interpolates, requires_grad = True)
124
+
125
+ critic_interpolates = critic(interpolates)
126
+
127
+ gradients = torch.autograd.grad(
128
+ outputs=critic_interpolates,
129
+ inputs=interpolates,
130
+ grad_outputs=assignGPU(torch.ones(critic_interpolates.size())),
131
+ create_graph=True, retain_graph=True, only_inputs=True
132
+ )[0]
133
+ gradients = gradients.view(gradients.size(0), -1)
134
+ gradient_penalty = ((gradients.norm(2, dim=1)-1)**2).mean()
135
+ return gradient_penalty
136
+
137
+ #####
138
+ #####
139
+
140
+ #################################################################################
141
+ # Hybrid Perception Block and DPSA LAyer #
142
+ #################################################################################
143
+
144
+
145
+ # helper functions
146
+
147
+ def exists(val):
148
+ return val is not None
149
+
150
+ def default(val, d):
151
+ return val if exists(val) else d
152
+
153
+ def l2norm(t):
154
+ return F.normalize(t, dim = -1)
155
+
156
+ # helper classes
157
+
158
+ class Residual(nn.Module):
159
+ def __init__(self, fn):
160
+ super().__init__()
161
+ self.fn = fn
162
+
163
+ def forward(self, x, **kwargs):
164
+ return self.fn(x, **kwargs) + x
165
+
166
+ class ChanLayerNorm(nn.Module):
167
+ def __init__(self, dim, eps = 1e-5):
168
+ super().__init__()
169
+ self.eps = eps
170
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
171
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
172
+
173
+ def forward(self, x):
174
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
175
+ mean = torch.mean(x, dim = 1, keepdim = True)
176
+ return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
177
+
178
+ # classes
179
+
180
+ class HPB(nn.Module):
181
+ """ Hybrid Perception Block """
182
+
183
+ def __init__(
184
+ self,
185
+ dim,
186
+ dim_head = 32,
187
+ heads = 8,
188
+ ff_mult = 4,
189
+ attn_height_top_k = 8,
190
+ attn_width_top_k = 8,
191
+ attn_dropout = 0.,
192
+ ff_dropout = 0.
193
+ ):
194
+ super().__init__()
195
+
196
+ self.attn = DPSA(
197
+ dim = dim,
198
+ heads = heads,
199
+ dim_head = dim_head,
200
+ height_top_k = attn_height_top_k,
201
+ width_top_k = attn_width_top_k,
202
+ dropout = attn_dropout
203
+ )
204
+
205
+ self.dwconv = nn.Conv2d(dim, dim, 3, padding = 1, groups = dim)
206
+ self.attn_parallel_combine_out = nn.Conv2d(dim * 2, dim, 1)
207
+
208
+ ff_inner_dim = dim * ff_mult
209
+
210
+ self.ff = nn.Sequential(
211
+ nn.Conv2d(dim, ff_inner_dim, 1),
212
+ nn.InstanceNorm2d(ff_inner_dim),
213
+ nn.GELU(),
214
+ nn.Dropout(ff_dropout),
215
+ Residual(nn.Sequential(
216
+ nn.Conv2d(ff_inner_dim, ff_inner_dim, 3, padding = 1, groups = ff_inner_dim),
217
+ nn.InstanceNorm2d(ff_inner_dim),
218
+ nn.GELU(),
219
+ nn.Dropout(ff_dropout)
220
+ )),
221
+ nn.Conv2d(ff_inner_dim, dim, 1),
222
+ nn.InstanceNorm2d(ff_inner_dim)
223
+ )
224
+
225
+ def forward(self, x):
226
+ attn_branch_out = self.attn(x)
227
+ conv_branch_out = self.dwconv(x)
228
+
229
+ concatted_branches = torch.cat((attn_branch_out, conv_branch_out), dim = 1)
230
+ attn_out = self.attn_parallel_combine_out(concatted_branches) + x
231
+
232
+ return self.ff(attn_out)
233
+
234
+ class DPSA(nn.Module):
235
+ """ Dual-pruned Self-attention Block """
236
+
237
+ def __init__(
238
+ self,
239
+ dim,
240
+ height_top_k = 8,
241
+ width_top_k = 8,
242
+ dim_head = 32,
243
+ heads = 8,
244
+ dropout = 0.
245
+ ):
246
+ super().__init__()
247
+ self.heads = heads
248
+ self.dim_head = dim_head
249
+ self.scale = dim_head ** -0.5
250
+ inner_dim = heads * dim_head
251
+
252
+ self.norm = ChanLayerNorm(dim)
253
+ self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
254
+
255
+ self.height_top_k = height_top_k
256
+ self.width_top_k = width_top_k
257
+
258
+ self.dropout = nn.Dropout(dropout)
259
+ self.to_out = nn.Conv2d(inner_dim, dim, 1)
260
+
261
+ def forward(self, x):
262
+ b, c, h, w = x.shape
263
+
264
+ x = self.norm(x)
265
+
266
+ q, k, v = self.to_qkv(x).chunk(3, dim = 1)
267
+
268
+ # fold out heads
269
+
270
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) c x y', h = self.heads), (q, k, v))
271
+
272
+ # they used l2 normalized queries and keys, cosine sim attention basically
273
+
274
+ q, k = map(l2norm, (q, k))
275
+
276
+ # calculate whether to select and rank along height and width
277
+
278
+ need_height_select_and_rank = self.height_top_k < h
279
+ need_width_select_and_rank = self.width_top_k < w
280
+
281
+ # select and rank keys / values, probing with query (reduced along height and width) and keys reduced along row and column respectively
282
+
283
+ if need_width_select_and_rank or need_height_select_and_rank:
284
+ q_probe = reduce(q, 'b h w d -> b d', 'sum')
285
+
286
+ # gather along height, then width
287
+
288
+ if need_height_select_and_rank:
289
+ k_height = reduce(k, 'b h w d -> b h d', 'sum')
290
+
291
+ top_h_indices = einsum('b d, b h d -> b h', q_probe, k_height).topk(k = self.height_top_k, dim = -1).indices
292
+
293
+ top_h_indices = repeat(top_h_indices, 'b h -> b h w d', d = self.dim_head, w = k.shape[-2])
294
+
295
+ k, v = map(lambda t: t.gather(1, top_h_indices), (k, v)) # first gather across height
296
+
297
+ if need_width_select_and_rank:
298
+ k_width = reduce(k, 'b h w d -> b w d', 'sum')
299
+
300
+ top_w_indices = einsum('b d, b w d -> b w', q_probe, k_width).topk(k = self.width_top_k, dim = -1).indices
301
+
302
+ top_w_indices = repeat(top_w_indices, 'b w -> b h w d', d = self.dim_head, h = k.shape[1])
303
+
304
+ k, v = map(lambda t: t.gather(2, top_w_indices), (k, v)) # then gather along width
305
+
306
+ # select the appropriate keys and values
307
+
308
+ q, k, v = map(lambda t: rearrange(t, 'b ... d -> b (...) d'), (q, k, v))
309
+
310
+ # cosine similarities
311
+
312
+ sim = einsum('b i d, b j d -> b i j', q, k)
313
+
314
+ # attention
315
+
316
+ attn = sim.softmax(dim = -1)
317
+ attn = self.dropout(attn)
318
+
319
+ # aggregate out
320
+
321
+ out = einsum('b i j, b j d -> b i d', attn, v)
322
+
323
+ # merge heads and combine out
324
+
325
+ out = rearrange(out, '(b h) (x y) d -> b (h d) x y', x = h, y = w, h = self.heads)
326
+ return self.to_out(out)
327
+
328
+ #####
329
+ #####
330
+
331
+ # New HybridPerceptionBlockGenerator
332
+
333
+ class HPBGenerator(nn.Module):
334
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.InstanceNorm2d, use_dropout=False, n_blocks=9, padding_type='reflect'):
335
+ assert(n_blocks >= 0)
336
+ super(HPBGenerator, self).__init__()
337
+ self.input_nc = input_nc
338
+ self.output_nc = output_nc
339
+ self.ngf = ngf
340
+ if type(norm_layer) == functools.partial:
341
+ use_bias = norm_layer.func == nn.InstanceNorm2d
342
+ else:
343
+ use_bias = norm_layer == nn.InstanceNorm2d
344
+
345
+ model = [
346
+ nn.ReflectionPad2d(3),
347
+ nn.Conv2d(
348
+ input_nc, ngf,
349
+ kernel_size=7,
350
+ padding=0,
351
+ bias=use_bias
352
+ ),
353
+ norm_layer(ngf),
354
+ nn.GELU()
355
+ ]
356
+
357
+ n_downsampling = 2
358
+ for i in range(n_downsampling):
359
+ mult = 2**i
360
+ model += [
361
+ nn.Conv2d(
362
+ ngf * mult, ngf * mult * 2, kernel_size=3,
363
+ stride=2, padding=1, bias=use_bias
364
+ ),
365
+ norm_layer(ngf * mult * 2),
366
+ nn.GELU()
367
+ ]
368
+
369
+ mult = 2**n_downsampling
370
+ for i in range(n_blocks):
371
+ model += [
372
+ HPB(ngf * mult, ngf)
373
+ ]
374
+
375
+ for i in range(n_downsampling):
376
+ mult = 2**(n_downsampling - i)
377
+ model += [
378
+ nn.ConvTranspose2d(
379
+ ngf * mult, int(ngf * mult / 2),
380
+ kernel_size=3, stride=2,
381
+ padding=1, output_padding=1,
382
+ bias=use_bias
383
+ ),
384
+ norm_layer(int(ngf * mult / 2)),
385
+ nn.GELU()
386
+ ]
387
+ model += [nn.ReflectionPad2d(3)]
388
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
389
+ model += [nn.Tanh()]
390
+
391
+ self.model = nn.Sequential(*model)
392
+
393
+ def forward(self, input):
394
+ return self.model(input)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow
5
+ numpy
6
+ huggingface_hub
7
+ einops