File size: 3,306 Bytes
da2b98d 5ddae77 33241cf 5ddae77 33241cf c2feb3e b1f0e98 33241cf b1f0e98 5ddae77 c2feb3e b1f0e98 5ddae77 da2b98d b1f0e98 5ddae77 b1f0e98 5ddae77 da2b98d b1f0e98 d481329 b1f0e98 da2b98d b1f0e98 da2b98d b1f0e98 da2b98d b1f0e98 da2b98d 33241cf b1f0e98 da2b98d 33241cf b1f0e98 d481329 33241cf da2b98d b1f0e98 da2b98d b1f0e98 da2b98d b1f0e98 da2b98d be5bf87 5ddae77 c2feb3e be5bf87 c2feb3e b1f0e98 da2b98d 33241cf c2feb3e da2b98d b1f0e98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
"""ResNet inference service implementation."""
import base64
import os
from io import BytesIO
import numpy as np
import torch
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
from app.core.logging import logger
from app.services.base import InferenceService
from app.api.models import BinaryMask, ImageRequest, Labels, PredictionResponse
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
"""ResNet-18 inference service for image classification."""
def __init__(self, model_name: str = "microsoft/resnet-18"):
self.model_name = model_name
self.model = None
self.processor = None
self._is_loaded = False
self.model_path = os.path.join("models", model_name)
logger.info(f"Initializing ResNet service: {self.model_path}")
def load_model(self) -> None:
if self._is_loaded:
return
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model not found: {self.model_path}")
config_path = os.path.join(self.model_path, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found: {config_path}")
logger.info(f"Loading model from {self.model_path}")
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
self.processor = AutoImageProcessor.from_pretrained(
self.model_path, local_files_only=True
)
self.model = ResNetForImageClassification.from_pretrained(
self.model_path, local_files_only=True
)
assert self.model is not None
self._is_loaded = True
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
def predict(self, request: ImageRequest) -> PredictionResponse:
if not self.is_loaded:
raise RuntimeError("model is not loaded")
assert self.processor is not None
assert self.model is not None
image_data = base64.b64decode(request.image.data)
image = Image.open(BytesIO(image_data))
if image.mode != 'RGB':
image = image.convert('RGB')
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
logits = self.model(**inputs).logits.squeeze() # pyright: ignore
# Convert the ImageNet output vector of dimension 1000 to the expected
# output format. This is for demonstration purposes
# and obviously will not perform well on the actual task.
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)]).tolist()
# Dummy localization mask: a rectangle approximately in the middle
x = image.width // 3
y = image.height // 3
# Row-major order
mask = np.zeros((image.height, image.width), dtype=np.uint8)
mask[y:(2*y), x:(2*x)] = 1
mask_obj = BinaryMask.from_numpy(mask)
return PredictionResponse(
logprobs=logprobs,
localizationMask=mask_obj,
)
@property
def is_loaded(self) -> bool:
return self._is_loaded
|