example / app /services /inference.py
jessehostetler's picture
Clean up docs. Fix test script incorrect path.
a12ee73
"""ResNet inference service implementation."""
import base64
import os
from io import BytesIO
import numpy as np
import torch
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
from app.core.logging import logger
from app.services.base import InferenceService
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
"""ResNet-18 inference service for image classification."""
def __init__(self, model_name: str = "microsoft/resnet-18"):
self.model_name = model_name
self.model = None
self.processor = None
self._is_loaded = False
self.model_path = os.path.join("models", model_name)
logger.info(f"Initializing ResNet service: {self.model_path}")
def load_model(self) -> None:
if self._is_loaded:
return
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model not found: {self.model_path}")
config_path = os.path.join(self.model_path, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found: {config_path}")
logger.info(f"Loading model from {self.model_path}")
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
self.processor = AutoImageProcessor.from_pretrained(
self.model_path, local_files_only=True
)
self.model = ResNetForImageClassification.from_pretrained(
self.model_path, local_files_only=True
)
assert self.model is not None
self._is_loaded = True
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
def predict(self, request: ImageRequest) -> PredictionResponse:
if not self.is_loaded:
raise RuntimeError("model is not loaded")
assert self.processor is not None
assert self.model is not None
image_data = base64.b64decode(request.image.data)
image = Image.open(BytesIO(image_data))
if image.mode != 'RGB':
image = image.convert('RGB')
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
logits = self.model(**inputs).logits.squeeze() # pyright: ignore
# Convert the ImageNet output vector of dimension 1000 to the expected
# output format. This is for demonstration purposes
# and obviously will not perform well on the actual task.
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()
# Dummy localization mask: a rectangle approximately in the middle
x = image.width // 3
y = image.height // 3
# Row-major order
mask = np.zeros((image.height, image.width), dtype=np.uint8)
mask[y:(2*y), x:(2*x)] = 1
mask_obj = BinaryMask.from_numpy(mask)
return PredictionResponse(
logprobs=logprobs,
localizationMask=mask_obj,
)
@property
def is_loaded(self) -> bool:
return self._is_loaded