Spaces:
Runtime error
Runtime error
| from model import UNetInpaint | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import io | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| self.model = UNetInpaint() | |
| self.model.load_state_dict(torch.load("model.pth", map_location="cpu")) | |
| self.model.eval() | |
| def __call__(self, data): | |
| image_bytes = data.get("image") | |
| mask_bytes = data.get("mask") | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| mask = Image.open(io.BytesIO(mask_bytes)).convert("L") | |
| image_np = np.array(image).astype(np.float32) / 255.0 | |
| mask_np = np.array(mask).astype(np.float32) / 255.0 | |
| mask_np = (mask_np > 0.5).astype(np.float32) | |
| mask_np = np.expand_dims(mask_np, axis=-1) | |
| image_np = np.transpose(image_np, (2, 0, 1)) | |
| mask_np = np.transpose(mask_np, (2, 0, 1)) | |
| image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np)) | |
| input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = self.model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0) | |
| output = (np.clip(output, 0, 1) * 255).astype(np.uint8) | |
| result = Image.fromarray(output) | |
| buf = io.BytesIO() | |
| result.save(buf, format="PNG") | |
| return {"image": buf.getvalue()} | |