SVS_oe3 / inference.py
shivamkunkolikar
May5 8:51PM
4f75b6d
import torch
import numpy as np
import cv2
import io
from model.model import UNetInpaint
# Load model
model = UNetInpaint()
model.load_state_dict(torch.load("model/pytorch_model.bin", map_location=torch.device("cpu")))
model.eval()
def bytes_to_cv2_image(image_bytes, mode='color'):
np_arr = np.frombuffer(image_bytes, np.uint8)
if mode == 'color':
image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
elif mode == 'gray':
image = cv2.imdecode(np_arr, cv2.IMREAD_GRAYSCALE)
return image
def preprocess(image_bytes_rgb, image_bytes_gray):
rgb = bytes_to_cv2_image(image_bytes_rgb, mode='color')
gray = bytes_to_cv2_image(image_bytes_gray, mode='gray')
rgb = cv2.resize(rgb, (256, 256))
gray = cv2.resize(gray, (256, 256))
rgb = rgb.astype(np.float32) / 255.0
gray = gray.astype(np.float32) / 255.0
gray = np.expand_dims(gray, axis=2)
input_np = np.concatenate((rgb, gray), axis=2) # shape: (256, 256, 4)
input_tensor = torch.from_numpy(input_np).permute(2, 0, 1).unsqueeze(0)
return input_tensor
def postprocess(output_tensor):
output = output_tensor.squeeze().permute(1, 2, 0).cpu().detach().numpy()
output = (output * 255).astype(np.uint8)
_, img_encoded = cv2.imencode('.png', output)
return img_encoded.tobytes()
# Required function for Hugging Face inference API
def predict(payload: dict):
image_bytes_rgb = bytes.fromhex(payload["rgb"])
image_bytes_gray = bytes.fromhex(payload["gray"])
input_tensor = preprocess(image_bytes_rgb, image_bytes_gray)
with torch.no_grad():
output = model(input_tensor)
output_bytes = postprocess(output)
return {"image": output_bytes.hex()}