detectivejoewest commited on
Commit
b293fbc
·
verified ·
1 Parent(s): f4b5046

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -4
handler.py CHANGED
@@ -1,6 +1,9 @@
1
  from diffusers import AutoencoderKL
 
 
2
  from transformers import CLIPProcessor, CLIPModel
3
  from model import Model
 
4
  from noise_scheduler import NoiseSchedule
5
  import torch
6
  import base64
@@ -35,13 +38,16 @@ class CLIP:
35
 
36
  class Inference:
37
  def __init__(self):
 
 
 
38
  self.clip = CLIP()
39
  self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu")
40
  self.ae.eval()
41
  for name, param in self.ae.named_parameters():
42
  param.requires_grad = False
43
  self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM)
44
- self.unet.load_state_dict(torch.load("repository/unet.pt", weights_only=False, map_location=torch.device('cpu')))
45
  self.unet.eval()
46
  for name, param in self.unet.named_parameters():
47
  param.requires_grad = False
@@ -65,11 +71,17 @@ class Inference:
65
 
66
  class EndpointHandler:
67
  def __init__(self, path: str = ""):
68
- # path -> repo directory on the endpoint container
69
- # you can read files via Path(path)/"unet.pt" if needed
70
  self.engine = Inference()
71
 
72
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
73
- png_bytes, score = self.engine(num_images=1)
 
 
 
 
 
 
 
74
  b64 = base64.b64encode(png_bytes).decode("utf-8")
75
  return {"image": b64, "score": float(score)}
 
 
1
  from diffusers import AutoencoderKL
2
+ from PIL import Image
3
+ import io
4
  from transformers import CLIPProcessor, CLIPModel
5
  from model import Model
6
+ from pathlib import Path
7
  from noise_scheduler import NoiseSchedule
8
  import torch
9
  import base64
 
38
 
39
  class Inference:
40
  def __init__(self):
41
+ here = Path(__file__).resolve().parent
42
+ ckpt_path = here / "unet.pt"
43
+
44
  self.clip = CLIP()
45
  self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu")
46
  self.ae.eval()
47
  for name, param in self.ae.named_parameters():
48
  param.requires_grad = False
49
  self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM)
50
+ self.unet.load_state_dict(torch.load(ckpt_path, weights_only=False, map_location=torch.device('cpu')))
51
  self.unet.eval()
52
  for name, param in self.unet.named_parameters():
53
  param.requires_grad = False
 
71
 
72
  class EndpointHandler:
73
  def __init__(self, path: str = ""):
 
 
74
  self.engine = Inference()
75
 
76
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
77
+ img_tensor, score, _ = self.engine(num_images=1) # (H,W,C) in [0,1]
78
+ img_uint8 = (img_tensor.clamp(0,1).numpy() * 255).astype("uint8")
79
+ pil_img = Image.fromarray(img_uint8)
80
+
81
+ buf = io.BytesIO()
82
+ pil_img.save(buf, format="PNG")
83
+ png_bytes = buf.getvalue()
84
+
85
  b64 = base64.b64encode(png_bytes).decode("utf-8")
86
  return {"image": b64, "score": float(score)}
87
+