rahimizadeh's picture
Update app.py
a34f323 verified
import gradio as gr
import numpy as np
import torch
import yaml
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from sklearn.metrics import roc_curve, auc, precision_recall_curve
from models.vae import VAEModel # Custom VAE implementation
# Load configuration
config = {
'image_size': [256, 256],
'thresholds': {
'low': 0.15,
'medium': 0.35,
'high': 0.65
}
}
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAEModel.load_from_checkpoint(config['model_path'])
model.eval()
model.to(device)
# Mock performance data (replace with your actual metrics)
performance_metrics = {
'AUC-ROC': 0.947,
'Sensitivity': 0.921,
'Specificity': 0.886,
'Precision': 0.893,
'F1-Score': 0.907,
'Inference Speed (ms)': 118
}
# Mock ROC and PR curve data
fpr, tpr, _ = roc_curve([0, 1]*50, np.random.rand(100)*0.2 + np.array([0]*50 + [0.8]*50))
roc_auc = auc(fpr, tpr)
precision, recall, _ = precision_recall_curve([0, 1]*50, np.random.rand(100)*0.2 + np.array([0]*50 + [0.8]*50))
pr_auc = auc(recall, precision)
# Preprocessing pipeline
preprocess = transforms.Compose([
transforms.Resize(config['image_size']),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485], std=[0.229])
])
def detect_anomalies(input_image):
"""Process MRI scan and detect anomalies with performance metrics"""
try:
# Preprocess image
image_tensor = preprocess(input_image).unsqueeze(0).to(device)
# Generate reconstruction and get anomaly score
with torch.no_grad():
reconstructed, mu, logvar = model(image_tensor)
anomaly_score = model.compute_anomaly_score(image_tensor, reconstructed)
# Create heatmap
diff = torch.abs(image_tensor - reconstructed).squeeze().cpu().numpy()
heatmap = (diff * 255).astype(np.uint8)
# Generate personalized recommendation
recommendation = generate_recommendation(anomaly_score, config['thresholds'])
# Create performance visualizations
roc_fig = plot_roc_curve(fpr, tpr, roc_auc)
pr_fig = plot_pr_curve(precision, recall, pr_auc)
metrics_table = create_metrics_table(performance_metrics)
return {
"anomaly_score": float(anomaly_score),
"heatmap": Image.fromarray(heatmap),
"diagnosis": recommendation,
"reconstructed": transforms.ToPILImage()(reconstructed.squeeze().cpu()),
"roc_curve": roc_fig,
"pr_curve": pr_fig,
"metrics": metrics_table
}
except Exception as e:
raise gr.Error(f"Processing failed: {str(e)}")
def generate_recommendation(score, thresholds):
"""Generate personalized remediation plan"""
if score < thresholds['low']:
return "Normal scan - No anomalies detected"
elif score < thresholds['medium']:
return "Mild anomaly detected - Recommend follow-up in 6 months"
elif score < thresholds['high']:
return "Moderate anomaly detected - Urgent specialist referral advised"
else:
return "Severe anomaly detected - Immediate medical intervention required"
def plot_roc_curve(fpr, tpr, roc_auc):
"""Generate ROC curve visualization"""
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, color='darkorange', lw=2,
label=f'ROC curve (AUC = {roc_auc:.2f})')
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('Receiver Operating Characteristic')
ax.legend(loc="lower right")
return fig
def plot_pr_curve(precision, recall, pr_auc):
"""Generate Precision-Recall curve visualization"""
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(recall, precision, color='blue', lw=2,
label=f'PR curve (AUC = {pr_auc:.2f})')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curve')
ax.legend(loc="lower left")
return fig
def create_metrics_table(metrics):
"""Create styled metrics table"""
df = pd.DataFrame.from_dict(metrics, orient='index', columns=['Value'])
df['Value'] = df['Value'].apply(lambda x: f"{x:.3f}" if isinstance(x, (float, int)) else x)
return df
# Gradio UI with performance metrics
with gr.Blocks(title="MRI Anomaly Detection with Performance Metrics", theme="soft") as demo:
gr.Markdown("""
# 🧠 MRI Anomaly Detection System
*Principal ML Engineer Demonstration - Generative AI with Comprehensive Metrics*
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload MRI Scan", type="pil")
submit_btn = gr.Button("Analyze Scan", variant="primary")
with gr.Column():
anomaly_score = gr.Number(label="Anomaly Score", precision=3)
diagnosis = gr.Textbox(label="Clinical Recommendation")
with gr.Tab("Original vs Reconstructed"):
gr.Markdown("**Left: Original | Right: Reconstructed**")
compare = gr.Gallery(columns=2, height="auto")
with gr.Tab("Anomaly Heatmap"):
heatmap = gr.Image(label="Anomaly Regions", interactive=False)
# Performance Metrics Section
with gr.Accordion("Model Performance Metrics", open=False):
with gr.Row():
with gr.Column():
roc_curve = gr.Plot(label="ROC Curve")
with gr.Column():
pr_curve = gr.Plot(label="Precision-Recall Curve")
metrics_table = gr.Dataframe(
label="Quantitative Metrics",
headers=["Metric", "Value"],
datatype=["str", "str"]
)
# Examples
gr.Examples(
examples=["examples/normal.jpg", "examples/abnormal.jpg"],
inputs=input_image,
outputs=[anomaly_score, diagnosis, compare, heatmap, roc_curve, pr_curve, metrics_table],
fn=detect_anomalies,
cache_examples=True
)
submit_btn.click(
fn=detect_anomalies,
inputs=input_image,
outputs={
"anomaly_score": anomaly_score,
"heatmap": heatmap,
"diagnosis": diagnosis,
"reconstructed": compare,
"roc_curve": roc_curve,
"pr_curve": pr_curve,
"metrics": metrics_table
}
)
if __name__ == "__main__":
# demo.launch(server_name="0.0.0.0", server_port=7860)
demo.launch()