VisionLanguageGroup commited on
Commit
047ae7d
·
1 Parent(s): 471a816

update counting and segmentation, tracking-basic

Browse files
Files changed (7) hide show
  1. app.py +89 -31
  2. counting.py +0 -2
  3. inference_count.py +27 -15
  4. inference_seg.py +18 -3
  5. inference_track.py +4 -3
  6. segmentation.py +0 -1
  7. tracking_one.py +0 -1
app.py CHANGED
@@ -21,7 +21,7 @@ from inference_track import load_model as load_track_model, run as run_track
21
 
22
  # ===== 清理缓存目录 =====
23
  print("===== 清理缓存 =====")
24
- cache_path = os.path.expanduser("~/.cache")
25
  if os.path.exists(cache_path):
26
  try:
27
  shutil.rmtree(cache_path)
@@ -138,8 +138,8 @@ def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
138
 
139
  # ===== 分割功能 =====
140
  def segment_with_choice(use_box_choice, annot_value):
141
- print(use_box_choice)
142
- print(annot_value)
143
  """分割主函数 - 每个实例不同颜色+轮廓"""
144
  if annot_value is None or len(annot_value) < 1:
145
  print("❌ No annotation input")
@@ -256,34 +256,86 @@ def count_cells_handler(use_box_choice, annot_value):
256
  return None, f"❌ 计数失败: {result['error']}"
257
 
258
  count = result['count']
 
 
 
 
 
259
 
260
  # 只提取密度图部分(假设visualized_path是拼接图,我们只要右半部分)
261
- viz_path = result.get('visualized_path')
262
 
263
  # 如果有density_map_path,直接使用
