object-eraser-model / handler.py
shivamkunkolikar
gradio update
f647f94
from model import UNetInpaint
from PIL import Image
import torch
import numpy as np
import io
class EndpointHandler:
def __init__(self, path=""):
self.model = UNetInpaint()
self.model.load_state_dict(torch.load("model.pth", map_location="cpu"))
self.model.eval()
def __call__(self, data):
image_bytes = data.get("image")
mask_bytes = data.get("mask")
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
mask = Image.open(io.BytesIO(mask_bytes)).convert("L")
image_np = np.array(image).astype(np.float32) / 255.0
mask_np = np.array(mask).astype(np.float32) / 255.0
mask_np = (mask_np > 0.5).astype(np.float32)
mask_np = np.expand_dims(mask_np, axis=-1)
image_np = np.transpose(image_np, (2, 0, 1))
mask_np = np.transpose(mask_np, (2, 0, 1))
image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np))
input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0)
with torch.no_grad():
output = self.model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
result = Image.fromarray(output)
buf = io.BytesIO()
result.save(buf, format="PNG")
return {"image": buf.getvalue()}