smiler488's picture
Update app.py
4cf2bd5 verified
import tempfile
from typing import List, Tuple, Optional, Dict, Any
import cv2
import gradio as gr
import numpy as np
import pandas as pd
# -----------------------------
# Global configuration
# -----------------------------
# Maximum allowed image side (pixels) to avoid OOM / heavy CPU usage
# Reduced from 2048 to 1024 for better performance (as in demo.py)
MAX_SIDE = 1024
# -----------------------------
# Utility functions
# -----------------------------
def downscale_bgr(img: np.ndarray) -> Tuple[np.ndarray, float]:
"""Downscale image so that the longest side is <= MAX_SIDE.
Returns
-------
img_resized : np.ndarray
Possibly downscaled BGR image.
scale : float
Applied scale factor (<= 1).
"""
h, w = img.shape[:2]
max_hw = max(h, w)
if max_hw <= MAX_SIDE:
return img, 1.0
scale = MAX_SIDE / float(max_hw)
img_resized = cv2.resize(img, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_AREA)
return img_resized, scale
def normalize_angle(angle: float, size_w: float, size_h: float) -> float:
"""Normalize OpenCV minAreaRect angle to [0, 180) degrees.
OpenCV returns angles depending on whether width < height. We fix it so that
the *long side* is treated as length and angle is always in [0, 180).
"""
a = angle
if size_w < size_h:
a += 90.0
a = ((a % 180.0) + 180.0) % 180.0
return a
# -----------------------------
# Reference object detection
# -----------------------------
def build_foreground_mask(img_bgr: np.ndarray) -> np.ndarray:
"""简单的前景掩码构建(来自demo.py)"""
h, w = img_bgr.shape[:2]
# 转换到LAB颜色空间
lab = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2LAB)
# 使用四个角落估计背景颜色
corner_size = min(h, w) // 10
corners = [
lab[:corner_size, :corner_size],
lab[:corner_size, -corner_size:],
lab[-corner_size:, :corner_size],
lab[-corner_size:, -corner_size:]
]
corner_pixels = np.vstack([c.reshape(-1, 3) for c in corners])
bg_color = np.mean(corner_pixels, axis=0)
# 计算每个像素与背景的距离
diff = lab.astype(np.float32) - bg_color
dist = np.sqrt(np.sum(diff * diff, axis=2))
# 使用Otsu阈值分割
dist_uint8 = np.clip(dist * 3, 0, 255).astype(np.uint8)
_, mask = cv2.threshold(dist_uint8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# 形态学处理
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
return mask
def detect_reference(
img_bgr: np.ndarray,
mode: str,
ref_size_mm: Optional[float],
) -> Tuple[float, Optional[Tuple[int, int]], Optional[str], Optional[Tuple[int, int, int, int]]]:
"""检测参考物:左上角第一个物体(简化版)
参数:
img_bgr: BGR图像
mode: 参考物模式 ("auto", "coin", "square")
ref_size_mm: 参考物包围框边长(毫米)
返回:
px_per_mm: 像素/毫米比例
ref_center: 参考物中心
ref_type: 参考物类型
ref_bbox: 参考物外接矩形
"""
h, w = img_bgr.shape[:2]
# 使用简单的前景掩码
mask = build_foreground_mask(img_bgr)
# 连通域分析
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask)
# 寻找左上角的参考物
candidates = []
min_area = (h * w) // 500 # 最小面积
max_area = (h * w) // 20 # 最大面积
for i in range(1, num_labels):
x, y, ww, hh, area = stats[i]
# 面积过滤
if area < min_area or area > max_area:
continue
# 位置过滤:必须在左上角区域
if x > w * 0.4 or y > h * 0.4:
continue
# 形状过滤:参考物应该接近正方形
aspect_ratio = max(ww, hh) / (min(ww, hh) + 1e-6)
if aspect_ratio > 3.0:
continue
cx, cy = centroids[i]
# 按位置排序:越靠近左上角越好
score = x + y
candidates.append((score, i, (x, y, ww, hh), area, (int(cx), int(cy))))
if not candidates:
# 如果没有找到参考物,使用安全的默认值
px_per_mm = 4.0
center = None
ref_type = None
bbox = None
return px_per_mm, center, ref_type, bbox
# 选择最左上角的候选物
candidates.sort(key=lambda c: c[0])
score, label_idx, bbox, area, center = candidates[0]
x, y, ww, hh = bbox
# 计算像素/毫米比例
ref_size = ref_size_mm if ref_size_mm and ref_size_mm > 0 else 25.0
ref_bbox_size_px = max(ww, hh)
px_per_mm = ref_bbox_size_px / ref_size
return px_per_mm, center, "square", (x, y, ww, hh)
# -----------------------------
# Segmentation & measurements
# -----------------------------
def build_mask_hsv(
img_bgr: np.ndarray,
sample_type: str,
hsv_low_h: int,
hsv_high_h: int,
color_tol: int,
) -> np.ndarray:
"""Build binary mask using HSV thresholds."""
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
h_channel = hsv[:, :, 0]
s_channel = hsv[:, :, 1]
v_channel = hsv[:, :, 2]
if sample_type == "leaves":
low_h = int(max(0, hsv_low_h))
high_h = int(min(179, hsv_high_h))
# H range
mask_h = cv2.inRange(h_channel, low_h, high_h)
# Remove very desaturated or very dark pixels
mask_s = cv2.inRange(s_channel, 30, 255)
mask_v = cv2.inRange(v_channel, 30, 255)
mask = cv2.bitwise_and(mask_h, cv2.bitwise_and(mask_s, mask_v))
else:
# seeds / grains: keep non-white pixels
mask_s = cv2.inRange(s_channel, 20, 255)
mask_v = cv2.inRange(v_channel, 20, 255)
mask = cv2.bitwise_and(mask_s, mask_v)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
return mask
def segment(
img_bgr: np.ndarray,
sample_type: str,
hsv_low_h: int,
hsv_high_h: int,
color_tol: int,
min_area_px: float,
max_area_px: float,
) -> List[Dict[str, Any]]:
"""Segment objects and compute basic geometric descriptors.
采用demo.py的简化分割算法,但保留HSV参数兼容性
"""
# 使用简单的前景掩码(demo.py方法)
mask = build_foreground_mask(img_bgr)
# 连通域分析
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask)
components: List[Dict[str, Any]] = []
# 按位置排序,跳过第一个(通常是参考物)
all_objects = []
for i in range(1, num_labels):
x, y, ww, hh, area = stats[i]
# 面积过滤
if area < min_area_px or area > max_area_px:
continue
cx, cy = centroids[i]
# 简单的位置评分:从左到右
score = x + y * 0.1 # 优先考虑x坐标
all_objects.append((score, i, (x, y, ww, hh), area, (int(cx), int(cy))))
if len(all_objects) == 0:
return []
# 排序并跳过第一个(参考物)
all_objects.sort(key=lambda obj: obj[0])
# 简单判断是否跳过第一个对象
skip_first = False
if len(all_objects) > 0:
_, _, (x, y, ww, hh), area, _ = all_objects[0]
h, w = img_bgr.shape[:2]
# 如果第一个对象在左上角且形状合理,跳过它
is_topleft = (x < w * 0.3 and y < h * 0.3)
aspect_ratio = max(ww, hh) / (min(ww, hh) + 1e-6)
is_reasonable_shape = aspect_ratio < 3.0
skip_first = is_topleft and is_reasonable_shape
# 处理对象
start_idx = 1 if skip_first else 0
for obj_data in all_objects[start_idx:]:
_, label_idx, bbox, area, center = obj_data
# 提取轮廓
component_mask = (labels == label_idx).astype(np.uint8) * 255
cnts, _ = cv2.findContours(component_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(cnts) == 0:
continue
cnt = cnts[0]
# 计算几何特征
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect).astype(np.int32)
peri = cv2.arcLength(cnt, True)
# 修复OpenCV minAreaRect的长短轴对应问题(使用PCA)
# 提取轮廓点
contour_points = cnt.reshape(-1, 2).astype(np.float32)
# 计算质心
cx = np.mean(contour_points[:, 0])
cy = np.mean(contour_points[:, 1])
# 计算协方差矩阵
centered_points = contour_points - np.array([cx, cy])
cov_matrix = np.cov(centered_points.T)
# 计算特征值和特征向量
eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix)
# 按特征值大小排序(降序)
idx = np.argsort(eigenvalues)[::-1]
eigenvalues = eigenvalues[idx]
eigenvectors = eigenvectors[:, idx]
# 主方向(最大特征值对应的特征向量)
main_direction = eigenvectors[:, 0]
# 投影到主方向和次方向
proj_main = np.dot(centered_points, main_direction)
proj_secondary = np.dot(centered_points, eigenvectors[:, 1])
# 计算投影边界
min_main = np.min(proj_main)
max_main = np.max(proj_main)
min_secondary = np.min(proj_secondary)
max_secondary = np.max(proj_secondary)
# 计算真实的长短轴长度
length_main = max_main - min_main
length_secondary = max_secondary - min_secondary
# 确保长轴对应较长的方向,并保存正确的投影边界
if length_main >= length_secondary:
w_obb = length_main
h_obb = length_secondary
angle = np.arctan2(main_direction[1], main_direction[0]) * 180.0 / np.pi
# 长轴是主方向
long_direction = main_direction
short_direction = eigenvectors[:, 1]
min_long_proj = min_main
max_long_proj = max_main
min_short_proj = min_secondary
max_short_proj = max_secondary
else:
w_obb = length_secondary
h_obb = length_main
secondary_direction = eigenvectors[:, 1]
angle = np.arctan2(secondary_direction[1], secondary_direction[0]) * 180.0 / np.pi
# 长轴是次方向
long_direction = eigenvectors[:, 1]
short_direction = main_direction
min_long_proj = min_secondary
max_long_proj = max_secondary
min_short_proj = min_main
max_short_proj = max_main
# 标准化角度到[0, 180)
angle = ((angle % 180.0) + 180.0) % 180.0
components.append({
"contour": cnt,
"rect": rect,
"box": box,
"area_px": float(area),
"peri_px": float(peri),
"center": (int(cx), int(cy)), # 使用PCA计算的质心
"pca_center": (cx, cy), # 保存精确的PCA质心
"angle": float(angle),
"length_px": float(w_obb),
"width_px": float(h_obb),
# 保存投影边界信息用于正确的包围框绘制
"min_long_proj": float(min_long_proj),
"max_long_proj": float(max_long_proj),
"min_short_proj": float(min_short_proj),
"max_short_proj": float(max_short_proj),
})
return components
def compute_color_metrics(img_bgr: np.ndarray, mask: np.ndarray) -> Tuple[float, float, float, int, int, int, float, float]:
"""Compute mean RGB / HSV and simple color indices in a mask region."""
mean_bgr = cv2.mean(img_bgr, mask=mask)
mean_b, mean_g, mean_r = mean_bgr[0], mean_bgr[1], mean_bgr[2]
rgb = np.array([[[mean_r, mean_g, mean_b]]], dtype=np.uint8)
hsv = cv2.cvtColor(rgb, cv2.COLOR_RGB2HSV)[0, 0]
h, s, v = int(hsv[0]), int(hsv[1]), int(hsv[2])
denom = (mean_r + mean_g + mean_b + 1e-6)
green_index = (2.0 * mean_g - mean_r - mean_b) / denom
brown_index = (mean_r - mean_b) / denom
return mean_r, mean_g, mean_b, h, s, v, green_index, brown_index
def compute_metrics(
img_bgr: np.ndarray,
components: List[Dict[str, Any]],
px_per_mm: float,
) -> pd.DataFrame:
"""Compute all morphological + color metrics for each component."""
rows: List[Dict[str, Any]] = []
for i, comp in enumerate(components, start=1):
# 使用新的length_px和width_px字段
length_mm = comp["length_px"] / px_per_mm
width_mm = comp["width_px"] / px_per_mm
area_mm2 = comp["area_px"] / (px_per_mm * px_per_mm)
perimeter_mm = comp["peri_px"] / px_per_mm
aspect_ratio = length_mm / (width_mm + 1e-6)
# 计算圆形度 (4π*面积/周长²)
circularity = (4.0 * np.pi * area_mm2) / (perimeter_mm * perimeter_mm + 1e-6)
# 计算颜色指标
mask_single = np.zeros(img_bgr.shape[:2], dtype=np.uint8)
cv2.drawContours(mask_single, [comp["contour"]], -1, 255, thickness=-1)
mean_r, mean_g, mean_b, h, s, v, gi, bi = compute_color_metrics(img_bgr, mask_single)
rows.append(
{
"label": f"s{i}",
"centerX_px": int(comp["center"][0]),
"centerY_px": int(comp["center"][1]),
"length_mm": round(length_mm, 2),
"width_mm": round(width_mm, 2),
"area_mm2": round(area_mm2, 2),
"perimeter_mm": round(perimeter_mm, 2),
"aspect_ratio": round(aspect_ratio, 2),
"circularity": round(circularity, 3),
"angle_deg": round(float(comp["angle"]), 1),
"meanR": int(round(mean_r)),
"meanG": int(round(mean_g)),
"meanB": int(round(mean_b)),
"hue": h,
"saturation": s,
"value": v,
"greenIndex": round(float(gi), 3),
"brownIndex": round(float(bi), 3),
}
)
if not rows:
return pd.DataFrame()
return pd.DataFrame(rows)
def render_overlay(
img_bgr: np.ndarray,
px_per_mm: float,
ref: Tuple[Optional[Tuple[int, int]], Optional[str]],
components: List[Dict[str, Any]],
df: pd.DataFrame,
ref_bbox: Optional[Tuple[int, int, int, int]] = None,
) -> np.ndarray:
"""Draw reference + sample annotations on the image.
采用demo.py的清晰可视化方法
"""
out = img_bgr.copy()
# 绘制参考物(红色矩形框)
ref_center, ref_type = ref
if ref_bbox is not None:
x, y, w, h = ref_bbox
cv2.rectangle(out, (int(x), int(y)), (int(x + w), int(y + h)), (0, 0, 255), 2)
cv2.putText(
out,
"REF",
(int(x), int(y) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 0, 255),
2,
cv2.LINE_AA,
)
# 绘制样品物体(完整标注)
for i, comp in enumerate(components, start=1):
# 1. 绘制完整轮廓(蓝色,加粗)
cv2.drawContours(out, [comp["contour"]], -1, (255, 0, 0), 3)
# 2. 绘制修正后的OBB包围框
# 使用PCA计算的精确质心
cx, cy = comp["pca_center"]
length_px = comp["length_px"] # 长轴长度
width_px = comp["width_px"] # 短轴长度
angle_deg = comp["angle"] # 长轴角度(度)
# 转换为弧度
angle_rad = np.radians(angle_deg)
# 使用实际的投影边界构建包围框
corners = []
# 获取保存的投影边界
min_long_proj = comp["min_long_proj"]
max_long_proj = comp["max_long_proj"]
min_short_proj = comp["min_short_proj"]
max_short_proj = comp["max_short_proj"]
# 获取长轴和短轴方向向量
long_dir = np.array([np.cos(angle_rad), np.sin(angle_rad)])
short_dir = np.array([-np.sin(angle_rad), np.cos(angle_rad)])
# 使用实际投影边界构建包围框的四个角点
for long_proj, short_proj in [(max_long_proj, max_short_proj), # 右上
(min_long_proj, max_short_proj), # 左上
(min_long_proj, min_short_proj), # 左下
(max_long_proj, min_short_proj)]: # 右下
# 从质心出发,沿长轴和短轴方向移动到角点
corner_point = np.array([cx, cy]) + long_proj * long_dir + short_proj * short_dir
corners.append([int(corner_point[0]), int(corner_point[1])])
# 绘制OBB包围框
corners = np.array(corners, dtype=np.int32)
cv2.drawContours(out, [corners], -1, (255, 0, 0), 2)
# 3. 绘制长短轴(包围框的边界线)
# 计算包围框各边的中点
edge_mids = []
for edge_idx in range(4):
next_edge_idx = (edge_idx + 1) % 4
mid_x = (corners[edge_idx][0] + corners[next_edge_idx][0]) / 2
mid_y = (corners[edge_idx][1] + corners[next_edge_idx][1]) / 2
edge_mids.append((int(mid_x), int(mid_y)))
# 计算各边的长度来确定哪条是长边
edge_lengths = []
for edge_idx in range(4):
next_edge_idx = (edge_idx + 1) % 4
length = np.sqrt((corners[next_edge_idx][0] - corners[edge_idx][0])**2 + (corners[next_edge_idx][1] - corners[edge_idx][1])**2)
edge_lengths.append(length)
# 找到最长的边
max_edge_idx = np.argmax(edge_lengths)
opposite_edge_idx = (max_edge_idx + 2) % 4
# 绘制长轴(连接最长边的中点和对边中点)
long_mid1 = edge_mids[max_edge_idx]
long_mid2 = edge_mids[opposite_edge_idx]
cv2.line(out, long_mid1, long_mid2, (255, 0, 0), 3)
# 绘制短轴(连接另外两边的中点)
short_edge1_idx = (max_edge_idx + 1) % 4
short_edge2_idx = (max_edge_idx + 3) % 4
short_mid1 = edge_mids[short_edge1_idx]
short_mid2 = edge_mids[short_edge2_idx]
cv2.line(out, short_mid1, short_mid2, (255, 0, 0), 2)
# 4. 绘制中心点和标签
# 使用PCA计算的精确质心
label_cx, label_cy = comp["pca_center"]
cv2.circle(out, (int(label_cx), int(label_cy)), 15, (0, 0, 0), -1)
cv2.putText(
out,
f"s{i}",
(int(label_cx) - 10, int(label_cy) + 5),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 255),
2,
cv2.LINE_AA,
)
return cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
def analyze(
image: Optional[np.ndarray],
sample_type: str,
expected_count: int,
ref_mode: str,
ref_size_mm: float,
min_area_px: float,
max_area_px: float,
color_tol: int,
hsv_low_h: int,
hsv_high_h: int,
) -> Tuple[Optional[np.ndarray], pd.DataFrame, Optional[str], List[Dict[str, Any]], Dict[str, Any]]:
"""主分析函数,整合demo.py的优化算法"""
try:
if image is None:
return None, pd.DataFrame(), None, [], {}
# 转换为BGR
img_rgb = np.array(image)
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
# 适度降采样
img_bgr, scale = downscale_bgr(img_bgr)
# 检测参考物(左上角第一个物体)
px_per_mm, ref_center, ref_type, ref_bbox = detect_reference(img_bgr, ref_mode, ref_size_mm)
# 分割所有样品
comps = segment(
img_bgr,
sample_type=sample_type,
hsv_low_h=hsv_low_h,
hsv_high_h=hsv_high_h,
color_tol=color_tol,
min_area_px=min_area_px,
max_area_px=max_area_px,
)
# 根据样品类型排序
if sample_type == "leaves":
comps.sort(key=lambda c: c["center"][0])
else:
comps.sort(key=lambda c: c["center"][1] * 0.3 + c["center"][0] * 0.7)
# 限制数量
if expected_count and expected_count > 0:
comps = comps[:int(expected_count)]
# 计算测量指标
df = compute_metrics(img_bgr, comps, px_per_mm)
# 绘制标注图像
overlay = render_overlay(
img_bgr.copy(),
px_per_mm,
(ref_center, ref_type),
comps,
df,
ref_bbox
)
# 保存CSV
csv = df.to_csv(index=False)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
tmp.write(csv.encode("utf-8"))
tmp.close()
# 转换为JSON
js = df.to_dict(orient="records")
# 存储状态用于交互修正
state_dict: Dict[str, Any] = {
"img_bgr": img_bgr,
"sample_type": sample_type,
"px_per_mm": px_per_mm,
"ref_center": ref_center,
"ref_type": ref_type,
"ref_bbox": ref_bbox,
"components": comps,
"expected_count": expected_count,
"ref_size_mm": ref_size_mm,
}
# 默认所有组件都是活跃样品
state_dict["active_indices"] = list(range(len(comps)))
return overlay, df, tmp.name, js, state_dict
except Exception as e:
return None, pd.DataFrame(), None, [{"error": str(e)}], {}
# --- Interactive correction helper ---
def apply_corrections(
click_event,
state_dict: Dict[str, Any],
correction_mode: str,
) -> Tuple[Dict[str, Any], Optional[np.ndarray], pd.DataFrame, Optional[str], List[Dict[str, Any]]]:
"""
Apply interactive corrections based on a click on the annotated image.
correction_mode:
- "none": do nothing
- "set-ref": treat the clicked object as the new reference
- "toggle-sample": toggle the clicked object between active/inactive sample
"""
# If no valid state or no correction requested, do nothing
if not state_dict or "img_bgr" not in state_dict or correction_mode == "none" or click_event is None:
return state_dict, None, pd.DataFrame(), None, []
try:
# Gradio SelectData usually provides (x, y) in .index
if hasattr(click_event, "index"):
x, y = click_event.index
else:
# Fallback: assume click_event is a tuple
x, y = click_event
img_bgr = state_dict["img_bgr"]
components: List[Dict[str, Any]] = state_dict.get("components", [])
if not components:
return state_dict, None, pd.DataFrame(), None, []
# Find nearest component center to the click
min_dist = 1e9
nearest_idx = -1
for i, comp in enumerate(components):
cx, cy = comp["center"]
d = (cx - x) ** 2 + (cy - y) ** 2
if d < min_dist:
min_dist = d
nearest_idx = i
if nearest_idx < 0:
return state_dict, None, pd.DataFrame(), None, []
px_per_mm = state_dict.get("px_per_mm", 4.0)
ref_center = state_dict.get("ref_center")
ref_type = state_dict.get("ref_type", "square")
ref_bbox = state_dict.get("ref_bbox")
ref_size_mm = state_dict.get("ref_size_mm", 20.0)
sample_type = state_dict.get("sample_type", "leaves")
active_indices = state_dict.get("active_indices", list(range(len(components))))
if correction_mode == "set-ref":
# Use this component as the new reference object
comp = components[nearest_idx]
box = comp["box"]
xs = box[:, 0]
ys = box[:, 1]
x0, y0 = int(xs.min()), int(ys.min())
w0, h0 = int(xs.max() - xs.min()), int(ys.max() - ys.min())
ref_bbox = (x0, y0, w0, h0)
ref_center = (int(comp["center"][0]), int(comp["center"][1]))
# Update px_per_mm using the largest side as diameter/side length
side_px = float(max(w0, h0))
px_per_mm = max(side_px / (ref_size_mm if ref_size_mm > 0 else 20.0), 1e-6)
ref_type = "square"
# Remove this component from active samples (reference is not a sample)
new_components = []
for i, c in enumerate(components):
if i != nearest_idx:
new_components.append(c)
components = new_components
# Rebuild active_indices to cover all remaining components
active_indices = list(range(len(components)))
state_dict["components"] = components
state_dict["ref_bbox"] = ref_bbox
state_dict["ref_center"] = ref_center
state_dict["px_per_mm"] = px_per_mm
state_dict["ref_type"] = ref_type
state_dict["active_indices"] = active_indices
elif correction_mode == "toggle-sample":
# Toggle this component in/out of the active sample set
if nearest_idx in active_indices:
active_indices = [idx for idx in active_indices if idx != nearest_idx]
else:
active_indices.append(nearest_idx)
active_indices = sorted(set(active_indices))
state_dict["active_indices"] = active_indices
# Rebuild the list of active components
active_components = [components[i] for i in active_indices]
# Recompute metrics and overlay using the updated state
df = compute_metrics(img_bgr, active_components, px_per_mm)
overlay = render_overlay(
img_bgr.copy(),
px_per_mm,
(state_dict.get("ref_center"), state_dict.get("ref_type")),
active_components,
df,
state_dict.get("ref_bbox"),
)
csv = df.to_csv(index=False)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".csv")
tmp.write(csv.encode("utf-8"))
tmp.close()
js = df.to_dict(orient="records")
return state_dict, overlay, df, tmp.name, js
except Exception:
# In case of any error, do not break the app; just keep current state
return state_dict, None, pd.DataFrame(), None, []
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# Biological Sample Quantifier (Leaves / Seeds)")
state = gr.State({})
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="numpy", label="Upload image")
sample_type = gr.Radio(["leaves", "seeds-grains"], value="leaves", label="Sample type")
expected = gr.Slider(1, 500, value=5, step=1, label="Expected count")
ref_mode = gr.Radio(["auto", "coin", "square"], value="auto", label="Reference mode")
ref_size = gr.Slider(1, 100, value=25.0, step=0.1, label="Reference size (mm)")
min_area = gr.Slider(10, 5000, value=500, step=10, label="Min area (px²)")
max_area = gr.Slider(1000, 200000, value=50000, step=1000, label="Max area (px²)")
color_tol = gr.Slider(5, 100, value=40, step=1, label="Color tolerance")
hsv_low = gr.Slider(0, 179, value=35, step=1, label="HSV H lower (leaves)")
hsv_high = gr.Slider(0, 179, value=85, step=1, label="HSV H upper (leaves)")
correction_mode = gr.Radio(
["none", "set-ref", "toggle-sample"],
value="none",
label="Correction mode (click on image)"
)
run = gr.Button("Analyze")
reset = gr.Button("Reset")
with gr.Column(scale=2):
overlay = gr.Image(label="Annotated", interactive=True)
table = gr.Dataframe(label="Metrics", wrap=True)
csv_out = gr.File(label="CSV export")
json_out = gr.JSON(label="JSON preview")
def _analyze(image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high):
overlay_img, df, csv_path, js, state_dict = analyze(
image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high
)
return overlay_img, df, csv_path, js, state_dict
run.click(
_analyze,
[image, sample_type, expected, ref_mode, ref_size, min_area, max_area, color_tol, hsv_low, hsv_high],
[overlay, table, csv_out, json_out, state],
)
def _reset():
return None, pd.DataFrame(), None, [], {}
reset.click(_reset, None, [overlay, table, csv_out, json_out, state])
def _on_select(evt, current_state, correction_mode):
# Apply corrections based on a click on the annotated image
new_state, overlay_img, df, csv_path, js = apply_corrections(evt, current_state or {}, correction_mode)
# If overlay_img is None, keep the existing outputs unchanged by returning gr.update()
if overlay_img is None:
return gr.update(), gr.update(), gr.update(), gr.update(), new_state
return overlay_img, df, csv_path, js, new_state
overlay.select(
_on_select,
[state, correction_mode],
[overlay, table, csv_out, json_out, state],
)
if __name__ == "__main__":
demo.launch()