Spaces:
Configuration error
Configuration error
| from torch import nn | |
| from visualizr.networks.encoder import Encoder | |
| from visualizr.networks.styledecoder import Synthesis | |
| class Generator(nn.Module): | |
| def __init__( | |
| self, | |
| size, | |
| style_dim=512, | |
| motion_dim=20, | |
| channel_multiplier=1, | |
| blur_kernel=[1, 3, 3, 1], | |
| ): | |
| super(Generator, self).__init__() | |
| # encoder | |
| self.enc = Encoder(size, style_dim, motion_dim) | |
| self.dec = Synthesis( | |
| size, style_dim, motion_dim, blur_kernel, channel_multiplier | |
| ) | |
| def get_direction(self): | |
| return self.dec.direction(None) | |
| def synthesis(self, wa, alpha, feat): | |
| img = self.dec(wa, alpha, feat) | |
| return img | |
| def forward(self, img_source, img_drive, h_start=None): | |
| wa, alpha, feats = self.enc(img_source, img_drive, h_start) | |
| # import pdb;pdb.set_trace() | |
| img_recon = self.dec(wa, alpha, feats) | |
| return img_recon | |