yanp's picture
Upload folder using huggingface_hub
0f42082 verified
"""ResNet inference service implementation."""
import base64
import os
import random
from io import BytesIO
import torch
from PIL import Image
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
from app.core.logging import logger
from app.services.base import InferenceService
from app.api.models import ImageRequest, Labels, LocalizationMask, PredictionResponse
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
"""ResNet-18 inference service for image classification."""
def __init__(self, model_name: str = "microsoft/resnet-18"):
self.model_name = model_name
self.model = None
self.processor = None
self._is_loaded = False
self.model_path = os.path.join("models", model_name)
logger.info(f"Initializing ResNet service: {self.model_path}")
def load_model(self) -> None:
if self._is_loaded:
return
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model not found: {self.model_path}")
config_path = os.path.join(self.model_path, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found: {config_path}")
logger.info(f"Loading model from {self.model_path}")
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
self.processor = AutoImageProcessor.from_pretrained(
self.model_path, local_files_only=True
)
self.model = ResNetForImageClassification.from_pretrained(
self.model_path, local_files_only=True
)
assert self.model is not None
self._is_loaded = True
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
def predict(self, request: ImageRequest) -> PredictionResponse:
if not self.is_loaded:
raise RuntimeError("model is not loaded")
assert self.processor is not None
assert self.model is not None
image_data = base64.b64decode(request.image.data)
image = Image.open(BytesIO(image_data))
width, height = image.size
if image.mode != 'RGB':
image = image.convert('RGB')
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
logits = self.model(**inputs).logits # pyright: ignore
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)], dim=-1).tolist()
mask_bytes = random.randbytes((width*height + 7) // 8)
mask_bits = base64.b64encode(mask_bytes).decode("utf-8")
return PredictionResponse(
logprobs=logprobs,
localizationMask=LocalizationMask(
width=width, height=height, bitsRowMajor=mask_bits
)
)
@property
def is_loaded(self) -> bool:
return self._is_loaded