testing / app.py
LiangLabUMB's picture
Update app.py
17fdd27 verified
Raw
History Blame Contribute Delete
23.7 kB
import gradio as gr
import spaces
from cellpose import models
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tempfile
from PIL import Image, ImageDraw
import io
from huggingface_hub import hf_hub_download
import base64
HF_REPO_ID = "myang4218/cellposemodel"
MODEL_OPTIONS = {
"Hemocytometer Model": "hemocytometermodel.npy",
"General Model": "generalmodel.npy"
}
loaded_models = {}
# ---- mobile-safe size limits (aggressive for Safari) ----
MAX_SIDE = 1024
MAX_PIXELS = 1024 * 1024
def safe_resize(image_np):
"""
Downscale image to fit within MAX_SIDE and MAX_PIXELS while
preserving aspect ratio. Works for RGB / RGBA / grayscale.
"""
h, w = image_np.shape[:2]
total = h * w
if max(h, w) <= MAX_SIDE and total <= MAX_PIXELS:
return image_np
# compute scale
scale_side = MAX_SIDE / max(h, w)
scale_pixels = (MAX_PIXELS / total) ** 0.5
scale = min(scale_side, scale_pixels)
new_w = max(1, int(w * scale))
new_h = max(1, int(h * scale))
return cv2.resize(image_np, (new_w, new_h), interpolation=cv2.INTER_AREA)
def draw_exclusion_overlay(image_np, left_width_pct, top_width_pct):
h, w = image_np.shape[:2]
# Convert to PIL for drawing
img_pil = Image.fromarray(image_np)
draw = ImageDraw.Draw(img_pil, 'RGBA')
# Calculate pixel widths from percentages
left_px = int(w * left_width_pct / 100)
top_px = int(h * top_width_pct / 100)
# Draw overlays for exclusion zones
if left_px > 0:
# Left exclusion zone
draw.rectangle(
[(0, 0), (left_px, h)],
fill=(255, 0, 0, 80) # Semi-transparent red
)
# border line
draw.line([(left_px, 0), (left_px, h)], fill=(255, 0, 0, 255), width=3)
if top_px > 0:
# Top exclusion zone
draw.rectangle(
[(0, 0), (w, top_px)],
fill=(255, 0, 0, 80) # Semi-transparent red
)
# border line
draw.line([(0, top_px), (w, top_px)], fill=(255, 0, 0, 255), width=3)
return np.array(img_pil)
def apply_stereological_exclusion(masks, left_width_pct, top_width_pct):
h, w = masks.shape
# Calculate pixel widths from percentages
left_px = int(w * left_width_pct / 100)
top_px = int(h * top_width_pct / 100)
filtered_masks = masks.copy()
cell_ids = np.unique(masks)
cell_ids = cell_ids[cell_ids > 0]
excluded_cells = []
included_cells = []
for cell_id in cell_ids:
cell_mask = (masks == cell_id)
# Get cell boundary coordinates
rows, cols = np.where(cell_mask)
# Check if cell touches left exclusion zone
touches_left = np.any(cols < left_px) if left_px > 0 else False
# Check if cell touches top exclusion zone
touches_top = np.any(rows < top_px) if top_px > 0 else False
# Exclude if touching left or top
if touches_left or touches_top:
filtered_masks[cell_mask] = 0
excluded_cells.append(cell_id)
else:
included_cells.append(cell_id)
# Renumber remaining cells
unique_ids = np.unique(filtered_masks)
unique_ids = unique_ids[unique_ids > 0]
renumbered_masks = np.zeros_like(filtered_masks)
for new_id, old_id in enumerate(unique_ids, start=1):
renumbered_masks[filtered_masks == old_id] = new_id
return renumbered_masks, len(excluded_cells), len(included_cells)
def classify_cells_by_blueness(image_np, masks, blue_threshold):
"""
Classify cells as dead (blue) or alive based on single blueness metric
Args:
image_np: RGB image array
masks: Cellpose segmentation masks
blue_threshold: Single threshold value (0-100) for blueness detection
Returns:
dead_count, alive_count, colored_overlay
"""
if len(image_np.shape) == 2:
image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
elif len(image_np.shape) == 3 and image_np.shape[2] == 4:
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV)
# Calculate blueness index for each pixel
hue = hsv[:, :, 0].astype(np.float32)
saturation = hsv[:, :, 1].astype(np.float32)
# Hue score: peaks around 115 (blue in HSV), drops off towards edges
# Handle hue wrap-around for blue detection (100-130 range)
hue_distance = np.minimum(np.abs(hue - 115), 180 - np.abs(hue - 115))
hue_score = np.maximum(0, 1 - hue_distance / 65) # 65 gives good blue range
# Combine hue proximity with saturation intensity
blueness = hue_score * (saturation / 255.0)
# Convert threshold from 0-100 to 0-1 scale
threshold = blue_threshold / 100.0
# Get unique cell IDs
cell_ids = np.unique(masks)
cell_ids = cell_ids[cell_ids > 0]
dead_cells = []
alive_cells = []
# Classify each cell
for cell_id in cell_ids:
cell_mask = (masks == cell_id)
cell_blueness = np.mean(blueness[cell_mask])
if cell_blueness > threshold:
dead_cells.append(cell_id)
else:
alive_cells.append(cell_id)
# Create colored overlay
overlay = image_np.copy().astype(np.float32) # Ensure float for blending
# Color dead cells red, alive cells green
for cell_id in dead_cells:
cell_mask = (masks == cell_id)
overlay[cell_mask] = [255, 0, 0]
for cell_id in alive_cells:
cell_mask = (masks == cell_id)
overlay[cell_mask] = [0, 255, 0]
# Blend with original image
alpha = 0.4
final_overlay = (1 - alpha) * image_np.astype(np.float32) + alpha * overlay
final_overlay = np.clip(final_overlay, 0, 255).astype(np.uint8)
return len(dead_cells), len(alive_cells), final_overlay
def measure_confluency(masks, image_np):
tot_pixels = image_np.shape[0] * image_np.shape[1]
cell_pixels = np.count_nonzero(masks)
confluency = cell_pixels / tot_pixels * 100
return confluency
def filter_mask_by_size(masks, minimum_pixels):
filtered_masks = masks.copy()
cell_ids = np.unique(masks)
cell_ids = cell_ids[cell_ids > 0]
removed_count = 0
for cell_id in cell_ids:
cell_mask = (masks == cell_id)
cell_pixels = np.count_nonzero(cell_mask)
if cell_pixels < minimum_pixels:
filtered_masks[cell_mask] = 0
removed_count += 1
unique_ids = np.unique(filtered_masks)
unique_ids = unique_ids[unique_ids > 0]
renumbered_masks = np.zeros_like(filtered_masks)
for new_id, old_id in enumerate(unique_ids, start=1):
renumbered_masks[filtered_masks == old_id] = new_id
return renumbered_masks, removed_count
def filter_mask_by_maxsize(masks, maximum_pixels):
filtered_masks = masks.copy()
cell_ids = np.unique(masks)
cell_ids = cell_ids[cell_ids > 0]
removed_count = 0
for cell_id in cell_ids:
cell_mask = (masks == cell_id)
cell_pixels = np.count_nonzero(cell_mask)
if cell_pixels > maximum_pixels:
filtered_masks[cell_mask] = 0
removed_count += 1
unique_ids = np.unique(filtered_masks)
unique_ids = unique_ids[unique_ids > 0]
renumbered_masks = np.zeros_like(filtered_masks)
for new_id, old_id in enumerate(unique_ids, start=1):
renumbered_masks[filtered_masks == old_id] = new_id
return renumbered_masks, removed_count
def rec_min_size(masks, q=25):
ids = np.unique(masks)
ids = ids[ids > 0]
if len(ids) == 0:
return 0
sizes = np.array([np.count_nonzero(masks == cid) for cid in ids])
return int(round(np.percentile(sizes, q)))
def toggle_stereological_mode(use_stereology):
"""Show/hide stereological controls based on checkbox"""
return gr.update(visible=use_stereology)
def update_exclusion_preview(image, left_width, top_width):
"""Update the preview image with exclusion zone overlay"""
if image is None:
return None
image_np = np.array(image)
overlay = draw_exclusion_overlay(image_np, left_width, top_width)
return Image.fromarray(overlay)
@spaces.GPU
def run_segmentation(image, model_choice, min_cell_size, max_cell_size,
use_stereology, left_exclusion, top_exclusion):
image_np = np.array(image)
image_np = safe_resize(image_np)
try:
model_filename = MODEL_OPTIONS[model_choice]
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=model_filename)
if model_filename in loaded_models:
model = loaded_models[model_filename]
else:
model = models.CellposeModel(gpu=True, pretrained_model=model_path)
loaded_models[model_filename] = model
# Process image format to RGB
if len(image_np.shape) == 2:
processed_image_np = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
elif len(image_np.shape) == 3 and image_np.shape[2] == 4:
processed_image_np = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
else:
processed_image_np = image_np
# Run Cellpose segmentation
masks_raw, flows, styles = model.eval(processed_image_np, diameter=None, channels=[0, 0])
ids = np.unique(masks_raw)
ids = ids[ids > 0]
sizes = np.array([np.count_nonzero(masks_raw == cid) for cid in ids])
print("num_cells:", len(ids))
print("mean:", sizes.mean() if len(sizes) > 0 else 0)
print("median:", np.median(sizes) if len(sizes) > 0 else 0)
print("p90:", np.percentile(sizes, 90) if len(sizes) > 0 else 0)
print("max:", sizes.max() if len(sizes) > 0 else 0)
# Compute recommendation from RAW masks
recommend_min = rec_min_size(masks_raw)
# If user sets slider to 0, use the recommendation
min_used = recommend_min if (min_cell_size == 0) else int(min_cell_size)
# Apply filters
masks = masks_raw.copy()
removed_small = 0
removed_large = 0
if min_used > 0:
masks, removed_small = filter_mask_by_size(masks, min_used)
if max_cell_size > 0:
masks, removed_large = filter_mask_by_maxsize(masks, int(max_cell_size))
# Apply stereological exclusion if enabled
excluded_count = 0
if use_stereology:
masks, excluded_count, included_count = apply_stereological_exclusion(
masks, left_exclusion, top_exclusion
)
filter_msg = ""
if removed_small:
filter_msg += f"Removed {removed_small} small objects (< {min_used} pixels).\n"
if removed_large:
filter_msg += f"Removed {removed_large} large objects (> {int(max_cell_size)} pixels).\n"
if use_stereology and excluded_count > 0:
filter_msg += f"Stereological exclusion: {excluded_count} cells excluded (touching left/top zones).\n"
cell_count = len(np.unique(masks)) - 1
confluency = measure_confluency(masks, processed_image_np)
# Create a basic segmentation overlay (without viability)
segmentation_overlay = processed_image_np.copy().astype(np.float32)
if masks.max() > 0:
np.random.seed(42) # For consistent random colors
colors = np.random.randint(0, 255, size=(masks.max() + 1, 3))
colors[0] = [0, 0, 0]
colored_mask = colors[masks]
alpha = 0.4
segmentation_overlay = (1 - alpha) * segmentation_overlay + alpha * colored_mask
segmentation_overlay = np.clip(segmentation_overlay, 0, 255).astype(np.uint8)
# Add exclusion zone overlay if stereology is enabled
if use_stereology:
segmentation_overlay = draw_exclusion_overlay(segmentation_overlay, left_exclusion, top_exclusion)
info_msg = ""
if filter_msg:
info_msg += filter_msg
info_msg += f"Segmentation complete! Found {cell_count} cells.\n"
info_msg += f"Confluency: {confluency:.1f}%\n"
if use_stereology:
info_msg += f"Stereological counting enabled (Left: {left_exclusion}%, Top: {top_exclusion}%)\n"
info_msg += "Now adjust the Blue Threshold for viability assessment."
return (
cell_count,
Image.fromarray(segmentation_overlay),
info_msg,
gr.update(visible=True),
pack_array(masks),
pack_array(processed_image_np),
confluency,
gr.update(value=recommend_min), # update slider display to recommended
)
except Exception as e:
import traceback
traceback.print_exc()
return (
0,
None,
f"Error during segmentation: {str(e)}",
gr.update(visible=False),
None,
None,
0.0,
gr.update(),
)
def update_viability_realtime(blue_threshold, stored_masks, stored_image_np):
# avoid unpacking None (e.g. slider moved before segmentation)
if stored_masks is None or stored_image_np is None:
return None, 0, 0, 0.0, "Please run segmentation first."
stored_masks = unpack_array(stored_masks)
stored_image_np = unpack_array(stored_image_np)
try:
dead_count, alive_count, viability_overlay_np = classify_cells_by_blueness(
stored_image_np, stored_masks, blue_threshold
)
total_count = alive_count + dead_count
viability_percent = (alive_count / total_count * 100) if total_count > 0 else 0.0
confluency = measure_confluency(stored_masks, stored_image_np)
overlay_image = Image.fromarray(viability_overlay_np)
info_msg = f"Total cells: {total_count}\nLive (green): {alive_count}\nDead (red): {dead_count}\n"
info_msg += f"Viability: {viability_percent:.1f}%\nConfluency: {confluency:.1f}%\nBlue threshold: {blue_threshold}%"
return overlay_image, alive_count, dead_count, viability_percent, info_msg
except Exception as e:
return None, 0, 0, 0.0, f"Error updating viability: {str(e)}"
def pack_array(arr):
pil = Image.fromarray(arr.astype(np.uint8))
buf = io.BytesIO()
pil.save(buf, format="PNG")
return buf.getvalue()
def unpack_array(data):
return np.array(Image.open(io.BytesIO(data)))
# Gradio interface
with gr.Blocks(
title="CellposeCellCounter",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown("# CellposeCellCounter")
gr.Markdown("For accurate cell confluency, crop the image to display only desired area. Note that some image file types are not yet supported. PNG and JPEG are preferred.")
# Define State components to store masks and image data across function calls
masks_state = gr.State(value=None)
image_state = gr.State(value=None)
with gr.Tab("Cell Quantification"):
gr.Markdown("Run segmentation")
with gr.Row():
with gr.Column():
img_input = gr.Image(
type="pil",
label="Microscopy image",
image_mode="RGB",
height=512
)
model_dropdown1 = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
label="Select Model",
value="Hemocytometer Model"
)
min_size_slider1 = gr.Slider(
minimum=0,
maximum=500,
value=0,
step=10,
label="Minimum Cell Size (pixels). Leave at zero for automated recommendation",
)
max_size_slider1 = gr.Slider(
minimum=0,
maximum=1000,
value=1000,
step=10,
label="Maximum Cell Size (pixels)",
)
# Stereological counting option
gr.Markdown("### Stereological Counting")
use_stereology_checkbox = gr.Checkbox(
label="Enable Stereological Counting",
value=False,
info="Use unbiased stereological rules for cell counting"
)
# Stereological controls (initially hidden)
with gr.Group(visible=False) as stereology_controls:
gr.Markdown("""
**Stereological Counting Rules:**
- Cells touching LEFT or TOP exclusion zones are EXCLUDED
- Cells touching RIGHT or BOTTOM edges are INCLUDED
- This provides unbiased counting for quantification
""")
exclusion_preview = gr.Image(
type="pil",
label="Exclusion Zone Preview (Red = Excluded)",
height=300
)
left_exclusion_slider = gr.Slider(
minimum=0,
maximum=50,
value=10,
step=1,
label="Left Exclusion Width (%)",
info="Width of left exclusion zone"
)
top_exclusion_slider = gr.Slider(
minimum=0,
maximum=50,
value=10,
step=1,
label="Top Exclusion Width (%)",
info="Width of top exclusion zone"
)
segment_btn1 = gr.Button("🔬 Run Segmentation", variant="primary", size="lg")
with gr.Column():
cell_count_output1 = gr.Number(label="Total Cells Detected", precision=0)
confluency_output1 = gr.Number(label="Confluency (%)", precision=1)
overlay_output1 = gr.Image(type="pil", label="Segmentation Result")
info_output1 = gr.Textbox(label="Processing Info", lines=4)
# Viability Assessment Section
with gr.Group(visible=False) as viability_section1:
gr.Markdown("### Viability Assessment (Trypan Blue)")
gr.Markdown("Adjust the threshold to classify cells as live (green) or dead (red).")
with gr.Row():
with gr.Column():
blue_threshold1 = gr.Slider(
minimum=0,
maximum=100,
value=25,
step=1,
label="Blue Threshold (%)",
info="Higher values = more selective for blue cells"
)
with gr.Column():
live_count_output1 = gr.Number(label="Live Cells (Green)", precision=0)
dead_count_output1 = gr.Number(label="Dead Cells (Red)", precision=0)
viability_overlay1 = gr.Image(type="pil", label="Viability Assessment (Green=Live, Red=Dead)")
viability_percent_output1 = gr.Number(label="Viability (%)", precision=1)
viability_info1 = gr.Textbox(label="Analysis Results", lines=5)
# Event handlers
# Toggle stereological controls visibility
use_stereology_checkbox.change(
fn=toggle_stereological_mode,
inputs=[use_stereology_checkbox],
outputs=[stereology_controls]
)
# Update exclusion preview when image is uploaded or sliders change
img_input.change(
fn=update_exclusion_preview,
inputs=[img_input, left_exclusion_slider, top_exclusion_slider],
outputs=[exclusion_preview]
)
left_exclusion_slider.change(
fn=update_exclusion_preview,
inputs=[img_input, left_exclusion_slider, top_exclusion_slider],
outputs=[exclusion_preview]
)
top_exclusion_slider.change(
fn=update_exclusion_preview,
inputs=[img_input, left_exclusion_slider, top_exclusion_slider],
outputs=[exclusion_preview]
)
# Run segmentation
segment_btn1.click(
fn=run_segmentation,
inputs=[
img_input,
model_dropdown1,
min_size_slider1,
max_size_slider1,
use_stereology_checkbox,
left_exclusion_slider,
top_exclusion_slider
],
outputs=[
cell_count_output1,
overlay_output1,
info_output1,
viability_section1,
masks_state,
image_state,
confluency_output1,
min_size_slider1
]
).then( # Chain the initial viability assessment after segmentation
fn=update_viability_realtime,
inputs=[blue_threshold1, masks_state, image_state],
outputs=[viability_overlay1, live_count_output1, dead_count_output1, viability_percent_output1, viability_info1]
)
# Slider changes update viability in real-time
blue_threshold1.change(
fn=update_viability_realtime,
inputs=[blue_threshold1, masks_state, image_state],
outputs=[viability_overlay1, live_count_output1, dead_count_output1, viability_percent_output1, viability_info1]
)
# Instructions
with gr.Accordion("Instructions", open=False):
gr.Markdown("""
### How to use:
1. **Upload and Segment**:
- Upload your microscopy image.
- Select a Cellpose model (e.g., "Hemocytometer Model" for suspension culture).
- **(Optional)** Enable Stereological Counting for unbiased quantification.
- Click "Run Segmentation".
2. **Stereological Counting** (Optional):
- Check "Enable Stereological Counting" to use unbiased counting rules.
- Adjust the Left and Top exclusion zone widths using the sliders.
- Preview shows excluded areas in red.
- **Counting Rules**:
- Cells touching LEFT or TOP exclusion zones are EXCLUDED
- Cells touching RIGHT or BOTTOM edges are INCLUDED
- This ensures unbiased, systematic counting
3. **Analysis Results**:
- **Cell Count**: Total number of detected cells (after exclusions if using stereology)
- **Confluency**: Percentage of image area covered by cells
4. **Real-time Viability Assessment (Trypan Blue)**:
- After segmentation, the viability section will become visible.
- Adjust the **"Blue Threshold (%)"** slider in real-time.
- **Lower values (10-20%)** are more sensitive.
- **Higher values (30-50%)** are more selective.
- Green cells = Live, Red cells = Dead.
5. **Interpreting Results**:
- The app displays total, live, and dead cell counts, viability percentage, and confluency.
- If stereological counting is enabled, excluded cells are noted in the processing info.
""")
if __name__ == "__main__":
demo.launch()