phoebe777777 commited on
Commit
ad05714
·
verified ·
1 Parent(s): 7799648

Delete app_cn.py

Browse files
Files changed (1) hide show
  1. app_cn.py +0 -1515
app_cn.py DELETED
@@ -1,1515 +0,0 @@
1
- import gradio as gr
2
- from gradio_bbox_annotator import BBoxAnnotator
3
- from PIL import Image
4
- import numpy as np
5
- import torch
6
- import os
7
- import shutil
8
- import time
9
- import json
10
- import uuid
11
- from pathlib import Path
12
- import tempfile
13
- import zipfile
14
- from skimage import measure
15
- from matplotlib import cm
16
- from glob import glob
17
- from natsort import natsorted
18
-
19
- # ===== 导入三个推理模块 =====
20
- from inference_seg import load_model as load_seg_model, run as run_seg
21
- from inference_count import load_model as load_count_model, run as run_count
22
- from inference_track import load_model as load_track_model, run as run_track
23
-
24
- # ===== 清理缓存目录 =====
25
- print("===== 清理缓存 =====")
26
- # cache_path = os.path.expanduser("~/.cache/")
27
- cache_path = os.path.expanduser("~/.cache/huggingface/gradio")
28
- if os.path.exists(cache_path):
29
- try:
30
- shutil.rmtree(cache_path)
31
- # print("✅ Deleted ~/.cache/")
32
- print("✅ Deleted ~/.cache/huggingface/gradio")
33
- except:
34
- pass
35
-
36
- # ===== 全局模型变量 =====
37
- SEG_MODEL = None
38
- SEG_DEVICE = torch.device("cpu")
39
-
40
- COUNT_MODEL = None
41
- COUNT_DEVICE = torch.device("cpu")
42
-
43
- TRACK_MODEL = None
44
- TRACK_DEVICE = torch.device("cpu")
45
-
46
- def load_all_models():
47
- """启动时加载所有模型"""
48
- global SEG_MODEL, SEG_DEVICE
49
- global COUNT_MODEL, COUNT_DEVICE
50
- global TRACK_MODEL, TRACK_DEVICE
51
-
52
- print("\n" + "="*60)
53
- print("📦 Loading Segmentation Model")
54
- print("="*60)
55
- SEG_MODEL, SEG_DEVICE = load_seg_model(use_box=False)
56
-
57
- print("\n" + "="*60)
58
- print("📦 Loading Counting Model")
59
- print("="*60)
60
- COUNT_MODEL, COUNT_DEVICE = load_count_model(use_box=False)
61
-
62
- print("\n" + "="*60)
63
- print("📦 Loading Tracking Model")
64
- print("="*60)
65
- TRACK_MODEL, TRACK_DEVICE = load_track_model(use_box=False)
66
-
67
- print("\n" + "="*60)
68
- print("✅ All Models Loaded Successfully")
69
- print("="*60)
70
-
71
- load_all_models()
72
-
73
- # ===== 保存用户反馈 =====
74
- DATASET_DIR = Path("solver_cache")
75
- DATASET_DIR.mkdir(parents=True, exist_ok=True)
76
-
77
- def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
78
- """保存用户反馈到JSON文件"""
79
- feedback_data = {
80
- "query_id": query_id,
81
- "feedback_type": feedback_type,
82
- "feedback_text": feedback_text,
83
- "image": img_path,
84
- "bboxes": bboxes,
85
- "datetime": time.strftime("%Y%m%d_%H%M%S")
86
- }
87
- feedback_file = DATASET_DIR / query_id / "feedback.json"
88
- feedback_file.parent.mkdir(parents=True, exist_ok=True)
89
-
90
- if feedback_file.exists():
91
- with feedback_file.open("r") as f:
92
- existing = json.load(f)
93
- if not isinstance(existing, list):
94
- existing = [existing]
95
- existing.append(feedback_data)
96
- feedback_data = existing
97
- else:
98
- feedback_data = [feedback_data]
99
-
100
- with feedback_file.open("w") as f:
101
- json.dump(feedback_data, f, indent=4, ensure_ascii=False)
102
-
103
- # ===== 辅助函数 =====
104
- def parse_first_bbox(bboxes):
105
- """解析第一个边界框"""
106
- if not bboxes:
107
- return None
108
- b = bboxes[0]
109
- if isinstance(b, dict):
110
- x, y = float(b.get("x", 0)), float(b.get("y", 0))
111
- w, h = float(b.get("width", 0)), float(b.get("height", 0))
112
- return x, y, x + w, y + h
113
- if isinstance(b, (list, tuple)) and len(b) >= 4:
114
- return float(b[0]), float(b[1]), float(b[2]), float(b[3])
115
- return None
116
-
117
- def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
118
- """将实例掩码转换为彩色图像"""
119
- def hsv_to_rgb(h, s, v):
120
- i = int(h * 6.0)
121
- f = h * 6.0 - i
122
- i = i % 6
123
- p = v * (1 - s)
124
- q = v * (1 - f * s)
125
- t = v * (1 - (1 - f) * s)
126
- if i == 0: r, g, b = v, t, p
127
- elif i == 1: r, g, b = q, v, p
128
- elif i == 2: r, g, b = p, v, t
129
- elif i == 3: r, g, b = p, q, v
130
- elif i == 4: r, g, b = t, p, v
131
- else: r, g, b = v, p, q
132
- return int(r * 255), int(g * 255), int(b * 255)
133
-
134
- palette = [(0, 0, 0)]
135
- for i in range(1, num_colors):
136
- h = (i % num_colors) / float(num_colors)
137
- palette.append(hsv_to_rgb(h, 1.0, 0.95))
138
-
139
- palette_arr = np.array(palette, dtype=np.uint8)
140
- color_idx = mask % num_colors
141
- return palette_arr[color_idx]
142
-
143
- # ===== 分割功能 =====
144
- def segment_with_choice(use_box_choice, annot_value):
145
- print("边界框选择:", use_box_choice)
146
- print("注释值:", annot_value)
147
- """分割主函数 - 每个实例不同颜色+轮廓"""
148
- if annot_value is None or len(annot_value) < 1:
149
- print("❌ No annotation input")
150
- return None, None
151
-
152
- img_path = annot_value[0]
153
- bboxes = annot_value[1] if len(annot_value) > 1 else []
154
-
155
- print(f"🖼️ 图像路径: {img_path}")
156
- box_array = None
157
- if use_box_choice == "Yes" and bboxes:
158
- box = parse_first_bbox(bboxes)
159
- if box:
160
- xmin, ymin, xmax, ymax = map(int, box)
161
- box_array = [[xmin, ymin, xmax, ymax]]
162
- print(f"📦 使用边界框: {box_array}")
163
-
164
- # 运行分割模型
165
- try:
166
- mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
167
- print("📏 mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
168
- except Exception as e:
169
- print(f"❌ 推理失败: {str(e)}")
170
- return None, None
171
-
172
- # 保存原始mask为TIF文件
173
- temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif")
174
- mask_img = Image.fromarray(mask.astype(np.uint16))
175
- mask_img.save(temp_mask_file.name)
176
- print(f"💾 原始mask保存到: {temp_mask_file.name}")
177
-
178
- # 读取原图
179
- try:
180
- img = Image.open(img_path)
181
- print("📷 Image mode:", img.mode, "size:", img.size)
182
- except Exception as e:
183
- print(f"❌ Failed to open image: {e}")
184
- return None, None
185
-
186
- try:
187
- img_rgb = img.convert("RGB").resize(mask.shape[::-1], resample=Image.BILINEAR)
188
- img_np = np.array(img_rgb, dtype=np.float32)
189
- if img_np.max() > 1.5:
190
- img_np = img_np / 255.0
191
- except Exception as e:
192
- print(f"❌ Error in image conversion/resizing: {e}")
193
- return None, None
194
-
195
- mask_np = np.array(mask)
196
- inst_mask = mask_np.astype(np.int32)
197
- unique_ids = np.unique(inst_mask)
198
- num_instances = len(unique_ids[unique_ids != 0])
199
- print(f"✅ Instance IDs found: {unique_ids}, Total instances: {num_instances}")
200
-
201
- if num_instances == 0:
202
- print("⚠️ No instance found, returning dummy red image")
203
- return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None
204
-
205
- # ==== Color Overlay (每个实例一个颜色) ====
206
- overlay = img_np.copy()
207
- alpha = 0.5
208
- # cmap = cm.get_cmap("hsv", num_instances + 1)
209
-
210
- for inst_id in np.unique(inst_mask):
211
- if inst_id == 0:
212
- continue
213
- binary_mask = (inst_mask == inst_id).astype(np.uint8)
214
- # color = np.array(cmap(inst_id / (num_instances + 1))[:3]) # RGB only, ignore alpha
215
- color = get_well_spaced_color(inst_id)
216
- overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
217
-
218
- # 绘制轮廓
219
- contours = measure.find_contours(binary_mask, 0.5)
220
- for contour in contours:
221
- contour = contour.astype(np.int32)
222
- # 确保坐标在范围内
223
- valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
224
- valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
225
- overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # 黄色轮廓
226
-
227
- overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
228
-
229
- return Image.fromarray(overlay), temp_mask_file.name
230
-
231
- # ===== 计数功能 =====
232
- def count_cells_handler(use_box_choice, annot_value):
233
- """计数处理函数 - 支持边界框,只返回密度图"""
234
- if annot_value is None or len(annot_value) < 1:
235
- return None, "⚠️ 请先上传图像"
236
-
237
- image_path = annot_value[0]
238
- bboxes = annot_value[1] if len(annot_value) > 1 else []
239
-
240
- print(f"🖼️ 图像路径: {image_path}")
241
- box_array = None
242
- if use_box_choice == "Yes" and bboxes:
243
- box = parse_first_bbox(bboxes)
244
- if box:
245
- xmin, ymin, xmax, ymax = map(int, box)
246
- box_array = [[xmin, ymin, xmax, ymax]]
247
- print(f"📦 使用边界框: {box_array}")
248
-
249
- try:
250
- print(f"🔢 Counting - Image: {image_path}")
251
-
252
- result = run_count(
253
- COUNT_MODEL,
254
- image_path,
255
- box=box_array,
256
- device=COUNT_DEVICE,
257
- visualize=True
258
- )
259
-
260
- if 'error' in result:
261
- return None, f"❌ 计数失败: {result['error']}"
262
-
263
- count = result['count']
264
- density_map = result['density_map']
265
- # save density map as temp file
266
- temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy")
267
- np.save(temp_density_file.name, density_map)
268
- print(f"💾 Density map saved to {temp_density_file.name}")
269
-
270
- # 只提取密度图部分(假设visualized_path是拼接图,我们只要右半部分)
271
- # viz_path = result.get('visualized_path')
272
-
273
- # 如果有density_map_path,直接使用
274
- # if 'density_map_path' in result:
275
- # density_path = result['density_map_path']
276
- # elif viz_path and os.path.exists(viz_path):
277
- # # 如果是拼接图,提取右半部分(密度图)
278
- # try:
279
- # viz_img = Image.open(viz_path)
280
- # w, h = viz_img.size
281
- # # 取右半部分
282
- # density_img = viz_img
283
- # # 保存为新文件
284
- # temp_density = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
285
- # density_img.save(temp_density.name)
286
- # density_path = temp_density.name
287
- # except:
288
- # density_path = viz_path
289
- # else:
290
- # density_path = viz_path
291
-
292
- # 读取原图
293
- try:
294
- img = Image.open(image_path)
295
- print("📷 Image mode:", img.mode, "size:", img.size)
296
- except Exception as e:
297
- print(f"❌ Failed to open image: {e}")
298
- return None, None
299
-
300
- try:
301
- img_rgb = img.convert("RGB").resize(density_map.shape[::-1], resample=Image.BILINEAR)
302
- img_np = np.array(img_rgb, dtype=np.float32)
303
- img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
304
- if img_np.max() > 1.5:
305
- img_np = img_np / 255.0
306
- except Exception as e:
307
- print(f"❌ Error in image conversion/resizing: {e}")
308
- return None, None
309
-
310
-
311
- # Normalize density map to [0, 1]
312
- density_normalized = density_map.copy()
313
- if density_normalized.max() > 0:
314
- density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min())
315
-
316
- # Apply colormap
317
- cmap = cm.get_cmap("jet")
318
- alpha = 0.3
319
- density_colored = cmap(density_normalized)[:, :, :3] # RGB only, ignore alpha
320
-
321
- # Create overlay
322
- overlay = img_np.copy()
323
-
324
- # Blend only where density is significant (optional: threshold)
325
- threshold = 0.01 # Only overlay where density > 1% of max
326
- significant_mask = density_normalized > threshold
327
-
328
- overlay[significant_mask] = (1 - alpha) * overlay[significant_mask] + alpha * density_colored[significant_mask]
329
-
330
- # Clip and convert to uint8
331
- overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
332
-
333
-
334
-
335
-
336
-
337
- result_text = f"✅ 检测到 {round(count)} 个细胞"
338
-
339
- print(f"✅ Counting done - Count: {count:.1f}")
340
-
341
- return Image.fromarray(overlay), temp_density_file.name, result_text
342
-
343
- # return density_path, result_text
344
-
345
- except Exception as e:
346
- print(f"❌ Counting error: {e}")
347
- import traceback
348
- traceback.print_exc()
349
- return None, f"❌ 计数失败: {str(e)}"
350
-
351
- # ===== 跟踪功能 =====
352
- def find_tif_dir(root_dir):
353
- """递归查找第一个包含 .tif 文件的目录"""
354
- for dirpath, _, filenames in os.walk(root_dir):
355
- if '__MACOSX' in dirpath:
356
- continue
357
- if any(f.lower().endswith('.tif') for f in filenames):
358
- return dirpath
359
- return None
360
-
361
- def is_valid_tiff(filepath):
362
- """Check if a file is a valid TIFF image"""
363
- try:
364
- with Image.open(filepath) as img:
365
- img.verify()
366
- return True
367
- except Exception as e:
368
- return False
369
-
370
- def find_valid_tif_dir(root_dir):
371
- """递归查找第一个包含有效 .tif 文件的目录"""
372
- for dirpath, dirnames, filenames in os.walk(root_dir):
373
- if '__MACOSX' in dirpath:
374
- continue
375
-
376
- potential_tifs = [
377
- os.path.join(dirpath, f)
378
- for f in filenames
379
- if f.lower().endswith(('.tif', '.tiff')) and not f.startswith('._')
380
- ]
381
-
382
- if not potential_tifs:
383
- continue
384
-
385
- valid_tifs = [f for f in potential_tifs if is_valid_tiff(f)]
386
-
387
- if valid_tifs:
388
- print(f"✅ Found {len(valid_tifs)} valid TIFF files in: {dirpath}")
389
- return dirpath
390
-
391
- return None
392
-
393
- def create_ctc_results_zip(output_dir):
394
- """
395
- Create a ZIP file with CTC format results
396
-
397
- Parameters:
398
- -----------
399
- output_dir : str
400
- Directory containing tracking results (res_track.txt, etc.)
401
-
402
- Returns:
403
- --------
404
- zip_path : str
405
- Path to created ZIP file
406
- """
407
- # Create temp directory for ZIP
408
- temp_zip_dir = tempfile.mkdtemp()
409
- zip_filename = f"tracking_results_{time.strftime('%Y%m%d_%H%M%S')}.zip"
410
- zip_path = os.path.join(temp_zip_dir, zip_filename)
411
-
412
- print(f"📦 Creating results ZIP: {zip_path}")
413
-
414
- # Create ZIP with all tracking results
415
- with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
416
- # Add all files from output directory
417
- for root, dirs, files in os.walk(output_dir):
418
- for file in files:
419
- file_path = os.path.join(root, file)
420
- arcname = os.path.relpath(file_path, output_dir)
421
- zipf.write(file_path, arcname)
422
- print(f" 📄 Added: {arcname}")
423
-
424
- # Add a README with summary
425
- readme_content = f"""Tracking Results Summary
426
- ========================
427
-
428
- Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}
429
-
430
- Files:
431
- ------
432
- - res_track.txt: CTC format tracking data
433
- Format: track_id start_frame end_frame parent_id
434
-
435
- - Segmentation masks
436
-
437
- For more information on CTC format:
438
- http://celltrackingchallenge.net/
439
- """
440
- zipf.writestr("README.txt", readme_content)
441
-
442
- print(f"✅ ZIP created: {zip_path} ({os.path.getsize(zip_path) / 1024:.1f} KB)")
443
- return zip_path
444
-
445
- # 使用更智能的颜色分配 - 让相邻的ID颜色差异更大
446
- def get_well_spaced_color(track_id, num_colors=256):
447
- """生成间隔良好的颜色,相邻ID使用对比色"""
448
- # 使用质数跳跃来分散颜色
449
- golden_ratio = 0.618033988749895
450
- hue = (track_id * golden_ratio) % 1.0
451
-
452
- # 使用高饱和度和明度
453
- import colorsys
454
- rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
455
- return np.array(rgb)
456
-
457
-
458
- def extract_first_frame(tif_dir):
459
- """
460
- Extract the first frame from a directory of TIF files
461
-
462
- Returns:
463
- --------
464
- first_frame_path : str
465
- Path to the first TIF frame
466
- """
467
- tif_files = natsorted(glob(os.path.join(tif_dir, "*.tif")) +
468
- glob(os.path.join(tif_dir, "*.tiff")))
469
- valid_tif_files = [f for f in tif_files
470
- if not os.path.basename(f).startswith('._') and is_valid_tiff(f)]
471
-
472
- if valid_tif_files:
473
- return valid_tif_files[0]
474
- return None
475
-
476
- def create_tracking_visualization(tif_dir, output_dir, valid_tif_files):
477
- """
478
- Create an animated GIF/video showing tracked objects with consistent colors
479
-
480
- Parameters:
481
- -----------
482
- tif_dir : str
483
- Directory containing input TIF frames
484
- output_dir : str
485
- Directory containing tracking results (masks)
486
- valid_tif_files : list
487
- List of valid TIF file paths
488
-
489
- Returns:
490
- --------
491
- video_path : str
492
- Path to generated visualization (GIF or first frame)
493
- """
494
- import numpy as np
495
- from matplotlib import colormaps
496
- from skimage import measure
497
- import tifffile
498
-
499
- # Look for tracking mask files in output directory
500
- # Common CTC formats: man_track*.tif, mask*.tif, or numbered masks
501
- mask_files = natsorted(glob(os.path.join(output_dir, "mask*.tif")) +
502
- glob(os.path.join(output_dir, "man_track*.tif")) +
503
- glob(os.path.join(output_dir, "*.tif")))
504
-
505
- if not mask_files:
506
- print("⚠️ No mask files found in output directory")
507
- # Return first frame as fallback
508
- return valid_tif_files[0]
509
-
510
- print(f"📊 Found {len(mask_files)} mask files")
511
-
512
- # Create color map for consistent track IDs
513
- # Use a colormap with many distinct colors
514
- # try:
515
- # cmap = colormaps.get_cmap("hsv")
516
- # except:
517
- # from matplotlib import cm
518
- # cmap = cm.get_cmap("hsv")
519
-
520
- frames = []
521
- alpha = 0.3 # Transparency for overlay
522
-
523
- # Process each frame
524
- num_frames = min(len(valid_tif_files), len(mask_files))
525
- for i in range(num_frames):
526
- try:
527
- # Load original image using tifffile (handles ZSTD compression)
528
- try:
529
- img_np = tifffile.imread(valid_tif_files[i])
530
-
531
- # Normalize to [0, 1] range based on actual data type and values
532
- if img_np.dtype == np.uint8:
533
- img_np = img_np.astype(np.float32) / 255.0
534
- elif img_np.dtype == np.uint16:
535
- # Normalize uint16 to [0, 1] using actual min/max
536
- img_min, img_max = img_np.min(), img_np.max()
537
- if img_max > img_min:
538
- img_np = (img_np.astype(np.float32) - img_min) / (img_max - img_min)
539
- else:
540
- img_np = img_np.astype(np.float32) / 65535.0
541
- else:
542
- # For float or other types, normalize based on actual range
543
- img_np = img_np.astype(np.float32)
544
- img_min, img_max = img_np.min(), img_np.max()
545
- if img_max > img_min:
546
- img_np = (img_np - img_min) / (img_max - img_min)
547
- else:
548
- img_np = np.clip(img_np, 0, 1)
549
-
550
- # Convert to RGB if grayscale
551
- if img_np.ndim == 2:
552
- img_np = np.stack([img_np]*3, axis=-1)
553
- img_np = img_np.astype(np.float32)
554
- if img_np.max() > 1.5:
555
- img_np = img_np / 255.0
556
- except Exception as e:
557
- print(f"⚠️ Error loading image frame {i}: {e}")
558
- # Fallback to PIL
559
- img = Image.open(valid_tif_files[i]).convert("RGB")
560
- img_np = np.array(img, dtype=np.float32) / 255.0
561
-
562
- # Load tracking mask using tifffile (handles ZSTD compression)
563
- try:
564
- mask = tifffile.imread(mask_files[i])
565
- except Exception as e:
566
- print(f"⚠️ Error loading mask frame {i}: {e}")
567
- # Fallback to PIL
568
- mask = np.array(Image.open(mask_files[i]))
569
-
570
- # Resize mask to match image if needed
571
- if mask.shape[:2] != img_np.shape[:2]:
572
- from scipy.ndimage import zoom
573
- zoom_factors = [img_np.shape[0] / mask.shape[0], img_np.shape[1] / mask.shape[1]]
574
- mask = zoom(mask, zoom_factors, order=0).astype(mask.dtype)
575
-
576
- # Create overlay
577
- overlay = img_np.copy()
578
-
579
- # Get unique track IDs (excluding background 0)
580
- track_ids = np.unique(mask)
581
- track_ids = track_ids[track_ids != 0]
582
-
583
- # Color each tracked object
584
- for track_id in track_ids:
585
- # Create binary mask for this track
586
- binary_mask = (mask == track_id)
587
-
588
- # Get consistent color for this track ID
589
- # color = np.array(cmap(int(track_id) % 256)[:3])
590
- color = get_well_spaced_color(int(track_id))
591
-
592
- # Blend color onto image
593
- overlay[binary_mask] = (1 - alpha) * overlay[binary_mask] + alpha * color
594
-
595
- # Draw contours (optional, adds yellow boundaries)
596
- try:
597
- contours = measure.find_contours(binary_mask.astype(np.uint8), 0.5)
598
- for contour in contours:
599
- contour = contour.astype(np.int32)
600
- valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
601
- valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
602
- overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # Yellow contour
603
- except:
604
- pass # Skip contours if they fail
605
-
606
- # Convert to uint8
607
- overlay_uint8 = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
608
- frames.append(Image.fromarray(overlay_uint8))
609
-
610
- if i % 10 == 0 or i == num_frames - 1:
611
- print(f" 📸 Processed frame {i+1}/{num_frames}")
612
-
613
- except Exception as e:
614
- print(f"⚠️ Error processing frame {i}: {e}")
615
- import traceback
616
- traceback.print_exc()
617
- continue
618
-
619
- if not frames:
620
- print("⚠️ No frames were processed successfully")
621
- return valid_tif_files[0]
622
-
623
- # Save as animated GIF
624
- try:
625
- temp_gif = tempfile.NamedTemporaryFile(delete=False, suffix=".gif")
626
- frames[0].save(
627
- temp_gif.name,
628
- save_all=True,
629
- append_images=frames[1:],
630
- duration=200, # 200ms per frame = 5fps
631
- loop=0
632
- )
633
- temp_gif.close() # Close the file handle
634
- print(f"✅ Created tracking visualization GIF: {temp_gif.name}")
635
- print(f" Size: {os.path.getsize(temp_gif.name)} bytes, Frames: {len(frames)}")
636
- return temp_gif.name
637
- except Exception as e:
638
- print(f"⚠️ Failed to create GIF: {e}")
639
- import traceback
640
- traceback.print_exc()
641
- # Return first frame as static image fallback
642
- try:
643
- temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
644
- frames[0].save(temp_img.name)
645
- temp_img.close()
646
- return temp_img.name
647
- except:
648
- return valid_tif_files[0]
649
-
650
-
651
- def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
652
- """
653
- 支持 ZIP 压缩包上传的 Tracking 处理函数 - 支持首帧边界框
654
-
655
- Parameters:
656
- -----------
657
- use_box_choice : str
658
- "Yes" or "No" - 是否使用边界框
659
- first_frame_annot : tuple or None
660
- (image_path, bboxes) from BBoxAnnotator, only used if user annotated first frame
661
- zip_file_obj : File
662
- Uploaded ZIP file containing TIF sequence
663
- """
664
- if zip_file_obj is None:
665
- return None, "⚠️ 请上传包含视频帧的压缩包 (.zip)", None, None
666
-
667
- temp_dir = None
668
- output_temp_dir = None
669
-
670
- try:
671
- # Parse bounding box if provided
672
- box_array = None
673
- if use_box_choice == "Yes" and first_frame_annot is not None:
674
- if isinstance(first_frame_annot, (list, tuple)) and len(first_frame_annot) > 1:
675
- bboxes = first_frame_annot[1]
676
- if bboxes:
677
- box = parse_first_bbox(bboxes)
678
- if box:
679
- xmin, ymin, xmax, ymax = map(int, box)
680
- box_array = [[xmin, ymin, xmax, ymax]]
681
- print(f"📦 使用边界框: {box_array}")
682
-
683
- # Extract input ZIP
684
- temp_dir = tempfile.mkdtemp()
685
- print(f"\n📦 解压到临时目录: {temp_dir}")
686
-
687
- with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref:
688
- extracted_count = 0
689
- skipped_count = 0
690
-
691
- for member in zip_ref.namelist():
692
- basename = os.path.basename(member)
693
-
694
- if ('__MACOSX' in member or
695
- basename.startswith('._') or
696
- basename.startswith('.DS_Store') or
697
- member.endswith('/')):
698
- skipped_count += 1
699
- continue
700
-
701
- try:
702
- zip_ref.extract(member, temp_dir)
703
- extracted_count += 1
704
- if basename.lower().endswith(('.tif', '.tiff')):
705
- print(f"📄 Extracted TIFF: {basename}")
706
- except Exception as e:
707
- print(f"⚠️ Failed to extract {member}: {e}")
708
-
709
- print(f"\n📊 提取: {extracted_count} 文件, 跳过: {skipped_count} 文件")
710
-
711
- # Find valid TIFF directory
712
- tif_dir = find_valid_tif_dir(temp_dir)
713
-
714
- if tif_dir is None:
715
- return None, "❌ 未找到有效的TIFF文件", None, None
716
-
717
- # Validate TIFF files
718
- tif_files = sorted(glob(os.path.join(tif_dir, "*.tif")) +
719
- glob(os.path.join(tif_dir, "*.tiff")))
720
- valid_tif_files = [f for f in tif_files
721
- if not os.path.basename(f).startswith('._') and is_valid_tiff(f)]
722
-
723
- if len(valid_tif_files) == 0:
724
- return None, "❌ 没有有效的TIFF文件", None, None
725
-
726
- print(f"📈 使用 {len(valid_tif_files)} 个TIFF文件")
727
-
728
- # Store paths for later visualization
729
- first_frame_path = valid_tif_files[0]
730
-
731
- # Create temporary output directory for CTC results
732
- output_temp_dir = tempfile.mkdtemp()
733
- print(f"💾 CTC结果将保存到: {output_temp_dir}")
734
-
735
- # Run tracking with optional bounding box
736
- result = run_track(
737
- TRACK_MODEL,
738
- video_dir=tif_dir,
739
- box=box_array, # Pass bounding box if specified
740
- device=TRACK_DEVICE,
741
- output_dir=output_temp_dir
742
- )
743
-
744
- if 'error' in result:
745
- return None, f"❌ 跟踪失败: {result['error']}", None, None
746
-
747
- # Create visualization video of tracked objects
748
- print("\n🎬 Creating tracking visualization...")
749
- try:
750
- tracking_video = create_tracking_visualization(
751
- tif_dir,
752
- output_temp_dir,
753
- valid_tif_files
754
- )
755
- except Exception as e:
756
- print(f"⚠️ Failed to create visualization: {e}")
757
- import traceback
758
- traceback.print_exc()
759
- # Fallback to first frame if visualization fails
760
- try:
761
- tracking_video = Image.open(first_frame_path)
762
- except:
763
- tracking_video = None
764
-
765
- # Create downloadable ZIP with results
766
- try:
767
- results_zip = create_ctc_results_zip(output_temp_dir)
768
- except Exception as e:
769
- print(f"⚠️ Failed to create ZIP: {e}")
770
- results_zip = None
771
-
772
- bbox_info = ""
773
- if box_array:
774
- bbox_info = f"\n🔲 使用边界框: [{box_array[0][0]}, {box_array[0][1]}, {box_array[0][2]}, {box_array[0][3]}]"
775
-
776
- result_text = f"""✅ 跟踪完成!
777
-
778
- 🖼️ 处理帧数: {len(valid_tif_files)}{bbox_info}
779
-
780
- 📥 点击下方按钮下载CTC格式结果
781
- 结果包含:
782
- - res_track.txt (CTC格式轨迹数据)
783
- - 其他跟踪相关文件
784
- - README.txt (结果说明)
785
- """
786
-
787
- print(f"\n✅ Tracking完成")
788
-
789
- # Clean up input temp directory (keep output temp for download)
790
- if temp_dir:
791
- try:
792
- shutil.rmtree(temp_dir)
793
- print(f"🗑️ 清理输入临时目录")
794
- except:
795
- pass
796
-
797
- return results_zip, result_text, gr.update(visible=True), tracking_video
798
-
799
- except zipfile.BadZipFile:
800
- return None, "❌ 不是有效的ZIP文件", None, None
801
- except Exception as e:
802
- import traceback
803
- traceback.print_exc()
804
-
805
- # Clean up on error
806
- for d in [temp_dir, output_temp_dir]:
807
- if d:
808
- try:
809
- shutil.rmtree(d)
810
- except:
811
- pass
812
-
813
- return None, f"❌ 跟踪失败: {str(e)}", None, None
814
-
815
-
816
-
817
- # ===== 示例图像 =====
818
- example_images_seg = [f for f in glob("example_imgs/seg/*")]
819
- # ["example_imgs/seg/003_img.png", "example_imgs/seg/1977_Well_F-5_Field_1.png"]
820
- example_images_cnt = [f for f in glob("example_imgs/cnt/*")]
821
- example_tracking_zips = [f for f in glob("example_imgs/tra/*.zip")]
822
-
823
- # ===== Gradio UI =====
824
- with gr.Blocks(
825
- title="Microscopy Analysis Suite",
826
- theme=gr.themes.Soft(),
827
- css="""
828
- .tabs button {
829
- font-size: 20px !important;
830
- font-weight: 600 !important;
831
- padding: 12px 20px !important;
832
- }
833
- """
834
- ) as demo:
835
- gr.Markdown(
836
- """
837
- # 🔬 显微图像分析工具套件
838
-
839
- 支持三种分析模式:
840
- - 🎨 **分割 (Segmentation)**: 实例分割显微镜物体
841
- - 🔢 **计数 (Counting)**: 基于密度图的显微镜物体计数
842
- - 🎬 **跟踪 (Tracking)**: 视频序列中的显微镜物体跟踪
843
- """
844
- )
845
-
846
- # 全局状态
847
- current_query_id = gr.State(str(uuid.uuid4()))
848
- user_uploaded_examples = gr.State(example_images_seg.copy()) # 初始化时包含原始示例
849
-
850
- with gr.Tabs():
851
- # ===== Tab 1: Segmentation =====
852
- with gr.Tab("🎨 分割 (Segmentation)"):
853
- gr.Markdown("## 显微镜物体实例分割")
854
- gr.Markdown(
855
- """
856
- **使用说明:**
857
- 1. 上传图像或选择示例图片(支持多种格式: .png, .jpg, .tif)
858
- 2. (可选) 标注一个目标物体的边界框并选择 "Yes",或直接点击 "运行分割"
859
- 3. 点击 "运行分割"
860
- 4. 查看分割结果,下载原始预测mask (.tif格式);如果需要,点击 "清空重选" 选择新图像运行
861
- 5. 评分并提交反馈以帮助我们改进模型!
862
- """
863
- )
864
-
865
- with gr.Row():
866
- with gr.Column(scale=1):
867
- annotator = BBoxAnnotator(
868
- label="🖼️ 上传图像 (可选标注边界框)",
869
- categories=["cell"]
870
- )
871
-
872
- # 示例图片Gallery
873
- example_gallery = gr.Gallery(
874
- label="📁 示例图片",
875
- columns=len(example_images_seg),
876
- rows=1,
877
- height=120,
878
- object_fit="cover",
879
- show_download_button=False
880
- )
881
-
882
-
883
- with gr.Row():
884
- use_box_radio = gr.Radio(
885
- choices=["Yes", "No"],
886
- value="No",
887
- label="🔲 使用边界框?"
888
- )
889
- with gr.Row():
890
- run_seg_btn = gr.Button("▶️ 运行分割", variant="primary", size="lg")
891
- clear_btn = gr.Button("🔄 清空重选", variant="secondary")
892
-
893
- # 上传示例图片
894
- image_uploader = gr.Image(
895
- label="➕ 上传新示例到Gallery",
896
- type="filepath"
897
- )
898
-
899
-
900
- with gr.Column(scale=2):
901
- seg_output = gr.Image(
902
- type="pil",
903
- label="📸 分割结果",
904
- height=400
905
- )
906
-
907
- # 下载原始预测结果
908
- download_mask_btn = gr.File(
909
- label="📥 下载原始预测 (.tif 格式)",
910
- visible=True,
911
- height=40,
912
- )
913
-
914
- # 满意度评分
915
- score_slider = gr.Slider(
916
- minimum=1,
917
- maximum=5,
918
- step=1,
919
- value=5,
920
- label="🌟 满意度评分 (1-5)"
921
- )
922
-
923
- # 反馈文本框
924
- feedback_box = gr.Textbox(
925
- placeholder="请输入您的反馈意见...",
926
- lines=2,
927
- label="💬 反馈意见"
928
- )
929
-
930
- # 提交按钮
931
- submit_feedback_btn = gr.Button("💾 提交反馈", variant="secondary")
932
-
933
- feedback_status = gr.Textbox(
934
- label="✅ 提交状态",
935
- lines=1,
936
- visible=False
937
- )
938
-
939
- # 绑定事件: 运行分割
940
- run_seg_btn.click(
941
- fn=segment_with_choice,
942
- inputs=[use_box_radio, annotator],
943
- outputs=[seg_output, download_mask_btn]
944
- )
945
-
946
- # 清空按钮事件
947
- clear_btn.click(
948
- fn=lambda: None,
949
- inputs=None,
950
- outputs=annotator
951
- )
952
-
953
- # 初始化Gallery显示
954
- demo.load(
955
- fn=lambda: example_images_seg.copy(),
956
- outputs=example_gallery
957
- )
958
-
959
- # 绑定事件: 上传示例图片
960
- def add_to_gallery(img_path, current_imgs):
961
- if not img_path:
962
- return current_imgs
963
- try:
964
- if img_path not in current_imgs:
965
- current_imgs.append(img_path)
966
- return current_imgs
967
- except:
968
- return current_imgs
969
-
970
- image_uploader.change(
971
- fn=add_to_gallery,
972
- inputs=[image_uploader, user_uploaded_examples],
973
- outputs=user_uploaded_examples
974
- ).then(
975
- fn=lambda imgs: imgs,
976
- inputs=user_uploaded_examples,
977
- outputs=example_gallery
978
- )
979
-
980
- # 绑定事件: 点击Gallery加载
981
- def load_from_gallery(evt: gr.SelectData, all_imgs):
982
- if evt.index is not None and evt.index < len(all_imgs):
983
- return all_imgs[evt.index]
984
- return None
985
-
986
- example_gallery.select(
987
- fn=load_from_gallery,
988
- inputs=user_uploaded_examples,
989
- outputs=annotator
990
- )
991
-
992
- # 绑定事件: 提交反馈
993
- def submit_user_feedback(query_id, score, comment, annot_val):
994
- try:
995
- img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
996
- bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else []
997
-
998
- save_feedback(
999
- query_id=query_id,
1000
- feedback_type=f"score_{int(score)}",
1001
- feedback_text=comment,
1002
- img_path=img_path,
1003
- bboxes=bboxes
1004
- )
1005
- return "✅ 反馈已提交,感谢您的评价!", gr.update(visible=True)
1006
- except Exception as e:
1007
- return f"❌ 提交失败: {str(e)}", gr.update(visible=True)
1008
-
1009
- submit_feedback_btn.click(
1010
- fn=submit_user_feedback,
1011
- inputs=[current_query_id, score_slider, feedback_box, annotator],
1012
- outputs=[feedback_status, feedback_status]
1013
- )
1014
-
1015
- # ===== Tab 2: Counting =====
1016
- with gr.Tab("🔢 计数 (Counting)"):
1017
- gr.Markdown("## 显微镜物体计数分析")
1018
- gr.Markdown(
1019
- """
1020
- **使用说明:**
1021
- 1. 上传图像或选择示例图片(支持多种格式: .png, .jpg, .tif)
1022
- 2. (可选) 标注边界框并选择 "Yes",或直接点击 "运行计数"
1023
- 3. 点击 "运行计数"
1024
- 4. 查看密度图,下载原始预测 (.npy格式);如果需要,点击 "清空重选" 选择新图像运行
1025
- 5. 评分并提交反馈以帮助我们改进模型!
1026
- """
1027
- )
1028
-
1029
- with gr.Row():
1030
- with gr.Column(scale=1):
1031
- count_annotator = BBoxAnnotator(
1032
- label="🖼️ 上传图像 (可选标注边界框)",
1033
- categories=["cell"]
1034
- )
1035
-
1036
- # Example gallery with "add" functionality
1037
- with gr.Row():
1038
- count_example_gallery = gr.Gallery(
1039
- label="📁 示例图片",
1040
- columns=len(example_images_cnt),
1041
- rows=1,
1042
- object_fit="cover",
1043
- height=120,
1044
- value=example_images_cnt.copy(), # Initialize with examples
1045
- show_download_button=False
1046
- )
1047
-
1048
-
1049
- with gr.Row():
1050
- count_use_box_radio = gr.Radio(
1051
- choices=["Yes", "No"],
1052
- value="No",
1053
- label="🔲 使用边界框?"
1054
- )
1055
-
1056
- with gr.Row():
1057
- count_btn = gr.Button("▶️ 运行计数", variant="primary", size="lg")
1058
- clear_btn = gr.Button("🔄 清空重选", variant="secondary")
1059
-
1060
- # Add button to upload new examples
1061
- with gr.Row():
1062
- count_image_uploader = gr.File(
1063
- label="➕ 添加示例图片",
1064
- file_types=["image"],
1065
- type="filepath"
1066
- )
1067
-
1068
-
1069
- with gr.Column(scale=2):
1070
- count_output = gr.Image(
1071
- label="📸 密度图",
1072
- type="filepath",
1073
- height=400
1074
- )
1075
- count_status = gr.Textbox(
1076
- label="📊 统计信息",
1077
- lines=2
1078
- )
1079
- download_density_btn = gr.File(
1080
- label="📥 下载原始预测 (.npy 格式)",
1081
- visible=True
1082
- )
1083
-
1084
- # 满意度评分
1085
- score_slider = gr.Slider(
1086
- minimum=1,
1087
- maximum=5,
1088
- step=1,
1089
- value=5,
1090
- label="🌟 满意度评分 (1-5)"
1091
- )
1092
-
1093
- # 反馈文本框
1094
- feedback_box = gr.Textbox(
1095
- placeholder="请输入您的反馈意见...",
1096
- lines=2,
1097
- label="💬 反馈意见"
1098
- )
1099
-
1100
- # 提交按钮
1101
- submit_feedback_btn = gr.Button("💾 提交反馈", variant="secondary")
1102
-
1103
- feedback_status = gr.Textbox(
1104
- label="✅ 提交状态",
1105
- lines=1,
1106
- visible=False
1107
- )
1108
-
1109
- # State for managing gallery images
1110
- count_user_examples = gr.State(example_images_cnt.copy())
1111
-
1112
- # Function to add image to gallery
1113
- def add_to_count_gallery(new_img_file, current_imgs):
1114
- """Add uploaded image to gallery"""
1115
- if new_img_file is None:
1116
- return current_imgs, current_imgs
1117
-
1118
- try:
1119
- # Add new image path to list
1120
- if new_img_file not in current_imgs:
1121
- current_imgs.append(new_img_file)
1122
- print(f"✅ Added image to gallery: {new_img_file}")
1123
- except Exception as e:
1124
- print(f"⚠️ Failed to add image: {e}")
1125
-
1126
- return current_imgs, current_imgs
1127
-
1128
- # When user uploads a new image file
1129
- count_image_uploader.upload(
1130
- fn=add_to_count_gallery,
1131
- inputs=[count_image_uploader, count_user_examples],
1132
- outputs=[count_user_examples, count_example_gallery]
1133
- )
1134
-
1135
- # When user selects from gallery, load into annotator
1136
- def load_from_count_gallery(evt: gr.SelectData, all_imgs):
1137
- """Load selected image from gallery into annotator"""
1138
- if evt.index is not None and evt.index < len(all_imgs):
1139
- selected_img = all_imgs[evt.index]
1140
- print(f"📸 Loading image from gallery: {selected_img}")
1141
- return selected_img
1142
- return None
1143
-
1144
- count_example_gallery.select(
1145
- fn=load_from_count_gallery,
1146
- inputs=count_user_examples,
1147
- outputs=count_annotator
1148
- )
1149
-
1150
- # Run counting
1151
- count_btn.click(
1152
- fn=count_cells_handler,
1153
- inputs=[count_use_box_radio, count_annotator],
1154
- outputs=[count_output, download_density_btn, count_status]
1155
- )
1156
-
1157
- # 清空按钮事件
1158
- clear_btn.click(
1159
- fn=lambda: None,
1160
- inputs=None,
1161
- outputs=count_annotator
1162
- )
1163
-
1164
- # 绑定事件: 提交反馈
1165
- def submit_user_feedback(query_id, score, comment, annot_val):
1166
- try:
1167
- img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
1168
- bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else []
1169
-
1170
- save_feedback(
1171
- query_id=query_id,
1172
- feedback_type=f"score_{int(score)}",
1173
- feedback_text=comment,
1174
- img_path=img_path,
1175
- bboxes=bboxes
1176
- )
1177
- return "✅ 反馈已提交,感谢您的评价!", gr.update(visible=True)
1178
- except Exception as e:
1179
- return f"❌ 提交失败: {str(e)}", gr.update(visible=True)
1180
-
1181
- submit_feedback_btn.click(
1182
- fn=submit_user_feedback,
1183
- inputs=[current_query_id, score_slider, feedback_box, annotator],
1184
- outputs=[feedback_status, feedback_status]
1185
- )
1186
-
1187
- # ===== Tab 3: Tracking =====
1188
- with gr.Tab("🎬 跟踪 (Tracking)"):
1189
- gr.Markdown("## 显微镜物体视频跟踪 - 支持 ZIP 压缩包上传")
1190
- gr.Markdown(
1191
- """
1192
- **使用说明:**
1193
- 1. 上传ZIP文件或从示例库选择,ZIP内应包含按时间顺序命名的TIF图像序列 (如: t000.tif, t001.tif...)
1194
- 2. (可选) 在首帧上���注边界框并选择 "Yes"
1195
- 3. 点击 "运行跟踪"
1196
- 4. 下载CTC格式结果;如果需要,点击 "清空重选" 选择新ZIP文件运行
1197
- 5. 评分并提交反馈以帮助我们改进模型!
1198
-
1199
- """
1200
- )
1201
-
1202
- with gr.Row():
1203
- with gr.Column(scale=1):
1204
- track_zip_upload = gr.File(
1205
- label="📦 上传视频帧 ZIP 文件",
1206
- file_types=[".zip"]
1207
- )
1208
-
1209
- # First frame annotation for bounding box
1210
- track_first_frame_annotator = BBoxAnnotator(
1211
- label="🖼️ 首帧边界框标注 (可选)",
1212
- categories=["cell"],
1213
- visible=False # Hidden initially
1214
- )
1215
-
1216
- # Example ZIP gallery
1217
- track_example_gallery = gr.Gallery(
1218
- label="📁 示例视频库 (点击选择)",
1219
- columns=10,
1220
- rows=1,
1221
- height=120,
1222
- object_fit="contain",
1223
- show_download_button=False
1224
- )
1225
-
1226
- with gr.Row():
1227
- track_use_box_radio = gr.Radio(
1228
- choices=["Yes", "No"],
1229
- value="No",
1230
- label="🔲 使用边界框?"
1231
- )
1232
-
1233
- with gr.Row():
1234
- track_btn = gr.Button("▶️ 运行跟踪", variant="primary", size="lg")
1235
- clear_btn = gr.Button("🔄 清空重选", variant="secondary")
1236
-
1237
- # Add to gallery button
1238
- track_gallery_upload = gr.File(
1239
- label="➕ 添加ZIP到示例库",
1240
- file_types=[".zip"],
1241
- type="filepath"
1242
- )
1243
-
1244
- with gr.Column(scale=2):
1245
- track_first_frame_preview = gr.Image(
1246
- label="📸 跟踪可视化 (动画预览)",
1247
- type="filepath",
1248
- height=400,
1249
- interactive=False
1250
- )
1251
-
1252
- track_output = gr.Textbox(
1253
- label="📊 跟踪信息",
1254
- lines=8,
1255
- interactive=False
1256
- )
1257
-
1258
- track_download = gr.File(
1259
- label="📥 下载跟踪结果 (CTC格式)",
1260
- visible=False
1261
- )
1262
-
1263
- # 满意度评分
1264
- score_slider = gr.Slider(
1265
- minimum=1,
1266
- maximum=5,
1267
- step=1,
1268
- value=5,
1269
- label="🌟 满意度评分 (1-5)"
1270
- )
1271
-
1272
- # 反馈文本框
1273
- feedback_box = gr.Textbox(
1274
- placeholder="请输入您的反馈意见...",
1275
- lines=2,
1276
- label="💬 反馈意见"
1277
- )
1278
-
1279
- # 提交按钮
1280
- submit_feedback_btn = gr.Button("💾 提交反馈", variant="secondary")
1281
-
1282
- feedback_status = gr.Textbox(
1283
- label="✅ 提交状态",
1284
- lines=1,
1285
- visible=False
1286
- )
1287
-
1288
- # State for tracking examples
1289
- track_user_examples = gr.State(example_tracking_zips.copy())
1290
-
1291
- # Function to get preview image from ZIP
1292
- def get_zip_preview(zip_path):
1293
- """Extract first frame from ZIP for gallery preview"""
1294
- try:
1295
- temp_dir = tempfile.mkdtemp()
1296
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
1297
- for member in zip_ref.namelist():
1298
- basename = os.path.basename(member)
1299
- if ('__MACOSX' not in member and
1300
- not basename.startswith('._') and
1301
- basename.lower().endswith(('.tif', '.tiff', '.png', '.jpg'))):
1302
- zip_ref.extract(member, temp_dir)
1303
- extracted_path = os.path.join(temp_dir, member)
1304
-
1305
- # Load and normalize for preview
1306
- import tifffile
1307
- import numpy as np
1308
-
1309
- img_np = tifffile.imread(extracted_path)
1310
- if img_np.dtype == np.uint16:
1311
- img_min, img_max = img_np.min(), img_np.max()
1312
- if img_max > img_min:
1313
- img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8)
1314
-
1315
- if img_np.ndim == 2:
1316
- img_np = np.stack([img_np]*3, axis=-1)
1317
-
1318
- # Save preview
1319
- preview_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
1320
- Image.fromarray(img_np).save(preview_path.name)
1321
- return preview_path.name
1322
- except:
1323
- pass
1324
- return None
1325
-
1326
- # Initialize gallery with previews
1327
- def init_tracking_gallery():
1328
- """Create preview images for ZIP examples"""
1329
- previews = []
1330
- for zip_path in example_tracking_zips:
1331
- if os.path.exists(zip_path):
1332
- preview = get_zip_preview(zip_path)
1333
- if preview:
1334
- previews.append(preview)
1335
- return previews
1336
-
1337
- # Load gallery on startup
1338
- demo.load(
1339
- fn=init_tracking_gallery,
1340
- outputs=track_example_gallery
1341
- )
1342
-
1343
- # Add ZIP to gallery
1344
- def add_zip_to_gallery(zip_path, current_zips):
1345
- if not zip_path:
1346
- return current_zips, track_example_gallery
1347
- try:
1348
- if zip_path not in current_zips:
1349
- current_zips.append(zip_path)
1350
- print(f"✅ Added ZIP to gallery: {zip_path}")
1351
- # Regenerate previews
1352
- previews = []
1353
- for zp in current_zips:
1354
- preview = get_zip_preview(zp)
1355
- if preview:
1356
- previews.append(preview)
1357
- return current_zips, previews
1358
- except Exception as e:
1359
- print(f"⚠️ Error: {e}")
1360
- return current_zips, []
1361
-
1362
- track_gallery_upload.upload(
1363
- fn=add_zip_to_gallery,
1364
- inputs=[track_gallery_upload, track_user_examples],
1365
- outputs=[track_user_examples, track_example_gallery]
1366
- )
1367
-
1368
- # Select ZIP from gallery
1369
- def load_zip_from_gallery(evt: gr.SelectData, all_zips):
1370
- if evt.index is not None and evt.index < len(all_zips):
1371
- selected_zip = all_zips[evt.index]
1372
- print(f"📁 Selected ZIP from gallery: {selected_zip}")
1373
- return selected_zip
1374
- return None
1375
-
1376
- track_example_gallery.select(
1377
- fn=load_zip_from_gallery,
1378
- inputs=track_user_examples,
1379
- outputs=track_zip_upload
1380
- )
1381
-
1382
- # Load first frame when ZIP is uploaded
1383
- def load_first_frame_for_annotation(zip_file_obj):
1384
- '''Load and normalize first frame from ZIP for annotation'''
1385
- if zip_file_obj is None:
1386
- return None, gr.update(visible=False)
1387
-
1388
- import tifffile
1389
- import numpy as np
1390
-
1391
- try:
1392
- temp_dir = tempfile.mkdtemp()
1393
- with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref:
1394
- for member in zip_ref.namelist():
1395
- basename = os.path.basename(member)
1396
- if ('__MACOSX' not in member and
1397
- not basename.startswith('._') and
1398
- basename.lower().endswith(('.tif', '.tiff'))):
1399
- zip_ref.extract(member, temp_dir)
1400
-
1401
- tif_dir = find_valid_tif_dir(temp_dir)
1402
- if tif_dir:
1403
- first_frame = extract_first_frame(tif_dir)
1404
- if first_frame:
1405
- # Load and normalize the first frame
1406
- try:
1407
- img_np = tifffile.imread(first_frame)
1408
-
1409
- # Normalize to [0, 255] uint8 range for display
1410
- if img_np.dtype == np.uint8:
1411
- pass # Already uint8
1412
- elif img_np.dtype == np.uint16:
1413
- # Normalize uint16 using actual min/max
1414
- img_min, img_max = img_np.min(), img_np.max()
1415
- if img_max > img_min:
1416
- img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8)
1417
- else:
1418
- img_np = (img_np.astype(np.float32) / 65535.0 * 255).astype(np.uint8)
1419
- else:
1420
- # Float or other types
1421
- img_np = img_np.astype(np.float32)
1422
- img_min, img_max = img_np.min(), img_np.max()
1423
- if img_max > img_min:
1424
- img_np = ((img_np - img_min) / (img_max - img_min) * 255).astype(np.uint8)
1425
- else:
1426
- img_np = np.clip(img_np * 255, 0, 255).astype(np.uint8)
1427
-
1428
- # Convert to RGB if grayscale
1429
- if img_np.ndim == 2:
1430
- img_np = np.stack([img_np]*3, axis=-1)
1431
- elif img_np.ndim == 3 and img_np.shape[2] > 3:
1432
- img_np = img_np[:, :, :3]
1433
-
1434
- # Save normalized image to temp file
1435
- temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
1436
- Image.fromarray(img_np).save(temp_img.name)
1437
-
1438
- print(f"✅ Loaded and normalized first frame: {first_frame}")
1439
- print(f" Original dtype: {tifffile.imread(first_frame).dtype}")
1440
- print(f" Normalized to uint8 RGB for annotation")
1441
-
1442
- return temp_img.name, gr.update(visible=True)
1443
- except Exception as e:
1444
- print(f"⚠️ Error normalizing first frame: {e}")
1445
- import traceback
1446
- traceback.print_exc()
1447
- # Fallback to original file
1448
- return first_frame, gr.update(visible=True)
1449
- except Exception as e:
1450
- print(f"⚠️ Error loading first frame: {e}")
1451
- import traceback
1452
- traceback.print_exc()
1453
- return None, gr.update(visible=False)
1454
-
1455
- # Load first frame when ZIP is uploaded
1456
- track_zip_upload.change(
1457
- fn=load_first_frame_for_annotation,
1458
- inputs=track_zip_upload,
1459
- outputs=[track_first_frame_annotator, track_first_frame_annotator]
1460
- )
1461
-
1462
- # Run tracking
1463
- track_btn.click(
1464
- fn=track_video_handler,
1465
- inputs=[track_use_box_radio, track_first_frame_annotator, track_zip_upload],
1466
- outputs=[track_download, track_output, track_download, track_first_frame_preview]
1467
- )
1468
-
1469
- # 清空按钮事件
1470
- clear_btn.click(
1471
- fn=lambda: None,
1472
- inputs=None,
1473
- outputs=track_first_frame_annotator
1474
- )
1475
-
1476
- # 绑定事件: 提交反馈
1477
- def submit_user_feedback(query_id, score, comment, annot_val):
1478
- try:
1479
- img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
1480
- bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else []
1481
-
1482
- save_feedback(
1483
- query_id=query_id,
1484
- feedback_type=f"score_{int(score)}",
1485
- feedback_text=comment,
1486
- img_path=img_path,
1487
- bboxes=bboxes
1488
- )
1489
- return "✅ 反馈已提交,感谢您的评价!", gr.update(visible=True)
1490
- except Exception as e:
1491
- return f"❌ 提交失败: {str(e)}", gr.update(visible=True)
1492
-
1493
- submit_feedback_btn.click(
1494
- fn=submit_user_feedback,
1495
- inputs=[current_query_id, score_slider, feedback_box, annotator],
1496
- outputs=[feedback_status, feedback_status]
1497
- )
1498
-
1499
- gr.Markdown(
1500
- """
1501
- ---
1502
- ### 💡 技术说明
1503
-
1504
- **MicroscopyMatching** - 基于 Stable Diffusion 的显微图像分析工具套件
1505
- """
1506
- )
1507
-
1508
- if __name__ == "__main__":
1509
- demo.queue().launch(
1510
- server_name="0.0.0.0",
1511
- server_port=7862,
1512
- share=False,
1513
- ssr_mode=False,
1514
- show_error=True,
1515
- )