File size: 3,040 Bytes
0f42082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""ResNet inference service implementation."""

import base64
import os
import random
from io import BytesIO

import torch
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification  # type: ignore[import-untyped]

from app.core.logging import logger
from app.services.base import InferenceService
from app.api.models import ImageRequest, Labels, LocalizationMask, PredictionResponse


class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
    """ResNet-18 inference service for image classification."""

    def __init__(self, model_name: str = "microsoft/resnet-18"):
        self.model_name = model_name
        self.model = None
        self.processor = None
        self._is_loaded = False
        self.model_path = os.path.join("models", model_name)
        logger.info(f"Initializing ResNet service: {self.model_path}")

    def load_model(self) -> None:
        if self._is_loaded:
            return

        if not os.path.exists(self.model_path):
            raise FileNotFoundError(f"Model not found: {self.model_path}")

        config_path = os.path.join(self.model_path, "config.json")
        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config not found: {config_path}")

        logger.info(f"Loading model from {self.model_path}")

        import warnings
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning)
            self.processor = AutoImageProcessor.from_pretrained(
                self.model_path, local_files_only=True
            )
            self.model = ResNetForImageClassification.from_pretrained(
                self.model_path, local_files_only=True
            )
            assert self.model is not None

        self._is_loaded = True
        logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")  # pyright: ignore

    def predict(self, request: ImageRequest) -> PredictionResponse:
        if not self.is_loaded:
            raise RuntimeError("model is not loaded")
        assert self.processor is not None
        assert self.model is not None

        image_data = base64.b64decode(request.image.data)
        image = Image.open(BytesIO(image_data))
        width, height = image.size

        if image.mode != 'RGB':
            image = image.convert('RGB')

        inputs = self.processor(image, return_tensors="pt")

        with torch.no_grad():
            logits = self.model(**inputs).logits   # pyright: ignore

        logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)], dim=-1).tolist()
        mask_bytes = random.randbytes((width*height + 7) // 8)
        mask_bits = base64.b64encode(mask_bytes).decode("utf-8")

        return PredictionResponse(
            logprobs=logprobs,
            localizationMask=LocalizationMask(
                width=width, height=height, bitsRowMajor=mask_bits
            )
        )

    @property
    def is_loaded(self) -> bool:
        return self._is_loaded