File size: 2,286 Bytes
9ba7979
0f5ff39
9ba7979
0f5ff39
9ba7979
 
0f5ff39
9ba7979
 
d0b38dd
9ba7979
 
 
 
0f5ff39
9ba7979
0f5ff39
9ba7979
 
 
 
0f5ff39
 
9ba7979
0f5ff39
9ba7979
 
 
 
0f5ff39
 
9ba7979
 
 
 
 
 
0f5ff39
9ba7979
 
 
 
0f5ff39
9ba7979
 
 
 
 
 
 
0f5ff39
d0b38dd
9ba7979
0f5ff39
d0b38dd
9ba7979
0f5ff39
9ba7979
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch as t, torch.nn as nn, torch.nn.functional as F

def cv(n_i, n_o, **kw): return nn.Conv2d(n_i, n_o, 3, padding=1, **kw)

class C(nn.Module):
    def forward(self, x): return t.tanh(x / 3) * 3

class B(nn.Module):
    def __init__(s, n_i, n_o):
        super().__init__()
        s.c = nn.Sequential(cv(n_i, n_o), nn.ReLU(), cv(n_o, n_o), nn.ReLU(), cv(n_o, n_o))
        s.s = nn.Conv2d(n_i, n_o, 1, bias=False) if n_i != n_o else nn.Identity()
        s.f = nn.ReLU()
    def forward(s, x): return s.f(s.c(x) + s.s(x))

def E(lc=4):
    return nn.Sequential(
        cv(3, 64), B(64, 64), cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
        cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
        cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
        cv(64, lc),
    )

def D(lc=16):
    return nn.Sequential(
        C(), cv(lc, 48), nn.ReLU(), B(48, 48), B(48, 48), nn.Upsample(scale_factor=2),
        cv(48, 48, bias=False), B(48, 48), B(48, 48), nn.Upsample(scale_factor=2),
        cv(48, 48, bias=False), B(48, 48), nn.Upsample(scale_factor=2),
        cv(48, 48, bias=False), B(48, 48), cv(48, 3),
    )

class M(nn.Module):
    lm, ls = 3, 0.5
    def __init__(s, ep="encoder.pth", dp="decoder.pth", lc=None):
        super().__init__()
        if lc is None: lc = s.glc(str(ep))
        s.e, s.d = E(lc), D(lc)

        def f(sd, mod, pfx):
            f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
            print(f"num keys: {len(f_sd)} of {len(mod.state_dict())}")
            mod.load_state_dict(f_sd, strict=False)

        if ep: f(t.load(ep, map_location="cpu", weights_only=True), s.e, "encoder.")
        if dp: f(t.load(dp, map_location="cpu", weights_only=True), s.d, "decoder.")

        s.e.requires_grad_(False)
        s.d.requires_grad_(False)

    def glc(s, ep): return 16 if "taef1" in ep or "taesd3" in ep else 4

    @staticmethod
    def sl(x): return x.div(2 * M.lm).add(M.ls).clamp(0, 1)

    @staticmethod
    def ul(x): return x.sub(M.ls).mul(2 * M.lm)

    def forward(s, x, rl=False):
        l, o = s.e(x), s.d(s.e(x))
        return (o.clamp(0, 1), l) if rl else o.clamp(0, 1)