File size: 1,805 Bytes
ecc7af3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from torchvision import transforms
from huggingface_hub import hf_hub_download
import json
import io
import base64
from PIL import Image
from omegaconf import OmegaConf

from model import Generator


class EndpointHandler:

    def __init__(self, path=''):
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

        repo_id = "Kiwinicki/sat2map-generator"
        generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
        config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
        model_path = hf_hub_download(repo_id=repo_id, filename="model.py")

        with open(config_path, "r") as f:
            config_dict = json.load(f)
        cfg = OmegaConf.create(config_dict)

        self.generator = Generator(cfg)
        self.generator.load_state_dict(torch.load(generator_path))
        self.generator.eval()


    def __call__(self, data: dict[str, any]) -> dict[str, str]:
        base64_image = data.get('inputs')
        image = self._decode_base64_image(base64_image)
        output = self.generator(image)
        output = output.squeeze(0)
        output_image = transforms.ToPILImage()(output)
        output_image = output_image.convert('RGB')
        output_buffer = io.BytesIO()
        output_image.save(output_buffer, format="png")
        base64_output = base64.b64encode(output_buffer.getvalue()).decode('utf-8')
        return {"output": base64_output}


    def _decode_base64_image(self, base64_image: str) -> torch.Tensor:
        image_decoded = base64.b64decode(base64_image)
        image = Image.open(io.BytesIO(image_decoded)).convert('RGB')
        image_tensor: torch.Tensor = self.transform(image)
        image_tensor = image_tensor.unsqueeze(0)
        return image_tensor