SeaSky1027 commited on
Commit
cbf69ea
·
verified ·
1 Parent(s): 7ffe09b

Update gan_model.py

Browse files
Files changed (1) hide show
  1. gan_model.py +0 -104
gan_model.py CHANGED
@@ -1,5 +1,3 @@
1
- import logging
2
-
3
  import torch
4
  from torch import nn
5
  from torch.nn import init
@@ -216,105 +214,3 @@ class Generator(nn.Module):
216
  out = self.colorize(out) # [bsz, 1, 80, 344]
217
 
218
  return out
219
-
220
-
221
- class Discriminator(nn.Module):
222
- def __init__(self, in_channel=1, channel=32, num_classes=7, embedding_dim=128):
223
- super().__init__()
224
- self.num_classes = num_classes
225
-
226
- def conv(in_channel, out_channel, downsample=True):
227
- return ConvBlock(in_channel, out_channel,
228
- bn=False,
229
- upsample=False, downsample=downsample)
230
-
231
- gain = 2 ** 0.5
232
-
233
- self.pre_conv = nn.Sequential(spectral_init(nn.Conv2d(in_channel, channel, 3,
234
- padding=1),
235
- gain=gain),
236
- nn.ReLU(),
237
- spectral_init(nn.Conv2d(channel, channel, 3,
238
- padding=1),
239
- gain=gain),
240
- nn.AvgPool2d(2))
241
- self.pre_skip = spectral_init(nn.Conv2d(in_channel, channel, 1))
242
-
243
- self.conv1 = conv(channel, channel * 2)
244
- self.conv2 = conv(channel * 2, channel * 2, downsample=False)
245
- self.attention = SelfAttention(channel * 2)
246
- self.conv3 = conv(channel * 2, channel * 4)
247
- self.conv4 = conv(channel * 4, channel * 4)
248
-
249
- self.linear = spectral_init(nn.Linear(channel * 4, 1))
250
-
251
- self.projection = nn.Sequential(
252
- spectral_init(nn.Linear(channel * 4, channel * 4)),
253
- nn.ReLU(),
254
- spectral_init(nn.Linear(channel * 4, channel * 4))
255
- )
256
-
257
- self.embedding = spectral_norm(nn.Embedding(num_embeddings=num_classes, embedding_dim=channel * 4))
258
-
259
- def forward(self, input, label):
260
- out = self.pre_conv(input) # [bsz, 32, 40, 172]
261
- out = out + self.pre_skip(F.avg_pool2d(input, 2)) # [bsz, 32, 40, 172]
262
-
263
- out = self.conv1(out) # [bsz, 64, 20, 86]
264
- out = self.conv2(out) # [bsz, 64, 20, 86]
265
- out, attention_map = self.attention(out) # [bsz, 64, 20, 86]
266
- out = self.conv3(out) # [bsz, 128, 10, 43]
267
- out = self.conv4(out) # [bsz, 128, 5, 21]
268
-
269
- out = F.relu(out)
270
- out = out.view(out.size(0), out.size(1), -1) # [bsz, 128, 105]
271
- out = out.sum(2) # [bsz, 128]
272
- adv_output = self.linear(out).squeeze(1) # [bsz, 1]
273
-
274
- condition = self.embedding(label) # [bsz, 128]
275
- prod = (out * condition).sum(1) # [bsz, 1]
276
- adv_output += prod
277
-
278
- contrastive_feature = self.projection(out) # [bsz, 128]
279
-
280
- return adv_output, contrastive_feature, condition
281
-
282
- def count_parameters(module):
283
- num_params = sum(p.numel() for p in module.parameters())
284
- return num_params
285
-
286
- if __name__ == '__main__':
287
- from HiFiGanWrapper import HiFiGanWrapper
288
- import numpy as np
289
-
290
- generator = Generator().eval()
291
- num_params = count_parameters(generator)
292
- print(f"Number of generator parameters: {num_params / 1000000:.2f} M")
293
- print()
294
-
295
- discriminator = Discriminator().eval()
296
- num_params = count_parameters(discriminator)
297
- print(f"Number of discriminator parameters: {num_params / 1000000:.2f} M")
298
- print()
299
-
300
- vocoder = HiFiGanWrapper(ckpt_path='./pretrained_checkpoints')
301
- num_params = count_parameters(vocoder.generator)
302
- print(f"Number of vocoder parameters: {num_params / 1000000:.2f} M")
303
- print()
304
-
305
- image = torch.randn(4, 1, 80, 344)
306
- labels = torch.LongTensor([0, 0, 1, 2])
307
-
308
- out, contrastive_feature, proxy = discriminator(image, labels)
309
- print('discriminator :', out.shape)
310
- print('contrastive_feature :', contrastive_feature.shape)
311
- print('proxy :', proxy.shape)
312
- print()
313
-
314
- out = generator(labels)
315
- print('generator :', out.shape)
316
- print()
317
-
318
- fake_sound = vocoder.generate_audio(out[0])
319
- fake_sound = np.concatenate((fake_sound, fake_sound[-136:]), axis=0)
320
- print('generated sound :', fake_sound.shape)
 
 
 
1
  import torch
2
  from torch import nn
3
  from torch.nn import init
 
214
  out = self.colorize(out) # [bsz, 1, 80, 344]
215
 
216
  return out