detectivejoewest commited on
Commit
aecc64d
·
verified ·
1 Parent(s): f9da6e2

Upload 5 files

Browse files
Files changed (5) hide show
  1. handler.py +75 -0
  2. model.py +109 -0
  3. noise_scheduler.py +46 -0
  4. requirements.txt +8 -0
  5. unet.pt +3 -0
handler.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import AutoencoderKL
2
+ from transformers import CLIPProcessor, CLIPModel
3
+ from model import Model
4
+ from noise_scheduler import NoiseSchedule
5
+ import torch
6
+ import base64
7
+ from typing import Any, Dict
8
+
9
+ LDM = True
10
+ image_size = 512
11
+ latent_size = 64
12
+ filters = [64, 128, 256, 512]
13
+ latent_dim = 4
14
+ t_dim = 512
15
+ T = 1000
16
+ depth = 2
17
+
18
+ class CLIP:
19
+ def __init__(self):
20
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
21
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
22
+ self.model.eval()
23
+ for name, param in self.model.named_parameters():
24
+ param.requires_grad = False
25
+
26
+ @torch.inference_mode()
27
+ def embed_images(self, images):
28
+ image = self.processor(images=images, return_tensors="pt").to(self.model.device)
29
+ return self.model.get_image_features(**image)
30
+
31
+ @torch.inference_mode()
32
+ def embed_text(self, text):
33
+ text = self.processor(text, padding=True, return_tensors="pt").to(self.model.device)
34
+ return self.model.get_text_features(**text)
35
+
36
+ class Inference:
37
+ def __init__(self):
38
+ self.clip = CLIP()
39
+ self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu")
40
+ self.ae.eval()
41
+ for name, param in self.ae.named_parameters():
42
+ param.requires_grad = False
43
+ self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM)
44
+ self.unet.load_state_dict(torch.load("unet.pt", weights_only=False, map_location=torch.device('cpu')))
45
+ self.unet.eval()
46
+ for name, param in self.unet.named_parameters():
47
+ param.requires_grad = False
48
+ self.noise_scheduler = NoiseSchedule(T=1000, shape=(4,64,64), ddim_mod=50, trainer_mode=True)
49
+ self.target_vector = self.clip.embed_text("A photo of a cat")[0]
50
+ self.target_vector = self.target_vector / self.target_vector.norm(p=2, dim=-1, keepdim=True)
51
+ @torch.inference_mode()
52
+ def __call__(self, num_images=8):
53
+ imgs = self.noise_scheduler.generate(self.unet, num_images=num_images, device='cpu')
54
+ max_img = None
55
+ max_score = -1
56
+ images = []
57
+ for img in imgs:
58
+ image = self.ae.decode(img.unsqueeze(0) / self.ae.config.scaling_factor)[0][0].cpu().permute(1,2,0)/2 + 0.5
59
+ image = torch.clamp(image, 0.0, 1.0)
60
+ images.append(image)
61
+ embeddings = self.clip.embed_images(images)
62
+ scores = (embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)) @ self.target_vector.T
63
+ i = torch.argmax(scores).item()
64
+ return images[i], scores[i], scores
65
+
66
+ class EndpointHandler:
67
+ def __init__(self, path: str = ""):
68
+ # path -> repo directory on the endpoint container
69
+ # you can read files via Path(path)/"unet.pt" if needed
70
+ self.engine = Inference(prompt="A photo of a cat")
71
+
72
+ def __call__(self) -> Dict[str, Any]:
73
+ png_bytes, score = self.engine(num_images=1)
74
+ b64 = base64.b64encode(png_bytes).decode("utf-8")
75
+ return {"image": b64, "score": float(score)}
model.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class SpatialAttention(nn.Module):
7
+ def __init__(self, in_c):
8
+ super().__init__()
9
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_c, eps=1e-6, affine=True)
10
+ self.Q = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
11
+ self.K = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
12
+ self.V = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
13
+ self.proj = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
14
+
15
+ def forward(self, x):
16
+ b, c, h, w = x.shape
17
+ R = self.norm(x)
18
+ q, v, k = self.Q(R), self.V(R), self.K(R)
19
+ q, v, k = q.reshape(b, c, h*w), v.reshape(b, c, h*w), k.reshape(b, c, h*w)
20
+ q, v, k = q.permute(0, 2, 1), v, k
21
+ R = torch.bmm(q, k) * (1.0 / math.sqrt(c))
22
+ R = F.softmax(R, dim=2)
23
+ R = torch.bmm(v, R)
24
+ R = R.reshape(b, c, h, w)
25
+ return self.proj(R) + x
26
+
27
+ class ResBlock(nn.Module):
28
+ def __init__(self, in_c, out_c):
29
+ super().__init__()
30
+ self.reshape = False
31
+ if in_c != out_c:
32
+ self.reshape = True
33
+ self.conv_reshape = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1)
34
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
35
+ self.conv1 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
36
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
37
+ self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
38
+
39
+ def forward(self, x):
40
+ if self.reshape:
41
+ x = self.conv_reshape(x)
42
+ res = x
43
+ x = self.norm1(x)
44
+ x = x * torch.sigmoid(x)
45
+ x = self.conv1(x)
46
+ x = self.norm2(x)
47
+ x = x * torch.sigmoid(x)
48
+ x = self.conv2(x)
49
+ x = x + res
50
+ return x
51
+
52
+ class Model(nn.Module):
53
+ def __init__(self, T=1000, filters=[32, 64, 96, 128], depth=2, t_dim=512, LDM=False):
54
+ super().__init__()
55
+ self.t_dim = t_dim
56
+ self.T = T
57
+ self.conv_in = nn.Conv2d(4 + self.t_dim if LDM else 3 + self.t_dim, filters[0], kernel_size=1)
58
+ self.down = nn.ModuleList([])
59
+ for i in range(1,len(filters)):
60
+ block = nn.Module()
61
+ block.Blocks = nn.ModuleList([ResBlock(filters[i-1], filters[i])])
62
+ for _ in range(1, depth):
63
+ block.Blocks.append(ResBlock(filters[i], filters[i]))
64
+ block.DownSample = nn.Conv2d(filters[i], filters[i], kernel_size=3, stride=2, padding=1)
65
+ self.down.append(block)
66
+
67
+ self.mid = nn.Sequential(ResBlock(filters[-1], filters[-1]),
68
+ SpatialAttention(filters[-1]),
69
+ ResBlock(filters[-1], filters[-1]))
70
+
71
+ self.up = nn.ModuleList([])
72
+ filters = filters[::-1]
73
+ for i in range(1,len(filters)):
74
+ block = nn.Module()
75
+ block.Blocks = nn.ModuleList([ResBlock(filters[i-1]*2, filters[i])])
76
+ for _ in range(1, depth):
77
+ block.Blocks.append(ResBlock(filters[i], filters[i]))
78
+ block.UpSample = nn.Upsample(scale_factor=2, mode="bilinear")
79
+ self.up.append(block)
80
+ self.conv_out = nn.Conv2d(filters[-1], 4 if LDM else 3, kernel_size=3, padding=1)
81
+
82
+ def get_sinusoidal_emb(self, t):
83
+ """ Recieves B 1 shaped t tensor with scalar timesteps, returns B D embeddings """
84
+ freqs = torch.exp(-math.log(self.T) * torch.arange(start=0, end=self.t_dim // 2, dtype=torch.float32) / (self.t_dim // 2)).to(device=t.device)
85
+ args = t[:, None].float() * freqs[None]
86
+ return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
87
+
88
+ def forward(self, x, t):
89
+ t_emb = self.get_sinusoidal_emb(t)
90
+ B, C, H, W = x.shape
91
+
92
+ t_emb = t_emb.unsqueeze(-1).unsqueeze(-1).expand(B, self.t_dim, H, W)
93
+ x = torch.cat((x,t_emb), 1)
94
+ x = self.conv_in(x)
95
+
96
+ cache = []
97
+ for block in self.down:
98
+ for resblock in block.Blocks:
99
+ x = resblock(x)
100
+ cache.append(x.clone())
101
+ x = block.DownSample(x)
102
+ x = self.mid(x)
103
+ for block in self.up:
104
+ x = block.UpSample(x)
105
+ x = torch.cat((x, cache.pop()), 1)
106
+ for resblock in block.Blocks:
107
+ x = resblock(x)
108
+
109
+ return (self.conv_out(x))
noise_scheduler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from IPython.display import clear_output
4
+
5
+ class NoiseSchedule:
6
+ """
7
+ Handles:
8
+ - DDIM inference (with a ddim_mod to skip steps)
9
+ - DDPM inference
10
+ - Forward Noising
11
+ - Linear beta schedule
12
+ - Classifier Free Guidance (w is a hyperparameter for cfg schedule)
13
+ """
14
+ def __init__(self, T, std=1, shape=(4, 64, 64), ddim_mod=10, trainer_mode=False):
15
+ self.T = T
16
+ self.std = std
17
+ self.ddim_mod = ddim_mod
18
+ self.beta = torch.tensor(np.linspace(1e-4, 0.02, T), dtype=torch.float32, device='cpu' if trainer_mode else 'cuda')
19
+ self.alpha = 1 - self.beta
20
+ self.alpha_bar = self.alpha.cumprod(dim=0)
21
+ self.w = torch.full((T,), 7.5, device='cpu' if trainer_mode else 'cuda')
22
+ self.shape = shape
23
+
24
+ def noise(self, x, t):
25
+ eps = torch.randn_like(x) * self.std
26
+ return (self.alpha_bar[t]**0.5) * x + ((1-self.alpha_bar[t])**0.5) * eps, eps
27
+
28
+ def ddim_step(self, xt, t, eps):
29
+ x0 = (xt - (1 - self.alpha_bar[t]).sqrt() * eps) / self.alpha_bar[t].sqrt()
30
+ x0 = x0.clamp(-1, 1)
31
+ # note that eps = (xt - sqrt(abar[t]) * x0) / sqrt(1 - abar[t])
32
+ xt_1 = self.alpha_bar[max(0,t - self.ddim_mod)].sqrt() * x0 + (1 - self.alpha_bar[max(0,t - self.ddim_mod)]).sqrt() * eps
33
+ return xt_1
34
+
35
+ def ddpm_step(self, x, eps, t, var=None):
36
+ var = self.beta[t] if var is None else var
37
+ return (self.alpha[t]**-0.5) * (x - ((1 - self.alpha_bar[t])**0.5) * eps) + var * torch.randn_like(x)
38
+
39
+ def generate(self, model, num_images=16, device="cuda"):
40
+ with torch.no_grad():
41
+ x = torch.randn((num_images, *self.shape), device=device) * self.std
42
+ for t in range(self.T-1, -1, -self.ddim_mod):
43
+ t_tensor = torch.full((num_images,),t, device=device)
44
+ epsilons = model(x, t=t_tensor)
45
+ x = self.ddim_step(x, t=t, eps=epsilons)
46
+ return x
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ diffusers
4
+ accelerate
5
+ safetensors
6
+ Pillow
7
+ opencv-python-headless
8
+ numpy
unet.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7045782f1dfbb51037ec23ca06142d4b5d60dedbfd28cafd3dbe5e07cead738
3
+ size 135132829