File size: 3,313 Bytes
8bd3ef8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
"""ResNet inference service implementation."""
import base64
import os
from io import BytesIO
import numpy as np
import torch
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
from app.core.logging import logger
from app.services.base import InferenceService
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
"""ResNet-18 inference service for image classification."""
def __init__(self, model_name: str = "microsoft/resnet-18"):
self.model_name = model_name
self.model = None
self.processor = None
self._is_loaded = False
self.model_path = os.path.join("models", model_name)
logger.info(f"Initializing ResNet service: {self.model_path}")
def load_model(self) -> None:
if self._is_loaded:
return
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model not found: {self.model_path}")
config_path = os.path.join(self.model_path, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found: {config_path}")
logger.info(f"Loading model from {self.model_path}")
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
self.processor = AutoImageProcessor.from_pretrained(
self.model_path, local_files_only=True
)
self.model = ResNetForImageClassification.from_pretrained(
self.model_path, local_files_only=True
)
assert self.model is not None
self._is_loaded = True
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
def predict(self, request: ImageRequest) -> PredictionResponse:
if not self.is_loaded:
raise RuntimeError("model is not loaded")
assert self.processor is not None
assert self.model is not None
image_data = base64.b64decode(request.image.data)
image = Image.open(BytesIO(image_data))
if image.mode != 'RGB':
image = image.convert('RGB')
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
logits = self.model(**inputs).logits.squeeze() # pyright: ignore
# Convert the ImageNet output vector of dimension 1000 to the expected
# output format. This is for demonstration purposes
# and obviously will not perform well on the actual task.
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()
# Dummy localization mask: a rectangle approximately in the middle
x = image.width // 3
y = image.height // 3
# Row-major order
mask = np.zeros((image.height, image.width), dtype=np.uint8)
mask[y:(2*y), x:(2*x)] = 1
mask_obj = BinaryMask.from_numpy(mask)
return PredictionResponse(
logprobs=logprobs,
localizationMask=mask_obj,
)
@property
def is_loaded(self) -> bool:
return self._is_loaded
|