File size: 1,369 Bytes
f647f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from model import UNetInpaint
from PIL import Image
import torch
import numpy as np
import io

class EndpointHandler:
    def __init__(self, path=""):
        self.model = UNetInpaint()
        self.model.load_state_dict(torch.load("model.pth", map_location="cpu"))
        self.model.eval()

    def __call__(self, data):
        image_bytes = data.get("image")
        mask_bytes = data.get("mask")

        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        mask = Image.open(io.BytesIO(mask_bytes)).convert("L")

        image_np = np.array(image).astype(np.float32) / 255.0
        mask_np = np.array(mask).astype(np.float32) / 255.0
        mask_np = (mask_np > 0.5).astype(np.float32)

        mask_np = np.expand_dims(mask_np, axis=-1)
        image_np = np.transpose(image_np, (2, 0, 1))
        mask_np = np.transpose(mask_np, (2, 0, 1))

        image_tensor = torch.tensor(image_np) * (1 - torch.tensor(mask_np))
        input_tensor = torch.cat([image_tensor, torch.tensor(mask_np)], dim=0).unsqueeze(0)

        with torch.no_grad():
            output = self.model(input_tensor).squeeze(0).numpy().transpose(1, 2, 0)
            output = (np.clip(output, 0, 1) * 255).astype(np.uint8)
            result = Image.fromarray(output)

        buf = io.BytesIO()
        result.save(buf, format="PNG")
        return {"image": buf.getvalue()}