Spaces:
Sleeping
Sleeping
| import tempfile | |
| from typing import List, Tuple, Optional, Dict, Any | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| # ----------------------------- | |
| # Global configuration | |
| # ----------------------------- | |
| # Maximum allowed image side (pixels) to avoid OOM / heavy CPU usage | |
| # Reduced from 2048 to 1024 for better performance (as in demo.py) | |
| MAX_SIDE = 1024 | |
| # ----------------------------- | |
| # Utility functions | |
| # ----------------------------- | |
| def downscale_bgr(img: np.ndarray) -> Tuple[np.ndarray, float]: | |
| """Downscale image so that the longest side is <= MAX_SIDE. | |
| Returns | |
| ------- | |
| img_resized : np.ndarray | |
| Possibly downscaled BGR image. | |
| scale : float | |
| Applied scale factor (<= 1). | |
| """ | |
| h, w = img.shape[:2] | |
| max_hw = max(h, w) | |
| if max_hw <= MAX_SIDE: | |
| return img, 1.0 | |
| scale = MAX_SIDE / float(max_hw) | |
| img_resized = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA) | |
| return img_resized, scale | |
| def normalize_angle(angle: float, size_w: float, size_h: float) -> float: | |
| """Normalize OpenCV minAreaRect angle to [0, 180) degrees. | |
| OpenCV returns angles depending on whether width < height. We fix it so that | |
| the *long side* is treated as length and angle is always in [0, 180). | |
| """ | |
| a = angle | |
| if size_w < size_h: | |
| a += 90.0 | |
| a = ((a % 180.0) + 180.0) % 180.0 | |
| return a | |
| # ----------------------------- | |
| # Reference object detection | |
| # ----------------------------- | |
| def build_foreground_mask(img_bgr: np.ndarray) -> np.ndarray: | |
| """简单的前景掩码构建(来自demo.py)""" | |
| h, w = img_bgr.shape[:2] | |
| # 转换到LAB颜色空间 | |
| lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB) | |
| # 使用四个角落估计背景颜色 | |
| corner_size = min(h, w) // 10 | |
| corners = [ | |
| lab[:corner_size, :corner_size], | |
| lab[:corner_size, -corner_size:], | |
| lab[-corner_size:, :corner_size], | |
| lab[-corner_size:, -corner_size:] | |
| ] | |
| corner_pixels = np.vstack([c.reshape(-1, 3) for c in corners]) | |
| bg_color = np.mean(corner_pixels, axis=0) | |
| # 计算每个像素与背景的距离 | |
| diff = lab.astype(np.float32) - bg_color | |
| dist = np.sqrt(np.sum(diff * diff, axis=2)) | |
| # 使用Otsu阈值分割 | |
| dist_uint8 = np.clip(dist * 3, 0, 255).astype(np.uint8) | |
| _, mask = cv2.threshold(dist_uint8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| # 形态学处理 | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| return mask | |
| def detect_reference( | |
| img_bgr: np.ndarray, | |
| mode: str, | |
| ref_size_mm: Optional[float], | |
| ) -> Tuple[float, Optional[Tuple[int, int]], Optional[str], Optional[Tuple[int, int, int, int]]]: | |
| """检测参考物:左上角第一个物体(简化版) | |
| 参数: | |
| img_bgr: BGR图像 | |
| mode: 参考物模式 ("auto", "coin", "square") | |
| ref_size_mm: 参考物包围框边长(毫米) | |
| 返回: | |
| px_per_mm: 像素/毫米比例 | |
| ref_center: 参考物中心 | |
| ref_type: 参考物类型 | |
| ref_bbox: 参考物外接矩形 | |
| """ | |
| h, w = img_bgr.shape[:2] | |
| # 使用简单的前景掩码 | |
| mask = build_foreground_mask(img_bgr) | |
| # 连通域分析 | |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask) | |
| # 寻找左上角的参考物 | |
| candidates = [] | |
| min_area = (h * w) // 500 # 最小面积 | |
| max_area = (h * w) // 20 # 最大面积 | |
| for i in range(1, num_labels): | |
| x, y, ww, hh, area = stats[i] | |
| # 面积过滤 | |
| if area < min_area or area > max_area: | |
| continue | |
| # 位置过滤:必须在左上角区域 | |
| if x > w * 0.4 or y > h * 0.4: | |
| continue | |
| # 形状过滤:参考物应该接近正方形 | |
| aspect_ratio = max(ww, hh) / (min(ww, hh) + 1e-6) | |
| if aspect_ratio > 3.0: | |
| continue | |
| cx, cy = centroids[i] | |
| # 按位置排序:越靠近左上角越好 | |
| score = x + y | |
| candidates.append((score, i, (x, y, ww, hh), area, (int(cx), int(cy)))) | |
| if not candidates: | |
| # 如果没有找到参考物,使用安全的默认值 | |
| px_per_mm = 4.0 | |
| center = None | |
| ref_type = None | |
| bbox = None | |
| return px_per_mm, center, ref_type, bbox | |
| # 选择最左上角的候选物 | |
| candidates.sort(key=lambda c: c[0]) | |
| score, label_idx, bbox, area, center = candidates[0] | |
| x, y, ww, hh = bbox | |
| # 计算像素/毫米比例 | |
| ref_size = ref_size_mm if ref_size_mm and ref_size_mm > 0 else 25.0 | |
| ref_bbox_size_px = max(ww, hh) | |
| px_per_mm = ref_bbox_size_px / ref_size | |
| return px_per_mm, center, "square", (x, y, ww, hh) | |
| # ----------------------------- | |
| # Segmentation & measurements | |
| # ----------------------------- | |
| def build_mask_hsv( | |
| img_bgr: np.ndarray, | |
| sample_type: str, | |
| hsv_low_h: int, | |
| hsv_high_h: int, | |
| color_tol: int, | |
| ) -> np.ndarray: | |
| """Build binary mask using HSV thresholds.""" | |
| hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) | |
| h_channel = hsv[:, :, 0] | |
| s_channel = hsv[:, :, 1] | |
| v_channel = hsv[:, :, 2] | |
| if sample_type == "leaves": | |
| low_h = int(max(0, hsv_low_h)) | |
| high_h = int(min(179, hsv_high_h)) | |
| # H range | |
| mask_h = cv2.inRange(h_channel, low_h, high_h) | |
| # Remove very desaturated or very dark pixels | |
| mask_s = cv2.inRange(s_channel, 30, 255) | |
| mask_v = cv2.inRange(v_channel, 30, 255) | |
| mask = cv2.bitwise_and(mask_h, cv2.bitwise_and(mask_s, mask_v)) | |
| else: | |
| # seeds / grains: keep non-white pixels | |
| mask_s = cv2.inRange(s_channel, 20, 255) | |
| mask_v = cv2.inRange(v_channel, 20, 255) | |
| mask = cv2.bitwise_and(mask_s, mask_v) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| return mask | |
| def segment( | |
| img_bgr: np.ndarray, | |
| sample_type: str, | |
| hsv_low_h: int, | |
| hsv_high_h: int, | |
| color_tol: int, | |
| min_area_px: float, | |
| max_area_px: float, | |
| ) -> List[Dict[str, Any]]: | |
| """Segment objects and compute basic geometric descriptors. | |
| 采用demo.py的简化分割算法,但保留HSV参数兼容性 | |
| """ | |
| # 使用简单的前景掩码(demo.py方法) | |
| mask = build_foreground_mask(img_bgr) | |
| # 连通域分析 | |
| num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask) | |
| components: List[Dict[str, Any]] = [] | |
| # 按位置排序,跳过第一个(通常是参考物) | |
| all_objects = [] | |
| for i in range(1, num_labels): | |
| x, y, ww, hh, area = stats[i] | |
| # 面积过滤 | |
| if area < min_area_px or area > max_area_px: | |
| continue | |
| cx, cy = centroids[i] | |
| # 简单的位置评分:从左到右 | |
| score = x + y * 0.1 # 优先考虑x坐标 | |
| all_objects.append((score, i, (x, y, ww, hh), area, (int(cx), int(cy)))) | |
| if len(all_objects) == 0: | |
| return [] | |
| # 排序并跳过第一个(参考物) | |
| all_objects.sort(key=lambda obj: obj[0]) | |
| # 简单判断是否跳过第一个对象 | |
| skip_first = False | |
| if len(all_objects) > 0: | |
| _, _, (x, y, ww, hh), area, _ = all_objects[0] | |
| h, w = img_bgr.shape[:2] | |
| # 如果第一个对象在左上角且形状合理,跳过它 | |
| is_topleft = (x < w * 0.3 and y < h * 0.3) | |
| aspect_ratio = max(ww, hh) / (min(ww, hh) + 1e-6) | |
| is_reasonable_shape = aspect_ratio < 3.0 | |
| skip_first = is_topleft and is_reasonable_shape | |
| # 处理对象 | |
| start_idx = 1 if skip_first else 0 | |
| for obj_data in all_objects[start_idx:]: | |
| _, label_idx, bbox, area, center = obj_data | |
| # 提取轮廓 | |
| component_mask = (labels == label_idx).astype(np.uint8) * 255 | |
| cnts, _ = cv2.findContours(component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if len(cnts) == 0: | |
| continue | |
| cnt = cnts[0] | |
| # 计算几何特征 | |
| rect = cv2.minAreaRect(cnt) | |
| box = cv2.boxPoints(rect).astype(np.int32) | |
| peri = cv2.arcLength(cnt, True) | |
| # 修复OpenCV minAreaRect的长短轴对应问题(使用PCA) | |
| # 提取轮廓点 | |
| contour_points = cnt.reshape(-1, 2).astype(np.float32) | |
| # 计算质心 | |
| cx = np.mean(contour_points[:, 0]) | |
| cy = np.mean(contour_points[:, 1]) | |
| # 计算协方差矩阵 | |
| centered_points = contour_points - np.array([cx, cy]) | |
| cov_matrix = np.cov(centered_points.T) | |
| # 计算特征值和特征向量 | |
| eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) | |
| # 按特征值大小排序(降序) | |
| idx = np.argsort(eigenvalues)[::-1] | |
| eigenvalues = eigenvalues[idx] | |
| eigenvectors = eigenvectors[:, idx] | |
| # 主方向(最大特征值对应的特征向量) | |
| main_direction = eigenvectors[:, 0] | |
| # 投影到主方向和次方向 | |
| proj_main = np.dot(centered_points, main_direction) | |
| proj_secondary = np.dot(centered_points, eigenvectors[:, 1]) | |
| # 计算投影边界 | |
| min_main = np.min(proj_main) | |
| max_main = np.max(proj_main) | |
| min_secondary = np.min(proj_secondary) | |
| max_secondary = np.max(proj_secondary) | |
| # 计算真实的长短轴长度 | |
| length_main = max_main - min_main | |
| length_secondary = max_secondary - min_secondary | |
| # 确保长轴对应较长的方向,并保存正确的投影边界 | |
| if length_main >= length_secondary: | |
| w_obb = length_main | |
| h_obb = length_secondary | |
| angle = np.arctan2(main_direction[1], main_direction[0]) * 180.0 / np.pi | |
| # 长轴是主方向 | |
| long_direction = main_direction | |
| short_direction = eigenvectors[:, 1] | |
| min_long_proj = min_main | |
| max_long_proj = max_main | |
| min_short_proj = min_secondary | |
| max_short_proj = max_secondary | |
| else: | |
| w_obb = length_secondary | |
| h_obb = length_main | |
| secondary_direction = eigenvectors[:, 1] | |
| angle = np.arctan2(secondary_direction[1], secondary_direction[0]) * 180.0 / np.pi | |
| # 长轴是次方向 | |
| long_direction = eigenvectors[:, 1] | |
| short_direction = main_direction | |
| min_long_proj = min_secondary | |
| max_long_proj = max_secondary | |
| min_short_proj = min_main | |
| max_short_proj = max_main | |
| # 标准化角度到[0, 180) | |
| angle = ((angle % 180.0) + 180.0) % 180.0 | |
| components.append({ | |
| "contour": cnt, | |
| "rect": rect, | |
| "box": box, | |
| "area_px": float(area), | |
| "peri_px": float(peri), | |
| "center": (int(cx), int(cy)), # 使用PCA计算的质心 | |
| "pca_center": (cx, cy), # 保存精确的PCA质心 | |
| "angle": float(angle), | |
| "length_px": float(w_obb), | |
| "width_px": float(h_obb), | |
| # 保存投影边界信息用于正确的包围框绘制 | |
| "min_long_proj": float(min_long_proj), | |
| "max_long_proj": float(max_long_proj), | |
| "min_short_proj": float(min_short_proj), | |
| "max_short_proj": float(max_short_proj), | |
| }) | |
| return components | |
| def compute_color_metrics(img_bgr: np.ndarray, mask: np.ndarray) -> Tuple[float, float, float, int, int, int, float, float]: | |
| """Compute mean RGB / HSV and simple color indices in a mask region.""" | |
| mean_bgr = cv2.mean(img_bgr, mask=mask) | |
| mean_b, mean_g, mean_r = mean_bgr[0], mean_bgr[1], mean_bgr[2] | |
| rgb = np.array([[[mean_r, mean_g, mean_b]]], dtype=np.uint8) | |
| hsv = cv2.cvtColor(rgb, cv2.COLOR_RGB2HSV)[0, 0] | |
| h, s, v = int(hsv[0]), int(hsv[1]), int(hsv[2]) | |
| denom = (mean_r + mean_g + mean_b + 1e-6) | |
| green_index = (2.0 * mean_g - mean_r - mean_b) / denom | |
| brown_index = (mean_r - mean_b) / denom | |
| return mean_r, mean_g, mean_b, h, s, v, green_index, brown_index | |
| def compute_metrics( | |
| img_bgr: np.ndarray, | |
| components: List[Dict[str, Any]], | |
| px_per_mm: float, | |
| ) -> pd.DataFrame: | |
| """Compute all morphological + color metrics for each component.""" | |
| rows: List[Dict[str, Any]] = [] | |
| for i, comp in enumerate(components, start=1): | |
| # 使用新的length_px和width_px字段 | |
| length_mm = comp["length_px"] / px_per_mm | |
| width_mm = comp["width_px"] / px_per_mm | |
| area_mm2 = comp["area_px"] / (px_per_mm * px_per_mm) | |
| perimeter_mm = comp["peri_px"] / px_per_mm | |
| aspect_ratio = length_mm / (width_mm + 1e-6) | |
| # 计算圆形度 (4π*面积/周长²) | |
| circularity = (4.0 * np.pi * area_mm2) / (perimeter_mm * perimeter_mm + 1e-6) | |
| # 计算颜色指标 | |
| mask_single = np.zeros(img_bgr.shape[:2], dtype=np.uint8) | |
| cv2.drawContours(mask_single, [comp["contour"]], -1, 255, thickness=-1) | |
| mean_r, mean_g, mean_b, h, s, v, gi, bi = compute_color_metrics(img_bgr, mask_single) | |
| rows.append( | |
| { | |
| "label": f"s{i}", | |
| "centerX_px": int(comp["center"][0]), | |
| "centerY_px": int(comp["center"][1]), | |
| "length_mm": round(length_mm, 2), | |
| "width_mm": round(width_mm, 2), | |
| "area_mm2": round(area_mm2, 2), | |
| "perimeter_mm": round(perimeter_mm, 2), | |
| "aspect_ratio": round(aspect_ratio, 2), | |
| "circularity": round(circularity, 3), | |
| "angle_deg": round(float(comp["angle"]), 1), | |
| "meanR": int(round(mean_r)), | |
| "meanG": int(round(mean_g)), | |
| "meanB": int(round(mean_b)), | |
| "hue": h, | |
| "saturation": s, | |
| "value": v, | |
| "greenIndex": round(float(gi), 3), | |
| "brownIndex": round(float(bi), 3), | |
| } | |
| ) | |
| if not rows: | |
| return pd.DataFrame() | |
| return pd.DataFrame(rows) | |
| def render_overlay( | |
| img_bgr: np.ndarray, | |
| px_per_mm: float, | |
| ref: Tuple[Optional[Tuple[int, int]], Optional[str]], | |
| components: List[Dict[str, Any]], | |
| df: pd.DataFrame, | |
| ref_bbox: Optional[Tuple[int, int, int, int]] = None, | |
| ) -> np.ndarray: | |
| """Draw reference + sample annotations on the image. | |
| 采用demo.py的清晰可视化方法 | |
| """ | |
| out = img_bgr.copy() | |
| # 绘制参考物(红色矩形框) | |
| ref_center, ref_type = ref | |
| if ref_bbox is not None: | |
| x, y, w, h = ref_bbox | |
| cv2.rectangle(out, (int(x), int(y)), (int(x + w), int(y + h)), (0, 0, 255), 2) | |
| cv2.putText( | |
| out, | |
| "REF", | |
| (int(x), int(y) - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.6, | |
| (0, 0, 255), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| # 绘制样品物体(完整标注) | |
| for i, comp in enumerate(components, start=1): | |
| # 1. 绘制完整轮廓(蓝色,加粗) | |
| cv2.drawContours(out, [comp["contour"]], -1, (255, 0, 0), 3) | |
| # 2. 绘制修正后的OBB包围框 | |
| # 使用PCA计算的精确质心 | |
| cx, cy = comp["pca_center"] | |
| length_px = comp["length_px"] # 长轴长度 | |
| width_px = comp["width_px"] # 短轴长度 | |
| angle_deg = comp["angle"] # 长轴角度(度) | |
| # 转换为弧度 | |
| angle_rad = np.radians(angle_deg) | |
| # 使用实际的投影边界构建包围框 | |
| corners = [] | |
| # 获取保存的投影边界 | |
| min_long_proj = comp["min_long_proj"] | |
| max_long_proj = comp["max_long_proj"] | |
| min_short_proj = comp["min_short_proj"] | |
| max_short_proj = comp["max_short_proj"] | |
| # 获取长轴和短轴方向向量 | |
| long_dir = np.array([np.cos(angle_rad), np.sin(angle_rad)]) | |
| short_dir = np.array([-np.sin(angle_rad), np.cos(angle_rad)]) | |
| # 使用实际投影边界构建包围框的四个角点 | |
| for long_proj, short_proj in [(max_long_proj, max_short_proj), # 右上 | |
| (min_long_proj, max_short_proj), # 左上 | |
| (min_long_proj, min_short_proj), # 左下 | |
| (max_long_proj, min_short_proj)]: # 右下 | |
| # 从质心出发,沿长轴和短轴方向移动到角点 | |
| corner_point = np.array([cx, cy]) + long_proj * long_dir + short_proj * short_dir | |
| corners.append([int(corner_point[0]), int(corner_point[1])]) | |
| # 绘制OBB包围框 | |
| corners = np.array(corners, dtype=np.int32) | |
| cv2.drawContours(out, [corners], -1, (255, 0, 0), 2) | |
| # 3. 绘制长短轴(包围框的边界线) | |
| # 计算包围框各边的中点 | |
| edge_mids = [] | |
| for edge_idx in range(4): | |
| next_edge_idx = (edge_idx + 1) % 4 | |
| mid_x = (corners[edge_idx][0] + corners[next_edge_idx][0]) / 2 | |
| mid_y = (corners[edge_idx][1] + corners[next_edge_idx][1]) / 2 | |
| edge_mids.append((int(mid_x), int(mid_y))) | |
| # 计算各边的长度来确定哪条是长边 | |
| edge_lengths = [] | |
| for edge_idx in range(4): | |
| next_edge_idx = (edge_idx + 1) % 4 | |
| length = np.sqrt((corners[next_edge_idx][0] - corners[edge_idx][0])**2 + (corners[next_edge_idx][1] - corners[edge_idx][1])**2) | |
| edge_lengths.append(length) | |
| # 找到最长的边 | |
| max_edge_idx = np.argmax(edge_lengths) | |
| opposite_edge_idx = (max_edge_idx + 2) % 4 | |
| # 绘制长轴(连接最长边的中点和对边中点) | |
| long_mid1 = edge_mids[max_edge_idx] | |
| long_mid2 = edge_mids[opposite_edge_idx] | |
| cv2.line(out, long_mid1, long_mid2, (255, 0, 0), 3) | |
| # 绘制短轴(连接另外两边的中点) | |
| short_edge1_idx = (max_edge_idx + 1) % 4 | |
| short_edge2_idx = (max_edge_idx + 3) % 4 | |
| short_mid1 = edge_mids[short_edge1_idx] | |
| short_mid2 = edge_mids[short_edge2_idx] | |
| cv2.line(out, short_mid1, short_mid2, (255, 0, 0), 2) | |
| # 4. 绘制中心点和标签 | |
| # 使用PCA计算的精确质心 | |
| label_cx, label_cy = comp["pca_center"] | |
| cv2.circle(out, (int(label_cx), int(label_cy)), 15, (0, 0, 0), -1) | |
| cv2.putText( | |
| out, | |
| f"s{i}", | |
| (int(label_cx) - 10, int(label_cy) + 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (255, 255, 255), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| return cv2.cvtColor(out, cv2.COLOR_BGR2RGB) | |
| def analyze( | |
| image: Optional[np.ndarray], | |
| sample_type: str, | |
| expected_count: int, | |
| ref_mode: str, | |
| ref_size_mm: float, | |
| min_area_px: float, | |
| max_area_px: float, | |
| color_tol: int, | |
| hsv_low_h: int, | |
| hsv_high_h: int, | |
| ) -> Tuple[Optional[np.ndarray], pd.DataFrame, Optional[str], List[Dict[str, Any]], Dict[str, Any]]: | |
| """主分析函数,整合demo.py的优化算法""" | |
| try: | |
| if image is None: | |
| return None, pd.DataFrame(), None, [], {} | |
| # 转换为BGR | |
| img_rgb = np.array(image) | |
| img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) | |
| # 适度降采样 | |
| img_bgr, scale = downscale_bgr(img_bgr) | |
| # 检测参考物(左上角第一个物体) | |
| px_per_mm, ref_center, ref_type, ref_bbox = detect_reference(img_bgr, ref_mode, ref_size_mm) | |
| # 分割所有样品 | |
| comps = segment( | |
| img_bgr, | |
| sample_type=sample_type, | |
| hsv_low_h=hsv_low_h, | |
| hsv_high_h=hsv_high_h, | |
| color_tol=color_tol, | |
| min_area_px=min_area_px, | |
| max_area_px=max_area_px, | |
| ) | |
| # 根据样品类型排序 | |
| if sample_type == "leaves": | |
| comps.sort(key=lambda c: c["center"][0]) | |
| else: | |
| comps.sort(key=lambda c: c["center"][1] * 0.3 + c["center"][0] * 0.7) | |
| # 限制数量 | |
| if expected_count and expected_count > 0: | |
| comps = comps[:int(expected_count)] | |
| # 计算测量指标 | |
| df = compute_metrics(img_bgr, comps, px_per_mm) | |
| # 绘制标注图像 | |
| overlay = render_overlay( | |
| img_bgr.copy(), | |
| px_per_mm, | |
| (ref_center, ref_type), | |
| comps, | |
| df, | |
| ref_bbox | |
| ) | |
| # 保存CSV | |
| csv = df.to_csv(index=False) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
| tmp.write(csv.encode("utf-8")) | |
| tmp.close() | |
| # 转换为JSON | |
| js = df.to_dict(orient="records") | |
| # 存储状态用于交互修正 | |
| state_dict: Dict[str, Any] = { | |
| "img_bgr": img_bgr, | |
| "sample_type": sample_type, | |
| "px_per_mm": px_per_mm, | |
| "ref_center": ref_center, | |
| "ref_type": ref_type, | |
| "ref_bbox": ref_bbox, | |
| "components": comps, | |
| "expected_count": expected_count, | |
| "ref_size_mm": ref_size_mm, | |
| } | |
| # 默认所有组件都是活跃样品 | |
| state_dict["active_indices"] = list(range(len(comps))) | |
| return overlay, df, tmp.name, js, state_dict | |
| except Exception as e: | |
| return None, pd.DataFrame(), None, [{"error": str(e)}], {} | |
| # --- Interactive correction helper --- | |
| def apply_corrections( | |
| click_event, | |
| state_dict: Dict[str, Any], | |
| correction_mode: str, | |
| ) -> Tuple[Dict[str, Any], Optional[np.ndarray], pd.DataFrame, Optional[str], List[Dict[str, Any]]]: | |
| """ | |
| Apply interactive corrections based on a click on the annotated image. | |
| correction_mode: | |
| - "none": do nothing | |
| - "set-ref": treat the clicked object as the new reference | |
| - "toggle-sample": toggle the clicked object between active/inactive sample | |
| """ | |
| # If no valid state or no correction requested, do nothing | |
| if not state_dict or "img_bgr" not in state_dict or correction_mode == "none" or click_event is None: | |
| return state_dict, None, pd.DataFrame(), None, [] | |
| try: | |
| # Gradio SelectData usually provides (x, y) in .index | |
| if hasattr(click_event, "index"): | |
| x, y = click_event.index | |
| else: | |
| # Fallback: assume click_event is a tuple | |
| x, y = click_event | |
| img_bgr = state_dict["img_bgr"] | |
| components: List[Dict[str, Any]] = state_dict.get("components", []) | |
| if not components: | |
| return state_dict, None, pd.DataFrame(), None, [] | |
| # Find nearest component center to the click | |
| min_dist = 1e9 | |
| nearest_idx = -1 | |
| for i, comp in enumerate(components): | |
| cx, cy = comp["center"] | |
| d = (cx - x) ** 2 + (cy - y) ** 2 | |
| if d < min_dist: | |
| min_dist = d | |
| nearest_idx = i | |
| if nearest_idx < 0: | |
| return state_dict, None, pd.DataFrame(), None, [] | |
| px_per_mm = state_dict.get("px_per_mm", 4.0) | |
| ref_center = state_dict.get("ref_center") | |
| ref_type = state_dict.get("ref_type", "square") | |
| ref_bbox = state_dict.get("ref_bbox") | |
| ref_size_mm = state_dict.get("ref_size_mm", 20.0) | |
| sample_type = state_dict.get("sample_type", "leaves") | |
| active_indices = state_dict.get("active_indices", list(range(len(components)))) | |
| if correction_mode == "set-ref": | |
| # Use this component as the new reference object | |
| comp = components[nearest_idx] | |
| box = comp["box"] | |
| xs = box[:, 0] | |
| ys = box[:, 1] | |
| x0, y0 = int(xs.min()), int(ys.min()) | |
| w0, h0 = int(xs.max() - xs.min()), int(ys.max() - ys.min()) | |
| ref_bbox = (x0, y0, w0, h0) | |
| ref_center = (int(comp["center"][0]), int(comp["center"][1])) | |
| # Update px_per_mm using the largest side as diameter/side length | |
| side_px = float(max(w0, h0)) | |
| px_per_mm = max(side_px / (ref_size_mm if ref_size_mm > 0 else 20.0), 1e-6) | |
| ref_type = "square" | |
| # Remove this component from active samples (reference is not a sample) | |
| new_components = [] | |
| for i, c in enumerate(components): | |
| if i != nearest_idx: | |
| new_components.append(c) | |
| components = new_components | |
| # Rebuild active_indices to cover all remaining components | |
| active_indices = list(range(len(components))) | |
| state_dict["components"] = components | |
| state_dict["ref_bbox"] = ref_bbox | |
| state_dict["ref_center"] = ref_center | |
| state_dict["px_per_mm"] = px_per_mm | |
| state_dict["ref_type"] = ref_type | |
| state_dict["active_indices"] = active_indices | |
| elif correction_mode == "toggle-sample": | |
| # Toggle this component in/out of the active sample set | |
| if nearest_idx in active_indices: | |
| active_indices = [idx for idx in active_indices if idx != nearest_idx] | |
| else: | |
| active_indices.append(nearest_idx) | |
| active_indices = sorted(set(active_indices)) | |
| state_dict["active_indices"] = active_indices | |
| # Rebuild the list of active components | |
| active_components = [components[i] for i in active_indices] | |
| # Recompute metrics and overlay using the updated state | |
| df = compute_metrics(img_bgr, active_components, px_per_mm) | |
| overlay = render_overlay( | |
| img_bgr.copy(), | |
| px_per_mm, | |
| (state_dict.get("ref_center"), state_dict.get("ref_type")), | |
| active_components, | |
| df, | |
| state_dict.get("ref_bbox"), | |
| ) | |
| csv = df.to_csv(index=False) | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") | |
| tmp.write(csv.encode("utf-8")) | |
| tmp.close() | |
| js = df.to_dict(orient="records") | |
| return state_dict, overlay, df, tmp.name, js | |
| except Exception: | |
| # In case of any error, do not break the app; just keep current state | |
| return state_dict, None, pd.DataFrame(), None, [] | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| gr.Markdown("# Biological Sample Quantifier (Leaves / Seeds)") | |
| state = gr.State({}) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image = gr.Image(type="numpy", label="Upload image") | |
| sample_type = gr.Radio(["leaves", "seeds-grains"], value="leaves", label="Sample type") | |
| expected = gr.Slider(1, 500, value=5, step=1, label="Expected count") | |
| ref_mode = gr.Radio(["auto", "coin", "square"], value="auto", label="Reference mode") | |
| ref_size = gr.Slider(1, 100, value=25.0, step=0.1, label="Reference size (mm)") | |
| min_area = gr.Slider(10, 5000, value=500, step=10, label="Min area (px²)") | |
| max_area = gr.Slider(1000, 200000, value=50000, step=1000, label="Max area (px²)") | |
| color_tol = gr.Slider(5, 100, value=40, step=1, label="Color tolerance") | |
| hsv_low = gr.Slider(0, 179, value=35, step=1, label="HSV H lower (leaves)") | |
| hsv_high = gr.Slider(0, 179, value=85, step=1, label="HSV H upper (leaves)") | |
| correction_mode = gr.Radio( | |
| ["none", "set-ref", "toggle-sample"], | |
| value="none", | |
| label="Correction mode (click on image)" | |
| ) | |
| run = gr.Button("Analyze") | |
| reset = gr.Button("Reset") | |
| with gr.Column(scale=2): | |
| overlay = gr.Image(label="Annotated", interactive=True) | |
| table = gr.Dataframe(label="Metrics", wrap=True) | |
| csv_out = gr.File(label="CSV export") | |
| json_out = gr.JSON(label="JSON preview") | |
| def _analyze(image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high): | |
| overlay_img, df, csv_path, js, state_dict = analyze( | |
| image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high | |
| ) | |
| return overlay_img, df, csv_path, js, state_dict | |
| run.click( | |
| _analyze, | |
| [image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high], | |
| [overlay, table, csv_out, json_out, state], | |
| ) | |
| def _reset(): | |
| return None, pd.DataFrame(), None, [], {} | |
| reset.click(_reset, None, [overlay, table, csv_out, json_out, state]) | |
| def _on_select(evt, current_state, correction_mode): | |
| # Apply corrections based on a click on the annotated image | |
| new_state, overlay_img, df, csv_path, js = apply_corrections(evt, current_state or {}, correction_mode) | |
| # If overlay_img is None, keep the existing outputs unchanged by returning gr.update() | |
| if overlay_img is None: | |
| return gr.update(), gr.update(), gr.update(), gr.update(), new_state | |
| return overlay_img, df, csv_path, js, new_state | |
| overlay.select( | |
| _on_select, | |
| [state, correction_mode], | |
| [overlay, table, csv_out, json_out, state], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |