Shengxiao0709 commited on
Commit
cbc523f
·
verified ·
1 Parent(s): 9bf9fb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -369
app.py CHANGED
@@ -281,72 +281,73 @@ import torch
281
  import os
282
  import shutil
283
  import subprocess
284
- import time, json, uuid
 
 
285
  from pathlib import Path
286
  import tempfile
287
- from inference import load_model, run
288
  from skimage import measure
289
- # === 图像处理依赖 ===
290
- from scipy.ndimage import label
291
  from matplotlib import cm
292
 
 
 
 
 
 
293
  # ===== 清理缓存目录 =====
294
- print("===== Space Usage =====")
295
- subprocess.run("du -sh *", shell=True)
296
- print("===== ~/.cache =====")
297
- subprocess.run("ls -lh ~/.cache", shell=True)
298
  cache_path = os.path.expanduser("~/.cache")
299
  if os.path.exists(cache_path):
300
- shutil.rmtree(cache_path)
301
- print("✅ Deleted ~/.cache to free space.")
302
-
303
- # ===== 模型初始化 =====
304
- MODEL = None
305
- DEVICE = torch.device("cpu")
306
- CUDA_READY = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
- # 用于counting和tracking的模型
309
- COUNTING_MODEL = None
310
- TRACKING_MODEL = None
311
 
312
- def load_model_cpu():
313
- global MODEL, DEVICE
314
- MODEL, DEVICE = load_model(use_box=False)
315
-
316
- load_model_cpu()
317
-
318
- def load_counting_model():
319
- """
320
- 加载计数模型
321
- 替换为你的计数模型加载代码
322
- """
323
- global COUNTING_MODEL
324
- # TODO: 替换为实际的计数模型
325
- # 例如: COUNTING_MODEL = torch.load("counting_model.pth")
326
- print("✅ Counting model loaded (placeholder)")
327
- pass
328
-
329
- def load_tracking_model():
330
- """
331
- 加载跟踪模型
332
- 替换为你的跟踪模型加载代码
333
- """
334
- global TRACKING_MODEL
335
- # TODO: 替换为实际的跟踪模型
336
- # 例如: TRACKING_MODEL = torch.load("tracking_model.pth")
337
- print("✅ Tracking model loaded (placeholder)")
338
- pass
339
-
340
- def prepare_cuda():
341
- global MODEL, DEVICE, CUDA_READY
342
- if torch.cuda.is_available() and not CUDA_READY:
343
- MODEL.to("cuda")
344
- DEVICE = torch.device("cuda")
345
- CUDA_READY = True
346
- _ = torch.zeros(1, device=DEVICE)
347
-
348
- # ===== BBox 解析 =====
349
  def parse_first_bbox(bboxes):
 
350
  if not bboxes:
351
  return None
352
  b = bboxes[0]
@@ -358,35 +359,8 @@ def parse_first_bbox(bboxes):
358
  return float(b[0]), float(b[1]), float(b[2]), float(b[3])
359
  return None
360
 
361
- # ===== 保存用户反馈 =====
362
- DATASET_DIR = Path("solver_cache")
363
- DATASET_DIR.mkdir(parents=True, exist_ok=True)
364
-
365
- def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
366
- feedback_data = {
367
- "query_id": query_id,
368
- "feedback_type": feedback_type,
369
- "feedback_text": feedback_text,
370
- "image": img_path,
371
- "bboxes": bboxes,
372
- "datetime": time.strftime("%Y%m%d_%H%M%S")
373
- }
374
- feedback_file = DATASET_DIR / query_id / "feedback.json"
375
- feedback_file.parent.mkdir(parents=True, exist_ok=True)
376
- if feedback_file.exists():
377
- with feedback_file.open("r") as f:
378
- existing = json.load(f)
379
- if not isinstance(existing, list):
380
- existing = [existing]
381
- existing.append(feedback_data)
382
- feedback_data = existing
383
- else:
384
- feedback_data = [feedback_data]
385
- with feedback_file.open("w") as f:
386
- json.dump(feedback_data, f, indent=4, ensure_ascii=False)
387
-
388
- # ===== 彩色 mask 可视化 =====
389
  def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
 
390
  mask = mask.astype(np.int32)
391
 
392
  def hsv_to_rgb(hh, ss, vv):
@@ -413,17 +387,17 @@ def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
413
  palette_arr = np.array(palette, dtype=np.uint8)
414
  return palette_arr[color_idx]
415
 
416
- # ===== 推理 + 实例彩色可视化 (Segmentation) =====
417
  def segment_with_choice(use_box_choice, annot_value, mode="Overlay"):
418
- prepare_cuda()
419
  if annot_value is None or len(annot_value) < 1:
420
- print("❌ No annotation input")
421
- return None
422
 
423
  img_path = annot_value[0]
424
  bboxes = annot_value[1] if len(annot_value) > 1 else []
425
 
426
- print(f"🖼️ Image path: {img_path}")
 
427
  box_array = None
428
  if use_box_choice == "Yes" and bboxes:
429
  box = parse_first_bbox(bboxes)
@@ -433,284 +407,221 @@ def segment_with_choice(use_box_choice, annot_value, mode="Overlay"):
433
  print(f"📦 Using box: {box_array}")
434
 
435
  try:
436
- mask = run(MODEL, img_path, box=box_array, device=DEVICE)
437
- print("📏 Mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
438
- except Exception as e:
439
- print(f"❌ Error during inference: {e}")
440
- return None
441
-
442
- try:
443
- img = Image.open(img_path)
444
- print("📷 Image mode:", img.mode, "size:", img.size)
445
  except Exception as e:
446
- print(f"❌ Failed to open image: {e}")
447
- return None
448
 
449
  try:
450
- img_rgb = img.convert("RGB").resize(mask.shape[::-1], resample=Image.BILINEAR)
 
 
451
  img_np = np.array(img_rgb, dtype=np.float32)
452
  if img_np.max() > 1.5:
453
  img_np = img_np / 255.0
454
  except Exception as e:
455
- print(f"❌ Error in image conversion/resizing: {e}")
456
- return None
457
 
 
458
  mask_np = np.array(mask)
459
  inst_mask = mask_np.astype(np.int32)
460
  unique_ids = np.unique(inst_mask)
461
  num_instances = len(unique_ids[unique_ids != 0])
462
- print(f"✅ Instance IDs found: {unique_ids}, Total instances: {num_instances}")
463
-
464
  if num_instances == 0:
465
- print("⚠️ No instance found, returning dummy red image")
466
- return Image.new("RGB", mask.shape[::-1], (255, 0, 0))
467
 
468
- # ==== Color Overlay (每个实例一个颜色) ====
469
  overlay = img_np.copy()
470
  alpha = 0.5
471
- cmap = cm.get_cmap("nipy_spectral", num_instances + 1)
472
 
473
- for inst_id in np.unique(inst_mask):
474
  if inst_id == 0:
475
  continue
476
  binary_mask = (inst_mask == inst_id).astype(np.uint8)
477
- color = np.array(cmap(inst_id / (num_instances + 1))[:3]) # RGB only, ignore alpha
478
  overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
479
 
480
- # 可选:绘制轮廓
481
  contours = measure.find_contours(binary_mask, 0.5)
482
  for contour in contours:
483
  contour = contour.astype(np.int32)
484
- overlay[contour[:, 0], contour[:, 1]] = [1.0, 1.0, 0.0] # 黄色轮廓
485
 
486
  overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
 
487
 
488
  if mode == "Instance Mask Only":
489
- return Image.fromarray(colorize_mask(inst_mask, num_colors=512))
490
 
491
- return Image.fromarray(overlay)
492
 
493
- # ===== Counting 功能 =====
494
- def count_cells(image_path):
495
- """
496
- 计数功能
497
- TODO: 替换为你的计数模型推理代码
498
- """
499
  if image_path is None:
500
- return None, "请先上传图像"
 
 
 
501
 
502
  try:
503
- img = Image.open(image_path)
504
- img_np = np.array(img)
505
-
506
- # TODO: 替换为实际的计数模型推理
507
- # 示例代码:
508
- # results = COUNTING_MODEL(img_np)
509
- # count = len(results)
510
-
511
- # 临时使用��单的计数方法作为演示
512
- from skimage import filters, morphology
513
- gray = np.array(img.convert('L'))
514
- thresh = filters.threshold_otsu(gray)
515
- binary = gray > thresh
516
- labeled = morphology.label(binary)
517
- count = labeled.max()
518
 
519
- # 可视化
520
- import matplotlib.pyplot as plt
521
- from matplotlib import cm
 
 
 
 
522
 
523
- fig, ax = plt.subplots(1, 1, figsize=(10, 10))
524
- ax.imshow(img)
525
 
526
- # 标注每个对象
527
- for region_id in range(1, count + 1):
528
- region_mask = labeled == region_id
529
- coords = np.argwhere(region_mask)
530
- if len(coords) > 0:
531
- y, x = coords.mean(axis=0)
532
- ax.text(x, y, str(region_id), color='red',
533
- fontsize=12, fontweight='bold',
534
- bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7))
535
 
536
- ax.axis('off')
537
 
