| import torch |
| import numpy as np |
| import cv2 |
| import io |
| from model.model import UNetInpaint |
|
|
| |
| model = UNetInpaint() |
| model.load_state_dict(torch.load("model/pytorch_model.bin", map_location=torch.device("cpu"))) |
| model.eval() |
|
|
| def bytes_to_cv2_image(image_bytes, mode='color'): |
| np_arr = np.frombuffer(image_bytes, np.uint8) |
| if mode == 'color': |
| image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) |
| elif mode == 'gray': |
| image = cv2.imdecode(np_arr, cv2.IMREAD_GRAYSCALE) |
| return image |
|
|
| def preprocess(image_bytes_rgb, image_bytes_gray): |
| rgb = bytes_to_cv2_image(image_bytes_rgb, mode='color') |
| gray = bytes_to_cv2_image(image_bytes_gray, mode='gray') |
|
|
| rgb = cv2.resize(rgb, (256, 256)) |
| gray = cv2.resize(gray, (256, 256)) |
|
|
| rgb = rgb.astype(np.float32) / 255.0 |
| gray = gray.astype(np.float32) / 255.0 |
|
|
| gray = np.expand_dims(gray, axis=2) |
|
|
| input_np = np.concatenate((rgb, gray), axis=2) |
| input_tensor = torch.from_numpy(input_np).permute(2, 0, 1).unsqueeze(0) |
|
|
| return input_tensor |
|
|
| def postprocess(output_tensor): |
| output = output_tensor.squeeze().permute(1, 2, 0).cpu().detach().numpy() |
| output = (output * 255).astype(np.uint8) |
| _, img_encoded = cv2.imencode('.png', output) |
| return img_encoded.tobytes() |
|
|
| |
| def predict(payload: dict): |
| image_bytes_rgb = bytes.fromhex(payload["rgb"]) |
| image_bytes_gray = bytes.fromhex(payload["gray"]) |
|
|
| input_tensor = preprocess(image_bytes_rgb, image_bytes_gray) |
| with torch.no_grad(): |
| output = model(input_tensor) |
| output_bytes = postprocess(output) |
|
|
| return {"image": output_bytes.hex()} |
|
|