Update handler.py
Browse files- handler.py +16 -4
handler.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
from diffusers import AutoencoderKL
|
|
|
|
|
|
|
| 2 |
from transformers import CLIPProcessor, CLIPModel
|
| 3 |
from model import Model
|
|
|
|
| 4 |
from noise_scheduler import NoiseSchedule
|
| 5 |
import torch
|
| 6 |
import base64
|
|
@@ -35,13 +38,16 @@ class CLIP:
|
|
| 35 |
|
| 36 |
class Inference:
|
| 37 |
def __init__(self):
|
|
|
|
|
|
|
|
|
|
| 38 |
self.clip = CLIP()
|
| 39 |
self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu")
|
| 40 |
self.ae.eval()
|
| 41 |
for name, param in self.ae.named_parameters():
|
| 42 |
param.requires_grad = False
|
| 43 |
self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM)
|
| 44 |
-
self.unet.load_state_dict(torch.load(
|
| 45 |
self.unet.eval()
|
| 46 |
for name, param in self.unet.named_parameters():
|
| 47 |
param.requires_grad = False
|
|
@@ -65,11 +71,17 @@ class Inference:
|
|
| 65 |
|
| 66 |
class EndpointHandler:
|
| 67 |
def __init__(self, path: str = ""):
|
| 68 |
-
# path -> repo directory on the endpoint container
|
| 69 |
-
# you can read files via Path(path)/"unet.pt" if needed
|
| 70 |
self.engine = Inference()
|
| 71 |
|
| 72 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
b64 = base64.b64encode(png_bytes).decode("utf-8")
|
| 75 |
return {"image": b64, "score": float(score)}
|
|
|
|
|
|
| 1 |
from diffusers import AutoencoderKL
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import io
|
| 4 |
from transformers import CLIPProcessor, CLIPModel
|
| 5 |
from model import Model
|
| 6 |
+
from pathlib import Path
|
| 7 |
from noise_scheduler import NoiseSchedule
|
| 8 |
import torch
|
| 9 |
import base64
|
|
|
|
| 38 |
|
| 39 |
class Inference:
|
| 40 |
def __init__(self):
|
| 41 |
+
here = Path(__file__).resolve().parent
|
| 42 |
+
ckpt_path = here / "unet.pt"
|
| 43 |
+
|
| 44 |
self.clip = CLIP()
|
| 45 |
self.ae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to('cuda' if torch.cuda.is_available() else "cpu")
|
| 46 |
self.ae.eval()
|
| 47 |
for name, param in self.ae.named_parameters():
|
| 48 |
param.requires_grad = False
|
| 49 |
self.unet = Model(T=T, filters=[64,128,256,512], t_dim=t_dim, depth=depth, LDM=LDM)
|
| 50 |
+
self.unet.load_state_dict(torch.load(ckpt_path, weights_only=False, map_location=torch.device('cpu')))
|
| 51 |
self.unet.eval()
|
| 52 |
for name, param in self.unet.named_parameters():
|
| 53 |
param.requires_grad = False
|
|
|
|
| 71 |
|
| 72 |
class EndpointHandler:
|
| 73 |
def __init__(self, path: str = ""):
|
|
|
|
|
|
|
| 74 |
self.engine = Inference()
|
| 75 |
|
| 76 |
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 77 |
+
img_tensor, score, _ = self.engine(num_images=1) # (H,W,C) in [0,1]
|
| 78 |
+
img_uint8 = (img_tensor.clamp(0,1).numpy() * 255).astype("uint8")
|
| 79 |
+
pil_img = Image.fromarray(img_uint8)
|
| 80 |
+
|
| 81 |
+
buf = io.BytesIO()
|
| 82 |
+
pil_img.save(buf, format="PNG")
|
| 83 |
+
png_bytes = buf.getvalue()
|
| 84 |
+
|
| 85 |
b64 = base64.b64encode(png_bytes).decode("utf-8")
|
| 86 |
return {"image": b64, "score": float(score)}
|
| 87 |
+
|