from diffusers import AutoencoderKL from PIL import Image import io from transformers import CLIPProcessor, CLIPModel from model import Model from pathlib import Path from noise_scheduler import NoiseSchedule import torch import base64 from typing import Any, Dict LDM = True image_size = 512 latent_size = 64 filters = [64, 128, 256, 512] latent_dim = 4 t_dim = 512 T = 1000 depth = 2 class CLIP: def __init__(self): self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") self.model.eval() for name, param in self.model.named_parameters(): param.requires_grad = False @torch.inference_mode() def embed_images(self, images): image = self.processor(images=images, return_tensors="pt").to(self.model.device) return self.model.get_image_features(**image) @torch.inference_mode() def embed_text(self, text): text = self.processor(text, padding=True, return_tensors="pt").to(self.model.device) return self.model.get_text_features(**text) class Inference: def __init__(self): here = Path(__file__).resolve().parent ckpt_path = here / "unet.pt" self.clip = CLIP() self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu") self.ae.eval() for name, param in self.ae.named_parameters(): param.requires_grad = False self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM) self.unet.load_state_dict(torch.load(ckpt_path, weights_only=False, map_location=torch.device('cpu'))) self.unet.eval() for name, param in self.unet.named_parameters(): param.requires_grad = False self.noise_scheduler = NoiseSchedule(T=1000, shape=(4,64,64), ddim_mod=50, trainer_mode=True) self.target_vector = self.clip.embed_text("A photo of a cat")[0] self.target_vector = self.target_vector / self.target_vector.norm(p=2, dim=-1, keepdim=True) @torch.inference_mode() def __call__(self, num_images=8): imgs = self.noise_scheduler.generate(self.unet, num_images=num_images, device='cpu') max_img = None max_score = -1 images = [] for img in imgs: image = self.ae.decode(img.unsqueeze(0) / self.ae.config.scaling_factor)[0][0].cpu().permute(1,2,0)/2 + 0.5 image = torch.clamp(image, 0.0, 1.0) images.append(image) embeddings = self.clip.embed_images(images) scores = (embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)) @ self.target_vector.T i = torch.argmax(scores).item() return images[i], scores[i], scores class EndpointHandler: def __init__(self, path: str = ""): self.engine = Inference() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: img_tensor, score, _ = self.engine(num_images=1) # (H,W,C) in [0,1] img_uint8 = (img_tensor.clamp(0,1).numpy() * 255).astype("uint8") pil_img = Image.fromarray(img_uint8) buf = io.BytesIO() pil_img.save(buf, format="PNG") png_bytes = buf.getvalue() b64 = base64.b64encode(png_bytes).decode("utf-8") return {"image": b64, "score": float(score)}