manbeast3b commited on
Commit
a89c2e3
·
1 Parent(s): ad160ca
Files changed (2) hide show
  1. src/pipeline.py +2 -2
  2. src/utils.py +64 -0
src/pipeline.py CHANGED
@@ -27,7 +27,7 @@ import numpy as np
27
  import torch.nn as nn
28
  import torch.nn.functional as F
29
  from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
30
- from model import _load
31
  import torchvision
32
  import os
33
 
@@ -92,4 +92,4 @@ def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator)
92
  sample=1
93
  empty_cache()
94
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pt").images[0]
95
- return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))# torchvision.transforms.functional.to_pil_image(image)
 
27
  import torch.nn as nn
28
  import torch.nn.functional as F
29
  from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
30
+ from utils import _load
31
  import torchvision
32
  import os
33
 
 
92
  sample=1
93
  empty_cache()
94
  image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pt").images[0]
95
+ return torchvision.transforms.functional.to_pil_image(image.to(torch.float32).mul_(2).sub_(1))# torchvision.transforms.functional.to_pil_image(image)
src/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ A=None
2
+ e_sd_pt="ko.pth"
3
+ d_sd_pt="ok.pth"
4
+ import torch as t, torch.nn as nn, torch.nn.functional as F
5
+ def C(n_in, n_out, **kwargs):
6
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
7
+ class Clamp(nn.Module):
8
+ def forward(self, x):
9
+ return t.tanh(x / 3) * 3
10
+ class B(nn.Module):
11
+ def __init__(self, n_in, n_out):
12
+ super().__init__()
13
+ self.conv = nn.Sequential(C(n_in, n_out), nn.ReLU(), C(n_out, n_out), nn.ReLU(), C(n_out, n_out))
14
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
15
+ self.fuse = nn.ReLU()
16
+ def forward(self, x):
17
+ return self.fuse(self.conv(x) + self.skip(x))
18
+ def E(latent_channels=4):
19
+ return nn.Sequential(
20
+ C(3, 64), B(64, 64),
21
+ C(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
22
+ C(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
23
+ C(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
24
+ C(64, latent_channels),
25
+ )
26
+ def D(latent_channels=16):
27
+ return nn.Sequential(
28
+ Clamp(),
29
+ C(latent_channels, 48),nn.ReLU(),B(48, 48), B(48, 48),
30
+ nn.Upsample(scale_factor=2), C(48, 48, bias=False),B(48, 48), B(48, 48),
31
+ nn.Upsample(scale_factor=2), C(48, 48, bias=False),B(48, 48),
32
+ nn.Upsample(scale_factor=2), C(48, 48, bias=False),B(48, 48),
33
+ C(48, 3),
34
+ )
35
+ class M(nn.Module):
36
+ lm, ls = 3, 0.5
37
+ def __init__(s, ep="encoder.pth", dp="decoder.pth", lc=None):
38
+ super().__init__()
39
+ if lc is None: lc = s.glc(str(ep))
40
+ s.e, s.d = E(lc), D(lc)
41
+ def f(sd, mod, pfx):
42
+ f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
43
+ mod.load_state_dict(f_sd, strict=False)
44
+ if ep: f(t.load(ep, map_location="cpu", weights_only=True), s.e, "encoder.")
45
+ if dp: f(t.load(dp, map_location="cpu", weights_only=True), s.d, "decoder.")
46
+ s.e.requires_grad_(False)
47
+ s.d.requires_grad_(False)
48
+ def glc(s, ep): return 16 if "taef1" in ep or "taesd3" in ep else 4
49
+ @staticmethod
50
+ def sl(x): return x.div(2 * M.lm).add(M.ls).clamp(0, 1)
51
+ @staticmethod
52
+ def ul(x): return x.sub(M.ls).mul(2 * M.lm)
53
+ def forward(s, x, rl=False):
54
+ l, o = s.e(x), s.d(s.e(x))
55
+ return (o.clamp(0, 1), l) if rl else o.clamp(0, 1)
56
+ def filter_state_dict(model, state_dict_path):
57
+ global E
58
+ state_dict = t.load(state_dict_path, map_location="cpu", weights_only=True)
59
+ prefix = 'encoder.' if type(model) == E else 'decoder.'
60
+ return {k.strip(prefix): v for k, v in state_dict.items() if k.strip(prefix) in model.state_dict() and v.size() == model.state_dict()[k.strip(prefix)].size()}
61
+ def _load(model, name, dtype=t.bfloat16):
62
+ model = E(16) if name=="E" else D(16)
63
+ model.load_state_dict(filter_state_dict(model, e_sd_pt if name=="E" else d_sd_pt), strict=False)
64
+ model.requires_grad_(False).to(dtype=dtype)