Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| from torch import nn | |
| from torch.nn import Module | |
| import torch.nn.functional as F | |
| from models.psp.stylegan2.model import EqualLinear, PixelNorm | |
| class Mapper(Module): | |
| def __init__(self, in_channel=512, out_channel=512, norm=True, num_layers=4): | |
| super(Mapper, self).__init__() | |
| layers = [PixelNorm()] if norm else [] | |
| layers.append(EqualLinear(in_channel, out_channel, lr_mul=0.01, activation='fused_lrelu')) | |
| for _ in range(num_layers-1): | |
| layers.append(EqualLinear(out_channel, out_channel, lr_mul=0.01, activation='fused_lrelu')) | |
| self.mapping = nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.mapping(x) | |
| return x | |
| class DeltaMapper(Module): | |
| def __init__(self): | |
| super(DeltaMapper, self).__init__() | |
| #Style Module(sm) | |
| self.sm_coarse = Mapper(512, 512) | |
| self.sm_medium = Mapper(512, 512) | |
| self.sm_fine = Mapper(2464, 2464) | |
| #Condition Module(cm) | |
| self.cm_coarse = Mapper(1024, 512) | |
| self.cm_medium = Mapper(1024, 512) | |
| self.cm_fine = Mapper(1024, 2464) | |
| #Fusion Module(fm) | |
| self.fm_coarse = Mapper(512*2, 512, norm=False) | |
| self.fm_medium = Mapper(512*2, 512, norm=False) | |
| self.fm_fine = Mapper(2464*2, 2464, norm=False) | |
| def forward(self, sspace_feat, clip_feat): | |
| s_coarse = sspace_feat[:, :3*512].view(-1,3,512) | |
| s_medium = sspace_feat[:, 3*512:7*512].view(-1,4,512) | |
| s_fine = sspace_feat[:, 7*512:] #channels:2464 | |
| s_coarse = self.sm_coarse(s_coarse) | |
| s_medium = self.sm_medium(s_medium) | |
| s_fine = self.sm_fine(s_fine) | |
| c_coarse = self.cm_coarse(clip_feat) | |
| c_medium = self.cm_medium(clip_feat) | |
| c_fine = self.cm_fine(clip_feat) | |
| x_coarse = torch.cat([s_coarse, torch.stack([c_coarse]*3, dim=1)], dim=2) #[b,3,1024] | |
| x_medium = torch.cat([s_medium, torch.stack([c_medium]*4, dim=1)], dim=2) #[b,4,1024] | |
| x_fine = torch.cat([s_fine, c_fine], dim=1) #[b,2464*2] | |
| x_coarse = self.fm_coarse(x_coarse) | |
| x_coarse = x_coarse.view(-1,3*512) | |
| x_medium = self.fm_medium(x_medium) | |
| x_medium = x_medium.view(-1,4*512) | |
| x_fine = self.fm_fine(x_fine) | |
| out = torch.cat([x_coarse, x_medium, x_fine], dim=1) | |
| return out |