Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| from cellpose import models | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import tempfile | |
| from PIL import Image | |
| import io | |
| from huggingface_hub import hf_hub_download | |
| HF_REPO_ID = "myang4218/cellposemodel" | |
| MODEL_OPTIONS = { | |
| "Hemocytometer Model": "hemocytometermodel.npy", | |
| "General Model": "generalmodel.npy" | |
| } | |
| loaded_models = {} | |
| def extract_region_from_editor(editor_data): | |
| """Extract the selected region from ImageEditor data""" | |
| if editor_data is None: | |
| return None, None | |
| if isinstance(editor_data, dict): | |
| background = editor_data.get('background') | |
| layers = editor_data.get('layers', []) | |
| if background is None: | |
| return None, None | |
| background_np = np.array(background) | |
| if layers and len(layers) > 0: | |
| selection_layer = layers[0] | |
| selection_np = np.array(selection_layer) | |
| if len(selection_np.shape) == 3: | |
| if selection_np.shape[2] == 4: # RGBA | |
| mask = selection_np[:, :, 3] > 0 | |
| else: # RGB | |
| mask = np.any(selection_np > 0, axis=2) | |
| else: | |
| mask = selection_np > 0 | |
| coords = np.where(mask) | |
| if len(coords[0]) > 0: | |
| y_min, y_max = coords[0].min(), coords[0].max() | |
| x_min, x_max = coords[1].min(), coords[1].max() | |
| pad = 5 | |
| h, w = background_np.shape[:2] | |
| y_min = max(0, y_min - pad) | |
| y_max = min(h, y_max + pad) | |
| x_min = max(0, x_min - pad) | |
| x_max = min(w, x_max + pad) | |
| region = background_np[y_min:y_max+1, x_min:x_max+1] | |
| return region, (x_min, y_min, x_max, y_max) | |
| return background_np, None | |
| else: | |
| if hasattr(editor_data, 'size'): | |
| image_np = np.array(editor_data) | |
| return image_np, None | |
| else: | |
| return None, None | |
| def classify_cells_by_blueness(image_np, masks, blue_threshold): | |
| """ | |
| Classify cells as dead (blue) or alive based on single blueness metric | |
| Args: | |
| image_np: RGB image array | |
| masks: Cellpose segmentation masks | |
| blue_threshold: Single threshold value (0-100) for blueness detection | |
| Returns: | |
| dead_count, alive_count, colored_overlay | |
| """ | |
| # Ensure image_np is RGB for consistency with HSV conversion | |
| if len(image_np.shape) == 2: | |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) | |
| elif len(image_np.shape) == 3 and image_np.shape[2] == 4: | |
| image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) | |
| # Convert RGB to HSV | |
| hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV) | |
| # Calculate blueness index for each pixel | |
| hue = hsv[:, :, 0].astype(np.float32) | |
| saturation = hsv[:, :, 1].astype(np.float32) | |
| # Hue score: peaks around 115 (blue in HSV), drops off towards edges | |
| # Handle hue wrap-around for blue detection (100-130 range) | |
| hue_distance = np.minimum(np.abs(hue - 115), 180 - np.abs(hue - 115)) | |
| hue_score = np.maximum(0, 1 - hue_distance / 65) # 65 gives good blue range | |
| # Combine hue proximity with saturation intensity | |
| blueness = hue_score * (saturation / 255.0) | |
| # Convert threshold from 0-100 to 0-1 scale | |
| threshold = blue_threshold / 100.0 | |
| # Get unique cell IDs (excluding background) | |
| cell_ids = np.unique(masks) | |
| cell_ids = cell_ids[cell_ids > 0] # Remove background (0) | |
| dead_cells = [] | |
| alive_cells = [] | |
| # Classify each cell | |
| for cell_id in cell_ids: | |
| # Get mask for this specific cell | |
| cell_mask = (masks == cell_id) | |
| # Calculate average blueness for this cell | |
| cell_blueness = np.mean(blueness[cell_mask]) | |
| # Classify based on threshold | |
| if cell_blueness > threshold: | |
| dead_cells.append(cell_id) | |
| else: | |
| alive_cells.append(cell_id) | |
| # Create colored overlay | |
| overlay = image_np.copy().astype(np.float32) # Ensure float for blending | |
| # Color dead cells red, alive cells green | |
| for cell_id in dead_cells: | |
| cell_mask = (masks == cell_id) | |
| overlay[cell_mask] = [255, 0, 0] # Red for dead | |
| for cell_id in alive_cells: | |
| cell_mask = (masks == cell_id) | |
| overlay[cell_mask] = [0, 255, 0] # Green for alive | |
| # Blend with original image | |
| alpha = 0.4 | |
| final_overlay = (1 - alpha) * image_np.astype(np.float32) + alpha * overlay | |
| final_overlay = np.clip(final_overlay, 0, 255).astype(np.uint8) | |
| return len(dead_cells), len(alive_cells), final_overlay | |
| def measure_confluency(masks, image_np): | |
| """Calculate the percentage of image area covered by cells""" | |
| tot_pixels = image_np.shape[0] * image_np.shape[1] | |
| cell_pixels = np.count_nonzero(masks) | |
| confluency = cell_pixels / tot_pixels * 100 | |
| return confluency | |
| def filter_mask_by_size(masks, minimum_pixels): | |
| filtered_masks = masks.copy() | |
| cell_ids = np.unique(masks) | |
| cell_ids = cell_ids[cell_ids > 0] #subtract background | |
| removed_count = 0 | |
| for cell_id in cell_ids: | |
| cell_mask = (masks == cell_id) | |
| cell_pixels = np.count_nonzero(cell_mask) | |
| if cell_pixels < minimum_pixels: | |
| filtered_masks[cell_mask] = 0 | |
| removed_count += 1 | |
| unique_ids = np.unique(filtered_masks) | |
| unique_ids = unique_ids[unique_ids > 0] | |
| renumbered_masks = np.zeros_like(filtered_masks) | |
| for new_id, old_id in enumerate(unique_ids, start=1): | |
| renumbered_masks[filtered_masks == old_id] = new_id | |
| return renumbered_masks, removed_count | |
| def filter_mask_by_maxsize(masks, maximum_pixels): | |
| filtered_masks = masks.copy() | |
| cell_ids = np.unique(masks) | |
| cell_ids = cell_ids[cell_ids > 0] | |
| removed_count = 0 | |
| for cell_id in cell_ids: | |
| cell_mask = (masks == cell_id) | |
| cell_pixels = np.count_nonzero(cell_mask) | |
| if cell_pixels > maximum_pixels: | |
| filtered_masks[cell_mask] = 0 | |
| removed_count += 1 | |
| unique_ids = np.unique(filtered_masks) | |
| unique_ids = unique_ids[unique_ids > 0] | |
| renumbered_masks = np.zeros_like(filtered_masks) | |
| for new_id, old_id in enumerate(unique_ids, start=1): | |
| renumbered_masks[filtered_masks == old_id] = new_id | |
| return renumbered_masks, removed_count | |
| def rec_min_size(masks, q=25): | |
| ids = np.unique(masks) | |
| ids = ids[ids > 0] | |
| if len(ids) == 0: | |
| return 0 | |
| sizes = np.array([np.count_nonzero(masks == cid) for cid in ids]) | |
| return int(round(np.percentile(sizes, q))) | |
| def run_segmentation_editor(editor_data, model_choice, min_cell_size, max_cell_size): | |
| """ | |
| Runs cell segmentation using ImageEditor data. | |
| Returns initial segmentation overlay, counts, confluency, and also masks/image for state. | |
| """ | |
| try: | |
| model_filename = MODEL_OPTIONS[model_choice] | |
| model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=model_filename) | |
| if model_filename in loaded_models: | |
| model = loaded_models[model_filename] | |
| else: | |
| model = models.CellposeModel(gpu=True, pretrained_model=model_path) | |
| loaded_models[model_filename] = model | |
| region_np, region_coords = extract_region_from_editor(editor_data) | |
| if region_np is None: | |
| return 0, None, f"No image provided.", gr.update(visible=False), None, None, 0.0 | |
| # Resize large images to prevent crashes | |
| max_size = 1024 # Don't fuck with this | |
| if region_np.shape[0] > max_size or region_np.shape[1] > max_size: | |
| h, w = region_np.shape[:2] | |
| if h > w: | |
| new_h, new_w = max_size, int(w * max_size / h) | |
| else: | |
| new_h, new_w = int(h * max_size / w), max_size | |
| region_np = cv2.resize(region_np, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| # Process image format to RGB | |
| if len(region_np.shape) == 2: | |
| processed_image_np = cv2.cvtColor(region_np, cv2.COLOR_GRAY2RGB) | |
| elif len(region_np.shape) == 3 and region_np.shape[2] == 4: | |
| processed_image_np = cv2.cvtColor(region_np, cv2.COLOR_RGBA2RGB) | |
| else: | |
| processed_image_np = region_np | |
| # Run Cellpose segmentation | |
| masks_raw, flows, styles = model.eval(processed_image_np, diameter=None, channels=[0, 0]) | |
| ids = np.unique(masks_raw) | |
| ids = ids[ids > 0] | |
| sizes = np.array([np.count_nonzero(masks_raw == cid) for cid in ids]) | |
| # Compute recommendation from RAW masks | |
| recommend_min = rec_min_size(masks_raw) | |
| # If user sets slider to 0, use the recommendation | |
| min_used = recommend_min if (min_cell_size == 0) else int(min_cell_size) | |
| # Apply filters | |
| masks = masks_raw.copy() | |
| removed_small = 0 | |
| removed_large = 0 | |
| if min_used > 0: | |
| masks, removed_small = filter_mask_by_size(masks, min_used) | |
| if max_cell_size > 0: | |
| masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size)) | |
| filter_msg = "" | |
| if removed_small: | |
| filter_msg += f"Removed {removed_small} small objects (< {min_used} pixels).\n" | |
| if removed_large: | |
| filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n" | |
| cell_count = len(np.unique(masks)) - 1 | |
| confluency = measure_confluency(masks, processed_image_np) | |
| # Create a basic segmentation overlay (without viability) | |
| segmentation_overlay = processed_image_np.copy().astype(np.float32) | |
| if masks.max() > 0: | |
| np.random.seed(42) # For consistent random colors | |
| colors = np.random.randint(0, 255, size=(masks.max() + 1, 3)) | |
| colors[0] = [0, 0, 0] # Background color | |
| colored_mask = colors[masks] | |
| alpha = 0.4 | |
| segmentation_overlay = (1 - alpha) * segmentation_overlay + alpha * colored_mask | |
| segmentation_overlay = np.clip(segmentation_overlay, 0, 255).astype(np.uint8) | |
| info_msg = f"Segmentation complete! Found {cell_count} cells.\n" | |
| info_msg += f"Confluency: {confluency:.1f}%\n" | |
| if region_coords: | |
| info_msg += f"Processed region: {region_coords[0]},{region_coords[1]} to {region_coords[2]},{region_coords[3]}\n" | |
| info_msg += f"Now adjust the Blue Threshold for viability assessment." | |
| # Return initial segmentation display and state variables | |
| return cell_count, Image.fromarray(segmentation_overlay), info_msg, gr.update(visible=True), masks, processed_image_np, confluency | |
| except Exception as e: | |
| return 0, None, f"Error during segmentation: {str(e)}", gr.update(visible=False), None, None, 0.0 | |
| def update_viability_realtime(blue_threshold, stored_masks, stored_image_np): | |
| """ | |
| Updates viability assessment in real-time based on blue threshold. | |
| Takes stored masks and image_np from state. | |
| """ | |
| if stored_masks is None or stored_image_np is None: | |
| return None, 0, 0, 0.0, "Please run segmentation first." | |
| try: | |
| dead_count, alive_count, viability_overlay_np = classify_cells_by_blueness( | |
| stored_image_np, stored_masks, blue_threshold | |
| ) | |
| total_count = alive_count + dead_count | |
| viability_percent = (alive_count / total_count * 100) if total_count > 0 else 0.0 | |
| confluency = measure_confluency(stored_masks, stored_image_np) | |
| overlay_image = Image.fromarray(viability_overlay_np) | |
| info_msg = f"Total cells: {total_count}\nLive (green): {alive_count}\nDead (red): {dead_count}\n" | |
| info_msg += f"Viability: {viability_percent:.1f}%\nConfluency: {confluency:.1f}%\nBlue threshold: {blue_threshold}%" | |
| return overlay_image, alive_count, dead_count, viability_percent, info_msg | |
| except Exception as e: | |
| return None, 0, 0, 0.0, f"Error updating viability: {str(e)}" | |
| # Create the Gradio interface | |
| with gr.Blocks( | |
| title="CellposeCellCounter", | |
| theme=gr.themes.Soft(), | |
| ) as demo: | |
| gr.Markdown("# CellposeCellCounter") | |
| gr.Markdown("For accurate cell confluency, crop the image to display only desired area.") | |
| # Define State components to store masks and image data across function calls | |
| masks_state = gr.State(value=None) | |
| image_state = gr.State(value=None) | |
| with gr.Tab("Image Editor (Draw Selection)"): | |
| gr.Markdown("### Draw selection and run segmentation") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_editor = gr.ImageEditor( | |
| label="Draw selection on image", | |
| type="pil", | |
| brush=gr.Brush(colors=["#ff0000"], color_mode="fixed", default_size=20), | |
| eraser=gr.Eraser(default_size=20) | |
| ) | |
| model_dropdown1 = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| label="Select Model", | |
| value="Hemocytometer Model" | |
| ) | |
| min_size_slider1 = gr.Slider( | |
| minimum=0, | |
| maximum=500, | |
| value=50, | |
| step=10, | |
| label="Minimum Cell Size (pixels)", | |
| ) | |
| max_size_slider1 = gr.Slider( | |
| minimum=0, | |
| maximum=10000, | |
| value=10000, | |
| step=10, | |
| label="Maximum Cell Size (pixels)", | |
| ) | |
| segment_btn1 = gr.Button("🔬 Run Segmentation", variant="primary", size="lg") | |
| with gr.Column(): | |
| cell_count_output1 = gr.Number(label="Total Cells Detected", precision=0) | |
| confluency_output1 = gr.Number(label="Confluency (%)", precision=1) | |
| overlay_output1 = gr.Image(type="pil", label="Segmentation Result") | |
| info_output1 = gr.Textbox(label="Processing Info", lines=4) | |
| # Viability Assessment Section | |
| with gr.Group(visible=False) as viability_section1: | |
| gr.Markdown("### Viability Assessment (Trypan Blue)") | |
| gr.Markdown("Adjust the threshold to classify cells as live (green) or dead (red).") | |
| with gr.Row(): | |
| with gr.Column(): | |
| blue_threshold1 = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=25, | |
| step=1, | |
| label="Blue Threshold (%)", | |
| info="Higher values = more selective for blue cells" | |
| ) | |
| with gr.Column(): | |
| live_count_output1 = gr.Number(label="Live Cells (Green)", precision=0) | |
| dead_count_output1 = gr.Number(label="Dead Cells (Red)", precision=0) | |
| viability_overlay1 = gr.Image(type="pil", label="Viability Assessment (Green=Live, Red=Dead)") | |
| viability_percent_output1 = gr.Number(label="Viability (%)", precision=1) | |
| viability_info1 = gr.Textbox(label="Analysis Results", lines=5) | |
| # Event handlers | |
| # segment_cells now returns masks and image_np which are stored in masks_state and image_state | |
| segment_btn1.click( | |
| fn=run_segmentation_editor, | |
| inputs=[image_editor, model_dropdown1,min_size_slider1, max_size_slider1], | |
| outputs=[cell_count_output1, overlay_output1, info_output1, viability_section1, masks_state, image_state, confluency_output1, min_size_slider1] | |
| ).then( # Chain the initial viability assessment after segmentation | |
| fn=update_viability_realtime, | |
| inputs=[blue_threshold1, masks_state, image_state], # Pass stored state as inputs | |
| outputs=[viability_overlay1, live_count_output1, dead_count_output1, viability_percent_output1, viability_info1] | |
| ) | |
| # Slider changes update viability in real-time | |
| blue_threshold1.change( | |
| fn=update_viability_realtime, | |
| inputs=[blue_threshold1, masks_state, image_state], | |
| outputs=[viability_overlay1, live_count_output1, dead_count_output1, viability_percent_output1, viability_info1] | |
| ) | |
| # Instructions | |
| with gr.Accordion("Instructions", open=False): | |
| gr.Markdown(""" | |
| ### How to use: | |
| 1. **Upload and Segment**: | |
| - Upload your microscopy image. | |
| - Select a Cellpose model (e.g., "Hemocytometer Model" for blood cells). | |
| - Draw a selection region using the Image Editor, or specify coordinates manually. | |
| - Click "Run Segmentation". | |
| 2. **Analysis Results**: | |
| - **Cell Count**: Total number of detected cells | |
| - **Confluency**: Percentage of image area covered by cells (useful for assessing cell density). Note that cell confluency is calculated per the entire area of the image input. | |
| 3. **Real-time Viability Assessment (Trypan Blue)**: | |
| - After segmentation, the viability section will become visible. | |
| - This tool is specifically designed for **Trypan Blue stained images**, where dead cells appear blue. | |
| - Adjust the **"Blue Threshold (%)"** slider in real-time. As you change it, the green (live) and red (dead) classification on the overlay will update. | |
| - **Lower values (e.g., 10-20%)** are more sensitive and will classify more cells as blue/dead. | |
| - **Higher values (e.g., 30-50%)** are more selective and will only classify strongly blue cells as dead. | |
| - Green cells = Live, Red cells = Dead. | |
| 4. **Interpreting Results**: | |
| - The app calculates and displays the total, live, and dead cell counts, along with the viability percentage and confluency. | |
| - **Confluency** helps assess how densely packed your cells are, which is important for cell culture monitoring. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |