FinalVision / app.py
Shengxiao0709's picture
Update app.py
cf8c65d verified
raw
history blame
25.6 kB
# import gradio as gr
# from gradio_bbox_annotator import BBoxAnnotator
# from PIL import Image
# import numpy as np
# import torch
# import os
# import shutil
# import subprocess
# import time, json, uuid
# from pathlib import Path
# import tempfile
# from inference import load_model, run
# from skimage import measure
# # === 图像处理依赖 ===
# from scipy.ndimage import label
# from matplotlib import cm
# # ===== 清理缓存目录 =====
# print("===== Space Usage =====")
# subprocess.run("du -sh *", shell=True)
# print("===== ~/.cache =====")
# subprocess.run("ls -lh ~/.cache", shell=True)
# cache_path = os.path.expanduser("~/.cache")
# if os.path.exists(cache_path):
# shutil.rmtree(cache_path)
# print("✅ Deleted ~/.cache to free space.")
# # ===== 模型初始化 =====
# MODEL = None
# DEVICE = torch.device("cpu")
# CUDA_READY = False
# def load_model_cpu():
# global MODEL, DEVICE
# MODEL, DEVICE = load_model(use_box=False)
# load_model_cpu()
# def prepare_cuda():
# global MODEL, DEVICE, CUDA_READY
# if torch.cuda.is_available() and not CUDA_READY:
# MODEL.to("cuda")
# DEVICE = torch.device("cuda")
# CUDA_READY = True
# _ = torch.zeros(1, device=DEVICE)
# # ===== BBox 解析 =====
# def parse_first_bbox(bboxes):
# if not bboxes:
# return None
# b = bboxes[0]
# if isinstance(b, dict):
# x, y = float(b.get("x", 0)), float(b.get("y", 0))
# w, h = float(b.get("width", 0)), float(b.get("height", 0))
# return x, y, x + w, y + h
# if isinstance(b, (list, tuple)) and len(b) >= 4:
# return float(b[0]), float(b[1]), float(b[2]), float(b[3])
# return None
# # ===== 保存用户反馈 =====
# DATASET_DIR = Path("solver_cache")
# DATASET_DIR.mkdir(parents=True, exist_ok=True)
# def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
# feedback_data = {
# "query_id": query_id,
# "feedback_type": feedback_type,
# "feedback_text": feedback_text,
# "image": img_path,
# "bboxes": bboxes,
# "datetime": time.strftime("%Y%m%d_%H%M%S")
# }
# feedback_file = DATASET_DIR / query_id / "feedback.json"
# feedback_file.parent.mkdir(parents=True, exist_ok=True)
# if feedback_file.exists():
# with feedback_file.open("r") as f:
# existing = json.load(f)
# if not isinstance(existing, list):
# existing = [existing]
# existing.append(feedback_data)
# feedback_data = existing
# else:
# feedback_data = [feedback_data]
# with feedback_file.open("w") as f:
# json.dump(feedback_data, f, indent=4, ensure_ascii=False)
# # ===== 彩色 mask 可视化 =====
# def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
# mask = mask.astype(np.int32)
# def hsv_to_rgb(hh, ss, vv):
# i = int(hh * 6.0)
# f = hh * 6.0 - i
# p = vv * (1.0 - ss)
# q = vv * (1.0 - f * ss)
# t = vv * (1.0 - (1.0 - f) * ss)
# i = i % 6
# if i == 0: r, g, b = vv, t, p
# elif i == 1: r, g, b = q, vv, p
# elif i == 2: r, g, b = p, vv, t
# elif i == 3: r, g, b = p, q, vv
# elif i == 4: r, g, b = t, p, vv
# else: r, g, b = vv, p, q
# return int(r*255), int(g*255), int(b*255)
# palette = [(0, 0, 0)]
# for k in range(1, num_colors):
# hue = (k % num_colors) / float(num_colors)
# palette.append(hsv_to_rgb(hue, 1.0, 0.95))
# color_idx = mask % num_colors
# palette_arr = np.array(palette, dtype=np.uint8)
# return palette_arr[color_idx]
# # ===== 推理 + 实例彩色可视化 =====
# def segment_with_choice(use_box_choice, annot_value, mode="Overlay"):
# prepare_cuda()
# if annot_value is None or len(annot_value) < 1:
# print("❌ No annotation input")
# return None
# img_path = annot_value[0]
# bboxes = annot_value[1] if len(annot_value) > 1 else []
# print(f"🖼️ Image path: {img_path}")
# box_array = None
# if use_box_choice == "Yes" and bboxes:
# box = parse_first_bbox(bboxes)
# if box:
# xmin, ymin, xmax, ymax = map(int, box)
# box_array = [[xmin, ymin, xmax, ymax]]
# print(f"📦 Using box: {box_array}")
# try:
# mask = run(MODEL, img_path, box=box_array, device=DEVICE)
# print("📏 Mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
# except Exception as e:
# print(f"❌ Error during inference: {e}")
# return None
# try:
# img = Image.open(img_path)
# print("📷 Image mode:", img.mode, "size:", img.size)
# except Exception as e:
# print(f"❌ Failed to open image: {e}")
# return None
# try:
# img_rgb = img.convert("RGB").resize(mask.shape[::-1], resample=Image.BILINEAR)
# img_np = np.array(img_rgb, dtype=np.float32)
# if img_np.max() > 1.5:
# img_np = img_np / 255.0
# except Exception as e:
# print(f"❌ Error in image conversion/resizing: {e}")
# return None
# mask_np = np.array(mask)
# inst_mask = mask_np.astype(np.int32)
# unique_ids = np.unique(inst_mask)
# num_instances = len(unique_ids[unique_ids != 0])
# print(f"✅ Instance IDs found: {unique_ids}, Total instances: {num_instances}")
# if num_instances == 0:
# print("⚠️ No instance found, returning dummy red image")
# return Image.new("RGB", mask.shape[::-1], (255, 0, 0))
# # ==== Color Overlay (每个实例一个颜色) ====
# overlay = img_np.copy()
# alpha = 0.5
# cmap = cm.get_cmap("nipy_spectral", num_instances + 1)
# for inst_id in np.unique(inst_mask):
# if inst_id == 0:
# continue
# binary_mask = (inst_mask == inst_id).astype(np.uint8)
# color = np.array(cmap(inst_id / (num_instances + 1))[:3]) # RGB only, ignore alpha
# overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
# # 可选:绘制轮廓
# contours = measure.find_contours(binary_mask, 0.5)
# for contour in contours:
# contour = contour.astype(np.int32)
# overlay[contour[:, 0], contour[:, 1]] = [1.0, 1.0, 0.0] # 黄色轮廓
# overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
# if mode == "Instance Mask Only":
# return Image.fromarray(colorize_mask(inst_mask, num_colors=512))
# return Image.fromarray(overlay)
# # ===== 示例图像 =====
# example_data = [
# ("003_img.png", [(50, 60, 120, 150, "cell")]),
# ("1977_Well_F-5_Field_1.png", [(30, 40, 100, 130, "cell")]),
# ]
# gallery_images = [p for p, _ in example_data]
# # ===== Gradio UI =====
# with gr.Blocks(title="Microscopy Cell Segmentation") as demo:
# gr.Markdown("## 🧬 Microscopy Image Segmentation — One Cell, One Color")
# with gr.Row():
# with gr.Column(scale=1):
# annotator = BBoxAnnotator(label="🖼️ Upload & Annotate", categories=["cell"])
# example_gallery = gr.Gallery(
# value=gallery_images,
# label="📁 Example Inputs",
# columns=[3], object_fit="cover", height=128
# )
# image_uploader = gr.Image(label="➕ Upload Image", type="filepath")
# run_btn = gr.Button("▶️ Run Segmentation")
# use_box_radio = gr.Radio(choices=["Yes", "No"], label="🔲 Use Bounding Box?", visible=False)
# confirm_btn = gr.Button("✅ Confirm", visible=False)
# mode_radio = gr.Radio(choices=["Overlay", "Instance Mask Only"], value="Overlay",
# label="🎨 Display Mode")
# with gr.Column(scale=2):
# image_output = gr.Image(type="pil", label="📸 Segmentation Result", height=400)
# score = gr.Slider(1, 5, step=1, value=3, label="🌟 Satisfaction (1–5)")
# comment_box = gr.Textbox(placeholder="Type your feedback...", lines=2, label="💬 Feedback")
# submit_score = gr.Button("💾 Submit Rating")
# user_uploaded_images = gr.State([])
# def add_uploaded_image(img_path, current_gallery):
# if not img_path:
# return current_gallery
# try:
# img = Image.open(img_path)
# img.thumbnail((128, 128))
# temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
# img.save(temp_file.name, format="PNG")
# thumb_path = temp_file.name
# if thumb_path not in current_gallery:
# current_gallery.append(thumb_path)
# except Exception as e:
# print(f"❌ Failed image: {e}")
# return current_gallery
# image_uploader.upload(add_uploaded_image, [image_uploader, user_uploaded_images], [example_gallery, user_uploaded_images])
# def on_gallery_select(evt: gr.SelectData, gallery_images):
# index = evt.index
# if index < len(example_data):
# selected_path, selected_boxes = example_data[index]
# return selected_path, selected_boxes
# else:
# selected_path = gallery_images[index]
# return selected_path, []
# example_gallery.select(on_gallery_select, inputs=[user_uploaded_images], outputs=[annotator])
# def show_radio():
# return gr.update(visible=True), gr.update(visible=True)
# run_btn.click(fn=show_radio, outputs=[use_box_radio, confirm_btn])
# confirm_btn.click(fn=segment_with_choice,
# inputs=[use_box_radio, annotator, mode_radio],
# outputs=image_output)
# def handle_comment(comment, annot_value):
# save_feedback(time.strftime("%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8], "comment", comment, annot_value[0], annot_value[1])
# return ""
# def handle_rating(score, annot_value):
# save_feedback(time.strftime("%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8], "rating", f"Satisfaction Score: {score}", annot_value[0], annot_value[1])
# return 3
# comment_box.submit(fn=handle_comment, inputs=[comment_box, annotator], outputs=[comment_box])
# submit_score.click(fn=handle_rating, inputs=[score, annotator], outputs=[score])
# if __name__ == "__main__":
# demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True)
import gradio as gr
from gradio_bbox_annotator import BBoxAnnotator
from PIL import Image
import numpy as np
import torch
import os
import shutil
import subprocess
import time
import json
import uuid
from pathlib import Path
import tempfile
from skimage import measure
from matplotlib import cm
# ===== 导入三个推理模块 =====
from inference_seg import load_model as load_seg_model, run as run_seg
from inference_count import load_model as load_count_model, run as run_count
from inference_track import load_model as load_track_model, run as run_track
import subprocess
print("\n===== 🔍 TOP 20 Disk Usage in your Space =====")
subprocess.run("du -sh /* /home/* /home/user/* | sort -hr | head -n 20", shell=True)
print("\n===== 🔍 Inside .cache =====")
subprocess.run("du -sh ~/.cache/* | sort -hr | head -n 10", shell=True)
print("\n===== 🔍 Inside current working dir =====")
subprocess.run("du -sh ./* | sort -hr | head -n 10", shell=True)
# ===== 清理缓存目录 =====
print("===== Space Usage =====")
subprocess.run("du -sh *", shell=True)
print("===== ~/.cache =====")
subprocess.run("ls -lh ~/.cache", shell=True)
cache_path = os.path.expanduser("~/.cache")
if os.path.exists(cache_path):
shutil.rmtree(cache_path)
print("✅ Deleted ~/.cache to free space.")
# ===== 全局模型变量 =====
SEG_MODEL = None
SEG_DEVICE = torch.device("cpu")
COUNT_MODEL = None
COUNT_DEVICE = torch.device("cpu")
TRACK_MODEL = None
TRACK_DEVICE = torch.device("cpu")
def load_all_models():
"""启动时加载所有模型"""
global SEG_MODEL, SEG_DEVICE
global COUNT_MODEL, COUNT_DEVICE
global TRACK_MODEL, TRACK_DEVICE
# 加载分割模型
print("\n" + "="*60)
print("📦 Loading Segmentation Model")
print("="*60)
SEG_MODEL, SEG_DEVICE = load_seg_model(use_box=False)
# 加载计数模型
print("\n" + "="*60)
print("📦 Loading Counting Model")
print("="*60)
COUNT_MODEL, COUNT_DEVICE = load_count_model(use_box=False)
# 加载跟踪模型
print("\n" + "="*60)
print("📦 Loading Tracking Model")
print("="*60)
TRACK_MODEL, TRACK_DEVICE = load_track_model(use_box=False)
print("\n" + "="*60)
print("✅ All Models Loaded Successfully")
print("="*60)
# 启动时加载所有模型
load_all_models()
# ===== 辅助函数 =====
def parse_first_bbox(bboxes):
"""解析第一个边界框"""
if not bboxes:
return None
b = bboxes[0]
if isinstance(b, dict):
x, y = float(b.get("x", 0)), float(b.get("y", 0))
w, h = float(b.get("width", 0)), float(b.get("height", 0))
return x, y, x + w, y + h
if isinstance(b, (list, tuple)) and len(b) >= 4:
return float(b[0]), float(b[1]), float(b[2]), float(b[3])
return None
def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
"""将实例掩码转换为彩色图像"""
mask = mask.astype(np.int32)
def hsv_to_rgb(hh, ss, vv):
i = int(hh * 6.0)
f = hh * 6.0 - i
p = vv * (1.0 - ss)
q = vv * (1.0 - f * ss)
t = vv * (1.0 - (1.0 - f) * ss)
i = i % 6
if i == 0: r, g, b = vv, t, p
elif i == 1: r, g, b = q, vv, p
elif i == 2: r, g, b = p, vv, t
elif i == 3: r, g, b = p, q, vv
elif i == 4: r, g, b = t, p, vv
else: r, g, b = vv, p, q
return int(r*255), int(g*255), int(b*255)
palette = [(0, 0, 0)]
for k in range(1, num_colors):
hue = (k % num_colors) / float(num_colors)
palette.append(hsv_to_rgb(hue, 1.0, 0.95))
color_idx = mask % num_colors
palette_arr = np.array(palette, dtype=np.uint8)
return palette_arr[color_idx]
# ===== 分割功能 =====
def segment_with_choice(use_box_choice, annot_value, mode="Overlay"):
"""分割处理函数"""
if annot_value is None or len(annot_value) < 1:
return None, "⚠️ 请先上传图像"
img_path = annot_value[0]
bboxes = annot_value[1] if len(annot_value) > 1 else []
print(f"🖼️ Segmentation - Image: {img_path}")
box_array = None
if use_box_choice == "Yes" and bboxes:
box = parse_first_bbox(bboxes)
if box:
xmin, ymin, xmax, ymax = map(int, box)
box_array = (xmin, ymin, xmax, ymax)
print(f"📦 Using box: {box_array}")
try:
# 运行分割
mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
if mask is None:
return None, "❌ 分割失败"
print(f"✅ Segmentation done - Mask shape: {mask.shape}")
except Exception as e:
print(f"❌ Segmentation error: {e}")
return None, f"分割失败: {str(e)}"
try:
# 读取原图
img = Image.open(img_path).convert("RGB")
img_rgb = img.resize(mask.shape[::-1], resample=Image.BILINEAR)
img_np = np.array(img_rgb, dtype=np.float32)
if img_np.max() > 1.5:
img_np = img_np / 255.0
except Exception as e:
print(f"❌ Image processing error: {e}")
return None, f"图像处理失败: {str(e)}"
# 生成可视化
mask_np = np.array(mask)
inst_mask = mask_np.astype(np.int32)
unique_ids = np.unique(inst_mask)
num_instances = len(unique_ids[unique_ids != 0])
if num_instances == 0:
result_text = "⚠️ 未检测到细胞"
return Image.new("RGB", mask.shape[::-1], (255, 200, 200)), result_text
# 创建叠加图
overlay = img_np.copy()
alpha = 0.5
cmap_vis = cm.get_cmap("nipy_spectral", num_instances + 1)
for inst_id in unique_ids:
if inst_id == 0:
continue
binary_mask = (inst_mask == inst_id).astype(np.uint8)
color = np.array(cmap_vis(inst_id / (num_instances + 1))[:3])
overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
# 绘制轮廓
contours = measure.find_contours(binary_mask, 0.5)
for contour in contours:
contour = contour.astype(np.int32)
overlay[contour[:, 0], contour[:, 1]] = [1.0, 1.0, 0.0]
overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
result_text = f"✅ 检测到 {num_instances} 个细胞"
if mode == "Instance Mask Only":
return Image.fromarray(colorize_mask(inst_mask, num_colors=512)), result_text
return Image.fromarray(overlay), result_text
# ===== 计数功能 =====
import zipfile
import tempfile
import shutil
def track_video_handler(zip_file_obj):
"""支持 ZIP 压缩包上传的 Tracking 处理函数"""
if zip_file_obj is None:
return None, "⚠️ 请上传包含视频帧的压缩包 (.zip)"
if TRACK_MODEL is None:
return None, "❌ 跟踪模型未加载"
try:
# 创建临时目录
temp_dir = tempfile.mkdtemp()
print(f"📦 解压到临时目录: {temp_dir}")
# 解压 zip 文件
with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
print(f"🎬 Tracking - Video frame folder: {temp_dir}")
result = run_track(
TRACK_MODEL,
video_dir=temp_dir, # 传入解压目录
box=None,
device=TRACK_DEVICE,
output_dir="tracked_results"
)
if 'error' in result:
return None, f"❌ 跟踪失败: {result['error']}"
num_tracks = result['num_tracks']
output_dir = result['output_dir']
result_text = f"""✅ 跟踪完成!
🎯 跟踪轨迹数量: {num_tracks}
📁 结果保存在: {output_dir}
包含的文件:
- res_track.txt (CTC格式轨迹)
- 其他跟踪数据文件
"""
print(f"✅ Tracking done - {num_tracks} tracks")
return None, result_text
except zipfile.BadZipFile:
return None, "❌ 上传的文件不是有效的 ZIP 压缩包"
except Exception as e:
import traceback
traceback.print_exc()
return None, f"❌ 跟踪失败: {str(e)}"
# ===== Gradio UI =====
with gr.Blocks(title="Microscopy Analysis Suite", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🔬 显微图像分析工具套件
支持三种分析模式:
- 🎨 **分割 (Segmentation)**: 实例分割,每个细胞不同颜色
- 🔢 **计数 (Counting)**: 基于密度图的细胞计数
- 🎬 **跟踪 (Tracking)**: 视频序列中的细胞运动跟踪
"""
)
with gr.Tabs():
# ===== Tab 1: Segmentation =====
with gr.Tab("🎨 分割 (Segmentation)"):
gr.Markdown("## 细胞实例分割 - 每个细胞一个颜色")
with gr.Row():
with gr.Column(scale=1):
annotator = BBoxAnnotator(
label="🖼️ 上传图像 (可选标注边界框)",
categories=["cell"]
)
with gr.Row():
use_box_radio = gr.Radio(
choices=["Yes", "No"],
value="No",
label="🔲 使用边界框?"
)
mode_radio = gr.Radio(
choices=["Overlay", "Instance Mask Only"],
value="Overlay",
label="🎨 显示模式"
)
run_seg_btn = gr.Button("▶️ 运行分割", variant="primary", size="lg")
gr.Markdown(
"""
**使用说明:**
1. 上传图像
2. (可选) 标注边界框并选择 "Yes"
3. 选择显示模式
4. 点击 "运行分割"
"""
)
with gr.Column(scale=2):
seg_output = gr.Image(
type="pil",
label="📸 分割结果",
height=500
)
seg_status = gr.Textbox(
label="📊 状态信息",
lines=2
)
# 绑定事件
run_seg_btn.click(
fn=segment_with_choice,
inputs=[use_box_radio, annotator, mode_radio],
outputs=[seg_output, seg_status]
)
# ===== Tab 2: Counting =====
with gr.Tab("🔢 计数 (Counting)"):
gr.Markdown("## 细胞计数分析 - 基于密度图")
with gr.Row():
with gr.Column(scale=1):
count_input = gr.Image(
label="🖼️ 上传图像",
type="filepath"
)
count_btn = gr.Button("▶️ 运行计数", variant="primary", size="lg")
gr.Markdown(
"""
**使用说明:**
1. 上传细胞图像
2. 点击 "运行计数"
3. 查看密度图和计数结果
**特点:**
- 基于 Stable Diffusion 特征
- 自动生成密度图
- 无需手动标注
"""
)
with gr.Column(scale=2):
count_output = gr.Image(
label="📸 计数结果 (左: 原图 | 右: 密度图)",
type="filepath",
height=500
)
count_status = gr.Textbox(
label="📊 统计信息",
lines=2
)
# 绑定事件
count_btn.click(
fn=count_cells_handler,
inputs=count_input,
outputs=[count_output, count_status]
)
# ===== Tab 3: Tracking =====
with gr.Tab("🎬 跟踪 (Tracking)"):
gr.Markdown("## 视频细胞跟踪 - 支持 ZIP 压缩包上传")
with gr.Row():
with gr.Column(scale=1):
track_zip_upload = gr.File(
label="📦 上传视频帧 ZIP 文件",
file_types=[".zip"]
)
track_btn = gr.Button("▶️ 运行跟踪", variant="primary", size="lg")
gr.Markdown(
"""
**使用说明:**
1. 上传包含视频帧序列的压缩包 `.zip`
2. 压缩包应直接包含 `.tif` 格式图像,如 t000.tif, t001.tif, ...
3. 点击 "运行跟踪"
4. 结果将保存到 `tracked_results/` 目录
**压缩包示例结构:**
```
frames.zip
├── t000.tif
├── t001.tif
├── t002.tif
└── ...
```
**跟踪模式:** Greedy (快速)
"""
)
with gr.Column(scale=2):
track_output = gr.Textbox(
label="📊 跟踪信息",
lines=12,
interactive=False
)
# 绑定事件:上传zip → 解压 → Tracking
track_btn.click(
fn=track_video_handler, # 你刚才改好的函数
inputs=track_zip_upload, # 文件上传
outputs=[None, track_output] # 第二个是 Textbox 输出
)
gr.Markdown(
"""
---
### 💡 技术说明
**分割 (Segmentation)**
- 模型: 基于 Stable Diffusion 特征的实例分割
- 输出: 每个细胞一个唯一颜色的掩码
**计数 (Counting)**
- 模型: 密度图估计
- 输出: 密度热力图 + 总计数
**跟踪 (Tracking)**
- 模型: Trackastra 跟踪算法
- 输出: CTC 格式的轨迹文件
---
📧 问题反馈 | 🌟 GitHub
"""
)
if __name__ == "__main__":
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)