File size: 2,396 Bytes
95b1715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import math

import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F

from models.psp.stylegan2.model import EqualLinear, PixelNorm

class Mapper(Module):

    def __init__(self, in_channel=512, out_channel=512, norm=True, num_layers=4):
        super(Mapper, self).__init__()

        layers = [PixelNorm()] if norm else []
        
        layers.append(EqualLinear(in_channel, out_channel, lr_mul=0.01, activation='fused_lrelu'))
        for _ in range(num_layers-1):
            layers.append(EqualLinear(out_channel, out_channel, lr_mul=0.01, activation='fused_lrelu'))
        self.mapping = nn.Sequential(*layers)

    def forward(self, x):
        x = self.mapping(x)
        return x

class DeltaMapper(Module):

    def __init__(self):
        super(DeltaMapper, self).__init__()

        #Style Module(sm)
        self.sm_coarse = Mapper(512,  512)
        self.sm_medium = Mapper(512,  512)
        self.sm_fine   = Mapper(2464, 2464)

        #Condition Module(cm)
        self.cm_coarse = Mapper(1024, 512)
        self.cm_medium = Mapper(1024, 512)
        self.cm_fine   = Mapper(1024, 2464)

        #Fusion Module(fm)
        self.fm_coarse = Mapper(512*2,  512,  norm=False)
        self.fm_medium = Mapper(512*2,  512,  norm=False)
        self.fm_fine   = Mapper(2464*2, 2464, norm=False)
        
    def forward(self, sspace_feat, clip_feat):

        s_coarse = sspace_feat[:, :3*512].view(-1,3,512)
        s_medium = sspace_feat[:, 3*512:7*512].view(-1,4,512)
        s_fine   = sspace_feat[:, 7*512:] #channels:2464

        s_coarse = self.sm_coarse(s_coarse)
        s_medium = self.sm_medium(s_medium)
        s_fine   = self.sm_fine(s_fine)

        c_coarse = self.cm_coarse(clip_feat)
        c_medium = self.cm_medium(clip_feat)
        c_fine   = self.cm_fine(clip_feat)

        x_coarse = torch.cat([s_coarse, torch.stack([c_coarse]*3, dim=1)], dim=2) #[b,3,1024]
        x_medium = torch.cat([s_medium, torch.stack([c_medium]*4, dim=1)], dim=2) #[b,4,1024]
        x_fine   = torch.cat([s_fine, c_fine], dim=1) #[b,2464*2]

        x_coarse = self.fm_coarse(x_coarse)
        x_coarse = x_coarse.view(-1,3*512)

        x_medium = self.fm_medium(x_medium)
        x_medium = x_medium.view(-1,4*512)

        x_fine   = self.fm_fine(x_fine)

        out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
        return out