"""ResNet inference service implementation.""" import base64 import os from io import BytesIO import numpy as np import torch from PIL import Image from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped] from app.core.logging import logger from app.services.base import InferenceService from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]): """ResNet-18 inference service for image classification.""" def __init__(self, model_name: str = "microsoft/resnet-18"): self.model_name = model_name self.model = None self.processor = None self._is_loaded = False self.model_path = os.path.join("models", model_name) logger.info(f"Initializing ResNet service: {self.model_path}") def load_model(self) -> None: if self._is_loaded: return if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model not found: {self.model_path}") config_path = os.path.join(self.model_path, "config.json") if not os.path.exists(config_path): raise FileNotFoundError(f"Config not found: {config_path}") logger.info(f"Loading model from {self.model_path}") import warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) self.processor = AutoImageProcessor.from_pretrained( self.model_path, local_files_only=True ) self.model = ResNetForImageClassification.from_pretrained( self.model_path, local_files_only=True ) assert self.model is not None self._is_loaded = True logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore def predict(self, request: ImageRequest) -> PredictionResponse: if not self.is_loaded: raise RuntimeError("model is not loaded") assert self.processor is not None assert self.model is not None image_data = base64.b64decode(request.image.data) image = Image.open(BytesIO(image_data)) if image.mode != 'RGB': image = image.convert('RGB') inputs = self.processor(image, return_tensors="pt") with torch.no_grad(): logits = self.model(**inputs).logits.squeeze() # pyright: ignore # Convert the ImageNet output vector of dimension 1000 to the expected # output format. This is for demonstration purposes # and obviously will not perform well on the actual task. logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist() # Dummy localization mask: a rectangle approximately in the middle x = image.width // 3 y = image.height // 3 # Row-major order mask = np.zeros((image.height, image.width), dtype=np.uint8) mask[y:(2*y), x:(2*x)] = 1 mask_obj = BinaryMask.from_numpy(mask) return PredictionResponse( logprobs=logprobs, localizationMask=mask_obj, ) @property def is_loaded(self) -> bool: return self._is_loaded