264
- if 'density_map_path' in result:
265
- density_path = result['density_map_path']
266
- elif viz_path and os.path.exists(viz_path):
267
- # 如果是拼接图,提取右半部分(密度图)
268
- try:
269
- viz_img = Image.open(viz_path)
270
- w, h = viz_img.size
271
- # 取右半部分
272
- density_img = viz_img.crop((w//2, 0, w, h))
273
- # 保存为新文件
274
- temp_density = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
275
- density_img.save(temp_density.name)
276
- density_path = temp_density.name
277
- except:
278
- density_path = viz_path
279
- else:
280
- density_path = viz_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  result_text = f"✅ 检测到 {count:.1f} 个细胞"
283
 
284
  print(f"✅ Counting done - Count: {count:.1f}")
 
 
285
 
286
- return density_path, result_text
287
 
288
  except Exception as e:
289
  print(f"❌ Counting error: {e}")
@@ -295,6 +347,8 @@ def count_cells_handler(use_box_choice, annot_value):
295
  def find_tif_dir(root_dir):
296
  """递归查找第一个包含 .tif 文件的目录"""
297
  for dirpath, _, filenames in os.walk(root_dir):
 
 
298
  if any(f.lower().endswith('.tif') for f in filenames):
299
  return dirpath
300
  return None
@@ -328,20 +382,18 @@ def track_video_handler(zip_file_obj):
328
  if 'error' in result:
329
  return None, f"❌ 跟踪失败: {result['error']}"
330
 
331
- num_tracks = result['num_tracks']
332
  output_dir = result['output_dir']
333
 
334
  result_text = f"""✅ 跟踪完成!
335
 
336
- 🎯 跟踪轨迹数量: {num_tracks}
337
- 📁 结果保存在: {output_dir}
338
 
339
- 包含的文件:
340
- - res_track.txt (CTC格式轨迹)
341
- - 其他跟踪数据文件
342
- """
343
 
344
- print(f"✅ Tracking done - {num_tracks} tracks")
345
  return None, result_text
346
 
347
  except zipfile.BadZipFile:
@@ -575,12 +627,18 @@ with gr.Blocks(title="Microscopy Analysis Suite", theme=gr.themes.Soft()) as dem
575
  label="📊 统计信息",
576
  lines=2
577
  )
 
 
 
 
 
 
578
 
579
  # 绑定事件
580
  count_btn.click(
581
  fn=count_cells_handler,
582
  inputs=[count_use_box_radio, count_annotator],
583
- outputs=[count_output, count_status]
584
  )
585
 
586
  # 初始化Gallery显示
 
21
 
22
  # ===== 清理缓存目录 =====
23
  print("===== 清理缓存 =====")
24
+ cache_path = os.path.expanduser("~/.cache/huggingface/gradio")
25
  if os.path.exists(cache_path):
26
  try:
27
  shutil.rmtree(cache_path)
 
138
 
139
  # ===== 分割功能 =====
140
  def segment_with_choice(use_box_choice, annot_value):
141
+ print("边界框选择:", use_box_choice)
142
+ print("注释值:", annot_value)
143
  """分割主函数 - 每个实例不同颜色+轮廓"""
144
  if annot_value is None or len(annot_value) < 1:
145
  print("❌ No annotation input")
 
256
  return None, f"❌ 计数失败: {result['error']}"
257
 
258
  count = result['count']
259
+ density_map = result['density_map']
260
+ # save density map as temp file
261
+ temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy")
262
+ np.save(temp_density_file.name, density_map)
263
+ print(f"💾 Density map saved to {temp_density_file.name}")
264
 
265
  # 只提取密度图部分(假设visualized_path是拼接图,我们只要右半部分)
266
+ # viz_path = result.get('visualized_path')
267
 
268
  # 如果有density_map_path,直接使用
269
+ # if 'density_map_path' in result:
270
+ # density_path = result['density_map_path']
271
+ # elif viz_path and os.path.exists(viz_path):
272
+ # # 如果是拼接图,提取右半部分(密度图)
273
+ # try:
274
+ # viz_img = Image.open(viz_path)
275
+ # w, h = viz_img.size
276
+ # # 取右半部分
277
+ # density_img = viz_img
278
+ # # 保存为新文件
279
+ # temp_density = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
280
+ # density_img.save(temp_density.name)
281
+ # density_path = temp_density.name
282
+ # except:
283
+ # density_path = viz_path
284
+ # else:
285
+ # density_path = viz_path
286
+
287
+ # 读取原图
288
+ try:
289
+ img = Image.open(image_path)
290
+ print("📷 Image mode:", img.mode, "size:", img.size)
291
+ except Exception as e:
292
+ print(f"❌ Failed to open image: {e}")
293
+ return None, None
294
+
295
+ try:
296
+ img_rgb = img.convert("RGB").resize(density_map.shape[::-1], resample=Image.BILINEAR)
297
+ img_np = np.array(img_rgb, dtype=np.float32)
298
+ img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
299
+ if img_np.max() > 1.5:
300
+ img_np = img_np / 255.0
301
+ except Exception as e:
302
+ print(f"❌ Error in image conversion/resizing: {e}")
303
+ return None, None
304
+
305
+
306
+ # Normalize density map to [0, 1]
307
+ density_normalized = density_map.copy()
308
+ if density_normalized.max() > 0:
309
+ density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min())
310
+
311
+ # Apply colormap
312
+ cmap = cm.get_cmap("jet")
313
+ alpha = 0.3
314
+ density_colored = cmap(density_normalized)[:, :, :3] # RGB only, ignore alpha
315
+
316
+ # Create overlay
317
+ overlay = img_np.copy()
318
+
319
+ # Blend only where density is significant (optional: threshold)
320
+ threshold = 0.01 # Only overlay where density > 1% of max
321
+ significant_mask = density_normalized > threshold
322
+
323
+ overlay[significant_mask] = (1 - alpha) * overlay[significant_mask] + alpha * density_colored[significant_mask]
324
+
325
+ # Clip and convert to uint8
326
+ overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
327
+
328
+
329
+
330
+
331
 
332
  result_text = f"✅ 检测到 {count:.1f} 个细胞"
333
 
334
  print(f"✅ Counting done - Count: {count:.1f}")
335
+
336
+ return Image.fromarray(overlay), temp_density_file.name, result_text
337
 
338
+ # return density_path, result_text
339
 
340
  except Exception as e:
341
  print(f"❌ Counting error: {e}")
 
347
  def find_tif_dir(root_dir):
348
  """递归查找第一个包含 .tif 文件的目录"""
349
  for dirpath, _, filenames in os.walk(root_dir):
350
+ if '__MACOSX' in dirpath:
351
+ continue
352
  if any(f.lower().endswith('.tif') for f in filenames):
353
  return dirpath
354
  return None
 
382
  if 'error' in result:
383
  return None, f"❌ 跟踪失败: {result['error']}"
384
 
 
385
  output_dir = result['output_dir']
