"""ResNet inference service implementation.""" import base64 import os import random from io import BytesIO import torch from PIL import Image from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped] from app.core.logging import logger from app.services.base import InferenceService from app.api.models import ImageRequest, Labels, LocalizationMask, PredictionResponse class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]): """ResNet-18 inference service for image classification.""" def __init__(self, model_name: str = "microsoft/resnet-18"): self.model_name = model_name self.model = None self.processor = None self._is_loaded = False self.model_path = os.path.join("models", model_name) logger.info(f"Initializing ResNet service: {self.model_path}") def load_model(self) -> None: if self._is_loaded: return if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model not found: {self.model_path}") config_path = os.path.join(self.model_path, "config.json") if not os.path.exists(config_path): raise FileNotFoundError(f"Config not found: {config_path}") logger.info(f"Loading model from {self.model_path}") import warnings with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) self.processor = AutoImageProcessor.from_pretrained( self.model_path, local_files_only=True ) self.model = ResNetForImageClassification.from_pretrained( self.model_path, local_files_only=True ) assert self.model is not None self._is_loaded = True logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore def predict(self, request: ImageRequest) -> PredictionResponse: if not self.is_loaded: raise RuntimeError("model is not loaded") assert self.processor is not None assert self.model is not None image_data = base64.b64decode(request.image.data) image = Image.open(BytesIO(image_data)) width, height = image.size if image.mode != 'RGB': image = image.convert('RGB') inputs = self.processor(image, return_tensors="pt") with torch.no_grad(): logits = self.model(**inputs).logits # pyright: ignore logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)], dim=-1).tolist() mask_bytes = random.randbytes((width*height + 7) // 8) mask_bits = base64.b64encode(mask_bytes).decode("utf-8") return PredictionResponse( logprobs=logprobs, localizationMask=LocalizationMask( width=width, height=height, bitsRowMajor=mask_bits ) ) @property def is_loaded(self) -> bool: return self._is_loaded