538
- # 保存到临时文件
539
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
540
- plt.savefig(temp_file.name, bbox_inches='tight', dpi=150)
541
- plt.close()
542
-
543
- result_text = f"🔢 检测到 {count} 个细胞"
544
- print(f"✅ Counting result: {count} cells")
545
-
546
- return temp_file.name, result_text
547
 
548
  except Exception as e:
549
  print(f"❌ Counting error: {e}")
550
- return None, f"计数失败: {str(e)}"
551
-
552
- # ===== Tracking 功能 =====
553
- def track_video(video_path, progress=gr.Progress()):
554
- """
555
- 视频跟踪功能
556
- TODO: 替换为你的跟踪模型推理代码
557
- """
558
- if video_path is None:
559
- return None, "请先上传视频"
 
 
 
 
 
 
560
 
561
  try:
562
- import cv2
563
 
564
- # 读取视频
565
- cap = cv2.VideoCapture(video_path)
566
- fps = int(cap.get(cv2.CAP_PROP_FPS))
567
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
568
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
569
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
570
 
571
- # 创建输出视频
572
- output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
573
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
574
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
575
 
576
- print(f"📹 Processing video: {total_frames} frames, {fps} fps")
 
577
 
578
- # TODO: 初始化跟踪器
579
- # tracker = initialize_your_tracker()
580
-
581
- frame_count = 0
582
- while cap.isOpened():
583
- ret, frame = cap.read()
584
- if not ret:
585
- break
586
-
587
- # TODO: 替换为实际的跟踪模型推理
588
- # tracked_frame, tracks = TRACKING_MODEL.update(frame)
589
-
590
- # 临时演示: 在帧上添加文字
591
- tracked_frame = frame.copy()
592
- cv2.putText(tracked_frame, f"Frame {frame_count}/{total_frames}",
593
- (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
594
-
595
- out.write(tracked_frame)
596
- frame_count += 1
597
-
598
- # 更新进度条
599
- if frame_count % 10 == 0:
600
- progress((frame_count / total_frames, f"处理中: {frame_count}/{total_frames}"))
601
-
602
- cap.release()
603
- out.release()
604
 
605
- result_text = f"✅ 跟踪完成! 处理了 {frame_count} "
606
- print(result_text)
607
 
608
- return output_path, result_text
609
 
610
  except Exception as e:
611
  print(f"❌ Tracking error: {e}")
612
- return None, f"跟踪失败: {str(e)}"
613
-
614
- # ===== 示例图像 =====
615
- example_data = [
616
- ("003_img.png", [(50, 60, 120, 150, "cell")]),
617
- ("1977_Well_F-5_Field_1.png", [(30, 40, 100, 130, "cell")]),
618
- ]
619
- gallery_images = [p for p, _ in example_data]
620
 
621
  # ===== Gradio UI =====
622
  with gr.Blocks(title="Microscopy Analysis Suite", theme=gr.themes.Soft()) as demo:
623
  gr.Markdown(
624
  """
625
  # 🔬 显微图像分析工具套件
626
- 支持三种分析模式: 分割 (Segmentation) | 计数 (Counting) | 跟踪 (Tracking)
 
 
 
 
627
  """
628
  )
629
 
630
  with gr.Tabs():
631
  # ===== Tab 1: Segmentation =====
632
  with gr.Tab("🎨 分割 (Segmentation)"):
633
- gr.Markdown("## 🧬 细胞分割 每个细胞一个颜色")
634
 
635
  with gr.Row():
636
  with gr.Column(scale=1):
637
- annotator = BBoxAnnotator(label="🖼️ 上传 & 标注", categories=["cell"])
638
-
639
- example_gallery = gr.Gallery(
640
- value=gallery_images,
641
- label="📁 示例图像",
642
- columns=[3], object_fit="cover", height=128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  )
644
-
645
- image_uploader = gr.Image(label="➕ 上传图像", type="filepath")
646
-
647
- run_btn = gr.Button("▶️ 运行分割", variant="primary")
648
- use_box_radio = gr.Radio(choices=["Yes", "No"], label="🔲 使用边界框?", visible=False)
649
- confirm_btn = gr.Button("✅ 确认", visible=False)
650
- mode_radio = gr.Radio(choices=["Overlay", "Instance Mask Only"], value="Overlay",
651
- label="🎨 显示模式")
652
 
653
  with gr.Column(scale=2):
654
- image_output = gr.Image(type="pil", label="📸 分割结果", height=400)
655
- score = gr.Slider(1, 5, step=1, value=3, label="🌟 满意度 (1–5)")
656
- comment_box = gr.Textbox(placeholder="输入您的反馈...", lines=2, label="💬 反馈")
657
- submit_score = gr.Button("💾 提交评分")
658
-
659
- user_uploaded_images = gr.State([])
660
-
661
- def add_uploaded_image(img_path, current_gallery):
662
- if not img_path:
663
- return current_gallery
664
- try:
665
- img = Image.open(img_path)
666
- img.thumbnail((128, 128))
667
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
668
- img.save(temp_file.name, format="PNG")
669
- thumb_path = temp_file.name
670
- if thumb_path not in current_gallery:
671
- current_gallery.append(thumb_path)
672
- except Exception as e:
673
- print(f"❌ Failed image: {e}")
674
- return current_gallery
675
-
676
- image_uploader.upload(add_uploaded_image, [image_uploader, user_uploaded_images],
677
- [example_gallery, user_uploaded_images])
678
-
679
- def on_gallery_select(evt: gr.SelectData, gallery_images):
680
- index = evt.index
681
- if index < len(example_data):
682
- selected_path, selected_boxes = example_data[index]
683
- return selected_path, selected_boxes
684
- else:
685
- selected_path = gallery_images[index]
686
- return selected_path, []
687
-
688
- example_gallery.select(on_gallery_select, inputs=[user_uploaded_images], outputs=[annotator])
689
-
690
- def show_radio():
691
- return gr.update(visible=True), gr.update(visible=True)
692
-
693
- run_btn.click(fn=show_radio, outputs=[use_box_radio, confirm_btn])
694
- confirm_btn.click(fn=segment_with_choice,
695
- inputs=[use_box_radio, annotator, mode_radio],
696
- outputs=image_output)
697
-
698
- def handle_comment(comment, annot_value):
699
- save_feedback(time.strftime("%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8],
700
- "comment", comment, annot_value[0], annot_value[1])
701
- return ""
702
-
703
- def handle_rating(score, annot_value):
704
- save_feedback(time.strftime("%Y%m%d_%H%M%S") + "_" + str(uuid.uuid4())[:8],
705
- "rating", f"Satisfaction Score: {score}", annot_value[0], annot_value[1])
706
- return 3
707
-
708
- comment_box.submit(fn=handle_comment, inputs=[comment_box, annotator], outputs=[comment_box])
709
- submit_score.click(fn=handle_rating, inputs=[score, annotator], outputs=[score])
710
 
711
  # ===== Tab 2: Counting =====
712
  with gr.Tab("🔢 计数 (Counting)"):
713
- gr.Markdown("## 细胞计数分析")
714
 
715
  with gr.Row():
716
  with gr.Column(scale=1):
@@ -718,127 +629,117 @@ with gr.Blocks(title="Microscopy Analysis Suite", theme=gr.themes.Soft()) as dem
718
  label="🖼️ 上传图像",
719
  type="filepath"
720
  )
721
- count_btn = gr.Button("▶️ 运行计数", variant="primary")
722
 
723
  gr.Markdown(
724
  """
725
- **说明:**
726
- - 自动检测并计数图像中的细胞
727
- - 结果会在图像上标注编号
 
 
 
 
 
 
728
  """
729
  )
730
 
731
  with gr.Column(scale=2):
732
- count_output_img = gr.Image(
733
- label="📸 计数结果",
734
- type="filepath"
 
735
  )
736
- count_output_text = gr.Textbox(
737
- label="🔢 统计信息",
738
  lines=2
739
  )
740
-
741
- count_score = gr.Slider(1, 5, step=1, value=3, label="🌟 满意度 (1–5)")
742
- count_comment = gr.Textbox(placeholder="输入反馈...", lines=2, label="💬 反馈")
743
- count_submit = gr.Button("💾 提交评分")
744
 
745
  # 绑定事件
746
  count_btn.click(
747
- fn=count_cells,
748
  inputs=count_input,
749
- outputs=[count_output_img, count_output_text]
750
- )
751
-
752
- def handle_count_feedback(score, comment, img_path):
753
- if img_path:
754
- save_feedback(
755
- time.strftime("%Y%m%d_%H%M%S") + "_count_" + str(uuid.uuid4())[:8],
756
- "counting",
757
- f"Score: {score}, Comment: {comment}",
758
- img_path,
759
- None
760
- )
761
- return 3, ""
762
-
763
- count_submit.click(
764
- fn=handle_count_feedback,
765
- inputs=[count_score, count_comment, count_input],
766
- outputs=[count_score, count_comment]
767
  )
768
 
769
  # ===== Tab 3: Tracking =====
770
  with gr.Tab("🎬 跟踪 (Tracking)"):
771
- gr.Markdown("## 视频细胞跟踪")
772
 
773
  with gr.Row():
774
  with gr.Column(scale=1):
775
- track_input = gr.Video(
776
- label="📹 上传视频"
 
 
777
  )
778
- track_btn = gr.Button("▶️ 运行跟踪", variant="primary")
779
 
780
  gr.Markdown(
781
  """
782
- **说明:**
783
- - 支持格式: MP4, AVI, MOV
784
- - 自动跟踪视频中的细胞运动
785
- - 处理时间取决于视频长度
 
 
 
 
 
 
 
 
 
 
 
 
786
  """
787
  )
788
 
789
  with gr.Column(scale=2):
790
- track_output_video = gr.Video(
791
- label="📸 跟踪结果"
 
792
  )
793
- track_output_text = gr.Textbox(
794
- label="📊 处理状态",
795
- lines=2
796
  )
797
-
798
- track_score = gr.Slider(1, 5, step=1, value=3, label="🌟 满意度 (1–5)")
799
- track_comment = gr.Textbox(placeholder="输入反馈...", lines=2, label="💬 反馈")
800
- track_submit = gr.Button("💾 提交评分")
801
 
802
  # 绑定事件
803
  track_btn.click(
804
- fn=track_video,
805
  inputs=track_input,
806
- outputs=[track_output_video, track_output_text]
807
- )
808
-
809
- def handle_track_feedback(score, comment, video_path):
810
- if video_path:
811
- save_feedback(
812
- time.strftime("%Y%m%d_%H%M%S") + "_track_" + str(uuid.uuid4())[:8],
813
- "tracking",
814
- f"Score: {score}, Comment: {comment}",
815
- video_path,
816
- None
817
- )
818
- return 3, ""
819
-
820
- track_submit.click(
821
- fn=handle_track_feedback,
822
- inputs=[track_score, track_comment, track_input],
823
- outputs=[track_score, track_comment]
824
  )
825
 
826
- # ===== 页脚 =====
827
  gr.Markdown(
828
  """
829
  ---
830
- ### 💡 功能说明
831
- - **Segmentation**: 分割并可视化图像中的每个细胞
832
- - **Counting**: 自动计数图像中的细胞数量
833
- - **Tracking**: 跟踪视频中细胞的运动轨迹
 
 
 
 
 
 
 
 
 
 
 
 
834
  """
835
  )
836
 
837
  if __name__ == "__main__":
838
  demo.queue().launch(
839
- server_name="0.0.0.0",
840
- server_port=7860,
841
- share=True,
842
  show_error=True
843
  )
844
-
 
281
  import os
282
  import shutil
283
  import subprocess
284
+ import time
285
+ import json
286
+ import uuid
287
  from pathlib import Path
288
  import tempfile
 
289
  from skimage import measure
 
 
290
  from matplotlib import cm
291
 
292
+ # ===== 导入三个推理模块 =====
293
+ from inference_seg import load_model as load_seg_model, run as run_seg
294
+ from inference_count import load_model as load_count_model, run as run_count
295
+ from inference_track import load_model as load_track_model, run as run_track
296
+
297
  # ===== 清理缓存目录 =====
298
+ print("===== Cleaning Cache =====")
 
 
 
299
  cache_path = os.path.expanduser("~/.cache")
300
  if os.path.exists(cache_path):
301
+ try:
302
+ shutil.rmtree(cache_path)
303
+ print("✅ Deleted ~/.cache to free space.")
304
+ except:
305
+ print("⚠️ Could not delete cache")
306
+
307
+ # ===== 全局模型变量 =====
308
+ SEG_MODEL = None
309
+ SEG_DEVICE = torch.device("cpu")
310
+
311
+ COUNT_MODEL = None
312
+ COUNT_DEVICE = torch.device("cpu")
313
+
314
+ TRACK_MODEL = None
315
+ TRACK_DEVICE = torch.device("cpu")
316
+
317
+ def load_all_models():
318
+ """启动时加载所有模型"""
319
+ global SEG_MODEL, SEG_DEVICE
320
+ global COUNT_MODEL, COUNT_DEVICE
321
+ global TRACK_MODEL, TRACK_DEVICE
322
+
323
+ # 加载分割模型
324
+ print("\n" + "="*60)
325
+ print("📦 Loading Segmentation Model")
326
+ print("="*60)
327
+ SEG_MODEL, SEG_DEVICE = load_seg_model(use_box=False)
328
+
329
+ # 加载计数模型
330
+ print("\n" + "="*60)
331
+ print("📦 Loading Counting Model")
332
+ print("="*60)
333
+ COUNT_MODEL, COUNT_DEVICE = load_count_model(use_box=False)
334
+
335
+ # 加载跟踪模型
336
+ print("\n" + "="*60)
337
+ print("📦 Loading Tracking Model")
338
+ print("="*60)
339
+ TRACK_MODEL, TRACK_DEVICE = load_track_model(use_box=False)
340
+
341
+ print("\n" + "="*60)
342
+ print("✅ All Models Loaded Successfully")
343
+ print("="*60)
344
 
345
+ # 启动时加载所有模型
346
+ load_all_models()
 
347
 
348
+ # ===== 辅助函数 =====
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  def parse_first_bbox(bboxes):
350
+ """解析第一个边界框"""
351
  if not bboxes:
352
  return None
353
  b = bboxes[0]
 
359
  return float(b[0]), float(b[1]), float(b[2]), float(b[3])
360
  return None
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
363
+ """将实例掩码转换为彩色图像"""
364
  mask = mask.astype(np.int32)
365
 
366
  def hsv_to_rgb(hh, ss, vv):
 
387
  palette_arr = np.array(palette, dtype=np.uint8)
388
  return palette_arr[color_idx]
389
 
390
+ # ===== 分割功能 =====
391
  def segment_with_choice(use_box_choice, annot_value, mode="Overlay"):
392
+ """分割处理函数"""
393
  if annot_value is None or len(annot_value) < 1:
394
+ return None, "⚠️ 请先上传图像"
 
395
 
396
  img_path = annot_value[0]
397
  bboxes = annot_value[1] if len(annot_value) > 1 else []
398
 
399
+ print(f"🖼️ Segmentation - Image: {img_path}")
400
+
401
  box_array = None
402
  if use_box_choice == "Yes" and bboxes:
403
  box = parse_first_bbox(bboxes)
 
407
  print(f"📦 Using box: {box_array}")
408
 
409
  try:
410
+ # 运行分割
411
+ mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
412
+
413
+ if mask is None:
414
+ return None, "❌ 分割失败"
415
+
416
+ print(f"✅ Segmentation done - Mask shape: {mask.shape}")
 
 
417
  except Exception as e:
418
+ print(f"❌ Segmentation error: {e}")
419
+ return None, f"分割失败: {str(e)}"
420
 
421
  try:
422
+ # 读取原图
423
+ img = Image.open(img_path).convert("RGB")
424
+ img_rgb = img.resize(mask.shape[::-1], resample=Image.BILINEAR)
425
  img_np = np.array(img_rgb, dtype=np.float32)
426
  if img_np.max() > 1.5:
427
  img_np = img_np / 255.0
428
  except Exception as e:
429
+ print(f"❌ Image processing error: {e}")
430
+ return None, f"图像处理失败: {str(e)}"
431
 
432
+ # 生成可视化
433
  mask_np = np.array(mask)
434
  inst_mask = mask_np.astype(np.int32)
435
  unique_ids = np.unique(inst_mask)
436
  num_instances = len(unique_ids[unique_ids != 0])
437
+
 
438
  if num_instances == 0:
439
+ result_text = "⚠️ 未检测到细胞"
440
+ return Image.new("RGB", mask.shape[::-1], (255, 200, 200)), result_text
441
 
442
+ # 创建叠加图
443
  overlay = img_np.copy()
444
  alpha = 0.5
445
+ cmap_vis = cm.get_cmap("nipy_spectral", num_instances + 1)
446
 
447
+ for inst_id in unique_ids:
448
  if inst_id == 0:
449
  continue
450
  binary_mask = (inst_mask == inst_id).astype(np.uint8)
451
+ color = np.array(cmap_vis(inst_id / (num_instances + 1))[:3])
452
  overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
453
 
454
+ # 绘制轮廓
455
  contours = measure.find_contours(binary_mask, 0.5)
456
  for contour in contours:
457
  contour = contour.astype(np.int32)
458
+ overlay[contour[:, 0], contour[:, 1]] = [1.0, 1.0, 0.0]
459
 
460
  overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
461
+ result_text = f"✅ 检测到 {num_instances} 个细胞"
462
 
463
  if mode == "Instance Mask Only":
464
+ return Image.fromarray(colorize_mask(inst_mask, num_colors=512)), result_text
465
 
466
+ return Image.fromarray(overlay), result_text
467
 
468
+ # ===== 计数功能 =====
469
+ def count_cells_handler(image_path):
470
+ """计数处理函数"""
 
 
 
471
  if image_path is None:
472
+ return None, "⚠️ 请先上传图像"
473
+
474
+ if COUNT_MODEL is None:
475
+ return None, "❌ 计数模型未加载"
476
 
477
  try:
478
+ print(f"🔢 Counting - Image: {image_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
+ result = run_count(
481
+ COUNT_MODEL,
482
+ image_path,
483
+ box=None,
484
+ device=COUNT_DEVICE,
485
+ visualize=True
486
+ )
487
 
488
+ if 'error' in result:
489
+ return None, f"❌ 计数失败: {result['error']}"
490
 
491
+ count = result['count']
492
+ viz_path = result['visualized_path']
493
+ result_text = f"✅ 检测到 {count:.1f} 个细胞"
 
 
 
 
 
 
494
 
495
+ print(f"✅ Counting done - Count: {count:.1f}")
496
 
497
+ return viz_path, result_text
 
 
 
 
 
 
 
 
498
 
499
  except Exception as e:
500
  print(f"❌ Counting error: {e}")
501
+ import traceback
502
+ traceback.print_exc()
503
+ return None, f"❌ 计数失败: {str(e)}"
504
+
505
+ # ===== 跟踪功能 =====
506
+ def track_video_handler(video_dir_input):
507
+ """跟踪处理函数"""
508
+ if video_dir_input is None or video_dir_input.strip() == "":
509
+ return None, "⚠️ 请输入视频帧目录路径"
510
+
511
+ if TRACK_MODEL is None:
512
+ return None, "❌ 跟踪模型未加载"
513
+
514
+ # 检查目录是否存在
515
+ if not os.path.exists(video_dir_input):
516
+ return None, f"❌ 目录不存在: {video_dir_input}"
517
 
518
  try:
519
+ print(f"🎬 Tracking - Video dir: {video_dir_input}")
520
 
521
+ result = run_track(
522
+ TRACK_MODEL,
523
+ video_dir=video_dir_input,
524
+ box=None,
525
+ device=TRACK_DEVICE,
526
+ output_dir="tracked_results"
527
+ )
528
 
529
+ if 'error' in result:
530
+ return None, f"❌ 跟踪失败: {result['error']}"
 
 
531
 
532
+ num_tracks = result['num_tracks']
533
+ output_dir = result['output_dir']
534
 
535
+ result_text = f"""✅ 跟踪完成!
536
+
537
+ 🎯 跟踪轨迹数量: {num_tracks}
538
+ 📁 结果保存在: {output_dir}
539
+
540
+ 包含的文件:
541
+ - res_track.txt (CTC格式轨迹)
542
+ - 其他跟踪数据文件
543
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
+ print(f"✅ Tracking done - {num_tracks} tracks")
 
546
 
547
+ return None, result_text
548
 
549
  except Exception as e:
550
  print(f"❌ Tracking error: {e}")
551
+ import traceback
552
+ traceback.print_exc()
553
+ return None, f"❌ 跟踪失败: {str(e)}"
 
 
 
 
 
554
 
555
  # ===== Gradio UI =====
556
  with gr.Blocks(title="Microscopy Analysis Suite", theme=gr.themes.Soft()) as demo:
557
  gr.Markdown(
558
  """
559
  # 🔬 显微图像分析工具套件
560
+
561
+ 支持三种分析模式:
562
+ - 🎨 **分割 (Segmentation)**: 实例分割,每个细胞不同颜色
563
+ - 🔢 **计数 (Counting)**: 基于密度图的细胞计数
564
+ - 🎬 **跟踪 (Tracking)**: 视频序列中的细胞运动跟踪
565
  """
566
  )
567
 
568
  with gr.Tabs():
569
  # ===== Tab 1: Segmentation =====
570
  with gr.Tab("🎨 分割 (Segmentation)"):
571
+ gr.Markdown("## 细胞实例分割 - 每个细胞一个颜色")
572
 
573
  with gr.Row():
574
  with gr.Column(scale=1):
575
+ annotator = BBoxAnnotator(
576
+ label="🖼️ 上传图像 (可选标注边界框)",
577
+ categories=["cell"]
578
+ )
579
+
580
+ with gr.Row():
581
+ use_box_radio = gr.Radio(
582
+ choices=["Yes", "No"],
583
+ value="No",
584
+ label="🔲 使用边界框?"
585
+ )
586
+ mode_radio = gr.Radio(
587
+ choices=["Overlay", "Instance Mask Only"],
588
+ value="Overlay",
589
+ label="🎨 显示模式"
590
+ )
591
+
592
+ run_seg_btn = gr.Button("▶️ 运行分割", variant="primary", size="lg")
593
+
594
+ gr.Markdown(
595
+ """
596
+ **使用说明:**
597
+ 1. 上传图像
598
+ 2. (可选) 标注边界框并选择 "Yes"
599
+ 3. 选择显示模式
600
+ 4. 点击 "运行分割"
601
+ """
602
  )
 
 
 
 
 
 
 
 
603
 
604
  with gr.Column(scale=2):
605
+ seg_output = gr.Image(
606
+ type="pil",
607
+ label="📸 分割结果",
608
+ height=500
609
+ )
610
+ seg_status = gr.Textbox(
611
+ label="📊 状态信息",
612
+ lines=2
613
+ )
614
+
615
+ # 绑定事件
616
+ run_seg_btn.click(
617
+ fn=segment_with_choice,
618
+ inputs=[use_box_radio, annotator, mode_radio],
619
+ outputs=[seg_output, seg_status]
620
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
  # ===== Tab 2: Counting =====
623
  with gr.Tab("🔢 计数 (Counting)"):
624
+ gr.Markdown("## 细胞计数分析 - 基于密度图")
625
 
626
  with gr.Row():
627
  with gr.Column(scale=1):
 
629
  label="🖼️ 上传图像",
630
  type="filepath"
631
  )
632
+ count_btn = gr.Button("▶️ 运行计数", variant="primary", size="lg")
633
 
634
  gr.Markdown(
635
  """
636
+ **使用说明:**
637
+ 1. 上传细胞图像
638
+ 2. 点击 "运行计数"
639
+ 3. 查看密度图和计数结果
640
+
641
+ **特点:**
642
+ - 基于 Stable Diffusion 特征
643
+ - 自动生成密度图
644
+ - 无需手动标注
645
  """
646
  )
647
 
648
  with gr.Column(scale=2):
649
+ count_output = gr.Image(
650
+ label="📸 计数结果 (左: 原图 | 右: 密度图)",
651
+ type="filepath",
652
+ height=500
653
  )
654
+ count_status = gr.Textbox(
655
+ label="📊 统计信息",
656
  lines=2
657
  )
 
 
 
 
658
 
659
  # 绑定事件
660
  count_btn.click(
661
+ fn=count_cells_handler,
662
  inputs=count_input,
663
+ outputs=[count_output, count_status]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
  )
665
 
666
  # ===== Tab 3: Tracking =====
667
  with gr.Tab("🎬 跟踪 (Tracking)"):
668
+ gr.Markdown("## 视频细胞跟踪 - 时间序列分析")
669
 
670
  with gr.Row():
671
  with gr.Column(scale=1):
672
+ track_input = gr.Textbox(
673
+ label="📁 视频帧目录路径",
674
+ placeholder="例如: example_imgs/2D+Time/Fluo-N2DL-HeLa/train/02/",
675
+ lines=2
676
  )
677
+ track_btn = gr.Button("▶️ 运行跟踪", variant="primary", size="lg")
678
 
679
  gr.Markdown(
680
  """
681
+ **使用说明:**
682
+ 1. 输入包含视频帧序列的目录路径
683
+ 2. 目录应包含: t000.tif, t001.tif, ...
684
+ 3. 点击 "运行跟踪"
685
+ 4. 结果将保存到 `tracked_results/` 目录
686
+
687
+ **输入格式:**
688
+ ```
689
+ video_dir/
690
+ ├── t000.tif
691
+ ├── t001.tif
692
+ ├── t002.tif
693
+ └── ...
694
+ ```
695
+
696
+ **跟踪模式:** Greedy (快速)
697
  """
698
  )
699
 
700
  with gr.Column(scale=2):
701
+ track_output = gr.Video(
702
+ label="📹 跟踪结果视频 (暂不支持)",
703
+ visible=False
704
  )
705
+ track_status = gr.Textbox(
706
+ label="📊 跟踪信息",
707
+ lines=10
708
  )
 
 
 
 
709
 
710
  # 绑定事件
711
  track_btn.click(
712
+ fn=track_video_handler,
713
  inputs=track_input,
714
+ outputs=[track_output, track_status]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  )
716
 
 
717
  gr.Markdown(
718
  """
719
  ---
720
+ ### 💡 技术说明
721
+
722
+ **分割 (Segmentation)**
723
+ - 模型: 基于 Stable Diffusion 特征的实例分割
724
+ - 输出: 每个细胞一个唯一颜色的掩码
725
+
726
+ **计数 (Counting)**
727
+ - 模型: 密度图估计
728
+ - 输出: 密度热力图 + 总计数
729
+
730
+ **跟踪 (Tracking)**
731
+ - 模型: Trackastra 跟踪算法
732
+ - 输出: CTC 格式的轨迹文件
733
+
734
+ ---
735
+ 📧 问题反馈 | 🌟 GitHub
736
  """
737
  )
738
 
739
  if __name__ == "__main__":
740
  demo.queue().launch(
741
+ server_name="0.0.0.0",
742
+ server_port=7860,
743
+ share=True,
744
  show_error=True
745
  )