386
 
387
  result_text = f"""✅ 跟踪完成!
388
 
389
+ 📁 结果保存在: {output_dir}
 
390
 
391
+ 包含的文件:
392
+ - res_track.txt (CTC格式轨迹)
393
+ - 其他跟踪数据文件
394
+ """
395
 
396
+ print(f"✅ Tracking done")
397
  return None, result_text
398
 
399
  except zipfile.BadZipFile:
 
627
  label="📊 统计信息",
628
  lines=2
629
  )
630
+
631
+ # 下载原始预测结果
632
+ download_density_btn = gr.File(
633
+ label="📥 下载原始预测 (.npy 格式)",
634
+ visible=True
635
+ )
636
 
637
  # 绑定事件
638
  count_btn.click(
639
  fn=count_cells_handler,
640
  inputs=[count_use_box_radio, count_annotator],
641
+ outputs=[count_output, download_density_btn, count_status]
642
  )
643
 
644
  # 初始化Gallery显示
counting.py CHANGED
@@ -1,7 +1,5 @@
1
  # stable diffusion x loca
2
  import os
3
- # os.system("source /etc/network_turbo")
4
- os.environ["CUDA_VISIBLE_DEVICES"] = "2"
5
  import pprint
6
  from typing import Any, List, Optional
7
  import argparse
 
1
  # stable diffusion x loca
2
  import os
 
 
3
  import pprint
4
  from typing import Any, List, Optional
5
  import argparse
inference_count.py CHANGED
@@ -49,7 +49,14 @@ def load_model(use_box=False):
49
  )
50
  MODEL.eval()
51
 
52
- DEVICE = torch.device("cpu")
 
 
 
 
 
 
 
53
 
54
  print("✅ Counting model loaded successfully")
55
  return MODEL, DEVICE
@@ -80,6 +87,15 @@ def run(model, img_path, box=None, device="cpu", visualize=True):
80
  'visualized_path': str (如果 visualize=True)
81
  }
82
  """
 
 
 
 
 
 
 
 
 
83
  if model is None:
84
  return {
85
  'density_map': None,
@@ -92,7 +108,8 @@ def run(model, img_path, box=None, device="cpu", visualize=True):
92
  print(f"🔄 Running counting inference on {img_path}")
93
 
94
  # 运行推理 (调用你的模型的 forward 方法)
95
- density_map, count = model(img_path, box)
 
96
 
97
  print(f"✅ Counting result: {count:.1f} objects")
98
 
@@ -103,9 +120,9 @@ def run(model, img_path, box=None, device="cpu", visualize=True):
103
  }
104
 
105
  # 可视化
106
- if visualize:
107
- viz_path = visualize_result(img_path, density_map, count)
108
- result['visualized_path'] = viz_path
109
 
110
  return result
111
 
@@ -153,18 +170,13 @@ def visualize_result(image_path, density_map, count):
153
  img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
154
 
155
  # 创建可视化 (与你原来的代码一致)
156
- fig, ax = plt.subplots(1, 2, figsize=(12, 6))
157
-
158
- # 左图: 原始图像
159
- ax[0].imshow(img_show)
160
- ax[0].axis('off')
161
- ax[0].set_title(f"Input image")
162
 
163
  # 右图: 密度图叠加
164
- ax[1].imshow(img_show)
165
- ax[1].imshow(density_map_show, cmap='jet', alpha=0.5)
166
- ax[1].axis('off')
167
- ax[1].set_title(f"Predicted density map, count: {count:.1f}")
168
 
169
  plt.tight_layout()
170
 
 
49
  )
50
  MODEL.eval()
51
 
52
+ if torch.cuda.is_available():
53
+ DEVICE = torch.device("cuda")
54
+ MODEL.move_to_device(DEVICE)
55
+ print("✅ Model moved to CUDA")
56
+ else:
57
+ DEVICE = torch.device("cpu")
58
+ MODEL.move_to_device(DEVICE)
59
+ print("✅ Model on CPU")
60
 
61
  print("✅ Counting model loaded successfully")
62
  return MODEL, DEVICE
 
87
  'visualized_path': str (如果 visualize=True)
88
  }
89
  """
90
+ print("DEVICE:", device)
91
+ model.move_to_device(device)
92
+ model.eval()
93
+ if box is not None:
94
+ use_box = True
95
+ else:
96
+ use_box = False
97
+ model.use_box = use_box
98
+
99
  if model is None:
100
  return {
101
  'density_map': None,
 
108
  print(f"🔄 Running counting inference on {img_path}")
109
 
110
  # 运行推理 (调用你的模型的 forward 方法)
111
+ with torch.no_grad():
112
+ density_map, count = model(img_path, box)
113
 
114
  print(f"✅ Counting result: {count:.1f} objects")
115
 
 
120
  }
121
 
122
  # 可视化
123
+ # if visualize:
124
+ # viz_path = visualize_result(img_path, density_map, count)
125
+ # result['visualized_path'] = viz_path
126
 
127
  return result
128
 
 
170
  img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
171
 
172
  # 创建可视化 (与你原来的代码一致)
173
+ fig, ax = plt.subplots(figsize=(8, 6))
 
 
 
 
 
174
 
175
  # 右图: 密度图叠加
176
+ ax.imshow(img_show)
177
+ ax.imshow(density_map_show, cmap='jet', alpha=0.5)
178
+ ax.axis('off')
179
+ # ax.set_title(f"Predicted density map, count: {count:.1f}")
180
 
181
  plt.tight_layout()
182
 
inference_seg.py CHANGED
@@ -18,15 +18,30 @@ def load_model(use_box=False):
18
  )
19
  MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
20
  MODEL.eval()
21
- DEVICE = torch.device("cpu")
 
 
 
 
 
 
 
22
  return MODEL, DEVICE
23
 
24
 
25
  @torch.no_grad()
26
  def run(model, img_path, box=None, device="cpu"):
27
- output = model(img_path, box=box)
 
 
 
 
 
 
 
 
 
28
  mask = output
29
- mask = (mask > 0).astype(np.uint8)
30
  return mask
31
  # import os
32
  # import torch
 
18
  )
19
  MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
20
  MODEL.eval()
21
+ if torch.cuda.is_available():
22
+ DEVICE = torch.device("cuda")
23
+ MODEL.move_to_device(DEVICE)
24
+ print("✅ Model moved to CUDA")
25
+ else:
26
+ DEVICE = torch.device("cpu")
27
+ MODEL.move_to_device(DEVICE)
28
+ print("✅ Model on CPU")
29
  return MODEL, DEVICE
30
 
31
 
32
  @torch.no_grad()
33
  def run(model, img_path, box=None, device="cpu"):
34
+ print("DEVICE:", device)
35
+ model.move_to_device(device)
36
+ model.eval()
37
+ with torch.no_grad():
38
+ if box is not None:
39
+ use_box = True
40
+ else:
41
+ use_box = False
42
+ model.use_box = use_box
43
+ output = model(img_path, box=box)
44
  mask = output
 
45
  return mask
46
  # import os
47
  # import torch
inference_track.py CHANGED
@@ -120,17 +120,18 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
120
  masks,
121
  outdir=output_dir,
122
  )
 
123
 
124
- num_tracks = len(track_graph.tracks())
125
 
126
- print(f"✅ Tracking completed: {num_tracks} tracks found")
127
 
128
  result = {
129
  'track_graph': track_graph,
130
  'masks': masks,
131
  'masks_tracked': masks_tracked,
132
  'output_dir': output_dir,
133
- 'num_tracks': num_tracks
134
  }
135
 
136
  return result
 
120
  masks,
121
  outdir=output_dir,
122
  )
123
+ print(f"✅ CTC results saved to {output_dir}")
124
 
125
+ # num_tracks = len(track_graph.tracks())
126
 
127
+ print(f"✅ Tracking completed")
128
 
129
  result = {
130
  'track_graph': track_graph,
131
  'masks': masks,
132
  'masks_tracked': masks_tracked,
133
  'output_dir': output_dir,
134
+ # 'num_tracks': num_tracks
135
  }
136
 
137
  return result
segmentation.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- os.environ["CUDA_VISIBLE_DEVICES"] = "2"
3
  import pprint
4
  from typing import Any, List, Optional
5
  import argparse
 
1
  import os
 
2
  import pprint
3
  from typing import Any, List, Optional
4
  import argparse
tracking_one.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3
  import pprint
4
  from typing import Any, List, Optional
5
  import argparse
 
1
  import os
 
2
  import pprint
3
  from typing import Any, List, Optional
4
  import argparse