Spaces:
Running
Running
| import onnxruntime | |
| import numpy as np | |
| import cv2 | |
| from typing import Tuple, List, Union | |
| from .base_onnx import BaseONNX | |
| class Counting(BaseONNX): | |
| UPPER_BOUND = 2560 | |
| MULTIPLE_OF = 32 | |
| def __init__(self, model_path): | |
| super().__init__(model_path) | |
| def preprocess_image(self, img: cv2.UMat, is_rgb: bool = True): | |
| """ | |
| 预处理图像,包括颜色转换、缩放和标准化 | |
| Args: | |
| img: 输入图像,BGR或RGB格式 | |
| is_rgb: 是否已经是RGB格式,默认为True | |
| Returns: | |
| 预处理后的图像张量,形状为(1, 3, H, W) | |
| """ | |
| if not is_rgb: | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| else: | |
| img = img | |
| img_copy = img.copy() | |
| # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| # 转换为 float32 类型 | |
| img = img.astype(np.float32) | |
| # 除以 255.0 | |
| img /= 255.0 | |
| # 减去均值 | |
| img -= np.array([0.485, 0.456, 0.406]) | |
| # 除以标准差 | |
| img /= np.array([0.229, 0.224, 0.225]) | |
| # 检查图像大小是否超过上限 | |
| origin_h, origin_w = img.shape[:2] | |
| max_size = max(origin_h, origin_w) | |
| if max_size > self.UPPER_BOUND: | |
| scale = self.UPPER_BOUND / max_size | |
| img = cv2.resize(img, None, fx=scale, fy=scale) | |
| h, w = img.shape[:2] | |
| # 确保图像尺寸是32的倍数 | |
| new_h = (h // self.MULTIPLE_OF) * self.MULTIPLE_OF | |
| new_w = (w // self.MULTIPLE_OF) * self.MULTIPLE_OF | |
| if h != new_h or w != new_w: | |
| img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) | |
| img_copy = cv2.resize(img_copy, (new_w, new_h), interpolation=cv2.INTER_LINEAR) | |
| # 调整维度顺序 (H,W,C) -> (C,H,W) | |
| img = np.transpose(img, (2, 0, 1)) | |
| # 添加 batch 维度 | |
| img = np.expand_dims(img, axis=0) | |
| return img, img_copy | |
| def run_inference(self, image: np.ndarray) -> any: | |
| """ | |
| Run inference on the image. | |
| Args: | |
| image (np.ndarray): The image to run inference on. | |
| Returns: | |
| tuple: A tuple containing the detection results and labels. | |
| """ | |
| # 运行推理 | |
| result = self.session.run(None, {self.input_name: image}) | |
| return result | |
| def pred(self, image: List[Union[cv2.UMat, str]], is_rgb: bool = True) -> Tuple[List[float], List[List[float]],]: | |
| """ | |
| Predict the detection results of the image. | |
| Args: | |
| image (cv2.UMat, str): The image to predict. | |
| Returns: | |
| """ | |
| if isinstance(image, str): | |
| img_bgr = cv2.imread(image) | |
| is_rgb = False | |
| else: | |
| img_bgr = image.copy() | |
| processed_image, _ = self.preprocess_image(img_bgr, is_rgb) | |
| scores, points = self.run_inference(processed_image) | |
| return scores, points | |
| def draw_pred(self, image: cv2.UMat, scores: List[float], points: List[List[float]]) -> cv2.UMat: | |
| marked_img = np.array(image.copy()) | |
| for point, score in zip(points, scores): | |
| # 确保点坐标在合理范围内 | |
| x, y = int(point[0]), int(point[1]) | |
| if 0 <= x < marked_img.shape[1] and 0 <= y < marked_img.shape[0]: | |
| cv2.circle(marked_img, (x, y), 5, (255, 0, 0), -1) | |
| return marked_img | |