| import gradio as gr |
| from gradio_bbox_annotator import BBoxAnnotator |
| from PIL import Image |
| import numpy as np |
| import torch |
| import os |
| import shutil |
| import time |
| import json |
| import uuid |
| from pathlib import Path |
| import tempfile |
| import zipfile |
| from skimage import measure |
| from matplotlib import cm |
| from glob import glob |
| from natsort import natsorted |
| from huggingface_hub import HfApi, upload_file |
| |
|
|
| from inference_seg import load_model as load_seg_model, run as run_seg |
| from inference_count import load_model as load_count_model, run as run_count |
| from inference_track import load_model as load_track_model, run as run_track |
|
|
| HF_TOKEN = os.getenv("HF_TOKEN") |
| DATASET_REPO = "phoebe777777/celltool_feedback" |
|
|
|
|
| print("===== clearing cache =====") |
| |
| cache_path = os.path.expanduser("~/.cache/huggingface/gradio") |
| if os.path.exists(cache_path): |
| try: |
| shutil.rmtree(cache_path) |
| |
| print("✅ Deleted ~/.cache/huggingface/gradio") |
| except: |
| pass |
|
|
| SEG_MODEL = None |
| SEG_DEVICE = torch.device("cpu") |
|
|
| COUNT_MODEL = None |
| COUNT_DEVICE = torch.device("cpu") |
|
|
| TRACK_MODEL = None |
| TRACK_DEVICE = torch.device("cpu") |
|
|
| def load_all_models(): |
| global SEG_MODEL, SEG_DEVICE |
| global COUNT_MODEL, COUNT_DEVICE |
| global TRACK_MODEL, TRACK_DEVICE |
| |
| print("\n" + "="*60) |
| print("📦 Loading Segmentation Model") |
| print("="*60) |
| SEG_MODEL, SEG_DEVICE = load_seg_model(use_box=False) |
| |
| print("\n" + "="*60) |
| print("📦 Loading Counting Model") |
| print("="*60) |
| COUNT_MODEL, COUNT_DEVICE = load_count_model(use_box=False) |
| |
| print("\n" + "="*60) |
| print("📦 Loading Tracking Model") |
| print("="*60) |
| TRACK_MODEL, TRACK_DEVICE = load_track_model(use_box=False) |
| |
| print("\n" + "="*60) |
| print("✅ All Models Loaded Successfully") |
| print("="*60) |
|
|
| load_all_models() |
|
|
| DATASET_DIR = Path("solver_cache") |
| DATASET_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None): |
| """Save feedback to Hugging Face Dataset""" |
| |
| if not HF_TOKEN: |
| print("⚠️ No HF_TOKEN found, using local storage") |
| save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes) |
| return |
| |
| feedback_data = { |
| "query_id": query_id, |
| "feedback_type": feedback_type, |
| "feedback_text": feedback_text, |
| "image_path": img_path, |
| "bboxes": str(bboxes), |
| "datetime": time.strftime("%Y-%m-%d %H:%M:%S"), |
| "timestamp": time.time() |
| } |
| |
| try: |
| api = HfApi() |
| |
| filename = f"feedback_{query_id}_{int(time.time())}.json" |
| |
| with open(filename, 'w', encoding='utf-8') as f: |
| json.dump(feedback_data, f, indent=2, ensure_ascii=False) |
| |
| api.upload_file( |
| path_or_fileobj=filename, |
| path_in_repo=f"data/{filename}", |
| repo_id=DATASET_REPO, |
| repo_type="dataset", |
| token=HF_TOKEN |
| ) |
| |
| os.remove(filename) |
| |
| print(f"✅ Feedback saved to HF Dataset: {DATASET_REPO}") |
| |
| except Exception as e: |
| print(f"⚠️ Failed to save to HF Dataset: {e}") |
| save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes) |
|
|
|
|
| def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None): |
| """Save feedback to local JSON file""" |
| feedback_data = { |
| "query_id": query_id, |
| "feedback_type": feedback_type, |
| "feedback_text": feedback_text, |
| "image": img_path, |
| "bboxes": bboxes, |
| "datetime": time.strftime("%Y%m%d_%H%M%S") |
| } |
| feedback_file = DATASET_DIR / query_id / "feedback.json" |
| feedback_file.parent.mkdir(parents=True, exist_ok=True) |
| |
| if feedback_file.exists(): |
| with feedback_file.open("r") as f: |
| existing = json.load(f) |
| if not isinstance(existing, list): |
| existing = [existing] |
| existing.append(feedback_data) |
| feedback_data = existing |
| else: |
| feedback_data = [feedback_data] |
| |
| with feedback_file.open("w") as f: |
| json.dump(feedback_data, f, indent=4, ensure_ascii=False) |
|
|
| def parse_first_bbox(bboxes): |
| """Parse the first bounding box from the annotation input, supports dict or list format""" |
| if not bboxes: |
| return None |
| b = bboxes[0] |
| if isinstance(b, dict): |
| x, y = float(b.get("x", 0)), float(b.get("y", 0)) |
| w, h = float(b.get("width", 0)), float(b.get("height", 0)) |
| return x, y, x + w, y + h |
| if isinstance(b, (list, tuple)) and len(b) >= 4: |
| return float(b[0]), float(b[1]), float(b[2]), float(b[3]) |
| return None |
|
|
| def parse_bboxes(bboxes): |
| """Parse all bounding boxes from the annotation input""" |
| if not bboxes: |
| return None |
| |
| result = [] |
| for b in bboxes: |
| if isinstance(b, dict): |
| x, y = float(b.get("x", 0)), float(b.get("y", 0)) |
| w, h = float(b.get("width", 0)), float(b.get("height", 0)) |
| result.append([x, y, x + w, y + h]) |
| elif isinstance(b, (list, tuple)) and len(b) >= 4: |
| result.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])]) |
| |
| return result |
|
|
| def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray: |
| """Convert a 2D mask of instance IDs to a color image for visualization.""" |
| def hsv_to_rgb(h, s, v): |
| i = int(h * 6.0) |
| f = h * 6.0 - i |
| i = i % 6 |
| p = v * (1 - s) |
| q = v * (1 - f * s) |
| t = v * (1 - (1 - f) * s) |
| if i == 0: r, g, b = v, t, p |
| elif i == 1: r, g, b = q, v, p |
| elif i == 2: r, g, b = p, v, t |
| elif i == 3: r, g, b = p, q, v |
| elif i == 4: r, g, b = t, p, v |
| else: r, g, b = v, p, q |
| return int(r * 255), int(g * 255), int(b * 255) |
|
|
| palette = [(0, 0, 0)] |
| for i in range(1, num_colors): |
| h = (i % num_colors) / float(num_colors) |
| palette.append(hsv_to_rgb(h, 1.0, 0.95)) |
|
|
| palette_arr = np.array(palette, dtype=np.uint8) |
| color_idx = mask % num_colors |
| return palette_arr[color_idx] |
|
|
|
|
| def render_seg_overlay(img_np, inst_mask, overlay_alpha): |
| """Render segmentation overlay from cached image/mask.""" |
| if img_np is None or inst_mask is None: |
| return None |
|
|
| overlay = img_np.copy() |
| alpha = float(np.clip(overlay_alpha, 0.0, 1.0)) |
|
|
| for inst_id in np.unique(inst_mask): |
| if inst_id == 0: |
| continue |
| binary_mask = (inst_mask == inst_id).astype(np.uint8) |
| color = get_well_spaced_color(inst_id) |
| overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color |
|
|
| contours = measure.find_contours(binary_mask, 0.5) |
| for contour in contours: |
| contour = contour.astype(np.int32) |
| valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1) |
| valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1) |
| overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] |
|
|
| overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) |
| return Image.fromarray(overlay) |
|
|
|
|
| def render_count_overlay(img_np, density_normalized, overlay_alpha): |
| """Render counting heatmap overlay from cached image/density.""" |
| if img_np is None or density_normalized is None: |
| return None |
|
|
| alpha = float(np.clip(overlay_alpha, 0.0, 1.0)) |
| cmap = cm.get_cmap("jet") |
| density_colored = cmap(density_normalized)[:, :, :3] |
|
|
| overlay = img_np.copy() |
| threshold = 0.01 |
| significant_mask = density_normalized > threshold |
| overlay[significant_mask] = (1 - alpha) * overlay[significant_mask] + alpha * density_colored[significant_mask] |
| overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) |
| return Image.fromarray(overlay) |
|
|
|
|
| def update_seg_overlay_alpha(overlay_alpha, seg_vis_cache): |
| """Live update segmentation visualization without rerunning inference.""" |
| if not seg_vis_cache: |
| return None |
| return render_seg_overlay(seg_vis_cache.get("img_np"), seg_vis_cache.get("inst_mask"), overlay_alpha) |
|
|
|
|
| def update_count_overlay_alpha(overlay_alpha, count_vis_cache): |
| """Live update counting visualization without rerunning inference.""" |
| if not count_vis_cache: |
| return None |
| return render_count_overlay(count_vis_cache.get("img_np"), count_vis_cache.get("density_normalized"), overlay_alpha) |
|
|
|
|
| def update_tracking_overlay_alpha(overlay_alpha, track_vis_cache): |
| """Regenerate tracking visualization at new opacity using cached outputs.""" |
| if not track_vis_cache: |
| return None |
|
|
| tif_dir = track_vis_cache.get("tif_dir") |
| output_dir = track_vis_cache.get("output_dir") |
| valid_tif_files = track_vis_cache.get("valid_tif_files") |
| if not tif_dir or not output_dir or not valid_tif_files: |
| return None |
|
|
| try: |
| return create_tracking_visualization( |
| tif_dir=tif_dir, |
| output_dir=output_dir, |
| valid_tif_files=valid_tif_files, |
| overlay_alpha=overlay_alpha |
| ) |
| except Exception as e: |
| print(f"⚠️ Failed to update tracking opacity: {e}") |
| return None |
|
|
|
|
| def cleanup_tracking_cache(track_vis_cache): |
| """Delete cached tracking temp directories from the previous run.""" |
| if not track_vis_cache: |
| return |
| for key in ["input_temp_dir", "output_dir"]: |
| path = track_vis_cache.get(key) |
| if path and os.path.isdir(path): |
| try: |
| shutil.rmtree(path) |
| except Exception: |
| pass |
|
|
|
|
| |
| def segment_with_choice(use_box_choice, annot_value, overlay_alpha): |
| """Segmentation handler - supports bounding box, returns colorized overlay and original mask path""" |
| if annot_value is None or len(annot_value) < 1: |
| print("❌ No annotation input") |
| return None, None, {} |
|
|
| img_path = annot_value[0] |
| bboxes = annot_value[1] if len(annot_value) > 1 else [] |
|
|
| print(f"🖼️ Image path: {img_path}") |
| box_array = None |
| if use_box_choice == "Yes" and bboxes: |
| box = parse_bboxes(bboxes) |
| if box: |
| box_array = box |
| print(f"📦 Using bounding boxes: {box_array}") |
|
|
|
|
| try: |
| mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE) |
| print("📏 mask shape:", mask.shape, "dtype:", mask.dtype) |
| except Exception as e: |
| print(f"❌ Inference failed: {str(e)}") |
| return None, None, {} |
|
|
| temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif") |
| mask_img = Image.fromarray(mask.astype(np.uint16)) |
| mask_img.save(temp_mask_file.name) |
| print(f"💾 Original mask saved to: {temp_mask_file.name}") |
|
|
| try: |
| img = Image.open(img_path) |
| print("📷 Image mode:", img.mode, "size:", img.size) |
| except Exception as e: |
| print(f"❌ Failed to open image: {e}") |
| return None, None, {} |
|
|
| try: |
| img_rgb = img.convert("RGB").resize(mask.shape[::-1], resample=Image.BILINEAR) |
| img_np = np.array(img_rgb, dtype=np.float32) |
| if img_np.max() > 1.5: |
| img_np = img_np / 255.0 |
| except Exception as e: |
| print(f"❌ Error in image conversion/resizing: {e}") |
| return None, None, {} |
|
|
| mask_np = np.array(mask) |
| inst_mask = mask_np.astype(np.int32) |
| unique_ids = np.unique(inst_mask) |
| num_instances = len(unique_ids[unique_ids != 0]) |
| if num_instances == 0: |
| print("⚠️ No instance found, returning dummy red image") |
| return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None, {} |
|
|
| overlay_img = render_seg_overlay(img_np, inst_mask, overlay_alpha) |
| seg_vis_cache = {"img_np": img_np, "inst_mask": inst_mask} |
| return overlay_img, temp_mask_file.name, seg_vis_cache |
|
|
|
|
| |
| def count_cells_handler(use_box_choice, annot_value, overlay_alpha): |
| """Counting handler - supports bounding box, returns only density map""" |
| if annot_value is None or len(annot_value) < 1: |
| return None, None, "⚠️ Please provide an image.", {} |
| |
| image_path = annot_value[0] |
| bboxes = annot_value[1] if len(annot_value) > 1 else [] |
|
|
| print(f"🖼️ Image path: {image_path}") |
| box_array = None |
| if use_box_choice == "Yes" and bboxes: |
| box = parse_bboxes(bboxes) |
| if box: |
| box_array = box |
| print(f"📦 Using bounding boxes: {box_array}") |
| |
| try: |
| print(f"🔢 Counting - Image: {image_path}") |
| |
| result = run_count( |
| COUNT_MODEL, |
| image_path, |
| box=box_array, |
| device=COUNT_DEVICE, |
| visualize=True |
| ) |
| |
| if 'error' in result: |
| return None, None, f"❌ Counting failed: {result['error']}", {} |
| |
| count = result['count'] |
| density_map = result['density_map'] |
| temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy") |
| np.save(temp_density_file.name, density_map) |
| print(f"💾 Density map saved to {temp_density_file.name}") |
| |
|
|
| try: |
| img = Image.open(image_path) |
| print("📷 Image mode:", img.mode, "size:", img.size) |
| except Exception as e: |
| print(f"❌ Failed to open image: {e}") |
| return None, None, f"❌ Failed to open image: {str(e)}", {} |
|
|
| try: |
| img_rgb = img.convert("RGB").resize(density_map.shape[::-1], resample=Image.BILINEAR) |
| img_np = np.array(img_rgb, dtype=np.float32) |
| img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8) |
| if img_np.max() > 1.5: |
| img_np = img_np / 255.0 |
| except Exception as e: |
| print(f"❌ Error in image conversion/resizing: {e}") |
| return None, None, f"❌ Error in image conversion/resizing: {str(e)}", {} |
|
|
| |
| density_normalized = density_map.copy() |
| if density_normalized.max() > 0: |
| density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min()) |
| |
| overlay_img = render_count_overlay(img_np, density_normalized, overlay_alpha) |
| result_text = f"✅ Detected {round(count)} objects" |
| if use_box_choice == "Yes" and box_array: |
| result_text += f"\n📦 Using bounding box: {box_array}" |
| |
|
|
| print(f"✅ Counting done - Count: {count:.1f}") |
|
|
| count_vis_cache = {"img_np": img_np, "density_normalized": density_normalized} |
| return overlay_img, temp_density_file.name, result_text, count_vis_cache |
| |
| |
| except Exception as e: |
| print(f"❌ Counting error: {e}") |
| import traceback |
| traceback.print_exc() |
| return None, None, f"❌ Counting failed: {str(e)}", {} |
|
|
|
|
| def find_tif_dir(root_dir): |
| """Recursively find the first directory containing .tif files""" |
| for dirpath, _, filenames in os.walk(root_dir): |
| if '__MACOSX' in dirpath: |
| continue |
| if any(f.lower().endswith('.tif') for f in filenames): |
| return dirpath |
| return None |
|
|
| def is_valid_tiff(filepath): |
| """Check if a file is a valid TIFF image""" |
| try: |
| with Image.open(filepath) as img: |
| img.verify() |
| return True |
| except Exception as e: |
| return False |
|
|
| def find_valid_tif_dir(root_dir): |
| """Recursively find the first directory containing valid .tif files""" |
| for dirpath, dirnames, filenames in os.walk(root_dir): |
| if '__MACOSX' in dirpath: |
| continue |
| |
| potential_tifs = [ |
| os.path.join(dirpath, f) |
| for f in filenames |
| if f.lower().endswith(('.tif', '.tiff')) and not f.startswith('._') |
| ] |
| |
| if not potential_tifs: |
| continue |
| |
| valid_tifs = [f for f in potential_tifs if is_valid_tiff(f)] |
| |
| if valid_tifs: |
| print(f"✅ Found {len(valid_tifs)} valid TIFF files in: {dirpath}") |
| return dirpath |
| |
| return None |
|
|
| def create_ctc_results_zip(output_dir): |
| """ |
| Create a ZIP file with CTC format results |
| |
| Parameters: |
| ----------- |
| output_dir : str |
| Directory containing tracking results (res_track.txt, etc.) |
| |
| Returns: |
| -------- |
| zip_path : str |
| Path to created ZIP file |
| """ |
| |
| temp_zip_dir = tempfile.mkdtemp() |
| zip_filename = f"tracking_results_{time.strftime('%Y%m%d_%H%M%S')}.zip" |
| zip_path = os.path.join(temp_zip_dir, zip_filename) |
| |
| print(f"📦 Creating results ZIP: {zip_path}") |
| |
| |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
| |
| for root, dirs, files in os.walk(output_dir): |
| for file in files: |
| file_path = os.path.join(root, file) |
| arcname = os.path.relpath(file_path, output_dir) |
| zipf.write(file_path, arcname) |
| print(f" 📄 Added: {arcname}") |
| |
| |
| readme_content = f"""Tracking Results Summary |
| ======================== |
| |
| Generated: {time.strftime('%Y-%m-%d %H:%M:%S')} |
| |
| Files: |
| ------ |
| - res_track.txt: CTC format tracking data |
| Format: track_id start_frame end_frame parent_id |
| |
| - Segmentation masks |
| |
| For more information on CTC format: |
| http://celltrackingchallenge.net/ |
| """ |
| zipf.writestr("README.txt", readme_content) |
| |
| print(f"✅ ZIP created: {zip_path} ({os.path.getsize(zip_path) / 1024:.1f} KB)") |
| return zip_path |
|
|
|
|
| def get_well_spaced_color(track_id, num_colors=256): |
| """Generate well-spaced colors, using contrasting colors for adjacent IDs""" |
|
|
| golden_ratio = 0.618033988749895 |
| hue = (track_id * golden_ratio) % 1.0 |
|
|
| import colorsys |
| rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95) |
| return np.array(rgb) |
|
|
|
|
| def extract_first_frame(tif_dir): |
| """ |
| Extract the first frame from a directory of TIF files |
| |
| Returns: |
| -------- |
| first_frame_path : str |
| Path to the first TIF frame |
| """ |
| tif_files = natsorted(glob(os.path.join(tif_dir, "*.tif")) + |
| glob(os.path.join(tif_dir, "*.tiff"))) |
| valid_tif_files = [f for f in tif_files |
| if not os.path.basename(f).startswith('._') and is_valid_tiff(f)] |
| |
| if valid_tif_files: |
| return valid_tif_files[0] |
| return None |
|
|
| def create_tracking_visualization(tif_dir, output_dir, valid_tif_files, overlay_alpha=0.3): |
| """ |
| Create an animated GIF/video showing tracked objects with consistent colors |
| |
| Parameters: |
| ----------- |
| tif_dir : str |
| Directory containing input TIF frames |
| output_dir : str |
| Directory containing tracking results (masks) |
| valid_tif_files : list |
| List of valid TIF file paths |
| |
| Returns: |
| -------- |
| video_path : str |
| Path to generated visualization (GIF or first frame) |
| """ |
| import numpy as np |
| from matplotlib import colormaps |
| from skimage import measure |
| import tifffile |
| |
| |
| |
| mask_files = natsorted(glob(os.path.join(output_dir, "mask*.tif")) + |
| glob(os.path.join(output_dir, "man_track*.tif")) + |
| glob(os.path.join(output_dir, "*.tif"))) |
| |
| if not mask_files: |
| print("⚠️ No mask files found in output directory") |
| |
| return valid_tif_files[0] |
| |
| print(f"📊 Found {len(mask_files)} mask files") |
|
|
| |
| frames = [] |
| alpha = float(np.clip(overlay_alpha, 0.0, 1.0)) |
| |
| |
| num_frames = min(len(valid_tif_files), len(mask_files)) |
| for i in range(num_frames): |
| try: |
| |
| try: |
| img_np = tifffile.imread(valid_tif_files[i]) |
|
|
| |
| if img_np.dtype == np.uint8: |
| img_np = img_np.astype(np.float32) / 255.0 |
| elif img_np.dtype == np.uint16: |
| |
| img_min, img_max = img_np.min(), img_np.max() |
| if img_max > img_min: |
| img_np = (img_np.astype(np.float32) - img_min) / (img_max - img_min) |
| else: |
| img_np = img_np.astype(np.float32) / 65535.0 |
| else: |
| |
| img_np = img_np.astype(np.float32) |
| img_min, img_max = img_np.min(), img_np.max() |
| if img_max > img_min: |
| img_np = (img_np - img_min) / (img_max - img_min) |
| else: |
| img_np = np.clip(img_np, 0, 1) |
|
|
| |
| if img_np.ndim == 2: |
| img_np = np.stack([img_np]*3, axis=-1) |
| img_np = img_np.astype(np.float32) |
| if img_np.max() > 1.5: |
| img_np = img_np / 255.0 |
| except Exception as e: |
| print(f"⚠️ Error loading image frame {i}: {e}") |
| |
| img = Image.open(valid_tif_files[i]).convert("RGB") |
| img_np = np.array(img, dtype=np.float32) / 255.0 |
| |
| |
| try: |
| mask = tifffile.imread(mask_files[i]) |
| except Exception as e: |
| print(f"⚠️ Error loading mask frame {i}: {e}") |
| |
| mask = np.array(Image.open(mask_files[i])) |
| |
| |
| if mask.shape[:2] != img_np.shape[:2]: |
| from scipy.ndimage import zoom |
| zoom_factors = [img_np.shape[0] / mask.shape[0], img_np.shape[1] / mask.shape[1]] |
| mask = zoom(mask, zoom_factors, order=0).astype(mask.dtype) |
| |
| |
| overlay = img_np.copy() |
| |
| |
| track_ids = np.unique(mask) |
| track_ids = track_ids[track_ids != 0] |
| |
| |
| for track_id in track_ids: |
| |
| binary_mask = (mask == track_id) |
| |
| |
| |
| color = get_well_spaced_color(int(track_id)) |
| |
| |
| overlay[binary_mask] = (1 - alpha) * overlay[binary_mask] + alpha * color |
| |
| |
| try: |
| contours = measure.find_contours(binary_mask.astype(np.uint8), 0.5) |
| for contour in contours: |
| contour = contour.astype(np.int32) |
| valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1) |
| valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1) |
| overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] |
| except: |
| pass |
| |
| |
| overlay_uint8 = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) |
| frames.append(Image.fromarray(overlay_uint8)) |
| |
| if i % 10 == 0 or i == num_frames - 1: |
| print(f" 📸 Processed frame {i+1}/{num_frames}") |
| |
| except Exception as e: |
| print(f"⚠️ Error processing frame {i}: {e}") |
| import traceback |
| traceback.print_exc() |
| continue |
| |
| if not frames: |
| print("⚠️ No frames were processed successfully") |
| return valid_tif_files[0] |
| |
| |
| try: |
| temp_gif = tempfile.NamedTemporaryFile(delete=False, suffix=".gif") |
| frames[0].save( |
| temp_gif.name, |
| save_all=True, |
| append_images=frames[1:], |
| duration=200, |
| loop=0 |
| ) |
| temp_gif.close() |
| print(f"✅ Created tracking visualization GIF: {temp_gif.name}") |
| print(f" Size: {os.path.getsize(temp_gif.name)} bytes, Frames: {len(frames)}") |
| return temp_gif.name |
| except Exception as e: |
| print(f"⚠️ Failed to create GIF: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| try: |
| temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png") |
| frames[0].save(temp_img.name) |
| temp_img.close() |
| return temp_img.name |
| except: |
| return valid_tif_files[0] |
|
|
| |
| def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj, overlay_alpha, prev_track_vis_cache): |
| """ |
| Tracking handler - processes a ZIP of TIF frames, supports bounding box, returns visualization and results ZIP |
| |
| Parameters: |
| ----------- |
| use_box_choice : str |
| "Yes" or "No" - whether to use bounding box annotation for tracking |
| first_frame_annot : tuple or None |
| (image_path, bboxes) from BBoxAnnotator, only used if user annotated first frame |
| zip_file_obj : File |
| Uploaded ZIP file containing TIF sequence |
| """ |
| if zip_file_obj is None: |
| return None, "⚠️ Please upload a ZIP file containing video frames (.zip)", None, None, {} |
| |
| cleanup_tracking_cache(prev_track_vis_cache) |
| temp_dir = None |
| output_temp_dir = None |
| |
| try: |
| |
| box_array = None |
| if use_box_choice == "Yes" and first_frame_annot is not None: |
| if isinstance(first_frame_annot, (list, tuple)) and len(first_frame_annot) > 1: |
| bboxes = first_frame_annot[1] |
| if bboxes: |
| box = parse_bboxes(bboxes) |
| if box: |
| box_array = box |
| print(f"📦 Using bounding boxes: {box_array}") |
| |
| |
| temp_dir = tempfile.mkdtemp() |
| print(f"\n📦 Extracting to temporary directory: {temp_dir}") |
|
|
| with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref: |
| extracted_count = 0 |
| skipped_count = 0 |
| |
| for member in zip_ref.namelist(): |
| basename = os.path.basename(member) |
| |
| if ('__MACOSX' in member or |
| basename.startswith('._') or |
| basename.startswith('.DS_Store') or |
| member.endswith('/')): |
| skipped_count += 1 |
| continue |
| |
| try: |
| zip_ref.extract(member, temp_dir) |
| extracted_count += 1 |
| if basename.lower().endswith(('.tif', '.tiff')): |
| print(f"📄 Extracted TIFF: {basename}") |
| except Exception as e: |
| print(f"⚠️ Failed to extract {member}: {e}") |
|
|
| print(f"\n📊 Extracted: {extracted_count} files, Skipped: {skipped_count} files") |
|
|
| |
| tif_dir = find_valid_tif_dir(temp_dir) |
| |
| if tif_dir is None: |
| return None, "❌ Did not find valid TIF directory", None, None, {} |
| |
| |
| tif_files = natsorted(glob(os.path.join(tif_dir, "*.tif")) + |
| glob(os.path.join(tif_dir, "*.tiff"))) |
| valid_tif_files = [f for f in tif_files |
| if not os.path.basename(f).startswith('._') and is_valid_tiff(f)] |
| |
| if len(valid_tif_files) == 0: |
| return None, "❌ Did not find valid TIF files", None, None, {} |
|
|
| print(f"📈 Using {len(valid_tif_files)} TIF files") |
|
|
| |
| first_frame_path = valid_tif_files[0] |
|
|
| |
| output_temp_dir = tempfile.mkdtemp() |
| print(f"💾 CTC-format results will be saved to: {output_temp_dir}") |
|
|
| |
| result = run_track( |
| TRACK_MODEL, |
| video_dir=tif_dir, |
| box=box_array, |
| device=TRACK_DEVICE, |
| output_dir=output_temp_dir |
| ) |
| |
| if 'error' in result: |
| return None, f"❌ Tracking failed: {result['error']}", None, None, {} |
| |
| |
| print("\n🎬 Creating tracking visualization...") |
| try: |
| tracking_video = create_tracking_visualization( |
| tif_dir, |
| output_temp_dir, |
| valid_tif_files, |
| overlay_alpha=overlay_alpha |
| ) |
| except Exception as e: |
| print(f"⚠️ Failed to create visualization: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| try: |
| tracking_video = Image.open(first_frame_path) |
| except: |
| tracking_video = None |
| |
| |
| try: |
| results_zip = create_ctc_results_zip(output_temp_dir) |
| except Exception as e: |
| print(f"⚠️ Failed to create ZIP: {e}") |
| results_zip = None |
| |
| bbox_info = "" |
| if box_array: |
| bbox_info = f"\n🔲 Using bounding box: [{box_array[0][0]}, {box_array[0][1]}, {box_array[0][2]}, {box_array[0][3]}]" |
|
|
| result_text = f"""✅ Tracking completed! |
| |
| 🖼️ Processed frames: {len(valid_tif_files)}{bbox_info} |
| |
| 📥 Click the button below to download CTC-format results |
| The results include: |
| - res_track.txt (CTC-format tracking data) |
| - Other tracking-related files |
| - README.txt (Results description) |
| """ |
|
|
| if use_box_choice == "Yes" and box_array: |
| result_text += f"\n📦 Using bounding box: {box_array}" |
|
|
| print(f"\n✅ Tracking completed") |
|
|
| track_vis_cache = { |
| "tif_dir": tif_dir, |
| "valid_tif_files": valid_tif_files, |
| "output_dir": output_temp_dir, |
| "input_temp_dir": temp_dir, |
| } |
|
|
| return results_zip, result_text, gr.update(visible=True), tracking_video, track_vis_cache |
|
|
| except zipfile.BadZipFile: |
| return None, "❌ Not a valid ZIP file", None, None, {} |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| |
| |
| for d in [temp_dir, output_temp_dir]: |
| if d: |
| try: |
| shutil.rmtree(d) |
| except: |
| pass |
| return None, f"❌ Tracking failed: {str(e)}", None, None, {} |
|
|
|
|
|
|
| |
| example_images_seg = [f for f in glob("example_imgs/seg/*")] |
| example_images_cnt = [f for f in glob("example_imgs/cnt/*")] |
| example_tracking_zips = [f for f in glob("example_imgs/tra/*.zip")] |
|
|
| |
| with gr.Blocks( |
| title="Microscopy Analysis Suite", |
| theme=gr.themes.Soft(), |
| css=""" |
| .tabs button { |
| font-size: 18px !important; |
| font-weight: 600 !important; |
| padding: 12px 20px !important; |
| } |
| .uniform-height { |
| height: 500px !important; |
| display: flex !important; |
| align-items: center !important; |
| justify-content: center !important; |
| } |
| |
| .uniform-height img, |
| .uniform-height canvas { |
| max-height: 500px !important; |
| object-fit: contain !important; |
| } |
| |
| #density_map_output { |
| height: 500px !important; |
| } |
| |
| #density_map_output .image-container { |
| height: 500px !important; |
| } |
| |
| #density_map_output img { |
| height: 480px !important; |
| width: auto !important; |
| max-width: 90% !important; |
| object-fit: contain !important; |
| } |
| """ |
| ) as demo: |
| gr.Markdown( |
| """ |
| # 🔬 Microscopy Image Analysis Suite |
| |
| Supporting three key tasks: |
| - 🎨 **Segmentation**: Instance segmentation of microscopic objects |
| - 🔢 **Counting**: Counting microscopic objects based on density maps |
| - 🎬 **Tracking**: Tracking microscopic objects in video sequences |
| """ |
| ) |
| |
| |
| current_query_id = gr.State(str(uuid.uuid4())) |
| user_uploaded_examples = gr.State(example_images_seg.copy()) |
| seg_vis_state = gr.State({}) |
| count_vis_state = gr.State({}) |
| track_vis_state = gr.State({}) |
| |
| with gr.Tabs(): |
| |
| with gr.Tab("🎨 Segmentation"): |
| gr.Markdown("## Instance Segmentation of Microscopic Objects") |
| gr.Markdown( |
| """ |
| **Instructions:** |
| 1. Upload an image or select an example image (supports various formats: .png, .jpg, .tif) |
| 2. (Optional) Specify a target object with a bounding box and select "Yes", or click "Run Segmentation" directly |
| 3. Click "Run Segmentation" |
| 4. View the segmentation results, download the original predicted mask (.tif format); if needed, click "Clear Selection" to choose a new image |
| |
| 🤘 Rate and submit feedback to help us improve the model! |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| annotator = BBoxAnnotator( |
| label="🖼️ Upload Image (Optional: Provide a Bounding Box)", |
| categories=["cell"], |
| ) |
| |
| |
| example_gallery = gr.Gallery( |
| label="📁 Example Image Gallery", |
| columns=len(example_images_seg), |
| rows=1, |
| height=120, |
| object_fit="cover", |
| show_download_button=False |
| ) |
| |
| |
| with gr.Row(): |
| use_box_radio = gr.Radio( |
| choices=["Yes", "No"], |
| value="No", |
| label="🔲 Specify Bounding Box?" |
| ) |
| with gr.Row(): |
| run_seg_btn = gr.Button("▶️ Run Segmentation", variant="primary", size="lg") |
| clear_btn = gr.Button("🔄 Clear Selection", variant="secondary") |
|
|
| |
| image_uploader = gr.Image( |
| label="➕ Upload New Example Image to Gallery", |
| type="filepath" |
| ) |
|
|
|
|
| with gr.Column(scale=2): |
| seg_output = gr.Image( |
| type="pil", |
| label="📸 Segmentation Result", |
| elem_classes="uniform-height" |
| ) |
| seg_alpha_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| step=0.05, |
| value=0.5, |
| label="🎚️ Overlay Opacity" |
| ) |
| |
| |
| download_mask_btn = gr.File( |
| label="📥 Download Original Prediction (.tif format)", |
| visible=True, |
| height=40, |
| ) |
|
|
| |
| score_slider = gr.Slider( |
| minimum=1, |
| maximum=5, |
| step=1, |
| value=5, |
| label="🌟 Satisfaction Rating (1-5)" |
| ) |
|
|
| |
| feedback_box = gr.Textbox( |
| placeholder="Please enter your feedback...", |
| lines=2, |
| label="💬 Feedback" |
| ) |
|
|
| |
| submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary") |
|
|
| feedback_status = gr.Textbox( |
| label="✅ Submission Status", |
| lines=1, |
| visible=False |
| ) |
| |
| |
| run_seg_btn.click( |
| fn=segment_with_choice, |
| inputs=[use_box_radio, annotator, seg_alpha_slider], |
| outputs=[seg_output, download_mask_btn, seg_vis_state] |
| ) |
| seg_alpha_slider.input( |
| fn=update_seg_overlay_alpha, |
| inputs=[seg_alpha_slider, seg_vis_state], |
| outputs=seg_output |
| ) |
|
|
| |
| clear_btn.click( |
| fn=lambda: (None, {}), |
| inputs=None, |
| outputs=[annotator, seg_vis_state] |
| ) |
| |
| |
| demo.load( |
| fn=lambda: example_images_seg.copy(), |
| outputs=example_gallery |
| ) |
| |
| |
| def add_to_gallery(img_path, current_imgs): |
| if not img_path: |
| return current_imgs |
| try: |
| if img_path not in current_imgs: |
| current_imgs.append(img_path) |
| return current_imgs |
| except: |
| return current_imgs |
| |
| image_uploader.change( |
| fn=add_to_gallery, |
| inputs=[image_uploader, user_uploaded_examples], |
| outputs=user_uploaded_examples |
| ).then( |
| fn=lambda imgs: imgs, |
| inputs=user_uploaded_examples, |
| outputs=example_gallery |
| ) |
| |
| |
| def load_from_gallery(evt: gr.SelectData, all_imgs): |
| if evt.index is not None and evt.index < len(all_imgs): |
| return all_imgs[evt.index] |
| return None |
| |
| example_gallery.select( |
| fn=load_from_gallery, |
| inputs=user_uploaded_examples, |
| outputs=annotator |
| ) |
| |
| |
| def submit_user_feedback(query_id, score, comment, annot_val): |
| try: |
| img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None |
| bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else [] |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| save_feedback_to_hf( |
| query_id=query_id, |
| feedback_type=f"score_{int(score)}", |
| feedback_text=comment, |
| img_path=img_path, |
| bboxes=bboxes |
| ) |
| return "✅ Feedback submitted, thank you!", gr.update(visible=True) |
| except Exception as e: |
| return f"❌ Submission failed: {str(e)}", gr.update(visible=True) |
|
|
| submit_feedback_btn.click( |
| fn=submit_user_feedback, |
| inputs=[current_query_id, score_slider, feedback_box, annotator], |
| outputs=[feedback_status, feedback_status] |
| ) |
| |
| |
| with gr.Tab("🔢 Counting"): |
| gr.Markdown("## Microscopy Object Counting Analysis") |
| gr.Markdown( |
| """ |
| **Usage Instructions:** |
| 1. Upload an image or select an example image (supports multiple formats: .png, .jpg, .tif) |
| 2. (Optional) Specify a target object with a bounding box and select "Yes", or click "Run Counting" directly |
| 3. Click "Run Counting" |
| 4. View the density map, download the original prediction (.npy format); if needed, click "Clear Selection" to choose a new image to run |
| |
| 🤘 Rate and submit feedback to help us improve the model! |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| count_annotator = BBoxAnnotator( |
| label="🖼️ Upload Image (Optional: Provide a Bounding Box)", |
| categories=["cell"], |
| ) |
| |
| |
| with gr.Row(): |
| count_example_gallery = gr.Gallery( |
| label="📁 Example Image Gallery", |
| columns=len(example_images_cnt), |
| rows=1, |
| object_fit="cover", |
| height=120, |
| value=example_images_cnt.copy(), |
| show_download_button=False |
| ) |
| |
| |
| with gr.Row(): |
| count_use_box_radio = gr.Radio( |
| choices=["Yes", "No"], |
| value="No", |
| label="🔲 Specify Bounding Box?" |
| ) |
|
|
| with gr.Row(): |
| count_btn = gr.Button("▶️ Run Counting", variant="primary", size="lg") |
| clear_btn = gr.Button("🔄 Clear Selection", variant="secondary") |
| |
| |
| with gr.Row(): |
| count_image_uploader = gr.File( |
| label="➕ Add Example Image to Gallery", |
| file_types=["image"], |
| type="filepath" |
| ) |
|
|
| |
| with gr.Column(scale=2): |
| count_output = gr.Image( |
| label="📸 Density Map", |
| type="filepath", |
| elem_id="density_map_output" |
| |
| ) |
| count_alpha_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| step=0.05, |
| value=0.3, |
| label="🎚️ Heatmap Opacity" |
| ) |
| count_status = gr.Textbox( |
| label="📊 Statistics", |
| lines=2 |
| ) |
| download_density_btn = gr.File( |
| label="📥 Download Original Prediction (.npy format)", |
| visible=True |
| ) |
|
|
| |
| score_slider = gr.Slider( |
| minimum=1, |
| maximum=5, |
| step=1, |
| value=5, |
| label="🌟 Satisfaction Rating (1-5)" |
| ) |
|
|
| |
| feedback_box = gr.Textbox( |
| placeholder="Please enter your feedback...", |
| lines=2, |
| label="💬 Feedback" |
| ) |
|
|
| |
| submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary") |
|
|
| feedback_status = gr.Textbox( |
| label="✅ Submission Status", |
| lines=1, |
| visible=False |
| ) |
| |
| |
| count_user_examples = gr.State(example_images_cnt.copy()) |
| |
| |
| def add_to_count_gallery(new_img_file, current_imgs): |
| """Add uploaded image to gallery""" |
| if new_img_file is None: |
| return current_imgs, current_imgs |
| |
| try: |
| |
| if new_img_file not in current_imgs: |
| current_imgs.append(new_img_file) |
| print(f"✅ Added image to gallery: {new_img_file}") |
| except Exception as e: |
| print(f"⚠️ Failed to add image: {e}") |
| |
| return current_imgs, current_imgs |
| |
| |
| count_image_uploader.upload( |
| fn=add_to_count_gallery, |
| inputs=[count_image_uploader, count_user_examples], |
| outputs=[count_user_examples, count_example_gallery] |
| ) |
| |
| |
| def load_from_count_gallery(evt: gr.SelectData, all_imgs): |
| """Load selected image from gallery into annotator""" |
| if evt.index is not None and evt.index < len(all_imgs): |
| selected_img = all_imgs[evt.index] |
| print(f"📸 Loading image from gallery: {selected_img}") |
| return selected_img |
| return None |
| |
| count_example_gallery.select( |
| fn=load_from_count_gallery, |
| inputs=count_user_examples, |
| outputs=count_annotator |
| ) |
| |
| |
| count_btn.click( |
| fn=count_cells_handler, |
| inputs=[count_use_box_radio, count_annotator, count_alpha_slider], |
| outputs=[count_output, download_density_btn, count_status, count_vis_state] |
| ) |
| count_alpha_slider.input( |
| fn=update_count_overlay_alpha, |
| inputs=[count_alpha_slider, count_vis_state], |
| outputs=count_output |
| ) |
|
|
| |
| clear_btn.click( |
| fn=lambda: (None, {}), |
| inputs=None, |
| outputs=[count_annotator, count_vis_state] |
| ) |
|
|
| |
| def submit_user_feedback(query_id, score, comment, annot_val): |
| try: |
| img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None |
| bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else [] |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| save_feedback_to_hf( |
| query_id=query_id, |
| feedback_type=f"score_{int(score)}", |
| feedback_text=comment, |
| img_path=img_path, |
| bboxes=bboxes |
| ) |
| return "✅ Feedback submitted successfully, thank you!", gr.update(visible=True) |
| except Exception as e: |
| return f"❌ Submission failed: {str(e)}", gr.update(visible=True) |
|
|
| submit_feedback_btn.click( |
| fn=submit_user_feedback, |
| inputs=[current_query_id, score_slider, feedback_box, annotator], |
| outputs=[feedback_status, feedback_status] |
| ) |
| |
| |
| with gr.Tab("🎬 Tracking"): |
| gr.Markdown("## Microscopy Object Video Tracking - Supports ZIP Upload") |
| gr.Markdown( |
| """ |
| **Instructions:** |
| 1. Upload a ZIP file or select from the example library. The ZIP should contain a sequence of TIF images named in chronological order (e.g., t000.tif, t001.tif...) |
| 2. (Optional) Specify a target object with a bounding box on the first frame and select "Yes", or click "Run Tracking" directly |
| 3. Click "Run Tracking" |
| 4. Download the CTC format results; if needed, click "Clear Selection" to choose a new ZIP file to run |
| |
| 🤘 Rate and submit feedback to help us improve the model! |
| |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| track_zip_upload = gr.File( |
| label="📦 Upload Image Sequence in ZIP File", |
| file_types=[".zip"] |
| ) |
|
|
| |
| track_first_frame_annotator = BBoxAnnotator( |
| label="🖼️ (Optional) First Frame Bounding Box Annotation", |
| categories=["cell"], |
| visible=False, |
| ) |
|
|
| |
| track_example_gallery = gr.Gallery( |
| label="📁 Example Video Gallery (Click to Select)", |
| columns=10, |
| rows=1, |
| height=120, |
| object_fit="contain", |
| show_download_button=False |
| ) |
| |
| with gr.Row(): |
| track_use_box_radio = gr.Radio( |
| choices=["Yes", "No"], |
| value="No", |
| label="🔲 Specify Bounding Box?" |
| ) |
|
|
| with gr.Row(): |
| track_btn = gr.Button("▶️ Run Tracking", variant="primary", size="lg") |
| clear_btn = gr.Button("🔄 Clear Selection", variant="secondary") |
| |
| |
| track_gallery_upload = gr.File( |
| label="➕ Add ZIP to Example Gallery", |
| file_types=[".zip"], |
| type="filepath" |
| ) |
| |
| with gr.Column(scale=2): |
| track_first_frame_preview = gr.Image( |
| label="📸 Tracking Visualization", |
| type="filepath", |
| |
| elem_classes="uniform-height", |
| interactive=False |
| ) |
| track_alpha_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| step=0.05, |
| value=0.3, |
| label="🎚️ Overlay Opacity" |
| ) |
| |
| track_output = gr.Textbox( |
| label="📊 Tracking Information", |
| lines=8, |
| interactive=False |
| ) |
| |
| track_download = gr.File( |
| label="📥 Download Tracking Results (CTC Format)", |
| visible=False |
| ) |
|
|
| |
| score_slider = gr.Slider( |
| minimum=1, |
| maximum=5, |
| step=1, |
| value=5, |
| label="🌟 Satisfaction Rating (1-5)" |
| ) |
|
|
| |
| feedback_box = gr.Textbox( |
| placeholder="Please enter your feedback...", |
| lines=2, |
| label="💬 Feedback" |
| ) |
|
|
| |
| submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary") |
|
|
| feedback_status = gr.Textbox( |
| label="✅ Submission Status", |
| lines=1, |
| visible=False |
| ) |
| |
| |
| track_user_examples = gr.State(example_tracking_zips.copy()) |
| |
| |
| def get_zip_preview(zip_path): |
| """Extract first frame from ZIP for gallery preview""" |
| try: |
| temp_dir = tempfile.mkdtemp() |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| for member in zip_ref.namelist(): |
| basename = os.path.basename(member) |
| if ('__MACOSX' not in member and |
| not basename.startswith('._') and |
| basename.lower().endswith(('.tif', '.tiff', '.png', '.jpg'))): |
| zip_ref.extract(member, temp_dir) |
| extracted_path = os.path.join(temp_dir, member) |
| |
| |
| import tifffile |
| import numpy as np |
| |
| img_np = tifffile.imread(extracted_path) |
| if img_np.dtype == np.uint16: |
| img_min, img_max = img_np.min(), img_np.max() |
| if img_max > img_min: |
| img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8) |
| |
| if img_np.ndim == 2: |
| img_np = np.stack([img_np]*3, axis=-1) |
| |
| |
| preview_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png") |
| Image.fromarray(img_np).save(preview_path.name) |
| return preview_path.name |
| except: |
| pass |
| return None |
| |
| |
| def init_tracking_gallery(): |
| """Create preview images for ZIP examples""" |
| previews = [] |
| for zip_path in example_tracking_zips: |
| if os.path.exists(zip_path): |
| preview = get_zip_preview(zip_path) |
| if preview: |
| previews.append(preview) |
| return previews |
| |
| |
| demo.load( |
| fn=init_tracking_gallery, |
| outputs=track_example_gallery |
| ) |
| |
| |
| def add_zip_to_gallery(zip_path, current_zips): |
| if not zip_path: |
| return current_zips, track_example_gallery |
| try: |
| if zip_path not in current_zips: |
| current_zips.append(zip_path) |
| print(f"✅ Added ZIP to gallery: {zip_path}") |
| |
| previews = [] |
| for zp in current_zips: |
| preview = get_zip_preview(zp) |
| if preview: |
| previews.append(preview) |
| return current_zips, previews |
| except Exception as e: |
| print(f"⚠️ Error: {e}") |
| return current_zips, [] |
| |
| track_gallery_upload.upload( |
| fn=add_zip_to_gallery, |
| inputs=[track_gallery_upload, track_user_examples], |
| outputs=[track_user_examples, track_example_gallery] |
| ) |
| |
| |
| def load_zip_from_gallery(evt: gr.SelectData, all_zips): |
| if evt.index is not None and evt.index < len(all_zips): |
| selected_zip = all_zips[evt.index] |
| print(f"📁 Selected ZIP from gallery: {selected_zip}") |
| return selected_zip |
| return None |
| |
| track_example_gallery.select( |
| fn=load_zip_from_gallery, |
| inputs=track_user_examples, |
| outputs=track_zip_upload |
| ) |
|
|
| |
| def load_first_frame_for_annotation(zip_file_obj): |
| '''Load and normalize first frame from ZIP for annotation''' |
| if zip_file_obj is None: |
| return None, gr.update(visible=False) |
| |
| import tifffile |
| import numpy as np |
| |
| try: |
| temp_dir = tempfile.mkdtemp() |
| with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref: |
| for member in zip_ref.namelist(): |
| basename = os.path.basename(member) |
| if ('__MACOSX' not in member and |
| not basename.startswith('._') and |
| basename.lower().endswith(('.tif', '.tiff'))): |
| zip_ref.extract(member, temp_dir) |
| |
| tif_dir = find_valid_tif_dir(temp_dir) |
| if tif_dir: |
| first_frame = extract_first_frame(tif_dir) |
| if first_frame: |
| |
| try: |
| img_np = tifffile.imread(first_frame) |
| |
| |
| if img_np.dtype == np.uint8: |
| pass |
| elif img_np.dtype == np.uint16: |
| |
| img_min, img_max = img_np.min(), img_np.max() |
| if img_max > img_min: |
| img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8) |
| else: |
| img_np = (img_np.astype(np.float32) / 65535.0 * 255).astype(np.uint8) |
| else: |
| |
| img_np = img_np.astype(np.float32) |
| img_min, img_max = img_np.min(), img_np.max() |
| if img_max > img_min: |
| img_np = ((img_np - img_min) / (img_max - img_min) * 255).astype(np.uint8) |
| else: |
| img_np = np.clip(img_np * 255, 0, 255).astype(np.uint8) |
| |
| |
| if img_np.ndim == 2: |
| img_np = np.stack([img_np]*3, axis=-1) |
| elif img_np.ndim == 3 and img_np.shape[2] > 3: |
| img_np = img_np[:, :, :3] |
| |
| |
| temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png") |
| Image.fromarray(img_np).save(temp_img.name) |
| |
| print(f"✅ Loaded and normalized first frame: {first_frame}") |
| print(f" Original dtype: {tifffile.imread(first_frame).dtype}") |
| print(f" Normalized to uint8 RGB for annotation") |
| |
| return temp_img.name, gr.update(visible=True) |
| except Exception as e: |
| print(f"⚠️ Error normalizing first frame: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| return first_frame, gr.update(visible=True) |
| except Exception as e: |
| print(f"⚠️ Error loading first frame: {e}") |
| import traceback |
| traceback.print_exc() |
| return None, gr.update(visible=False) |
| |
| |
| track_zip_upload.change( |
| fn=load_first_frame_for_annotation, |
| inputs=track_zip_upload, |
| outputs=[track_first_frame_annotator, track_first_frame_annotator] |
| ) |
| |
| |
| track_btn.click( |
| fn=track_video_handler, |
| inputs=[track_use_box_radio, track_first_frame_annotator, track_zip_upload, track_alpha_slider, track_vis_state], |
| outputs=[track_download, track_output, track_download, track_first_frame_preview, track_vis_state] |
| ) |
| track_alpha_slider.change( |
| fn=update_tracking_overlay_alpha, |
| inputs=[track_alpha_slider, track_vis_state], |
| outputs=track_first_frame_preview |
| ) |
|
|
| |
| clear_btn.click( |
| fn=lambda: (None, {}), |
| inputs=None, |
| outputs=[track_first_frame_annotator, track_vis_state] |
| ) |
|
|
| |
| def submit_user_feedback(query_id, score, comment, annot_val): |
| try: |
| img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None |
| bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else [] |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| save_feedback_to_hf( |
| query_id=query_id, |
| feedback_type=f"score_{int(score)}", |
| feedback_text=comment, |
| img_path=img_path, |
| bboxes=bboxes |
| ) |
| return "✅ Feedback submitted successfully, thank you!", gr.update(visible=True) |
| except Exception as e: |
| return f"❌ Submission failed: {str(e)}", gr.update(visible=True) |
|
|
| submit_feedback_btn.click( |
| fn=submit_user_feedback, |
| inputs=[current_query_id, score_slider, feedback_box, annotator], |
| outputs=[feedback_status, feedback_status] |
| ) |
| |
| gr.Markdown( |
| """ |
| --- |
| ### 📒 Note: |
| |
| This project is currently available with usage limits for research trial use and feedback collection. We plan to release a free public version in the future. We are actively improving the toolkit and greatly appreciate your feedback! |
| |
| |
| |
| ### 💡 Technical Details |
| |
| **MicroscopyMatching** - A general-purpose microscopy image analysis toolkit based on Stable Diffusion |
| |
| """ |
| ) |
|
|
| if __name__ == "__main__": |
| demo.queue().launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| ssr_mode=False, |
| show_error=True, |
| ) |
|
|