dummy / handler.py
Serkan Ozturk
gsfglşsdkif1
65c97a3
from typing import Dict, List, Any
import base64
import io
import torch
import numpy as np
import torch.nn.functional as F
from serkan import SimpleUpscaleModel
import os
from PIL import Image
def decode_image(base64_str: str) -> np.ndarray:
"""Decode base64 string to an image (numpy array)"""
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data))
return np.array(image)
class EndpointHandler():
def __init__(self, path="."):
# load the optimized model
self.model = SimpleUpscaleModel()
model_path = os.path.join(path, "model_weights.pth")
self.model.load_state_dict(torch.load(model_path))
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
inputs = data.pop("inputs", data)
img = inputs["image"]
img = decode_image(img)
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()
# Load the image
upscaled = self.model(img)
upscaled = upscaled.squeeze(0).permute(1,2,0)
upscaled = upscaled.numpy()
upscaled = np.clip(upscaled, 0, 255).astype(np.uint8)
pil = Image.fromarray(upscaled)
# Save the image to a buffer
buffered = io.BytesIO()
pil.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Return a dictionary with the base64 image and additional data
return {
"image": img_str
}
# postprocess the prediction
return "OKAY"