cfm_svc / vits_decoder /discriminator.py
Hector Li
Initial commit for Hugging Face
df93d13
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from .msd import ScaleDiscriminator
from .mpd import MultiPeriodDiscriminator
from .mrd import MultiResolutionDiscriminator
class Discriminator(nn.Module):
def __init__(self, hp):
super(Discriminator, self).__init__()
self.MRD = MultiResolutionDiscriminator(hp)
self.MPD = MultiPeriodDiscriminator(hp)
self.MSD = ScaleDiscriminator()
def forward(self, x):
r = self.MRD(x)
p = self.MPD(x)
s = self.MSD(x)
return r + p + s
if __name__ == '__main__':
hp = OmegaConf.load('../config/base.yaml')
model = Discriminator(hp)
x = torch.randn(3, 1, 16384)
print(x.shape)
output = model(x)
for features, score in output:
for feat in features:
print(feat.shape)
print(score.shape)
pytorch_total_params = sum(p.numel()
for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)