File size: 2,286 Bytes
9ba7979 0f5ff39 9ba7979 0f5ff39 9ba7979 0f5ff39 9ba7979 d0b38dd 9ba7979 0f5ff39 9ba7979 0f5ff39 9ba7979 0f5ff39 9ba7979 0f5ff39 9ba7979 0f5ff39 9ba7979 0f5ff39 9ba7979 0f5ff39 9ba7979 0f5ff39 d0b38dd 9ba7979 0f5ff39 d0b38dd 9ba7979 0f5ff39 9ba7979 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import torch as t, torch.nn as nn, torch.nn.functional as F
def cv(n_i, n_o, **kw): return nn.Conv2d(n_i, n_o, 3, padding=1, **kw)
class C(nn.Module):
def forward(self, x): return t.tanh(x / 3) * 3
class B(nn.Module):
def __init__(s, n_i, n_o):
super().__init__()
s.c = nn.Sequential(cv(n_i, n_o), nn.ReLU(), cv(n_o, n_o), nn.ReLU(), cv(n_o, n_o))
s.s = nn.Conv2d(n_i, n_o, 1, bias=False) if n_i != n_o else nn.Identity()
s.f = nn.ReLU()
def forward(s, x): return s.f(s.c(x) + s.s(x))
def E(lc=4):
return nn.Sequential(
cv(3, 64), B(64, 64), cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
cv(64, lc),
)
def D(lc=16):
return nn.Sequential(
C(), cv(lc, 48), nn.ReLU(), B(48, 48), B(48, 48), nn.Upsample(scale_factor=2),
cv(48, 48, bias=False), B(48, 48), B(48, 48), nn.Upsample(scale_factor=2),
cv(48, 48, bias=False), B(48, 48), nn.Upsample(scale_factor=2),
cv(48, 48, bias=False), B(48, 48), cv(48, 3),
)
class M(nn.Module):
lm, ls = 3, 0.5
def __init__(s, ep="encoder.pth", dp="decoder.pth", lc=None):
super().__init__()
if lc is None: lc = s.glc(str(ep))
s.e, s.d = E(lc), D(lc)
def f(sd, mod, pfx):
f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
print(f"num keys: {len(f_sd)} of {len(mod.state_dict())}")
mod.load_state_dict(f_sd, strict=False)
if ep: f(t.load(ep, map_location="cpu", weights_only=True), s.e, "encoder.")
if dp: f(t.load(dp, map_location="cpu", weights_only=True), s.d, "decoder.")
s.e.requires_grad_(False)
s.d.requires_grad_(False)
def glc(s, ep): return 16 if "taef1" in ep or "taesd3" in ep else 4
@staticmethod
def sl(x): return x.div(2 * M.lm).add(M.ls).clamp(0, 1)
@staticmethod
def ul(x): return x.sub(M.ls).mul(2 * M.lm)
def forward(s, x, rl=False):
l, o = s.e(x), s.d(s.e(x))
return (o.clamp(0, 1), l) if rl else o.clamp(0, 1) |