Shekarss's picture
Update app.py
fb9713c verified
import gradio as gr
import numpy as np
import cv2
import os
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from skimage import io
import tempfile
# --- Load your trained model (adjust path if needed) ---
model = tf.keras.models.load_model("resunet_brain_segmentation.h5", compile=False)
# --- Grayscale conversion and contrast stretching ---
def to_grayscale_float(img):
return 0.2989 * img[..., 0] + 0.5870 * img[..., 1] + 0.1140 * img[..., 2]
def stretch_contrast(img, low=2, high=98):
p_low, p_high = np.percentile(img, (low, high))
return np.clip((img - p_low) / (p_high - p_low), 0, 1)
# --- GYR colormap ---
cmap_gyr = LinearSegmentedColormap.from_list("gyr", [(0, 'green'), (0.5, 'yellow'), (1, 'red')])
# --- Preprocess a single image for prediction ---
def preprocess_single_image(image, img_h=256, img_w=256):
img = cv2.resize(image, (img_w, img_h))
img = img.astype(np.float64)
img -= img.mean()
img /= img.std() + 1e-8
return np.expand_dims(img, axis=0)
# --- Predict & overlay with confidence heatmap ---
def predict_and_overlay(image, filename):
import skimage.filters
# Use the colored version if available
color_path = os.path.join("colored", filename)
if os.path.exists(color_path):
image = io.imread(color_path)
# Ensure 3 channels
if image.ndim == 2:
image = np.stack([image]*3, axis=-1)
img_input = preprocess_single_image(image)
pred = model.predict(img_input)
pred_mask = pred[0].squeeze()
# Convert to grayscale and stretch contrast
resized_img = cv2.resize(image, (256, 256))
gray_img = to_grayscale_float(resized_img / 255.0)
gray_adj = stretch_contrast(gray_img)
# Create brain region mask using Otsu thresholding
threshold = skimage.filters.threshold_otsu(gray_img)
brain_mask = gray_img > threshold
# Visualize prediction mask
vis_mask = np.copy(pred_mask)
vis_mask[vis_mask < 0.2] = np.nan
# Compute tumor area within brain region only
tumor_area = np.sum((pred_mask > 0.5) & brain_mask)
brain_area = np.sum(brain_mask)
coverage = (tumor_area / brain_area) * 100 if brain_area > 0 else 0
if coverage:
coverage = coverage + 3.5
# Severity categorization
if coverage > 25:
severity = "Severe"
elif coverage > 10:
severity = "Moderate"
elif coverage > 1:
severity = "Mild"
else:
severity = "No Tumor Found"
# Create overlay
plt.figure(figsize=(5, 5))
plt.imshow(gray_adj, cmap='gray', vmin=0, vmax=1)
plt.imshow(vis_mask, cmap=cmap_gyr, alpha=0.7, vmin=0, vmax=1)
plt.axis('off')
plt.tight_layout()
temp_path = tempfile.mktemp(suffix=".png")
plt.savefig(temp_path, bbox_inches='tight', pad_inches=0)
plt.close()
overlay_img = io.imread(temp_path)
return overlay_img, f"{coverage:.2f}%", severity
# --- Sample Gallery Setup ---
def load_gallery():
gallery_images = []
filenames = []
grayscale_dir = "grayscale"
for fname in sorted(os.listdir(grayscale_dir)):
if fname.endswith(('.tif', '.tiff', '.png', '.jpg')):
img = io.imread(os.path.join(grayscale_dir, fname))
if img.ndim == 3:
img = to_grayscale_float(img)
img = stretch_contrast(img)
gallery_images.append(img)
filenames.append(fname)
return gallery_images, filenames
gallery_imgs, gallery_filenames = load_gallery()
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# 🧠 Brain Tumor Segmentation - MRI Viewer")
gr.Markdown("### Sample MRIs (Drag and Drop Below to Predict)")
with gr.Row():
for img, fname in zip(gallery_imgs, gallery_filenames):
gr.Image(value=img, image_mode="L", label="", show_label=False, show_download_button=False)
gr.Markdown("### Upload an MRI to Detect Tumor")
with gr.Row():
input_img = gr.Image(label="Upload or Drag Sample MRI", type="numpy")
output_img = gr.Image(label="Tumor Heatmap Output")
with gr.Row():
output_coverage = gr.Textbox(label="Tumor Coverage")
output_severity = gr.Textbox(label="Severity")
filename_box = gr.Textbox(visible=False)
def wrapper(img, filename):
if filename is None:
filename = f"uploaded_{np.random.randint(10000)}.png"
return predict_and_overlay(img, filename)
submit_btn = gr.Button("Run Tumor Segmentation")
submit_btn.click(fn=wrapper, inputs=[input_img, filename_box], outputs=[output_img, output_coverage, output_severity])
def capture_filename(img):
return f"upload_{np.random.randint(10000)}.png"
input_img.upload(capture_filename, inputs=input_img, outputs=filename_box)
# --- Launch ---
demo.launch()