detectivejoewest commited on
Commit
6c662fd
·
verified ·
1 Parent(s): aecc64d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +75 -75
handler.py CHANGED
@@ -1,75 +1,75 @@
1
- from diffusers import AutoencoderKL
2
- from transformers import CLIPProcessor, CLIPModel
3
- from model import Model
4
- from noise_scheduler import NoiseSchedule
5
- import torch
6
- import base64
7
- from typing import Any, Dict
8
-
9
- LDM = True
10
- image_size = 512
11
- latent_size = 64
12
- filters = [64, 128, 256, 512]
13
- latent_dim = 4
14
- t_dim = 512
15
- T = 1000
16
- depth = 2
17
-
18
- class CLIP:
19
- def __init__(self):
20
- self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
21
- self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
22
- self.model.eval()
23
- for name, param in self.model.named_parameters():
24
- param.requires_grad = False
25
-
26
- @torch.inference_mode()
27
- def embed_images(self, images):
28
- image = self.processor(images=images, return_tensors="pt").to(self.model.device)
29
- return self.model.get_image_features(**image)
30
-
31
- @torch.inference_mode()
32
- def embed_text(self, text):
33
- text = self.processor(text, padding=True, return_tensors="pt").to(self.model.device)
34
- return self.model.get_text_features(**text)
35
-
36
- class Inference:
37
- def __init__(self):
38
- self.clip = CLIP()
39
- self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu")
40
- self.ae.eval()
41
- for name, param in self.ae.named_parameters():
42
- param.requires_grad = False
43
- self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM)
44
- self.unet.load_state_dict(torch.load("unet.pt", weights_only=False, map_location=torch.device('cpu')))
45
- self.unet.eval()
46
- for name, param in self.unet.named_parameters():
47
- param.requires_grad = False
48
- self.noise_scheduler = NoiseSchedule(T=1000, shape=(4,64,64), ddim_mod=50, trainer_mode=True)
49
- self.target_vector = self.clip.embed_text("A photo of a cat")[0]
50
- self.target_vector = self.target_vector / self.target_vector.norm(p=2, dim=-1, keepdim=True)
51
- @torch.inference_mode()
52
- def __call__(self, num_images=8):
53
- imgs = self.noise_scheduler.generate(self.unet, num_images=num_images, device='cpu')
54
- max_img = None
55
- max_score = -1
56
- images = []
57
- for img in imgs:
58
- image = self.ae.decode(img.unsqueeze(0) / self.ae.config.scaling_factor)[0][0].cpu().permute(1,2,0)/2 + 0.5
59
- image = torch.clamp(image, 0.0, 1.0)
60
- images.append(image)
61
- embeddings = self.clip.embed_images(images)
62
- scores = (embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)) @ self.target_vector.T
63
- i = torch.argmax(scores).item()
64
- return images[i], scores[i], scores
65
-
66
- class EndpointHandler:
67
- def __init__(self, path: str = ""):
68
- # path -> repo directory on the endpoint container
69
- # you can read files via Path(path)/"unet.pt" if needed
70
- self.engine = Inference(prompt="A photo of a cat")
71
-
72
- def __call__(self) -> Dict[str, Any]:
73
- png_bytes, score = self.engine(num_images=1)
74
- b64 = base64.b64encode(png_bytes).decode("utf-8")
75
- return {"image": b64, "score": float(score)}
 
1
+ from diffusers import AutoencoderKL
2
+ from transformers import CLIPProcessor, CLIPModel
3
+ from model import Model
4
+ from noise_scheduler import NoiseSchedule
5
+ import torch
6
+ import base64
7
+ from typing import Any, Dict
8
+
9
+ LDM = True
10
+ image_size = 512
11
+ latent_size = 64
12
+ filters = [64, 128, 256, 512]
13
+ latent_dim = 4
14
+ t_dim = 512
15
+ T = 1000
16
+ depth = 2
17
+
18
+ class CLIP:
19
+ def __init__(self):
20
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
21
+ self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
22
+ self.model.eval()
23
+ for name, param in self.model.named_parameters():
24
+ param.requires_grad = False
25
+
26
+ @torch.inference_mode()
27
+ def embed_images(self, images):
28
+ image = self.processor(images=images, return_tensors="pt").to(self.model.device)
29
+ return self.model.get_image_features(**image)
30
+
31
+ @torch.inference_mode()
32
+ def embed_text(self, text):
33
+ text = self.processor(text, padding=True, return_tensors="pt").to(self.model.device)
34
+ return self.model.get_text_features(**text)
35
+
36
+ class Inference:
37
+ def __init__(self):
38
+ self.clip = CLIP()
39
+ self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu")
40
+ self.ae.eval()
41
+ for name, param in self.ae.named_parameters():
42
+ param.requires_grad = False
43
+ self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM)
44
+ self.unet.load_state_dict(torch.load("unet.pt", weights_only=False, map_location=torch.device('cpu')))
45
+ self.unet.eval()
46
+ for name, param in self.unet.named_parameters():
47
+ param.requires_grad = False
48
+ self.noise_scheduler = NoiseSchedule(T=1000, shape=(4,64,64), ddim_mod=50, trainer_mode=True)
49
+ self.target_vector = self.clip.embed_text("A photo of a cat")[0]
50
+ self.target_vector = self.target_vector / self.target_vector.norm(p=2, dim=-1, keepdim=True)
51
+ @torch.inference_mode()
52
+ def __call__(self, num_images=8):
53
+ imgs = self.noise_scheduler.generate(self.unet, num_images=num_images, device='cpu')
54
+ max_img = None
55
+ max_score = -1
56
+ images = []
57
+ for img in imgs:
58
+ image = self.ae.decode(img.unsqueeze(0) / self.ae.config.scaling_factor)[0][0].cpu().permute(1,2,0)/2 + 0.5
59
+ image = torch.clamp(image, 0.0, 1.0)
60
+ images.append(image)
61
+ embeddings = self.clip.embed_images(images)
62
+ scores = (embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)) @ self.target_vector.T
63
+ i = torch.argmax(scores).item()
64
+ return images[i], scores[i], scores
65
+
66
+ class EndpointHandler:
67
+ def __init__(self, path: str = ""):
68
+ # path -> repo directory on the endpoint container
69
+ # you can read files via Path(path)/"unet.pt" if needed
70
+ self.engine = Inference()
71
+
72
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
73
+ png_bytes, score = self.engine(num_images=1)
74
+ b64 = base64.b64encode(png_bytes).decode("utf-8")
75
+ return {"image": b64, "score": float(score)}