|
|
"""ResNet inference service implementation.""" |
|
|
|
|
|
import base64 |
|
|
import os |
|
|
from io import BytesIO |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoImageProcessor, ResNetForImageClassification |
|
|
|
|
|
from app.core.logging import logger |
|
|
from app.services.base import InferenceService |
|
|
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse |
|
|
|
|
|
|
|
|
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]): |
|
|
"""ResNet-18 inference service for image classification.""" |
|
|
|
|
|
def __init__(self, model_name: str = "microsoft/resnet-18"): |
|
|
self.model_name = model_name |
|
|
self.model = None |
|
|
self.processor = None |
|
|
self._is_loaded = False |
|
|
self.model_path = os.path.join("models", model_name) |
|
|
logger.info(f"Initializing ResNet service: {self.model_path}") |
|
|
|
|
|
def load_model(self) -> None: |
|
|
if self._is_loaded: |
|
|
return |
|
|
|
|
|
if not os.path.exists(self.model_path): |
|
|
raise FileNotFoundError(f"Model not found: {self.model_path}") |
|
|
|
|
|
config_path = os.path.join(self.model_path, "config.json") |
|
|
if not os.path.exists(config_path): |
|
|
raise FileNotFoundError(f"Config not found: {config_path}") |
|
|
|
|
|
logger.info(f"Loading model from {self.model_path}") |
|
|
|
|
|
import warnings |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
self.processor = AutoImageProcessor.from_pretrained( |
|
|
self.model_path, local_files_only=True |
|
|
) |
|
|
self.model = ResNetForImageClassification.from_pretrained( |
|
|
self.model_path, local_files_only=True |
|
|
) |
|
|
assert self.model is not None |
|
|
|
|
|
self._is_loaded = True |
|
|
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") |
|
|
|
|
|
def predict(self, request: ImageRequest) -> PredictionResponse: |
|
|
if not self.is_loaded: |
|
|
raise RuntimeError("model is not loaded") |
|
|
assert self.processor is not None |
|
|
assert self.model is not None |
|
|
|
|
|
image_data = base64.b64decode(request.image.data) |
|
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
inputs = self.processor(image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = self.model(**inputs).logits.squeeze() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist() |
|
|
|
|
|
|
|
|
x = image.width // 3 |
|
|
y = image.height // 3 |
|
|
|
|
|
mask = np.zeros((image.height, image.width), dtype=np.uint8) |
|
|
mask[y:(2*y), x:(2*x)] = 1 |
|
|
mask_obj = BinaryMask.from_numpy(mask) |
|
|
|
|
|
return PredictionResponse( |
|
|
logprobs=logprobs, |
|
|
localizationMask=mask_obj, |
|
|
) |
|
|
|
|
|
@property |
|
|
def is_loaded(self) -> bool: |
|
|
return self._is_loaded |
|
|
|