from model import UNetInpaint from PIL import Image import torch import numpy as np import io class EndpointHandler: def __init__(self, path=""): self.model = UNetInpaint() self.model.load_state_dict(torch.load("model.pth", map_location="cpu")) self.model.eval() def __call__(self, data): image_bytes = data.get("image") mask_bytes = data.get("mask") image = Image.open(io.BytesIO(image_bytes)).convert("RGB") mask = Image.open(io.BytesIO(mask_bytes)).convert("L") image_np = np.array(image).astype(np.float32) / 255.0 mask_np = np.array(mask).astype(np.float32) / 255.0 mask_np = (mask_np > 0.5).astype(np.float32) mask_np = np.expand_dims(mask_np, axis=-1) image_np = np.transpose(image_np, (2, 0, 1)) mask_np = np.transpose(mask_np, (2, 0, 1)) image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np)) input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0) with torch.no_grad(): output = self.model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0) output = (np.clip(output, 0, 1) * 255).astype(np.uint8) result = Image.fromarray(output) buf = io.BytesIO() result.save(buf, format="PNG") return {"image": buf.getvalue()}