Upload 5 files
Browse files- handler.py +75 -0
- model.py +109 -0
- noise_scheduler.py +46 -0
- requirements.txt +8 -0
- unet.pt +3 -0
handler.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import AutoencoderKL
|
| 2 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 3 |
+
from model import Model
|
| 4 |
+
from noise_scheduler import NoiseSchedule
|
| 5 |
+
import torch
|
| 6 |
+
import base64
|
| 7 |
+
from typing import Any, Dict
|
| 8 |
+
|
| 9 |
+
LDM = True
|
| 10 |
+
image_size = 512
|
| 11 |
+
latent_size = 64
|
| 12 |
+
filters = [64, 128, 256, 512]
|
| 13 |
+
latent_dim = 4
|
| 14 |
+
t_dim = 512
|
| 15 |
+
T = 1000
|
| 16 |
+
depth = 2
|
| 17 |
+
|
| 18 |
+
class CLIP:
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
| 21 |
+
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 22 |
+
self.model.eval()
|
| 23 |
+
for name, param in self.model.named_parameters():
|
| 24 |
+
param.requires_grad = False
|
| 25 |
+
|
| 26 |
+
@torch.inference_mode()
|
| 27 |
+
def embed_images(self, images):
|
| 28 |
+
image = self.processor(images=images, return_tensors="pt").to(self.model.device)
|
| 29 |
+
return self.model.get_image_features(**image)
|
| 30 |
+
|
| 31 |
+
@torch.inference_mode()
|
| 32 |
+
def embed_text(self, text):
|
| 33 |
+
text = self.processor(text, padding=True, return_tensors="pt").to(self.model.device)
|
| 34 |
+
return self.model.get_text_features(**text)
|
| 35 |
+
|
| 36 |
+
class Inference:
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self.clip = CLIP()
|
| 39 |
+
self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu")
|
| 40 |
+
self.ae.eval()
|
| 41 |
+
for name, param in self.ae.named_parameters():
|
| 42 |
+
param.requires_grad = False
|
| 43 |
+
self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM)
|
| 44 |
+
self.unet.load_state_dict(torch.load("unet.pt", weights_only=False, map_location=torch.device('cpu')))
|
| 45 |
+
self.unet.eval()
|
| 46 |
+
for name, param in self.unet.named_parameters():
|
| 47 |
+
param.requires_grad = False
|
| 48 |
+
self.noise_scheduler = NoiseSchedule(T=1000, shape=(4,64,64), ddim_mod=50, trainer_mode=True)
|
| 49 |
+
self.target_vector = self.clip.embed_text("A photo of a cat")[0]
|
| 50 |
+
self.target_vector = self.target_vector / self.target_vector.norm(p=2, dim=-1, keepdim=True)
|
| 51 |
+
@torch.inference_mode()
|
| 52 |
+
def __call__(self, num_images=8):
|
| 53 |
+
imgs = self.noise_scheduler.generate(self.unet, num_images=num_images, device='cpu')
|
| 54 |
+
max_img = None
|
| 55 |
+
max_score = -1
|
| 56 |
+
images = []
|
| 57 |
+
for img in imgs:
|
| 58 |
+
image = self.ae.decode(img.unsqueeze(0) / self.ae.config.scaling_factor)[0][0].cpu().permute(1,2,0)/2 + 0.5
|
| 59 |
+
image = torch.clamp(image, 0.0, 1.0)
|
| 60 |
+
images.append(image)
|
| 61 |
+
embeddings = self.clip.embed_images(images)
|
| 62 |
+
scores = (embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)) @ self.target_vector.T
|
| 63 |
+
i = torch.argmax(scores).item()
|
| 64 |
+
return images[i], scores[i], scores
|
| 65 |
+
|
| 66 |
+
class EndpointHandler:
|
| 67 |
+
def __init__(self, path: str = ""):
|
| 68 |
+
# path -> repo directory on the endpoint container
|
| 69 |
+
# you can read files via Path(path)/"unet.pt" if needed
|
| 70 |
+
self.engine = Inference(prompt="A photo of a cat")
|
| 71 |
+
|
| 72 |
+
def __call__(self) -> Dict[str, Any]:
|
| 73 |
+
png_bytes, score = self.engine(num_images=1)
|
| 74 |
+
b64 = base64.b64encode(png_bytes).decode("utf-8")
|
| 75 |
+
return {"image": b64, "score": float(score)}
|
model.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
class SpatialAttention(nn.Module):
|
| 7 |
+
def __init__(self, in_c):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_c, eps=1e-6, affine=True)
|
| 10 |
+
self.Q = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
|
| 11 |
+
self.K = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
|
| 12 |
+
self.V = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
|
| 13 |
+
self.proj = nn.Conv2d(in_c, in_c, kernel_size=1, stride=1, padding=0)
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
b, c, h, w = x.shape
|
| 17 |
+
R = self.norm(x)
|
| 18 |
+
q, v, k = self.Q(R), self.V(R), self.K(R)
|
| 19 |
+
q, v, k = q.reshape(b, c, h*w), v.reshape(b, c, h*w), k.reshape(b, c, h*w)
|
| 20 |
+
q, v, k = q.permute(0, 2, 1), v, k
|
| 21 |
+
R = torch.bmm(q, k) * (1.0 / math.sqrt(c))
|
| 22 |
+
R = F.softmax(R, dim=2)
|
| 23 |
+
R = torch.bmm(v, R)
|
| 24 |
+
R = R.reshape(b, c, h, w)
|
| 25 |
+
return self.proj(R) + x
|
| 26 |
+
|
| 27 |
+
class ResBlock(nn.Module):
|
| 28 |
+
def __init__(self, in_c, out_c):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.reshape = False
|
| 31 |
+
if in_c != out_c:
|
| 32 |
+
self.reshape = True
|
| 33 |
+
self.conv_reshape = nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1)
|
| 34 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
|
| 35 |
+
self.conv1 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
|
| 36 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True)
|
| 37 |
+
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
if self.reshape:
|
| 41 |
+
x = self.conv_reshape(x)
|
| 42 |
+
res = x
|
| 43 |
+
x = self.norm1(x)
|
| 44 |
+
x = x * torch.sigmoid(x)
|
| 45 |
+
x = self.conv1(x)
|
| 46 |
+
x = self.norm2(x)
|
| 47 |
+
x = x * torch.sigmoid(x)
|
| 48 |
+
x = self.conv2(x)
|
| 49 |
+
x = x + res
|
| 50 |
+
return x
|
| 51 |
+
|
| 52 |
+
class Model(nn.Module):
|
| 53 |
+
def __init__(self, T=1000, filters=[32, 64, 96, 128], depth=2, t_dim=512, LDM=False):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.t_dim = t_dim
|
| 56 |
+
self.T = T
|
| 57 |
+
self.conv_in = nn.Conv2d(4 + self.t_dim if LDM else 3 + self.t_dim, filters[0], kernel_size=1)
|
| 58 |
+
self.down = nn.ModuleList([])
|
| 59 |
+
for i in range(1,len(filters)):
|
| 60 |
+
block = nn.Module()
|
| 61 |
+
block.Blocks = nn.ModuleList([ResBlock(filters[i-1], filters[i])])
|
| 62 |
+
for _ in range(1, depth):
|
| 63 |
+
block.Blocks.append(ResBlock(filters[i], filters[i]))
|
| 64 |
+
block.DownSample = nn.Conv2d(filters[i], filters[i], kernel_size=3, stride=2, padding=1)
|
| 65 |
+
self.down.append(block)
|
| 66 |
+
|
| 67 |
+
self.mid = nn.Sequential(ResBlock(filters[-1], filters[-1]),
|
| 68 |
+
SpatialAttention(filters[-1]),
|
| 69 |
+
ResBlock(filters[-1], filters[-1]))
|
| 70 |
+
|
| 71 |
+
self.up = nn.ModuleList([])
|
| 72 |
+
filters = filters[::-1]
|
| 73 |
+
for i in range(1,len(filters)):
|
| 74 |
+
block = nn.Module()
|
| 75 |
+
block.Blocks = nn.ModuleList([ResBlock(filters[i-1]*2, filters[i])])
|
| 76 |
+
for _ in range(1, depth):
|
| 77 |
+
block.Blocks.append(ResBlock(filters[i], filters[i]))
|
| 78 |
+
block.UpSample = nn.Upsample(scale_factor=2, mode="bilinear")
|
| 79 |
+
self.up.append(block)
|
| 80 |
+
self.conv_out = nn.Conv2d(filters[-1], 4 if LDM else 3, kernel_size=3, padding=1)
|
| 81 |
+
|
| 82 |
+
def get_sinusoidal_emb(self, t):
|
| 83 |
+
""" Recieves B 1 shaped t tensor with scalar timesteps, returns B D embeddings """
|
| 84 |
+
freqs = torch.exp(-math.log(self.T) * torch.arange(start=0, end=self.t_dim // 2, dtype=torch.float32) / (self.t_dim // 2)).to(device=t.device)
|
| 85 |
+
args = t[:, None].float() * freqs[None]
|
| 86 |
+
return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 87 |
+
|
| 88 |
+
def forward(self, x, t):
|
| 89 |
+
t_emb = self.get_sinusoidal_emb(t)
|
| 90 |
+
B, C, H, W = x.shape
|
| 91 |
+
|
| 92 |
+
t_emb = t_emb.unsqueeze(-1).unsqueeze(-1).expand(B, self.t_dim, H, W)
|
| 93 |
+
x = torch.cat((x,t_emb), 1)
|
| 94 |
+
x = self.conv_in(x)
|
| 95 |
+
|
| 96 |
+
cache = []
|
| 97 |
+
for block in self.down:
|
| 98 |
+
for resblock in block.Blocks:
|
| 99 |
+
x = resblock(x)
|
| 100 |
+
cache.append(x.clone())
|
| 101 |
+
x = block.DownSample(x)
|
| 102 |
+
x = self.mid(x)
|
| 103 |
+
for block in self.up:
|
| 104 |
+
x = block.UpSample(x)
|
| 105 |
+
x = torch.cat((x, cache.pop()), 1)
|
| 106 |
+
for resblock in block.Blocks:
|
| 107 |
+
x = resblock(x)
|
| 108 |
+
|
| 109 |
+
return (self.conv_out(x))
|
noise_scheduler.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from IPython.display import clear_output
|
| 4 |
+
|
| 5 |
+
class NoiseSchedule:
|
| 6 |
+
"""
|
| 7 |
+
Handles:
|
| 8 |
+
- DDIM inference (with a ddim_mod to skip steps)
|
| 9 |
+
- DDPM inference
|
| 10 |
+
- Forward Noising
|
| 11 |
+
- Linear beta schedule
|
| 12 |
+
- Classifier Free Guidance (w is a hyperparameter for cfg schedule)
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, T, std=1, shape=(4, 64, 64), ddim_mod=10, trainer_mode=False):
|
| 15 |
+
self.T = T
|
| 16 |
+
self.std = std
|
| 17 |
+
self.ddim_mod = ddim_mod
|
| 18 |
+
self.beta = torch.tensor(np.linspace(1e-4, 0.02, T), dtype=torch.float32, device='cpu' if trainer_mode else 'cuda')
|
| 19 |
+
self.alpha = 1 - self.beta
|
| 20 |
+
self.alpha_bar = self.alpha.cumprod(dim=0)
|
| 21 |
+
self.w = torch.full((T,), 7.5, device='cpu' if trainer_mode else 'cuda')
|
| 22 |
+
self.shape = shape
|
| 23 |
+
|
| 24 |
+
def noise(self, x, t):
|
| 25 |
+
eps = torch.randn_like(x) * self.std
|
| 26 |
+
return (self.alpha_bar[t]**0.5) * x + ((1-self.alpha_bar[t])**0.5) * eps, eps
|
| 27 |
+
|
| 28 |
+
def ddim_step(self, xt, t, eps):
|
| 29 |
+
x0 = (xt - (1 - self.alpha_bar[t]).sqrt() * eps) / self.alpha_bar[t].sqrt()
|
| 30 |
+
x0 = x0.clamp(-1, 1)
|
| 31 |
+
# note that eps = (xt - sqrt(abar[t]) * x0) / sqrt(1 - abar[t])
|
| 32 |
+
xt_1 = self.alpha_bar[max(0,t - self.ddim_mod)].sqrt() * x0 + (1 - self.alpha_bar[max(0,t - self.ddim_mod)]).sqrt() * eps
|
| 33 |
+
return xt_1
|
| 34 |
+
|
| 35 |
+
def ddpm_step(self, x, eps, t, var=None):
|
| 36 |
+
var = self.beta[t] if var is None else var
|
| 37 |
+
return (self.alpha[t]**-0.5) * (x - ((1 - self.alpha_bar[t])**0.5) * eps) + var * torch.randn_like(x)
|
| 38 |
+
|
| 39 |
+
def generate(self, model, num_images=16, device="cuda"):
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
x = torch.randn((num_images, *self.shape), device=device) * self.std
|
| 42 |
+
for t in range(self.T-1, -1, -self.ddim_mod):
|
| 43 |
+
t_tensor = torch.full((num_images,),t, device=device)
|
| 44 |
+
epsilons = model(x, t=t_tensor)
|
| 45 |
+
x = self.ddim_step(x, t=t, eps=epsilons)
|
| 46 |
+
return x
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
diffusers
|
| 4 |
+
accelerate
|
| 5 |
+
safetensors
|
| 6 |
+
Pillow
|
| 7 |
+
opencv-python-headless
|
| 8 |
+
numpy
|
unet.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7045782f1dfbb51037ec23ca06142d4b5d60dedbfd28cafd3dbe5e07cead738
|
| 3 |
+
size 135132829
|