Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import Tuple | |
| import torch | |
| from fastapi import HTTPException, status | |
| from PIL import Image | |
| from transformers import PreTrainedModel | |
| from transformers.image_processing_utils import BaseImageProcessor | |
| logger = logging.getLogger(__name__) | |
| async def classify_mushroom_in_image_svc( | |
| img: Image.Image, model: PreTrainedModel, preprocessor: BaseImageProcessor | |
| ) -> Tuple[str, str, str]: | |
| """Service used to classify a mushroom shown in an image. | |
| The mushroom is classified to one of many well known mushroom classes/types, | |
| as well as according to its toxicity profile (i.e. edible or poisonous). | |
| Additionally, a probability is returned showing confidence of classification. | |
| :param img: the input image of the mushroom to be classified | |
| :type img: Image.Image | |
| :param model: the pretrained model | |
| :type model: PretrainedModel | |
| :param preprocessor: the auto preprocessor for image transforms (rescales, crops, normalizations etc.) | |
| :type preprocessor: BaseImageProcessor | |
| :raises HTTPException: Internal Server Error | |
| :return: mushroom_type, toxicity_profile, classification_confidence | |
| :rtype: Tuple[str, str, float] | |
| """ | |
| try: | |
| logger.debug("Loading classification model.") | |
| inputs = preprocessor(img, return_tensors="pt").to(model.device) | |
| # Turn on model evaluation mode and inference mode | |
| model.eval() | |
| with torch.inference_mode(): | |
| logger.debug("Starting classification process...") | |
| # Make a prediction on image with an extra dimension and send it to the target device | |
| target_image_pred = model(inputs["pixel_values"]) | |
| # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification) | |
| target_image_pred_probs = torch.softmax(target_image_pred, dim=1) | |
| # model predicts one of the 12 potential mushroom classes | |
| predicted_label = target_image_pred.argmax(dim=1).item() | |
| # Get the label/class name of the prediction made using id2label | |
| class_name = model.config.id2label[predicted_label] | |
| # Split class_name to mushroom type and toxicity profile | |
| class_type, toxicity = class_name.rsplit("_", 1) | |
| # 4 decimal points precision | |
| prob = round(target_image_pred_probs.max().item(), 4) | |
| logger.debug("Finished classification process...") | |
| return class_type, toxicity, prob | |
| except Exception as e: | |
| logger.error(f"Classification process error: {e}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Classification process failed due to an internal error. Contact support if this persists.", | |
| ) | |