import tempfile from typing import List, Tuple, Optional, Dict, Any import cv2 import gradio as gr import numpy as np import pandas as pd # ----------------------------- # Global configuration # ----------------------------- # Maximum allowed image side (pixels) to avoid OOM / heavy CPU usage # Reduced from 2048 to 1024 for better performance (as in demo.py) MAX_SIDE = 1024 # ----------------------------- # Utility functions # ----------------------------- def downscale_bgr(img: np.ndarray) -> Tuple[np.ndarray, float]: """Downscale image so that the longest side is <= MAX_SIDE. Returns ------- img_resized : np.ndarray Possibly downscaled BGR image. scale : float Applied scale factor (<= 1). """ h, w = img.shape[:2] max_hw = max(h, w) if max_hw <= MAX_SIDE: return img, 1.0 scale = MAX_SIDE / float(max_hw) img_resized = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA) return img_resized, scale def normalize_angle(angle: float, size_w: float, size_h: float) -> float: """Normalize OpenCV minAreaRect angle to [0, 180) degrees. OpenCV returns angles depending on whether width < height. We fix it so that the *long side* is treated as length and angle is always in [0, 180). """ a = angle if size_w < size_h: a += 90.0 a = ((a % 180.0) + 180.0) % 180.0 return a # ----------------------------- # Reference object detection # ----------------------------- def build_foreground_mask(img_bgr: np.ndarray) -> np.ndarray: """简单的前景掩码构建(来自demo.py)""" h, w = img_bgr.shape[:2] # 转换到LAB颜色空间 lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB) # 使用四个角落估计背景颜色 corner_size = min(h, w) // 10 corners = [ lab[:corner_size, :corner_size], lab[:corner_size, -corner_size:], lab[-corner_size:, :corner_size], lab[-corner_size:, -corner_size:] ] corner_pixels = np.vstack([c.reshape(-1, 3) for c in corners]) bg_color = np.mean(corner_pixels, axis=0) # 计算每个像素与背景的距离 diff = lab.astype(np.float32) - bg_color dist = np.sqrt(np.sum(diff * diff, axis=2)) # 使用Otsu阈值分割 dist_uint8 = np.clip(dist * 3, 0, 255).astype(np.uint8) _, mask = cv2.threshold(dist_uint8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) # 形态学处理 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) return mask def detect_reference( img_bgr: np.ndarray, mode: str, ref_size_mm: Optional[float], ) -> Tuple[float, Optional[Tuple[int, int]], Optional[str], Optional[Tuple[int, int, int, int]]]: """检测参考物:左上角第一个物体(简化版) 参数: img_bgr: BGR图像 mode: 参考物模式 ("auto", "coin", "square") ref_size_mm: 参考物包围框边长(毫米) 返回: px_per_mm: 像素/毫米比例 ref_center: 参考物中心 ref_type: 参考物类型 ref_bbox: 参考物外接矩形 """ h, w = img_bgr.shape[:2] # 使用简单的前景掩码 mask = build_foreground_mask(img_bgr) # 连通域分析 num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask) # 寻找左上角的参考物 candidates = [] min_area = (h * w) // 500 # 最小面积 max_area = (h * w) // 20 # 最大面积 for i in range(1, num_labels): x, y, ww, hh, area = stats[i] # 面积过滤 if area < min_area or area > max_area: continue # 位置过滤:必须在左上角区域 if x > w * 0.4 or y > h * 0.4: continue # 形状过滤:参考物应该接近正方形 aspect_ratio = max(ww, hh) / (min(ww, hh) + 1e-6) if aspect_ratio > 3.0: continue cx, cy = centroids[i] # 按位置排序:越靠近左上角越好 score = x + y candidates.append((score, i, (x, y, ww, hh), area, (int(cx), int(cy)))) if not candidates: # 如果没有找到参考物,使用安全的默认值 px_per_mm = 4.0 center = None ref_type = None bbox = None return px_per_mm, center, ref_type, bbox # 选择最左上角的候选物 candidates.sort(key=lambda c: c[0]) score, label_idx, bbox, area, center = candidates[0] x, y, ww, hh = bbox # 计算像素/毫米比例 ref_size = ref_size_mm if ref_size_mm and ref_size_mm > 0 else 25.0 ref_bbox_size_px = max(ww, hh) px_per_mm = ref_bbox_size_px / ref_size return px_per_mm, center, "square", (x, y, ww, hh) # ----------------------------- # Segmentation & measurements # ----------------------------- def build_mask_hsv( img_bgr: np.ndarray, sample_type: str, hsv_low_h: int, hsv_high_h: int, color_tol: int, ) -> np.ndarray: """Build binary mask using HSV thresholds.""" hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV) h_channel = hsv[:, :, 0] s_channel = hsv[:, :, 1] v_channel = hsv[:, :, 2] if sample_type == "leaves": low_h = int(max(0, hsv_low_h)) high_h = int(min(179, hsv_high_h)) # H range mask_h = cv2.inRange(h_channel, low_h, high_h) # Remove very desaturated or very dark pixels mask_s = cv2.inRange(s_channel, 30, 255) mask_v = cv2.inRange(v_channel, 30, 255) mask = cv2.bitwise_and(mask_h, cv2.bitwise_and(mask_s, mask_v)) else: # seeds / grains: keep non-white pixels mask_s = cv2.inRange(s_channel, 20, 255) mask_v = cv2.inRange(v_channel, 20, 255) mask = cv2.bitwise_and(mask_s, mask_v) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) return mask def segment( img_bgr: np.ndarray, sample_type: str, hsv_low_h: int, hsv_high_h: int, color_tol: int, min_area_px: float, max_area_px: float, ) -> List[Dict[str, Any]]: """Segment objects and compute basic geometric descriptors. 采用demo.py的简化分割算法,但保留HSV参数兼容性 """ # 使用简单的前景掩码(demo.py方法) mask = build_foreground_mask(img_bgr) # 连通域分析 num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask) components: List[Dict[str, Any]] = [] # 按位置排序,跳过第一个(通常是参考物) all_objects = [] for i in range(1, num_labels): x, y, ww, hh, area = stats[i] # 面积过滤 if area < min_area_px or area > max_area_px: continue cx, cy = centroids[i] # 简单的位置评分:从左到右 score = x + y * 0.1 # 优先考虑x坐标 all_objects.append((score, i, (x, y, ww, hh), area, (int(cx), int(cy)))) if len(all_objects) == 0: return [] # 排序并跳过第一个(参考物) all_objects.sort(key=lambda obj: obj[0]) # 简单判断是否跳过第一个对象 skip_first = False if len(all_objects) > 0: _, _, (x, y, ww, hh), area, _ = all_objects[0] h, w = img_bgr.shape[:2] # 如果第一个对象在左上角且形状合理,跳过它 is_topleft = (x < w * 0.3 and y < h * 0.3) aspect_ratio = max(ww, hh) / (min(ww, hh) + 1e-6) is_reasonable_shape = aspect_ratio < 3.0 skip_first = is_topleft and is_reasonable_shape # 处理对象 start_idx = 1 if skip_first else 0 for obj_data in all_objects[start_idx:]: _, label_idx, bbox, area, center = obj_data # 提取轮廓 component_mask = (labels == label_idx).astype(np.uint8) * 255 cnts, _ = cv2.findContours(component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if len(cnts) == 0: continue cnt = cnts[0] # 计算几何特征 rect = cv2.minAreaRect(cnt) box = cv2.boxPoints(rect).astype(np.int32) peri = cv2.arcLength(cnt, True) # 修复OpenCV minAreaRect的长短轴对应问题(使用PCA) # 提取轮廓点 contour_points = cnt.reshape(-1, 2).astype(np.float32) # 计算质心 cx = np.mean(contour_points[:, 0]) cy = np.mean(contour_points[:, 1]) # 计算协方差矩阵 centered_points = contour_points - np.array([cx, cy]) cov_matrix = np.cov(centered_points.T) # 计算特征值和特征向量 eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) # 按特征值大小排序(降序) idx = np.argsort(eigenvalues)[::-1] eigenvalues = eigenvalues[idx] eigenvectors = eigenvectors[:, idx] # 主方向(最大特征值对应的特征向量) main_direction = eigenvectors[:, 0] # 投影到主方向和次方向 proj_main = np.dot(centered_points, main_direction) proj_secondary = np.dot(centered_points, eigenvectors[:, 1]) # 计算投影边界 min_main = np.min(proj_main) max_main = np.max(proj_main) min_secondary = np.min(proj_secondary) max_secondary = np.max(proj_secondary) # 计算真实的长短轴长度 length_main = max_main - min_main length_secondary = max_secondary - min_secondary # 确保长轴对应较长的方向,并保存正确的投影边界 if length_main >= length_secondary: w_obb = length_main h_obb = length_secondary angle = np.arctan2(main_direction[1], main_direction[0]) * 180.0 / np.pi # 长轴是主方向 long_direction = main_direction short_direction = eigenvectors[:, 1] min_long_proj = min_main max_long_proj = max_main min_short_proj = min_secondary max_short_proj = max_secondary else: w_obb = length_secondary h_obb = length_main secondary_direction = eigenvectors[:, 1] angle = np.arctan2(secondary_direction[1], secondary_direction[0]) * 180.0 / np.pi # 长轴是次方向 long_direction = eigenvectors[:, 1] short_direction = main_direction min_long_proj = min_secondary max_long_proj = max_secondary min_short_proj = min_main max_short_proj = max_main # 标准化角度到[0, 180) angle = ((angle % 180.0) + 180.0) % 180.0 components.append({ "contour": cnt, "rect": rect, "box": box, "area_px": float(area), "peri_px": float(peri), "center": (int(cx), int(cy)), # 使用PCA计算的质心 "pca_center": (cx, cy), # 保存精确的PCA质心 "angle": float(angle), "length_px": float(w_obb), "width_px": float(h_obb), # 保存投影边界信息用于正确的包围框绘制 "min_long_proj": float(min_long_proj), "max_long_proj": float(max_long_proj), "min_short_proj": float(min_short_proj), "max_short_proj": float(max_short_proj), }) return components def compute_color_metrics(img_bgr: np.ndarray, mask: np.ndarray) -> Tuple[float, float, float, int, int, int, float, float]: """Compute mean RGB / HSV and simple color indices in a mask region.""" mean_bgr = cv2.mean(img_bgr, mask=mask) mean_b, mean_g, mean_r = mean_bgr[0], mean_bgr[1], mean_bgr[2] rgb = np.array([[[mean_r, mean_g, mean_b]]], dtype=np.uint8) hsv = cv2.cvtColor(rgb, cv2.COLOR_RGB2HSV)[0, 0] h, s, v = int(hsv[0]), int(hsv[1]), int(hsv[2]) denom = (mean_r + mean_g + mean_b + 1e-6) green_index = (2.0 * mean_g - mean_r - mean_b) / denom brown_index = (mean_r - mean_b) / denom return mean_r, mean_g, mean_b, h, s, v, green_index, brown_index def compute_metrics( img_bgr: np.ndarray, components: List[Dict[str, Any]], px_per_mm: float, ) -> pd.DataFrame: """Compute all morphological + color metrics for each component.""" rows: List[Dict[str, Any]] = [] for i, comp in enumerate(components, start=1): # 使用新的length_px和width_px字段 length_mm = comp["length_px"] / px_per_mm width_mm = comp["width_px"] / px_per_mm area_mm2 = comp["area_px"] / (px_per_mm * px_per_mm) perimeter_mm = comp["peri_px"] / px_per_mm aspect_ratio = length_mm / (width_mm + 1e-6) # 计算圆形度 (4π*面积/周长²) circularity = (4.0 * np.pi * area_mm2) / (perimeter_mm * perimeter_mm + 1e-6) # 计算颜色指标 mask_single = np.zeros(img_bgr.shape[:2], dtype=np.uint8) cv2.drawContours(mask_single, [comp["contour"]], -1, 255, thickness=-1) mean_r, mean_g, mean_b, h, s, v, gi, bi = compute_color_metrics(img_bgr, mask_single) rows.append( { "label": f"s{i}", "centerX_px": int(comp["center"][0]), "centerY_px": int(comp["center"][1]), "length_mm": round(length_mm, 2), "width_mm": round(width_mm, 2), "area_mm2": round(area_mm2, 2), "perimeter_mm": round(perimeter_mm, 2), "aspect_ratio": round(aspect_ratio, 2), "circularity": round(circularity, 3), "angle_deg": round(float(comp["angle"]), 1), "meanR": int(round(mean_r)), "meanG": int(round(mean_g)), "meanB": int(round(mean_b)), "hue": h, "saturation": s, "value": v, "greenIndex": round(float(gi), 3), "brownIndex": round(float(bi), 3), } ) if not rows: return pd.DataFrame() return pd.DataFrame(rows) def render_overlay( img_bgr: np.ndarray, px_per_mm: float, ref: Tuple[Optional[Tuple[int, int]], Optional[str]], components: List[Dict[str, Any]], df: pd.DataFrame, ref_bbox: Optional[Tuple[int, int, int, int]] = None, ) -> np.ndarray: """Draw reference + sample annotations on the image. 采用demo.py的清晰可视化方法 """ out = img_bgr.copy() # 绘制参考物(红色矩形框) ref_center, ref_type = ref if ref_bbox is not None: x, y, w, h = ref_bbox cv2.rectangle(out, (int(x), int(y)), (int(x + w), int(y + h)), (0, 0, 255), 2) cv2.putText( out, "REF", (int(x), int(y) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2, cv2.LINE_AA, ) # 绘制样品物体(完整标注) for i, comp in enumerate(components, start=1): # 1. 绘制完整轮廓(蓝色,加粗) cv2.drawContours(out, [comp["contour"]], -1, (255, 0, 0), 3) # 2. 绘制修正后的OBB包围框 # 使用PCA计算的精确质心 cx, cy = comp["pca_center"] length_px = comp["length_px"] # 长轴长度 width_px = comp["width_px"] # 短轴长度 angle_deg = comp["angle"] # 长轴角度(度) # 转换为弧度 angle_rad = np.radians(angle_deg) # 使用实际的投影边界构建包围框 corners = [] # 获取保存的投影边界 min_long_proj = comp["min_long_proj"] max_long_proj = comp["max_long_proj"] min_short_proj = comp["min_short_proj"] max_short_proj = comp["max_short_proj"] # 获取长轴和短轴方向向量 long_dir = np.array([np.cos(angle_rad), np.sin(angle_rad)]) short_dir = np.array([-np.sin(angle_rad), np.cos(angle_rad)]) # 使用实际投影边界构建包围框的四个角点 for long_proj, short_proj in [(max_long_proj, max_short_proj), # 右上 (min_long_proj, max_short_proj), # 左上 (min_long_proj, min_short_proj), # 左下 (max_long_proj, min_short_proj)]: # 右下 # 从质心出发,沿长轴和短轴方向移动到角点 corner_point = np.array([cx, cy]) + long_proj * long_dir + short_proj * short_dir corners.append([int(corner_point[0]), int(corner_point[1])]) # 绘制OBB包围框 corners = np.array(corners, dtype=np.int32) cv2.drawContours(out, [corners], -1, (255, 0, 0), 2) # 3. 绘制长短轴(包围框的边界线) # 计算包围框各边的中点 edge_mids = [] for edge_idx in range(4): next_edge_idx = (edge_idx + 1) % 4 mid_x = (corners[edge_idx][0] + corners[next_edge_idx][0]) / 2 mid_y = (corners[edge_idx][1] + corners[next_edge_idx][1]) / 2 edge_mids.append((int(mid_x), int(mid_y))) # 计算各边的长度来确定哪条是长边 edge_lengths = [] for edge_idx in range(4): next_edge_idx = (edge_idx + 1) % 4 length = np.sqrt((corners[next_edge_idx][0] - corners[edge_idx][0])**2 + (corners[next_edge_idx][1] - corners[edge_idx][1])**2) edge_lengths.append(length) # 找到最长的边 max_edge_idx = np.argmax(edge_lengths) opposite_edge_idx = (max_edge_idx + 2) % 4 # 绘制长轴(连接最长边的中点和对边中点) long_mid1 = edge_mids[max_edge_idx] long_mid2 = edge_mids[opposite_edge_idx] cv2.line(out, long_mid1, long_mid2, (255, 0, 0), 3) # 绘制短轴(连接另外两边的中点) short_edge1_idx = (max_edge_idx + 1) % 4 short_edge2_idx = (max_edge_idx + 3) % 4 short_mid1 = edge_mids[short_edge1_idx] short_mid2 = edge_mids[short_edge2_idx] cv2.line(out, short_mid1, short_mid2, (255, 0, 0), 2) # 4. 绘制中心点和标签 # 使用PCA计算的精确质心 label_cx, label_cy = comp["pca_center"] cv2.circle(out, (int(label_cx), int(label_cy)), 15, (0, 0, 0), -1) cv2.putText( out, f"s{i}", (int(label_cx) - 10, int(label_cy) + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA, ) return cv2.cvtColor(out, cv2.COLOR_BGR2RGB) def analyze( image: Optional[np.ndarray], sample_type: str, expected_count: int, ref_mode: str, ref_size_mm: float, min_area_px: float, max_area_px: float, color_tol: int, hsv_low_h: int, hsv_high_h: int, ) -> Tuple[Optional[np.ndarray], pd.DataFrame, Optional[str], List[Dict[str, Any]], Dict[str, Any]]: """主分析函数,整合demo.py的优化算法""" try: if image is None: return None, pd.DataFrame(), None, [], {} # 转换为BGR img_rgb = np.array(image) img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) # 适度降采样 img_bgr, scale = downscale_bgr(img_bgr) # 检测参考物(左上角第一个物体) px_per_mm, ref_center, ref_type, ref_bbox = detect_reference(img_bgr, ref_mode, ref_size_mm) # 分割所有样品 comps = segment( img_bgr, sample_type=sample_type, hsv_low_h=hsv_low_h, hsv_high_h=hsv_high_h, color_tol=color_tol, min_area_px=min_area_px, max_area_px=max_area_px, ) # 根据样品类型排序 if sample_type == "leaves": comps.sort(key=lambda c: c["center"][0]) else: comps.sort(key=lambda c: c["center"][1] * 0.3 + c["center"][0] * 0.7) # 限制数量 if expected_count and expected_count > 0: comps = comps[:int(expected_count)] # 计算测量指标 df = compute_metrics(img_bgr, comps, px_per_mm) # 绘制标注图像 overlay = render_overlay( img_bgr.copy(), px_per_mm, (ref_center, ref_type), comps, df, ref_bbox ) # 保存CSV csv = df.to_csv(index=False) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") tmp.write(csv.encode("utf-8")) tmp.close() # 转换为JSON js = df.to_dict(orient="records") # 存储状态用于交互修正 state_dict: Dict[str, Any] = { "img_bgr": img_bgr, "sample_type": sample_type, "px_per_mm": px_per_mm, "ref_center": ref_center, "ref_type": ref_type, "ref_bbox": ref_bbox, "components": comps, "expected_count": expected_count, "ref_size_mm": ref_size_mm, } # 默认所有组件都是活跃样品 state_dict["active_indices"] = list(range(len(comps))) return overlay, df, tmp.name, js, state_dict except Exception as e: return None, pd.DataFrame(), None, [{"error": str(e)}], {} # --- Interactive correction helper --- def apply_corrections( click_event, state_dict: Dict[str, Any], correction_mode: str, ) -> Tuple[Dict[str, Any], Optional[np.ndarray], pd.DataFrame, Optional[str], List[Dict[str, Any]]]: """ Apply interactive corrections based on a click on the annotated image. correction_mode: - "none": do nothing - "set-ref": treat the clicked object as the new reference - "toggle-sample": toggle the clicked object between active/inactive sample """ # If no valid state or no correction requested, do nothing if not state_dict or "img_bgr" not in state_dict or correction_mode == "none" or click_event is None: return state_dict, None, pd.DataFrame(), None, [] try: # Gradio SelectData usually provides (x, y) in .index if hasattr(click_event, "index"): x, y = click_event.index else: # Fallback: assume click_event is a tuple x, y = click_event img_bgr = state_dict["img_bgr"] components: List[Dict[str, Any]] = state_dict.get("components", []) if not components: return state_dict, None, pd.DataFrame(), None, [] # Find nearest component center to the click min_dist = 1e9 nearest_idx = -1 for i, comp in enumerate(components): cx, cy = comp["center"] d = (cx - x) ** 2 + (cy - y) ** 2 if d < min_dist: min_dist = d nearest_idx = i if nearest_idx < 0: return state_dict, None, pd.DataFrame(), None, [] px_per_mm = state_dict.get("px_per_mm", 4.0) ref_center = state_dict.get("ref_center") ref_type = state_dict.get("ref_type", "square") ref_bbox = state_dict.get("ref_bbox") ref_size_mm = state_dict.get("ref_size_mm", 20.0) sample_type = state_dict.get("sample_type", "leaves") active_indices = state_dict.get("active_indices", list(range(len(components)))) if correction_mode == "set-ref": # Use this component as the new reference object comp = components[nearest_idx] box = comp["box"] xs = box[:, 0] ys = box[:, 1] x0, y0 = int(xs.min()), int(ys.min()) w0, h0 = int(xs.max() - xs.min()), int(ys.max() - ys.min()) ref_bbox = (x0, y0, w0, h0) ref_center = (int(comp["center"][0]), int(comp["center"][1])) # Update px_per_mm using the largest side as diameter/side length side_px = float(max(w0, h0)) px_per_mm = max(side_px / (ref_size_mm if ref_size_mm > 0 else 20.0), 1e-6) ref_type = "square" # Remove this component from active samples (reference is not a sample) new_components = [] for i, c in enumerate(components): if i != nearest_idx: new_components.append(c) components = new_components # Rebuild active_indices to cover all remaining components active_indices = list(range(len(components))) state_dict["components"] = components state_dict["ref_bbox"] = ref_bbox state_dict["ref_center"] = ref_center state_dict["px_per_mm"] = px_per_mm state_dict["ref_type"] = ref_type state_dict["active_indices"] = active_indices elif correction_mode == "toggle-sample": # Toggle this component in/out of the active sample set if nearest_idx in active_indices: active_indices = [idx for idx in active_indices if idx != nearest_idx] else: active_indices.append(nearest_idx) active_indices = sorted(set(active_indices)) state_dict["active_indices"] = active_indices # Rebuild the list of active components active_components = [components[i] for i in active_indices] # Recompute metrics and overlay using the updated state df = compute_metrics(img_bgr, active_components, px_per_mm) overlay = render_overlay( img_bgr.copy(), px_per_mm, (state_dict.get("ref_center"), state_dict.get("ref_type")), active_components, df, state_dict.get("ref_bbox"), ) csv = df.to_csv(index=False) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv") tmp.write(csv.encode("utf-8")) tmp.close() js = df.to_dict(orient="records") return state_dict, overlay, df, tmp.name, js except Exception: # In case of any error, do not break the app; just keep current state return state_dict, None, pd.DataFrame(), None, [] with gr.Blocks(theme=gr.themes.Default()) as demo: gr.Markdown("# Biological Sample Quantifier (Leaves / Seeds)") state = gr.State({}) with gr.Row(): with gr.Column(scale=1): image = gr.Image(type="numpy", label="Upload image") sample_type = gr.Radio(["leaves", "seeds-grains"], value="leaves", label="Sample type") expected = gr.Slider(1, 500, value=5, step=1, label="Expected count") ref_mode = gr.Radio(["auto", "coin", "square"], value="auto", label="Reference mode") ref_size = gr.Slider(1, 100, value=25.0, step=0.1, label="Reference size (mm)") min_area = gr.Slider(10, 5000, value=500, step=10, label="Min area (px²)") max_area = gr.Slider(1000, 200000, value=50000, step=1000, label="Max area (px²)") color_tol = gr.Slider(5, 100, value=40, step=1, label="Color tolerance") hsv_low = gr.Slider(0, 179, value=35, step=1, label="HSV H lower (leaves)") hsv_high = gr.Slider(0, 179, value=85, step=1, label="HSV H upper (leaves)") correction_mode = gr.Radio( ["none", "set-ref", "toggle-sample"], value="none", label="Correction mode (click on image)" ) run = gr.Button("Analyze") reset = gr.Button("Reset") with gr.Column(scale=2): overlay = gr.Image(label="Annotated", interactive=True) table = gr.Dataframe(label="Metrics", wrap=True) csv_out = gr.File(label="CSV export") json_out = gr.JSON(label="JSON preview") def _analyze(image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high): overlay_img, df, csv_path, js, state_dict = analyze( image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high ) return overlay_img, df, csv_path, js, state_dict run.click( _analyze, [image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high], [overlay, table, csv_out, json_out, state], ) def _reset(): return None, pd.DataFrame(), None, [], {} reset.click(_reset, None, [overlay, table, csv_out, json_out, state]) def _on_select(evt, current_state, correction_mode): # Apply corrections based on a click on the annotated image new_state, overlay_img, df, csv_path, js = apply_corrections(evt, current_state or {}, correction_mode) # If overlay_img is None, keep the existing outputs unchanged by returning gr.update() if overlay_img is None: return gr.update(), gr.update(), gr.update(), gr.update(), new_state return overlay_img, df, csv_path, js, new_state overlay.select( _on_select, [state, correction_mode], [overlay, table, csv_out, json_out, state], ) if __name__ == "__main__": demo.launch()