Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import os | |
| import huggingface_hub | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import yaml # type: ignore | |
| from mmdet.apis import inference_detector, init_detector | |
| class Model: | |
| def __init__(self, model_name: str): | |
| self.device = torch.device( | |
| 'cuda:0' if torch.cuda.is_available() else 'cpu') | |
| self.model_name = model_name | |
| self.model = self._load_model(model_name) | |
| def _load_model(self, name: str) -> nn.Module: | |
| return init_detector('configs/_base_/faster-rcnn_r50_fpn_1x_coco.py', 'models/orgaquant_pretrained_new.pth' , device=self.device) | |
| def set_model(self, name: str) -> None: | |
| if name == self.model_name: | |
| return | |
| self.model_name = name | |
| self.model = self._load_model(name) | |
| def detect_and_visualize( | |
| self, image: np.ndarray, score_threshold: float | |
| ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray], | |
| list[list[np.ndarray]]] | |
| | dict[str, np.ndarray], np.ndarray]: | |
| out = self.detect(image) | |
| vis = self.visualize_detection_results(image, out, score_threshold) | |
| return out, vis | |
| def detect( | |
| self, image: np.ndarray | |
| ) -> list[np.ndarray] | tuple[ | |
| list[np.ndarray], list[list[np.ndarray]]] | dict[str, np.ndarray]: | |
| out = inference_detector(self.model, image) | |
| return out | |
| def visualize_detection_results( | |
| self, | |
| image: np.ndarray, | |
| detection_results: list[np.ndarray] | |
| | tuple[list[np.ndarray], list[list[np.ndarray]]] | |
| | dict[str, np.ndarray], | |
| score_threshold: float = 0.3) -> np.ndarray: | |
| print('Detection results',detection_results) | |
| vis = self.model.show_result(image, | |
| detection_results, | |
| score_thr=score_threshold, | |
| bbox_color=None, | |
| text_color=(200, 200, 200), | |
| mask_color=None) | |
| return vis | |
| class AppModel(Model): | |
| def run( | |
| self, model_name: str, image: np.ndarray, score_threshold: float | |
| ) -> tuple[list[np.ndarray] | tuple[list[np.ndarray], | |
| list[list[np.ndarray]]] | |
| | dict[str, np.ndarray], np.ndarray]: | |
| self.set_model(model_name) | |
| return self.detect_and_visualize(image, score_threshold) |