Spaces:
Runtime error
Runtime error
| """ResNet inference service implementation.""" | |
| import base64 | |
| import os | |
| import random | |
| from io import BytesIO | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped] | |
| from app.core.logging import logger | |
| from app.services.base import InferenceService | |
| from app.api.models import ImageRequest, Labels, LocalizationMask, PredictionResponse | |
| class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]): | |
| """ResNet-18 inference service for image classification.""" | |
| def __init__(self, model_name: str = "microsoft/resnet-18"): | |
| self.model_name = model_name | |
| self.model = None | |
| self.processor = None | |
| self._is_loaded = False | |
| self.model_path = os.path.join("models", model_name) | |
| logger.info(f"Initializing ResNet service: {self.model_path}") | |
| def load_model(self) -> None: | |
| if self._is_loaded: | |
| return | |
| if not os.path.exists(self.model_path): | |
| raise FileNotFoundError(f"Model not found: {self.model_path}") | |
| config_path = os.path.join(self.model_path, "config.json") | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(f"Config not found: {config_path}") | |
| logger.info(f"Loading model from {self.model_path}") | |
| import warnings | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| self.processor = AutoImageProcessor.from_pretrained( | |
| self.model_path, local_files_only=True | |
| ) | |
| self.model = ResNetForImageClassification.from_pretrained( | |
| self.model_path, local_files_only=True | |
| ) | |
| assert self.model is not None | |
| self._is_loaded = True | |
| logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore | |
| def predict(self, request: ImageRequest) -> PredictionResponse: | |
| if not self.is_loaded: | |
| raise RuntimeError("model is not loaded") | |
| assert self.processor is not None | |
| assert self.model is not None | |
| image_data = base64.b64decode(request.image.data) | |
| image = Image.open(BytesIO(image_data)) | |
| width, height = image.size | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| inputs = self.processor(image, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = self.model(**inputs).logits # pyright: ignore | |
| logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)], dim=-1).tolist() | |
| mask_bytes = random.randbytes((width*height + 7) // 8) | |
| mask_bits = base64.b64encode(mask_bytes).decode("utf-8") | |
| return PredictionResponse( | |
| logprobs=logprobs, | |
| localizationMask=LocalizationMask( | |
| width=width, height=height, bitsRowMajor=mask_bits | |
| ) | |
| ) | |
| def is_loaded(self) -> bool: | |
| return self._is_loaded | |