File size: 3,306 Bytes
da2b98d
 
5ddae77
33241cf
5ddae77
33241cf
c2feb3e
b1f0e98
 
33241cf
b1f0e98
 
5ddae77
c2feb3e
b1f0e98
 
5ddae77
da2b98d
b1f0e98
5ddae77
b1f0e98
 
 
 
5ddae77
da2b98d
b1f0e98
d481329
b1f0e98
 
 
da2b98d
 
b1f0e98
da2b98d
 
 
b1f0e98
da2b98d
b1f0e98
da2b98d
 
 
 
 
 
 
 
 
33241cf
b1f0e98
da2b98d
33241cf
b1f0e98
d481329
33241cf
 
 
 
 
da2b98d
 
b1f0e98
da2b98d
 
b1f0e98
da2b98d
b1f0e98
da2b98d
be5bf87
5ddae77
c2feb3e
 
be5bf87
 
c2feb3e
 
 
 
 
 
 
 
b1f0e98
da2b98d
33241cf
c2feb3e
da2b98d
b1f0e98
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""ResNet inference service implementation."""

import base64
import os
from io import BytesIO

import numpy as np
import torch
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification  # type: ignore[import-untyped]

from app.core.logging import logger
from app.services.base import InferenceService
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse


class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
    """ResNet-18 inference service for image classification."""

    def __init__(self, model_name: str = "microsoft/resnet-18"):
        self.model_name = model_name
        self.model = None
        self.processor = None
        self._is_loaded = False
        self.model_path = os.path.join("models", model_name)
        logger.info(f"Initializing ResNet service: {self.model_path}")

    def load_model(self) -> None:
        if self._is_loaded:
            return

        if not os.path.exists(self.model_path):
            raise FileNotFoundError(f"Model not found: {self.model_path}")

        config_path = os.path.join(self.model_path, "config.json")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config not found: {config_path}")

        logger.info(f"Loading model from {self.model_path}")

        import warnings
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning)
            self.processor = AutoImageProcessor.from_pretrained(
                self.model_path, local_files_only=True
            )
            self.model = ResNetForImageClassification.from_pretrained(
                self.model_path, local_files_only=True
            )
            assert self.model is not None

        self._is_loaded = True
        logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")  # pyright: ignore

    def predict(self, request: ImageRequest) -> PredictionResponse:
        if not self.is_loaded:
            raise RuntimeError("model is not loaded")
        assert self.processor is not None
        assert self.model is not None

        image_data = base64.b64decode(request.image.data)
        image = Image.open(BytesIO(image_data))

        if image.mode != 'RGB':
            image = image.convert('RGB')

        inputs = self.processor(image, return_tensors="pt")

        with torch.no_grad():
            logits = self.model(**inputs).logits.squeeze()   # pyright: ignore

        # Convert the ImageNet output vector of dimension 1000 to the expected
        # output format. This is for demonstration purposes
        # and obviously will not perform well on the actual task.
        logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()

        # Dummy localization mask: a rectangle approximately in the middle
        x = image.width // 3
        y = image.height // 3
        # Row-major order
        mask = np.zeros((image.height, image.width), dtype=np.uint8)
        mask[y:(2*y), x:(2*x)] = 1
        mask_obj = BinaryMask.from_numpy(mask)

        return PredictionResponse(
            logprobs=logprobs,
            localizationMask=mask_obj,
        )

    @property
    def is_loaded(self) -> bool:
        return self._is_loaded