import gradio as gr import spaces from cellpose import models import numpy as np import cv2 import matplotlib.pyplot as plt import tempfile from PIL import Image import io from huggingface_hub import hf_hub_download HF_REPO_ID = "myang4218/cellposemodel" MODEL_OPTIONS = { "Hemocytometer Model": "hemocytometermodel.npy", "General Model": "generalmodel.npy" } loaded_models = {} def extract_region_from_editor(editor_data): """Extract the selected region from ImageEditor data""" if editor_data is None: return None, None if isinstance(editor_data, dict): background = editor_data.get('background') layers = editor_data.get('layers', []) if background is None: return None, None background_np = np.array(background) if layers and len(layers) > 0: selection_layer = layers[0] selection_np = np.array(selection_layer) if len(selection_np.shape) == 3: if selection_np.shape[2] == 4: # RGBA mask = selection_np[:, :, 3] > 0 else: # RGB mask = np.any(selection_np > 0, axis=2) else: mask = selection_np > 0 coords = np.where(mask) if len(coords[0]) > 0: y_min, y_max = coords[0].min(), coords[0].max() x_min, x_max = coords[1].min(), coords[1].max() pad = 5 h, w = background_np.shape[:2] y_min = max(0, y_min - pad) y_max = min(h, y_max + pad) x_min = max(0, x_min - pad) x_max = min(w, x_max + pad) region = background_np[y_min:y_max+1, x_min:x_max+1] return region, (x_min, y_min, x_max, y_max) return background_np, None else: if hasattr(editor_data, 'size'): image_np = np.array(editor_data) return image_np, None else: return None, None def classify_cells_by_blueness(image_np, masks, blue_threshold): """ Classify cells as dead (blue) or alive based on single blueness metric Args: image_np: RGB image array masks: Cellpose segmentation masks blue_threshold: Single threshold value (0-100) for blueness detection Returns: dead_count, alive_count, colored_overlay """ # Ensure image_np is RGB for consistency with HSV conversion if len(image_np.shape) == 2: image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) elif len(image_np.shape) == 3 and image_np.shape[2] == 4: image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) # Convert RGB to HSV hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV) # Calculate blueness index for each pixel hue = hsv[:, :, 0].astype(np.float32) saturation = hsv[:, :, 1].astype(np.float32) # Hue score: peaks around 115 (blue in HSV), drops off towards edges # Handle hue wrap-around for blue detection (100-130 range) hue_distance = np.minimum(np.abs(hue - 115), 180 - np.abs(hue - 115)) hue_score = np.maximum(0, 1 - hue_distance / 65) # 65 gives good blue range # Combine hue proximity with saturation intensity blueness = hue_score * (saturation / 255.0) # Convert threshold from 0-100 to 0-1 scale threshold = blue_threshold / 100.0 # Get unique cell IDs (excluding background) cell_ids = np.unique(masks) cell_ids = cell_ids[cell_ids > 0] # Remove background (0) dead_cells = [] alive_cells = [] # Classify each cell for cell_id in cell_ids: # Get mask for this specific cell cell_mask = (masks == cell_id) # Calculate average blueness for this cell cell_blueness = np.mean(blueness[cell_mask]) # Classify based on threshold if cell_blueness > threshold: dead_cells.append(cell_id) else: alive_cells.append(cell_id) # Create colored overlay overlay = image_np.copy().astype(np.float32) # Ensure float for blending # Color dead cells red, alive cells green for cell_id in dead_cells: cell_mask = (masks == cell_id) overlay[cell_mask] = [255, 0, 0] # Red for dead for cell_id in alive_cells: cell_mask = (masks == cell_id) overlay[cell_mask] = [0, 255, 0] # Green for alive # Blend with original image alpha = 0.4 final_overlay = (1 - alpha) * image_np.astype(np.float32) + alpha * overlay final_overlay = np.clip(final_overlay, 0, 255).astype(np.uint8) return len(dead_cells), len(alive_cells), final_overlay def measure_confluency(masks, image_np): """Calculate the percentage of image area covered by cells""" tot_pixels = image_np.shape[0] * image_np.shape[1] cell_pixels = np.count_nonzero(masks) confluency = cell_pixels / tot_pixels * 100 return confluency def filter_mask_by_size(masks, minimum_pixels): filtered_masks = masks.copy() cell_ids = np.unique(masks) cell_ids = cell_ids[cell_ids > 0] #subtract background removed_count = 0 for cell_id in cell_ids: cell_mask = (masks == cell_id) cell_pixels = np.count_nonzero(cell_mask) if cell_pixels < minimum_pixels: filtered_masks[cell_mask] = 0 removed_count += 1 unique_ids = np.unique(filtered_masks) unique_ids = unique_ids[unique_ids > 0] renumbered_masks = np.zeros_like(filtered_masks) for new_id, old_id in enumerate(unique_ids, start=1): renumbered_masks[filtered_masks == old_id] = new_id return renumbered_masks, removed_count def filter_mask_by_maxsize(masks, maximum_pixels): filtered_masks = masks.copy() cell_ids = np.unique(masks) cell_ids = cell_ids[cell_ids > 0] removed_count = 0 for cell_id in cell_ids: cell_mask = (masks == cell_id) cell_pixels = np.count_nonzero(cell_mask) if cell_pixels > maximum_pixels: filtered_masks[cell_mask] = 0 removed_count += 1 unique_ids = np.unique(filtered_masks) unique_ids = unique_ids[unique_ids > 0] renumbered_masks = np.zeros_like(filtered_masks) for new_id, old_id in enumerate(unique_ids, start=1): renumbered_masks[filtered_masks == old_id] = new_id return renumbered_masks, removed_count def rec_min_size(masks, q=25): ids = np.unique(masks) ids = ids[ids > 0] if len(ids) == 0: return 0 sizes = np.array([np.count_nonzero(masks == cid) for cid in ids]) return int(round(np.percentile(sizes, q))) @spaces.GPU def run_segmentation_editor(editor_data, model_choice, min_cell_size, max_cell_size): """ Runs cell segmentation using ImageEditor data. Returns initial segmentation overlay, counts, confluency, and also masks/image for state. """ try: model_filename = MODEL_OPTIONS[model_choice] model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=model_filename) if model_filename in loaded_models: model = loaded_models[model_filename] else: model = models.CellposeModel(gpu=True, pretrained_model=model_path) loaded_models[model_filename] = model region_np, region_coords = extract_region_from_editor(editor_data) if region_np is None: return 0, None, f"No image provided.", gr.update(visible=False), None, None, 0.0 # Resize large images to prevent crashes max_size = 1024 # Don't fuck with this if region_np.shape[0] > max_size or region_np.shape[1] > max_size: h, w = region_np.shape[:2] if h > w: new_h, new_w = max_size, int(w * max_size / h) else: new_h, new_w = int(h * max_size / w), max_size region_np = cv2.resize(region_np, (new_w, new_h), interpolation=cv2.INTER_AREA) # Process image format to RGB if len(region_np.shape) == 2: processed_image_np = cv2.cvtColor(region_np, cv2.COLOR_GRAY2RGB) elif len(region_np.shape) == 3 and region_np.shape[2] == 4: processed_image_np = cv2.cvtColor(region_np, cv2.COLOR_RGBA2RGB) else: processed_image_np = region_np # Run Cellpose segmentation masks_raw, flows, styles = model.eval(processed_image_np, diameter=None, channels=[0, 0]) ids = np.unique(masks_raw) ids = ids[ids > 0] sizes = np.array([np.count_nonzero(masks_raw == cid) for cid in ids]) # Compute recommendation from RAW masks recommend_min = rec_min_size(masks_raw) # If user sets slider to 0, use the recommendation min_used = recommend_min if (min_cell_size == 0) else int(min_cell_size) # Apply filters masks = masks_raw.copy() removed_small = 0 removed_large = 0 if min_used > 0: masks, removed_small = filter_mask_by_size(masks, min_used) if max_cell_size > 0: masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size)) filter_msg = "" if removed_small: filter_msg += f"Removed {removed_small} small objects (< {min_used} pixels).\n" if removed_large: filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n" cell_count = len(np.unique(masks)) - 1 confluency = measure_confluency(masks, processed_image_np) # Create a basic segmentation overlay (without viability) segmentation_overlay = processed_image_np.copy().astype(np.float32) if masks.max() > 0: np.random.seed(42) # For consistent random colors colors = np.random.randint(0, 255, size=(masks.max() + 1, 3)) colors[0] = [0, 0, 0] # Background color colored_mask = colors[masks] alpha = 0.4 segmentation_overlay = (1 - alpha) * segmentation_overlay + alpha * colored_mask segmentation_overlay = np.clip(segmentation_overlay, 0, 255).astype(np.uint8) info_msg = f"Segmentation complete! Found {cell_count} cells.\n" info_msg += f"Confluency: {confluency:.1f}%\n" if region_coords: info_msg += f"Processed region: {region_coords[0]},{region_coords[1]} to {region_coords[2]},{region_coords[3]}\n" info_msg += f"Now adjust the Blue Threshold for viability assessment." # Return initial segmentation display and state variables return cell_count, Image.fromarray(segmentation_overlay), info_msg, gr.update(visible=True), masks, processed_image_np, confluency except Exception as e: return 0, None, f"Error during segmentation: {str(e)}", gr.update(visible=False), None, None, 0.0 def update_viability_realtime(blue_threshold, stored_masks, stored_image_np): """ Updates viability assessment in real-time based on blue threshold. Takes stored masks and image_np from state. """ if stored_masks is None or stored_image_np is None: return None, 0, 0, 0.0, "Please run segmentation first." try: dead_count, alive_count, viability_overlay_np = classify_cells_by_blueness( stored_image_np, stored_masks, blue_threshold ) total_count = alive_count + dead_count viability_percent = (alive_count / total_count * 100) if total_count > 0 else 0.0 confluency = measure_confluency(stored_masks, stored_image_np) overlay_image = Image.fromarray(viability_overlay_np) info_msg = f"Total cells: {total_count}\nLive (green): {alive_count}\nDead (red): {dead_count}\n" info_msg += f"Viability: {viability_percent:.1f}%\nConfluency: {confluency:.1f}%\nBlue threshold: {blue_threshold}%" return overlay_image, alive_count, dead_count, viability_percent, info_msg except Exception as e: return None, 0, 0, 0.0, f"Error updating viability: {str(e)}" # Create the Gradio interface with gr.Blocks( title="CellposeCellCounter", theme=gr.themes.Soft(), ) as demo: gr.Markdown("# CellposeCellCounter") gr.Markdown("For accurate cell confluency, crop the image to display only desired area.") # Define State components to store masks and image data across function calls masks_state = gr.State(value=None) image_state = gr.State(value=None) with gr.Tab("Image Editor (Draw Selection)"): gr.Markdown("### Draw selection and run segmentation") with gr.Row(): with gr.Column(): image_editor = gr.ImageEditor( label="Draw selection on image", type="pil", brush=gr.Brush(colors=["#ff0000"], color_mode="fixed", default_size=20), eraser=gr.Eraser(default_size=20) ) model_dropdown1 = gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Hemocytometer Model" ) min_size_slider1 = gr.Slider( minimum=0, maximum=500, value=50, step=10, label="Minimum Cell Size (pixels)", ) max_size_slider1 = gr.Slider( minimum=0, maximum=10000, value=10000, step=10, label="Maximum Cell Size (pixels)", ) segment_btn1 = gr.Button("🔬 Run Segmentation", variant="primary", size="lg") with gr.Column(): cell_count_output1 = gr.Number(label="Total Cells Detected", precision=0) confluency_output1 = gr.Number(label="Confluency (%)", precision=1) overlay_output1 = gr.Image(type="pil", label="Segmentation Result") info_output1 = gr.Textbox(label="Processing Info", lines=4) # Viability Assessment Section with gr.Group(visible=False) as viability_section1: gr.Markdown("### Viability Assessment (Trypan Blue)") gr.Markdown("Adjust the threshold to classify cells as live (green) or dead (red).") with gr.Row(): with gr.Column(): blue_threshold1 = gr.Slider( minimum=0, maximum=100, value=25, step=1, label="Blue Threshold (%)", info="Higher values = more selective for blue cells" ) with gr.Column(): live_count_output1 = gr.Number(label="Live Cells (Green)", precision=0) dead_count_output1 = gr.Number(label="Dead Cells (Red)", precision=0) viability_overlay1 = gr.Image(type="pil", label="Viability Assessment (Green=Live, Red=Dead)") viability_percent_output1 = gr.Number(label="Viability (%)", precision=1) viability_info1 = gr.Textbox(label="Analysis Results", lines=5) # Event handlers # segment_cells now returns masks and image_np which are stored in masks_state and image_state segment_btn1.click( fn=run_segmentation_editor, inputs=[image_editor, model_dropdown1,min_size_slider1, max_size_slider1], outputs=[cell_count_output1, overlay_output1, info_output1, viability_section1, masks_state, image_state, confluency_output1, min_size_slider1] ).then( # Chain the initial viability assessment after segmentation fn=update_viability_realtime, inputs=[blue_threshold1, masks_state, image_state], # Pass stored state as inputs outputs=[viability_overlay1, live_count_output1, dead_count_output1, viability_percent_output1, viability_info1] ) # Slider changes update viability in real-time blue_threshold1.change( fn=update_viability_realtime, inputs=[blue_threshold1, masks_state, image_state], outputs=[viability_overlay1, live_count_output1, dead_count_output1, viability_percent_output1, viability_info1] ) # Instructions with gr.Accordion("Instructions", open=False): gr.Markdown(""" ### How to use: 1. **Upload and Segment**: - Upload your microscopy image. - Select a Cellpose model (e.g., "Hemocytometer Model" for blood cells). - Draw a selection region using the Image Editor, or specify coordinates manually. - Click "Run Segmentation". 2. **Analysis Results**: - **Cell Count**: Total number of detected cells - **Confluency**: Percentage of image area covered by cells (useful for assessing cell density). Note that cell confluency is calculated per the entire area of the image input. 3. **Real-time Viability Assessment (Trypan Blue)**: - After segmentation, the viability section will become visible. - This tool is specifically designed for **Trypan Blue stained images**, where dead cells appear blue. - Adjust the **"Blue Threshold (%)"** slider in real-time. As you change it, the green (live) and red (dead) classification on the overlay will update. - **Lower values (e.g., 10-20%)** are more sensitive and will classify more cells as blue/dead. - **Higher values (e.g., 30-50%)** are more selective and will only classify strongly blue cells as dead. - Green cells = Live, Red cells = Dead. 4. **Interpreting Results**: - The app calculates and displays the total, live, and dead cell counts, along with the viability percentage and confluency. - **Confluency** helps assess how densely packed your cells are, which is important for cell culture monitoring. """) if __name__ == "__main__": demo.launch()