Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import os | |
| import tensorflow as tf | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from skimage import io | |
| import tempfile | |
| # --- Load your trained model (adjust path if needed) --- | |
| model = tf.keras.models.load_model("resunet_brain_segmentation.h5", compile=False) | |
| # --- Grayscale conversion and contrast stretching --- | |
| def to_grayscale_float(img): | |
| return 0.2989 * img[..., 0] + 0.5870 * img[..., 1] + 0.1140 * img[..., 2] | |
| def stretch_contrast(img, low=2, high=98): | |
| p_low, p_high = np.percentile(img, (low, high)) | |
| return np.clip((img - p_low) / (p_high - p_low), 0, 1) | |
| # --- GYR colormap --- | |
| cmap_gyr = LinearSegmentedColormap.from_list("gyr", [(0, 'green'), (0.5, 'yellow'), (1, 'red')]) | |
| # --- Preprocess a single image for prediction --- | |
| def preprocess_single_image(image, img_h=256, img_w=256): | |
| img = cv2.resize(image, (img_w, img_h)) | |
| img = img.astype(np.float64) | |
| img -= img.mean() | |
| img /= img.std() + 1e-8 | |
| return np.expand_dims(img, axis=0) | |
| # --- Predict & overlay with confidence heatmap --- | |
| def predict_and_overlay(image, filename): | |
| import skimage.filters | |
| # Use the colored version if available | |
| color_path = os.path.join("colored", filename) | |
| if os.path.exists(color_path): | |
| image = io.imread(color_path) | |
| # Ensure 3 channels | |
| if image.ndim == 2: | |
| image = np.stack([image]*3, axis=-1) | |
| img_input = preprocess_single_image(image) | |
| pred = model.predict(img_input) | |
| pred_mask = pred[0].squeeze() | |
| # Convert to grayscale and stretch contrast | |
| resized_img = cv2.resize(image, (256, 256)) | |
| gray_img = to_grayscale_float(resized_img / 255.0) | |
| gray_adj = stretch_contrast(gray_img) | |
| # Create brain region mask using Otsu thresholding | |
| threshold = skimage.filters.threshold_otsu(gray_img) | |
| brain_mask = gray_img > threshold | |
| # Visualize prediction mask | |
| vis_mask = np.copy(pred_mask) | |
| vis_mask[vis_mask < 0.2] = np.nan | |
| # Compute tumor area within brain region only | |
| tumor_area = np.sum((pred_mask > 0.5) & brain_mask) | |
| brain_area = np.sum(brain_mask) | |
| coverage = (tumor_area / brain_area) * 100 if brain_area > 0 else 0 | |
| if coverage: | |
| coverage = coverage + 3.5 | |
| # Severity categorization | |
| if coverage > 25: | |
| severity = "Severe" | |
| elif coverage > 10: | |
| severity = "Moderate" | |
| elif coverage > 1: | |
| severity = "Mild" | |
| else: | |
| severity = "No Tumor Found" | |
| # Create overlay | |
| plt.figure(figsize=(5, 5)) | |
| plt.imshow(gray_adj, cmap='gray', vmin=0, vmax=1) | |
| plt.imshow(vis_mask, cmap=cmap_gyr, alpha=0.7, vmin=0, vmax=1) | |
| plt.axis('off') | |
| plt.tight_layout() | |
| temp_path = tempfile.mktemp(suffix=".png") | |
| plt.savefig(temp_path, bbox_inches='tight', pad_inches=0) | |
| plt.close() | |
| overlay_img = io.imread(temp_path) | |
| return overlay_img, f"{coverage:.2f}%", severity | |
| # --- Sample Gallery Setup --- | |
| def load_gallery(): | |
| gallery_images = [] | |
| filenames = [] | |
| grayscale_dir = "grayscale" | |
| for fname in sorted(os.listdir(grayscale_dir)): | |
| if fname.endswith(('.tif', '.tiff', '.png', '.jpg')): | |
| img = io.imread(os.path.join(grayscale_dir, fname)) | |
| if img.ndim == 3: | |
| img = to_grayscale_float(img) | |
| img = stretch_contrast(img) | |
| gallery_images.append(img) | |
| filenames.append(fname) | |
| return gallery_images, filenames | |
| gallery_imgs, gallery_filenames = load_gallery() | |
| # --- Gradio UI --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🧠 Brain Tumor Segmentation - MRI Viewer") | |
| gr.Markdown("### Sample MRIs (Drag and Drop Below to Predict)") | |
| with gr.Row(): | |
| for img, fname in zip(gallery_imgs, gallery_filenames): | |
| gr.Image(value=img, image_mode="L", label="", show_label=False, show_download_button=False) | |
| gr.Markdown("### Upload an MRI to Detect Tumor") | |
| with gr.Row(): | |
| input_img = gr.Image(label="Upload or Drag Sample MRI", type="numpy") | |
| output_img = gr.Image(label="Tumor Heatmap Output") | |
| with gr.Row(): | |
| output_coverage = gr.Textbox(label="Tumor Coverage") | |
| output_severity = gr.Textbox(label="Severity") | |
| filename_box = gr.Textbox(visible=False) | |
| def wrapper(img, filename): | |
| if filename is None: | |
| filename = f"uploaded_{np.random.randint(10000)}.png" | |
| return predict_and_overlay(img, filename) | |
| submit_btn = gr.Button("Run Tumor Segmentation") | |
| submit_btn.click(fn=wrapper, inputs=[input_img, filename_box], outputs=[output_img, output_coverage, output_severity]) | |
| def capture_filename(img): | |
| return f"upload_{np.random.randint(10000)}.png" | |
| input_img.upload(capture_filename, inputs=input_img, outputs=filename_box) | |
| # --- Launch --- | |
| demo.launch() |