|
|
from diffusers import AutoencoderKL |
|
|
from PIL import Image |
|
|
import io |
|
|
from transformers import CLIPProcessor, CLIPModel |
|
|
from model import Model |
|
|
from pathlib import Path |
|
|
from noise_scheduler import NoiseSchedule |
|
|
import torch |
|
|
import base64 |
|
|
from typing import Any, Dict |
|
|
|
|
|
LDM = True |
|
|
image_size = 512 |
|
|
latent_size = 64 |
|
|
filters = [64, 128, 256, 512] |
|
|
latent_dim = 4 |
|
|
t_dim = 512 |
|
|
T = 1000 |
|
|
depth = 2 |
|
|
|
|
|
class CLIP: |
|
|
def __init__(self): |
|
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
self.model.eval() |
|
|
for name, param in self.model.named_parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
@torch.inference_mode() |
|
|
def embed_images(self, images): |
|
|
image = self.processor(images=images, return_tensors="pt").to(self.model.device) |
|
|
return self.model.get_image_features(**image) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def embed_text(self, text): |
|
|
text = self.processor(text, padding=True, return_tensors="pt").to(self.model.device) |
|
|
return self.model.get_text_features(**text) |
|
|
|
|
|
class Inference: |
|
|
def __init__(self): |
|
|
here = Path(__file__).resolve().parent |
|
|
ckpt_path = here / "unet.pt" |
|
|
|
|
|
self.clip = CLIP() |
|
|
self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu") |
|
|
self.ae.eval() |
|
|
for name, param in self.ae.named_parameters(): |
|
|
param.requires_grad = False |
|
|
self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM) |
|
|
self.unet.load_state_dict(torch.load(ckpt_path, weights_only=False, map_location=torch.device('cpu'))) |
|
|
self.unet.eval() |
|
|
for name, param in self.unet.named_parameters(): |
|
|
param.requires_grad = False |
|
|
self.noise_scheduler = NoiseSchedule(T=1000, shape=(4,64,64), ddim_mod=50, trainer_mode=True) |
|
|
self.target_vector = self.clip.embed_text("A photo of a cat")[0] |
|
|
self.target_vector = self.target_vector / self.target_vector.norm(p=2, dim=-1, keepdim=True) |
|
|
@torch.inference_mode() |
|
|
def __call__(self, num_images=8): |
|
|
imgs = self.noise_scheduler.generate(self.unet, num_images=num_images, device='cpu') |
|
|
max_img = None |
|
|
max_score = -1 |
|
|
images = [] |
|
|
for img in imgs: |
|
|
image = self.ae.decode(img.unsqueeze(0) / self.ae.config.scaling_factor)[0][0].cpu().permute(1,2,0)/2 + 0.5 |
|
|
image = torch.clamp(image, 0.0, 1.0) |
|
|
images.append(image) |
|
|
embeddings = self.clip.embed_images(images) |
|
|
scores = (embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)) @ self.target_vector.T |
|
|
i = torch.argmax(scores).item() |
|
|
return images[i], scores[i], scores |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
self.engine = Inference() |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
img_tensor, score, _ = self.engine(num_images=1) |
|
|
img_uint8 = (img_tensor.clamp(0,1).numpy() * 255).astype("uint8") |
|
|
pil_img = Image.fromarray(img_uint8) |
|
|
|
|
|
buf = io.BytesIO() |
|
|
pil_img.save(buf, format="PNG") |
|
|
png_bytes = buf.getvalue() |
|
|
|
|
|
b64 = base64.b64encode(png_bytes).decode("utf-8") |
|
|
return {"image": b64, "score": float(score)} |
|
|
|
|
|
|