Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from gradio_bbox_annotator import BBoxAnnotator | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| import os | |
| import shutil | |
| import time | |
| import json | |
| import uuid | |
| from pathlib import Path | |
| import tempfile | |
| import zipfile | |
| from skimage import measure | |
| from matplotlib import cm | |
| # ===== 导入三个推理模块 ===== | |
| from inference_seg import load_model as load_seg_model, run as run_seg | |
| from inference_count import load_model as load_count_model, run as run_count | |
| from inference_track import load_model as load_track_model, run as run_track | |
| # ===== 清理缓存目录 ===== | |
| print("===== 清理缓存 =====") | |
| cache_path = os.path.expanduser("~/.cache") | |
| if os.path.exists(cache_path): | |
| try: | |
| shutil.rmtree(cache_path) | |
| print("✅ Deleted ~/.cache") | |
| except: | |
| pass | |
| # ===== 全局模型变量 ===== | |
| SEG_MODEL = None | |
| SEG_DEVICE = torch.device("cpu") | |
| COUNT_MODEL = None | |
| COUNT_DEVICE = torch.device("cpu") | |
| TRACK_MODEL = None | |
| TRACK_DEVICE = torch.device("cpu") | |
| def load_all_models(): | |
| """启动时加载所有模型""" | |
| global SEG_MODEL, SEG_DEVICE | |
| global COUNT_MODEL, COUNT_DEVICE | |
| global TRACK_MODEL, TRACK_DEVICE | |
| print("\n" + "="*60) | |
| print("📦 Loading Segmentation Model") | |
| print("="*60) | |
| SEG_MODEL, SEG_DEVICE = load_seg_model(use_box=False) | |
| print("\n" + "="*60) | |
| print("📦 Loading Counting Model") | |
| print("="*60) | |
| COUNT_MODEL, COUNT_DEVICE = load_count_model(use_box=False) | |
| print("\n" + "="*60) | |
| print("📦 Loading Tracking Model") | |
| print("="*60) | |
| TRACK_MODEL, TRACK_DEVICE = load_track_model(use_box=False) | |
| print("\n" + "="*60) | |
| print("✅ All Models Loaded Successfully") | |
| print("="*60) | |
| load_all_models() | |
| # ===== 保存用户反馈 ===== | |
| DATASET_DIR = Path("solver_cache") | |
| DATASET_DIR.mkdir(parents=True, exist_ok=True) | |
| def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None): | |
| """保存用户反馈到JSON文件""" | |
| feedback_data = { | |
| "query_id": query_id, | |
| "feedback_type": feedback_type, | |
| "feedback_text": feedback_text, | |
| "image": img_path, | |
| "bboxes": bboxes, | |
| "datetime": time.strftime("%Y%m%d_%H%M%S") | |
| } | |
| feedback_file = DATASET_DIR / query_id / "feedback.json" | |
| feedback_file.parent.mkdir(parents=True, exist_ok=True) | |
| if feedback_file.exists(): | |
| with feedback_file.open("r") as f: | |
| existing = json.load(f) | |
| if not isinstance(existing, list): | |
| existing = [existing] | |
| existing.append(feedback_data) | |
| feedback_data = existing | |
| else: | |
| feedback_data = [feedback_data] | |
| with feedback_file.open("w") as f: | |
| json.dump(feedback_data, f, indent=4, ensure_ascii=False) | |
| # ===== 辅助函数 ===== | |
| def parse_first_bbox(bboxes): | |
| """解析第一个边界框""" | |
| if not bboxes: | |
| return None | |
| b = bboxes[0] | |
| if isinstance(b, dict): | |
| x, y = float(b.get("x", 0)), float(b.get("y", 0)) | |
| w, h = float(b.get("width", 0)), float(b.get("height", 0)) | |
| return x, y, x + w, y + h | |
| if isinstance(b, (list, tuple)) and len(b) >= 4: | |
| return float(b[0]), float(b[1]), float(b[2]), float(b[3]) | |
| return None | |
| def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray: | |
| """将实例掩码转换为彩色图像""" | |
| def hsv_to_rgb(h, s, v): | |
| i = int(h * 6.0) | |
| f = h * 6.0 - i | |
| i = i % 6 | |
| p = v * (1 - s) | |
| q = v * (1 - f * s) | |
| t = v * (1 - (1 - f) * s) | |
| if i == 0: r, g, b = v, t, p | |
| elif i == 1: r, g, b = q, v, p | |
| elif i == 2: r, g, b = p, v, t | |
| elif i == 3: r, g, b = p, q, v | |
| elif i == 4: r, g, b = t, p, v | |
| else: r, g, b = v, p, q | |
| return int(r * 255), int(g * 255), int(b * 255) | |
| palette = [(0, 0, 0)] | |
| for i in range(1, num_colors): | |
| h = (i % num_colors) / float(num_colors) | |
| palette.append(hsv_to_rgb(h, 1.0, 0.95)) | |
| palette_arr = np.array(palette, dtype=np.uint8) | |
| color_idx = mask % num_colors | |
| return palette_arr[color_idx] | |
| # ===== 分割功能 ===== | |
| def segment_with_choice(use_box_choice, annot_value): | |
| """分割主函数 - 每个实例不同颜色+轮廓""" | |
| if annot_value is None or len(annot_value) < 1: | |
| print("❌ No annotation input") | |
| return None, None | |
| img_path = annot_value[0] | |
| bboxes = annot_value[1] if len(annot_value) > 1 else [] | |
| print(f"🖼️ 图像路径: {img_path}") | |
| box_array = None | |
| if use_box_choice == "Yes" and bboxes: | |
| box = parse_first_bbox(bboxes) | |
| if box: | |
| xmin, ymin, xmax, ymax = map(int, box) | |
| box_array = [[xmin, ymin, xmax, ymax]] | |
| print(f"📦 使用边界框: {box_array}") | |
| # 运行分割模型 | |
| try: | |
| mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE) | |
| print("📏 mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask)) | |
| except Exception as e: | |
| print(f"❌ 推理失败: {str(e)}") | |
| return None, None | |
| # 保存原始mask为TIF文件 | |
| temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif") | |
| mask_img = Image.fromarray(mask.astype(np.uint16)) | |
| mask_img.save(temp_mask_file.name) | |
| print(f"💾 原始mask保存到: {temp_mask_file.name}") | |
| # 读取原图 | |
| try: | |
| img = Image.open(img_path) | |
| print("📷 Image mode:", img.mode, "size:", img.size) | |
| except Exception as e: | |
| print(f"❌ Failed to open image: {e}") | |
| return None, None | |
| try: | |
| img_rgb = img.convert("RGB").resize(mask.shape[::-1], resample=Image.BILINEAR) | |
| img_np = np.array(img_rgb, dtype=np.float32) | |
| if img_np.max() > 1.5: | |
| img_np = img_np / 255.0 | |
| except Exception as e: | |
| print(f"❌ Error in image conversion/resizing: {e}") | |
| return None, None | |
| mask_np = np.array(mask) | |
| inst_mask = mask_np.astype(np.int32) | |
| unique_ids = np.unique(inst_mask) | |
| num_instances = len(unique_ids[unique_ids != 0]) | |
| print(f"✅ Instance IDs found: {unique_ids}, Total instances: {num_instances}") | |
| if num_instances == 0: | |
| print("⚠️ No instance found, returning dummy red image") | |
| return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None | |
| # ==== Color Overlay (每个实例一个颜色) ==== | |
| overlay = img_np.copy() | |
| alpha = 0.5 | |
| cmap = cm.get_cmap("nipy_spectral", num_instances + 1) | |
| for inst_id in np.unique(inst_mask): | |
| if inst_id == 0: | |
| continue | |
| binary_mask = (inst_mask == inst_id).astype(np.uint8) | |
| color = np.array(cmap(inst_id / (num_instances + 1))[:3]) # RGB only, ignore alpha | |
| overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color | |
| # 绘制轮廓 | |
| contours = measure.find_contours(binary_mask, 0.5) | |
| for contour in contours: | |
| contour = contour.astype(np.int32) | |
| # 确保坐标在范围内 | |
| valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1) | |
| valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1) | |
| overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # 黄色轮廓 | |
| overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) | |
| return Image.fromarray(overlay), temp_mask_file.name | |
| # ===== 计数功能 ===== | |
| def count_cells_handler(use_box_choice, annot_value): | |
| """计数处理函数 - 支持边界框,只返回密度图""" | |
| if annot_value is None or len(annot_value) < 1: | |
| return None, "⚠️ 请先上传图像" | |
| image_path = annot_value[0] | |
| bboxes = annot_value[1] if len(annot_value) > 1 else [] | |
| print(f"🖼️ 图像路径: {image_path}") | |
| box_array = None | |
| if use_box_choice == "Yes" and bboxes: | |
| box = parse_first_bbox(bboxes) | |
| if box: | |
| xmin, ymin, xmax, ymax = map(int, box) | |
| box_array = [[xmin, ymin, xmax, ymax]] | |
| print(f"📦 使用边界框: {box_array}") | |
| try: | |
| print(f"🔢 Counting - Image: {image_path}") | |
| result = run_count( | |
| COUNT_MODEL, | |
| image_path, | |
| box=box_array, | |
| device=COUNT_DEVICE, | |
| visualize=True | |
| ) | |
| if 'error' in result: | |
| return None, f"❌ 计数失败: {result['error']}" | |
| count = result['count'] | |
| # 只提取密度图部分(假设visualized_path是拼接图,我们只要右半部分) | |
| viz_path = result.get('visualized_path') | |
| # 如果有density_map_path,直接使用 | |
| if 'density_map_path' in result: | |
| density_path = result['density_map_path'] | |
| elif viz_path and os.path.exists(viz_path): | |
| # 如果是拼接图,提取右半部分(密度图) | |
| try: | |
| viz_img = Image.open(viz_path) | |
| w, h = viz_img.size | |
| # 取右半部分 | |
| density_img = viz_img.crop((w//2, 0, w, h)) | |
| # 保存为新文件 | |
| temp_density = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
| density_img.save(temp_density.name) | |
| density_path = temp_density.name | |
| except: | |
| density_path = viz_path | |
| else: | |
| density_path = viz_path | |
| result_text = f"✅ 检测到 {count:.1f} 个细胞" | |
| print(f"✅ Counting done - Count: {count:.1f}") | |
| return density_path, result_text | |
| except Exception as e: | |
| print(f"❌ Counting error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"❌ 计数失败: {str(e)}" | |
| # ===== 跟踪功能 ===== | |
| def find_tif_dir(root_dir): | |
| """递归查找第一个包含 .tif 文件的目录""" | |
| for dirpath, _, filenames in os.walk(root_dir): | |
| if any(f.lower().endswith('.tif') for f in filenames): | |
| return dirpath | |
| return None | |
| def track_video_handler(zip_file_obj): | |
| """支持 ZIP 压缩包上传的 Tracking 处理函数""" | |
| if zip_file_obj is None: | |
| return None, "⚠️ 请上传包含视频帧的压缩包 (.zip)" | |
| try: | |
| temp_dir = tempfile.mkdtemp() | |
| print(f"📦 解压到临时目录: {temp_dir}") | |
| with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref: | |
| zip_ref.extractall(temp_dir) | |
| tif_dir = find_tif_dir(temp_dir) | |
| if tif_dir is None: | |
| return None, "❌ 解压后未找到任何 .tif 图像" | |
| print(f"🎬 Tracking - Found .tif in: {tif_dir}") | |
| result = run_track( | |
| TRACK_MODEL, | |
| video_dir=tif_dir, | |
| box=None, | |
| device=TRACK_DEVICE, | |
| output_dir="tracked_results" | |
| ) | |
| if 'error' in result: | |
| return None, f"❌ 跟踪失败: {result['error']}" | |
| num_tracks = result['num_tracks'] | |
| output_dir = result['output_dir'] | |
| result_text = f"""✅ 跟踪完成! | |
| 🎯 跟踪轨迹数量: {num_tracks} | |
| 📁 结果保存在: {output_dir} | |
| 包含的文件: | |
| - res_track.txt (CTC格式轨迹) | |
| - 其他跟踪数据文件 | |
| """ | |
| print(f"✅ Tracking done - {num_tracks} tracks") | |
| return None, result_text | |
| except zipfile.BadZipFile: | |
| return None, "❌ 上传的文件不是有效的 ZIP 压缩包" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"❌ 跟踪失败: {str(e)}" | |
| # ===== 示例图像 ===== | |
| example_images = ["003_img.png", "1977_Well_F-5_Field_1.png"] | |
| # ===== Gradio UI ===== | |
| with gr.Blocks(title="Microscopy Analysis Suite", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🔬 显微图像分析工具套件 | |
| 支持三种分析模式: | |
| - 🎨 **分割 (Segmentation)**: 实例分割,每个细胞不同颜色 | |
| - 🔢 **计数 (Counting)**: 基于密度图的细胞计数 | |
| - 🎬 **跟踪 (Tracking)**: 视频序列中的细胞运动跟踪 | |
| """ | |
| ) | |
| # 全局状态 | |
| current_query_id = gr.State(str(uuid.uuid4())) | |
| user_uploaded_examples = gr.State(example_images.copy()) # 初始化时包含原始示例 | |
| with gr.Tabs(): | |
| # ===== Tab 1: Segmentation ===== | |
| with gr.Tab("🎨 分割 (Segmentation)"): | |
| gr.Markdown("## 细胞实例分割 - 每个细胞一个颜色") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| annotator = BBoxAnnotator( | |
| label="🖼️ 上传图像 (可选标注边界框)", | |
| categories=["cell"] | |
| ) | |
| # 示例图片Gallery | |
| example_gallery = gr.Gallery( | |
| label="📁 示例图片", | |
| columns=3, | |
| object_fit="cover", | |
| height=150 | |
| ) | |
| # 上传示例图片 | |
| image_uploader = gr.Image( | |
| label="➕ 上传新示例到Gallery", | |
| type="filepath" | |
| ) | |
| with gr.Row(): | |
| use_box_radio = gr.Radio( | |
| choices=["Yes", "No"], | |
| value="No", | |
| label="🔲 使用边界框?" | |
| ) | |
| run_seg_btn = gr.Button("▶️ 运行分割", variant="primary", size="lg") | |
| gr.Markdown( | |
| """ | |
| **使用说明:** | |
| 1. 上传图像或从Gallery选择示例 | |
| 2. (可选) 标注边界框并选择 "Yes" | |
| 3. 点击 "运行分割" | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| seg_output = gr.Image( | |
| type="pil", | |
| label="📸 分割结果", | |
| height=400 | |
| ) | |
| # 下载原始预测结果 | |
| download_mask_btn = gr.File( | |
| label="📥 下载原始预测 (.tif 格式)", | |
| visible=True | |
| ) | |
| # 满意度评分 | |
| score_slider = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| step=1, | |
| value=5, | |
| label="🌟 满意度评分 (1-5)" | |
| ) | |
| # 反馈文本框 | |
| feedback_box = gr.Textbox( | |
| placeholder="请输入您的反馈意见...", | |
| lines=2, | |
| label="💬 反馈意见" | |
| ) | |
| # 提交按钮 | |
| submit_feedback_btn = gr.Button("💾 提交反馈", variant="secondary") | |
| feedback_status = gr.Textbox( | |
| label="✅ 提交状态", | |
| lines=1, | |
| visible=False | |
| ) | |
| # 绑定事件: 运行分割 | |
| run_seg_btn.click( | |
| fn=segment_with_choice, | |
| inputs=[use_box_radio, annotator], | |
| outputs=[seg_output, download_mask_btn] | |
| ) | |
| # 初始化Gallery显示 | |
| demo.load( | |
| fn=lambda: example_images.copy(), | |
| outputs=example_gallery | |
| ) | |
| # 绑定事件: 上传示例图片 | |
| def add_to_gallery(img_path, current_imgs): | |
| if not img_path: | |
| return current_imgs | |
| try: | |
| if img_path not in current_imgs: | |
| current_imgs.append(img_path) | |
| return current_imgs | |
| except: | |
| return current_imgs | |
| image_uploader.change( | |
| fn=add_to_gallery, | |
| inputs=[image_uploader, user_uploaded_examples], | |
| outputs=user_uploaded_examples | |
| ).then( | |
| fn=lambda imgs: imgs, | |
| inputs=user_uploaded_examples, | |
| outputs=example_gallery | |
| ) | |
| # 绑定事件: 点击Gallery加载 | |
| def load_from_gallery(evt: gr.SelectData, all_imgs): | |
| if evt.index is not None and evt.index < len(all_imgs): | |
| return all_imgs[evt.index] | |
| return None | |
| example_gallery.select( | |
| fn=load_from_gallery, | |
| inputs=user_uploaded_examples, | |
| outputs=annotator | |
| ) | |
| # 绑定事件: 提交反馈 | |
| def submit_user_feedback(query_id, score, comment, annot_val): | |
| try: | |
| img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None | |
| bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else [] | |
| save_feedback( | |
| query_id=query_id, | |
| feedback_type=f"score_{int(score)}", | |
| feedback_text=comment, | |
| img_path=img_path, | |
| bboxes=bboxes | |
| ) | |
| return "✅ 反馈已提交,感谢您的评价!", gr.update(visible=True) | |
| except Exception as e: | |
| return f"❌ 提交失败: {str(e)}", gr.update(visible=True) | |
| submit_feedback_btn.click( | |
| fn=submit_user_feedback, | |
| inputs=[current_query_id, score_slider, feedback_box, annotator], | |
| outputs=[feedback_status, feedback_status] | |
| ) | |
| # ===== Tab 2: Counting ===== | |
| with gr.Tab("🔢 计数 (Counting)"): | |
| gr.Markdown("## 细胞计数分析 - 基于密度图") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| count_annotator = BBoxAnnotator( | |
| label="🖼️ 上传图像 (可选标注边界框)", | |
| categories=["cell"] | |
| ) | |
| # 示例图片Gallery (与Segmentation相同) | |
| count_example_gallery = gr.Gallery( | |
| label="📁 示例图片", | |
| columns=3, | |
| object_fit="cover", | |
| height=150 | |
| ) | |
| # 上传示例图片 | |
| count_image_uploader = gr.Image( | |
| label="➕ 上传新示例到Gallery", | |
| type="filepath" | |
| ) | |
| with gr.Row(): | |
| count_use_box_radio = gr.Radio( | |
| choices=["Yes", "No"], | |
| value="No", | |
| label="🔲 使用边界框?" | |
| ) | |
| count_btn = gr.Button("▶️ 运行计数", variant="primary", size="lg") | |
| gr.Markdown( | |
| """ | |
| **使用说明:** | |
| 1. 上传图像或从Gallery选择示例 | |
| 2. (可选) 标注边界框并选择 "Yes" | |
| 3. 点击 "运行计数" | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| count_output = gr.Image( | |
| label="📸 密度图", | |
| type="filepath", | |
| height=400 | |
| ) | |
| count_status = gr.Textbox( | |
| label="📊 统计信息", | |
| lines=2 | |
| ) | |
| # 绑定事件 | |
| count_btn.click( | |
| fn=count_cells_handler, | |
| inputs=[count_use_box_radio, count_annotator], | |
| outputs=[count_output, count_status] | |
| ) | |
| # 初始化Gallery显示 | |
| demo.load( | |
| fn=lambda: example_images.copy(), | |
| outputs=count_example_gallery | |
| ) | |
| # 绑定事件: 上传示例图片到Counting Gallery | |
| count_user_examples = gr.State(example_images.copy()) | |
| def add_to_count_gallery(img_path, current_imgs): | |
| if not img_path: | |
| return current_imgs | |
| try: | |
| if img_path not in current_imgs: | |
| current_imgs.append(img_path) | |
| return current_imgs | |
| except: | |
| return current_imgs | |
| count_image_uploader.change( | |
| fn=add_to_count_gallery, | |
| inputs=[count_image_uploader, count_user_examples], | |
| outputs=count_user_examples | |
| ).then( | |
| fn=lambda imgs: imgs, | |
| inputs=count_user_examples, | |
| outputs=count_example_gallery | |
| ) | |
| # 绑定事件: 点击Gallery加载到count_annotator | |
| def load_from_count_gallery(evt: gr.SelectData, all_imgs): | |
| if evt.index is not None and evt.index < len(all_imgs): | |
| return all_imgs[evt.index] | |
| return None | |
| count_example_gallery.select( | |
| fn=load_from_count_gallery, | |
| inputs=count_user_examples, | |
| outputs=count_annotator | |
| ) | |
| # ===== Tab 3: Tracking ===== | |
| with gr.Tab("🎬 跟踪 (Tracking)"): | |
| gr.Markdown("## 视频细胞跟踪 - 支持 ZIP 压缩包上传") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| track_zip_upload = gr.File( | |
| label="📦 上传视频帧 ZIP 文件", | |
| file_types=[".zip"] | |
| ) | |
| track_btn = gr.Button("▶️ 运行跟踪", variant="primary", size="lg") | |
| gr.Markdown( | |
| """ | |
| **使用说明:** | |
| 1. 上传包含 `.tif` 图像的 ZIP 压缩包 | |
| 2. 点击 "运行跟踪" | |
| 3. 结果保存到 `tracked_results/` 目录 | |
| """ | |
| ) | |
| with gr.Column(scale=2): | |
| track_output = gr.Textbox( | |
| label="📊 跟踪信息", | |
| lines=12, | |
| interactive=False | |
| ) | |
| dummy = gr.Textbox(visible=False) | |
| track_btn.click( | |
| fn=track_video_handler, | |
| inputs=track_zip_upload, | |
| outputs=[dummy, track_output] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### 💡 技术说明 | |
| **分割 (Segmentation)** - 基于 Stable Diffusion 特征的实例分割 | |
| **计数 (Counting)** - 密度图估计 | |
| **跟踪 (Tracking)** - Trackastra 跟踪算法 | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True | |
| ) | |