import gradio as gr import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Rectangle from pathlib import Path from skimage import io, measure, color, segmentation import os import warnings from PIL import Image import pandas as pd try: from cellpose import models CELLPOSE_AVAILABLE = True except ImportError: CELLPOSE_AVAILABLE = False try: from ultralytics import YOLO YOLO_AVAILABLE = True except ImportError: YOLO_AVAILABLE = False # Configuration IMAGE_FOLDER = "./imgs" CSV_FILE = "train.csv" # Category names mapping (0-27) CATEGORY_NAMES = { 0: "Nucleoplasm", 1: "Nuclear membrane", 2: "Nucleoli", 3: "Nucleoli fibrillar center", 4: "Nuclear speckles", 5: "Nuclear bodies", 6: "Endoplasmic reticulum", 7: "Golgi apparatus", 8: "Peroxisomes", 9: "Endosomes", 10: "Lysosomes", 11: "Intermediate filaments", 12: "Actin filaments", 13: "Focal adhesion sites", 14: "Microtubules", 15: "Microtubule ends", 16: "Cytokinetic bridge", 17: "Mitotic spindle", 18: "Microtubule organizing center", 19: "Centrosome", 20: "Lipid droplets", 21: "Plasma membrane", 22: "Cell junctions", 23: "Mitochondria", 24: "Aggresome", 25: "Cytosol", 26: "Cytoplasmic bodies", 27: "Rods & rings" } # Global state class AppState: def __init__(self): self.image_files = [] self.selected_image = None self.current_image = None self.masks = None self.cell_properties = [] self.cellpose_model = None self.yolo_model = None self.current_model_type = None self.selected_cell = None self.csv_data = None self.image_categories = {} state = AppState() def extract_image_id(filename): """Extract image ID from filename.""" basename = os.path.basename(filename) name_without_ext = os.path.splitext(basename)[0] for color in ['_blue', '_green', '_red', '_yellow']: if name_without_ext.endswith(color): return name_without_ext.replace(color, '') return name_without_ext def load_csv_data(): """Auto-load CSV file.""" if not os.path.exists(CSV_FILE): return try: state.csv_data = pd.read_csv(CSV_FILE) state.image_categories = {} for _, row in state.csv_data.iterrows(): img_id = row['Id'] target = str(row['Target']) category_indices = [int(x) for x in target.split()] category_names = [CATEGORY_NAMES.get(idx, f"Unknown-{idx}") for idx in category_indices] state.image_categories[img_id] = { 'indices': category_indices, 'names': category_names } except Exception as e: print(f"Could not load CSV: {e}") def scan_folder(): """Auto-scan folder for images.""" if not os.path.exists(IMAGE_FOLDER) or not os.path.isdir(IMAGE_FOLDER): return None try: extensions = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'} state.image_files = [] for f in sorted(Path(IMAGE_FOLDER).iterdir()): if f.suffix.lower() in extensions: state.image_files.append(str(f)) if len(state.image_files) == 0: return None # Generate gallery gallery_items = [(img, os.path.basename(img)) for img in state.image_files] return gallery_items except Exception as e: print(f"Scan error: {e}") return None def prepare_image_for_yolo(image): """Convert grayscale to RGB for YOLO.""" if image.ndim == 2: return np.stack([image, image, image], axis=-1) elif image.ndim == 3 and image.shape[2] == 3: return image elif image.ndim == 3 and image.shape[2] == 1: gray = image[:, :, 0] return np.stack([gray, gray, gray], axis=-1) return image def select_image_from_gallery(evt: gr.SelectData): """Handle image selection from gallery.""" if not state.image_files or evt.index >= len(state.image_files): return None, "Invalid selection", "", gr.update(choices=[]) state.selected_image = state.image_files[evt.index] try: with warnings.catch_warnings(): warnings.simplefilter("ignore") state.current_image = io.imread(state.selected_image) if state.current_image.dtype == np.uint16: state.current_image = ((state.current_image / state.current_image.max()) * 255).astype(np.uint8) # Reset segmentation state.masks = None state.cell_properties = [] state.selected_cell = None # Get categories categories_text = get_image_categories() # Show original image fig = create_visualization(show_numbers=False) return fig, f"Loaded: {os.path.basename(state.selected_image)}", categories_text, gr.update(choices=[]) except Exception as e: return None, f"Load failed: {str(e)}", "", gr.update(choices=[]) def get_image_categories(): """Get category information for selected image.""" if not state.image_categories or not state.selected_image: return "" img_id = extract_image_id(state.selected_image) categories = state.image_categories.get(img_id) if categories: result = "Image Categories\n" + "=" * 30 + "\n" for idx, name in zip(categories['indices'], categories['names']): result += f"[{idx}] {name}\n" return result return "" def run_cellpose_segmentation(model_type, diameter, use_gpu): """Run Cellpose segmentation.""" if state.current_image is None: return None, "No image selected", gr.update(choices=[]) if not CELLPOSE_AVAILABLE: return None, "Cellpose not installed", gr.update(choices=[]) try: with warnings.catch_warnings(): warnings.simplefilter("ignore") # Parse diameter if diameter == "auto": diam = None else: try: diam = float(diameter) except: diam = None # Load model if state.cellpose_model is None or state.current_model_type != model_type: state.cellpose_model = models.CellposeModel( gpu=use_gpu, model_type=model_type ) state.current_model_type = model_type # Run segmentation channels = [0, 0] state.masks, flows, styles = state.cellpose_model.eval( state.current_image, diameter=diam, channels=channels ) if state.masks is None or state.masks.max() == 0: return None, "No cells detected", gr.update(choices=[]) return finalize_segmentation() except Exception as e: return None, f"Error: {str(e)}", gr.update(choices=[]) def run_yolo_segmentation(model_path, confidence, iou, use_gpu): """Run YOLO segmentation.""" if state.current_image is None: return None, "No image selected", gr.update(choices=[]) if not YOLO_AVAILABLE: return None, "YOLO not installed", gr.update(choices=[]) try: with warnings.catch_warnings(): warnings.simplefilter("ignore") # Load model if state.yolo_model is None or state.current_model_type != model_path: state.yolo_model = YOLO(model_path) state.current_model_type = model_path device = 'cuda' if use_gpu else 'cpu' yolo_image = prepare_image_for_yolo(state.current_image) # Run prediction results = state.yolo_model.predict( yolo_image, conf=confidence, iou=iou, device=device, verbose=False ) # Convert to masks state.masks = yolo_results_to_masks(results[0]) if state.masks is None or state.masks.max() == 0: return None, "No objects detected", gr.update(choices=[]) return finalize_segmentation() except Exception as e: return None, f"Error: {str(e)}", gr.update(choices=[]) def yolo_results_to_masks(result): """Convert YOLO results to mask format.""" if result.masks is None: return None h, w = state.current_image.shape[:2] combined_mask = np.zeros((h, w), dtype=np.int32) masks = result.masks.data.cpu().numpy() for idx, mask in enumerate(masks, start=1): mask_resized = np.array(Image.fromarray(mask).resize((w, h), Image.NEAREST)) combined_mask[mask_resized > 0.5] = idx return combined_mask def finalize_segmentation(): """Finalize segmentation (common for both methods).""" try: if state.current_image.ndim == 3: from skimage.color import rgb2gray intensity = (rgb2gray(state.current_image) * 255).astype(np.uint8) else: intensity = state.current_image state.cell_properties = measure.regionprops(state.masks, intensity_image=intensity) # Create visualization fig = create_visualization(show_numbers=False) # Create cell list cell_list = [f"Cell {prop.label} | Area: {prop.area}px²" for prop in state.cell_properties] return fig, f"{state.masks.max()} cells detected", gr.update(choices=cell_list) except Exception as e: return None, f"Error: {str(e)}", gr.update(choices=[]) def create_visualization(show_numbers=False, highlight_cell=None): """Create segmentation visualization.""" if state.current_image is None: return None try: with warnings.catch_warnings(): warnings.simplefilter("ignore") fig, ax = plt.subplots(figsize=(8, 8)) if state.masks is not None: # Prepare display image if state.current_image.ndim == 2: display_img = state.current_image else: from skimage.color import rgb2gray display_img = (rgb2gray(state.current_image) * 255).astype(np.uint8) # Create overlay overlay = color.label2rgb(state.masks, display_img, bg_label=0, alpha=0.4) ax.imshow(overlay) # Add outlines outlines = segmentation.find_boundaries(state.masks, mode='outer') outline_img = np.zeros((*state.masks.shape, 4)) outline_img[outlines] = [1, 0, 0, 1] ax.imshow(outline_img) # Show cell numbers if show_numbers and state.cell_properties: for prop in state.cell_properties: cy, cx = prop.centroid ax.text(cx, cy, str(prop.label), color='yellow', fontsize=8, fontweight='bold', ha='center', va='center', bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.5, edgecolor='yellow', linewidth=1)) # Highlight selected cell if highlight_cell is not None: cell_mask = state.masks == highlight_cell cell_outline = segmentation.find_boundaries(cell_mask, mode='outer') highlight_img = np.zeros((*state.masks.shape, 4)) highlight_img[cell_outline] = [1, 1, 0, 1] ax.imshow(highlight_img) for prop in state.cell_properties: if prop.label == highlight_cell: minr, minc, maxr, maxc = prop.bbox rect = Rectangle((minc, minr), maxc-minc, maxr-minr, fill=False, edgecolor='yellow', linewidth=2) ax.add_patch(rect) break ax.set_title(f'Segmentation Overlay ({state.masks.max()} cells)') else: # Show original if state.current_image.ndim == 2: ax.imshow(state.current_image, cmap='gray') else: ax.imshow(state.current_image) ax.set_title('Original Image') ax.axis('off') plt.tight_layout() return fig except Exception as e: print(f"Visualization error: {e}") return None def toggle_view(view_type, show_numbers): """Toggle between original and overlay view.""" if view_type == "Original" and state.masks is not None: # Show original without overlay fig, ax = plt.subplots(figsize=(8, 8)) if state.current_image.ndim == 2: ax.imshow(state.current_image, cmap='gray') else: ax.imshow(state.current_image) ax.set_title('Original Image') ax.axis('off') plt.tight_layout() return fig else: return create_visualization(show_numbers=show_numbers, highlight_cell=state.selected_cell) def toggle_cell_numbers(show_numbers): """Toggle cell number display.""" if state.masks is None: return None fig = create_visualization(show_numbers=show_numbers, highlight_cell=state.selected_cell) return fig def select_cell(cell_choice): """Handle cell selection from dropdown.""" if not cell_choice or not state.cell_properties: return None, "" try: # Extract cell ID from choice string "Cell X | Area: Ypx²" cell_id = int(cell_choice.split('|')[0].replace('Cell', '').strip()) state.selected_cell = cell_id # Find cell properties for prop in state.cell_properties: if prop.label == cell_id: details = f"Cell {cell_id}\n" details += "=" * 25 + "\n" details += f"Area: {prop.area}px²\n" details += f"Centroid: ({prop.centroid[1]:.0f}, {prop.centroid[0]:.0f})\n" details += f"Eccentricity: {prop.eccentricity:.3f}\n" details += f"Solidity: {prop.solidity:.3f}\n" details += f"Intensity: {prop.mean_intensity:.1f}\n" # Add categories if available categories = get_image_categories() if categories: details += "\n" + categories # Update visualization fig = create_visualization(show_numbers=False, highlight_cell=cell_id) return fig, details return None, "Cell not found" except Exception as e: return None, f"Error: {str(e)}" def run_segmentation(method, cp_model, diameter, yolo_model, confidence, iou, use_gpu): """Run segmentation based on selected method.""" if method == "Cellpose": return run_cellpose_segmentation(cp_model, diameter, use_gpu) else: return run_yolo_segmentation(yolo_model, confidence, iou, use_gpu) def save_results(): """Save segmentation results.""" if state.masks is None: return None, "No results to save" try: import tempfile temp_dir = tempfile.mkdtemp() base_name = Path(state.selected_image).stem if state.selected_image else "segmentation" # Save mask mask_path = os.path.join(temp_dir, f"{base_name}_masks.npy") np.save(mask_path, state.masks) # Save CSV csv_path = os.path.join(temp_dir, f"{base_name}_measurements.csv") with open(csv_path, 'w') as f: f.write("ID,Area,Centroid_X,Centroid_Y,Eccentricity,Solidity,Mean_Intensity\n") for prop in state.cell_properties: f.write(f"{prop.label},{prop.area},{prop.centroid[1]:.1f}," f"{prop.centroid[0]:.1f},{prop.eccentricity:.3f}," f"{prop.solidity:.3f},{prop.mean_intensity:.1f}\n") return [mask_path, csv_path], "Results saved" except Exception as e: return None, f"Error: {str(e)}" # Initialize: Load CSV and scan folder load_csv_data() initial_gallery = scan_folder() # Create Gradio interface with gr.Blocks(title="Cell Segmentation Tool", theme=gr.themes.Soft()) as demo: gr.Markdown("# Cell Segmentation Application") with gr.Row(): # LEFT COLUMN - Image Gallery with gr.Column(scale=1): gr.Markdown("### Image Gallery") image_gallery = gr.Gallery( value=initial_gallery, label=f"{len(state.image_files)} images" if state.image_files else "No images", show_label=True, elem_id="gallery", columns=1, rows=None, height=600, object_fit="contain" ) status_text = gr.Textbox(label="Status", interactive=False) # CENTER COLUMN - Image View with gr.Column(scale=2): gr.Markdown("### Image View") with gr.Row(): view_mode = gr.Radio( ["Original", "Overlay"], value="Overlay", label="View Mode" ) show_numbers = gr.Checkbox(label="Show Cell Numbers", value=False) image_display = gr.Plot(label="") # RIGHT COLUMN - Controls & Results with gr.Column(scale=1): gr.Markdown("### Segmentation Settings") method = gr.Radio( ["Cellpose", "YOLO"], label="Method", value="Cellpose" ) # Cellpose controls with gr.Group(visible=True) as cellpose_group: cp_model = gr.Dropdown( ["nuclei", "cyto", "cyto2", "cyto3"], label="Cellpose Model", value="nuclei" ) diameter = gr.Textbox(label="Diameter", value="auto") # YOLO controls with gr.Group(visible=False) as yolo_group: yolo_model = gr.Textbox(label="YOLO Model", value="yolov8n-seg.pt") confidence = gr.Slider(0, 1, value=0.25, label="Confidence") iou = gr.Slider(0, 1, value=0.45, label="IoU") use_gpu = gr.Checkbox(label="Use GPU", value=False) run_button = gr.Button("Run Segmentation", variant="primary", size="lg") gr.Markdown("### Detected Cells") cell_dropdown = gr.Dropdown( label="Select Cell", choices=[], interactive=True ) gr.Markdown("### Cell Details") cell_details = gr.Textbox( label="", lines=12, interactive=False ) save_button = gr.Button("Save Results", variant="secondary") output_files = gr.File(label="Download", file_count="multiple") # Event handlers def toggle_method(method_choice): return ( gr.update(visible=method_choice == "Cellpose"), gr.update(visible=method_choice == "YOLO") ) method.change(toggle_method, inputs=[method], outputs=[cellpose_group, yolo_group]) image_gallery.select( select_image_from_gallery, outputs=[image_display, status_text, cell_details, cell_dropdown] ) view_mode.change( toggle_view, inputs=[view_mode, show_numbers], outputs=[image_display] ) show_numbers.change( toggle_cell_numbers, inputs=[show_numbers], outputs=[image_display] ) run_button.click( run_segmentation, inputs=[method, cp_model, diameter, yolo_model, confidence, iou, use_gpu], outputs=[image_display, status_text, cell_dropdown] ) cell_dropdown.change( select_cell, inputs=[cell_dropdown], outputs=[image_display, cell_details] ) save_button.click( save_results, outputs=[output_files, status_text] ) if __name__ == "__main__": demo.launch(share=False)