Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import yaml | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| from torchvision import transforms | |
| from sklearn.metrics import roc_curve, auc, precision_recall_curve | |
| from models.vae import VAEModel # Custom VAE implementation | |
| # Load configuration | |
| config = { | |
| 'image_size': [256, 256], | |
| 'thresholds': { | |
| 'low': 0.15, | |
| 'medium': 0.35, | |
| 'high': 0.65 | |
| } | |
| } | |
| # Initialize model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = VAEModel.load_from_checkpoint(config['model_path']) | |
| model.eval() | |
| model.to(device) | |
| # Mock performance data (replace with your actual metrics) | |
| performance_metrics = { | |
| 'AUC-ROC': 0.947, | |
| 'Sensitivity': 0.921, | |
| 'Specificity': 0.886, | |
| 'Precision': 0.893, | |
| 'F1-Score': 0.907, | |
| 'Inference Speed (ms)': 118 | |
| } | |
| # Mock ROC and PR curve data | |
| fpr, tpr, _ = roc_curve([0, 1]*50, np.random.rand(100)*0.2 + np.array([0]*50 + [0.8]*50)) | |
| roc_auc = auc(fpr, tpr) | |
| precision, recall, _ = precision_recall_curve([0, 1]*50, np.random.rand(100)*0.2 + np.array([0]*50 + [0.8]*50)) | |
| pr_auc = auc(recall, precision) | |
| # Preprocessing pipeline | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(config['image_size']), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485], std=[0.229]) | |
| ]) | |
| def detect_anomalies(input_image): | |
| """Process MRI scan and detect anomalies with performance metrics""" | |
| try: | |
| # Preprocess image | |
| image_tensor = preprocess(input_image).unsqueeze(0).to(device) | |
| # Generate reconstruction and get anomaly score | |
| with torch.no_grad(): | |
| reconstructed, mu, logvar = model(image_tensor) | |
| anomaly_score = model.compute_anomaly_score(image_tensor, reconstructed) | |
| # Create heatmap | |
| diff = torch.abs(image_tensor - reconstructed).squeeze().cpu().numpy() | |
| heatmap = (diff * 255).astype(np.uint8) | |
| # Generate personalized recommendation | |
| recommendation = generate_recommendation(anomaly_score, config['thresholds']) | |
| # Create performance visualizations | |
| roc_fig = plot_roc_curve(fpr, tpr, roc_auc) | |
| pr_fig = plot_pr_curve(precision, recall, pr_auc) | |
| metrics_table = create_metrics_table(performance_metrics) | |
| return { | |
| "anomaly_score": float(anomaly_score), | |
| "heatmap": Image.fromarray(heatmap), | |
| "diagnosis": recommendation, | |
| "reconstructed": transforms.ToPILImage()(reconstructed.squeeze().cpu()), | |
| "roc_curve": roc_fig, | |
| "pr_curve": pr_fig, | |
| "metrics": metrics_table | |
| } | |
| except Exception as e: | |
| raise gr.Error(f"Processing failed: {str(e)}") | |
| def generate_recommendation(score, thresholds): | |
| """Generate personalized remediation plan""" | |
| if score < thresholds['low']: | |
| return "Normal scan - No anomalies detected" | |
| elif score < thresholds['medium']: | |
| return "Mild anomaly detected - Recommend follow-up in 6 months" | |
| elif score < thresholds['high']: | |
| return "Moderate anomaly detected - Urgent specialist referral advised" | |
| else: | |
| return "Severe anomaly detected - Immediate medical intervention required" | |
| def plot_roc_curve(fpr, tpr, roc_auc): | |
| """Generate ROC curve visualization""" | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.plot(fpr, tpr, color='darkorange', lw=2, | |
| label=f'ROC curve (AUC = {roc_auc:.2f})') | |
| ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') | |
| ax.set_xlim([0.0, 1.0]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel('False Positive Rate') | |
| ax.set_ylabel('True Positive Rate') | |
| ax.set_title('Receiver Operating Characteristic') | |
| ax.legend(loc="lower right") | |
| return fig | |
| def plot_pr_curve(precision, recall, pr_auc): | |
| """Generate Precision-Recall curve visualization""" | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.plot(recall, precision, color='blue', lw=2, | |
| label=f'PR curve (AUC = {pr_auc:.2f})') | |
| ax.set_xlim([0.0, 1.0]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel('Recall') | |
| ax.set_ylabel('Precision') | |
| ax.set_title('Precision-Recall Curve') | |
| ax.legend(loc="lower left") | |
| return fig | |
| def create_metrics_table(metrics): | |
| """Create styled metrics table""" | |
| df = pd.DataFrame.from_dict(metrics, orient='index', columns=['Value']) | |
| df['Value'] = df['Value'].apply(lambda x: f"{x:.3f}" if isinstance(x, (float, int)) else x) | |
| return df | |
| # Gradio UI with performance metrics | |
| with gr.Blocks(title="MRI Anomaly Detection with Performance Metrics", theme="soft") as demo: | |
| gr.Markdown(""" | |
| # 🧠 MRI Anomaly Detection System | |
| *Principal ML Engineer Demonstration - Generative AI with Comprehensive Metrics* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Upload MRI Scan", type="pil") | |
| submit_btn = gr.Button("Analyze Scan", variant="primary") | |
| with gr.Column(): | |
| anomaly_score = gr.Number(label="Anomaly Score", precision=3) | |
| diagnosis = gr.Textbox(label="Clinical Recommendation") | |
| with gr.Tab("Original vs Reconstructed"): | |
| gr.Markdown("**Left: Original | Right: Reconstructed**") | |
| compare = gr.Gallery(columns=2, height="auto") | |
| with gr.Tab("Anomaly Heatmap"): | |
| heatmap = gr.Image(label="Anomaly Regions", interactive=False) | |
| # Performance Metrics Section | |
| with gr.Accordion("Model Performance Metrics", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| roc_curve = gr.Plot(label="ROC Curve") | |
| with gr.Column(): | |
| pr_curve = gr.Plot(label="Precision-Recall Curve") | |
| metrics_table = gr.Dataframe( | |
| label="Quantitative Metrics", | |
| headers=["Metric", "Value"], | |
| datatype=["str", "str"] | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=["examples/normal.jpg", "examples/abnormal.jpg"], | |
| inputs=input_image, | |
| outputs=[anomaly_score, diagnosis, compare, heatmap, roc_curve, pr_curve, metrics_table], | |
| fn=detect_anomalies, | |
| cache_examples=True | |
| ) | |
| submit_btn.click( | |
| fn=detect_anomalies, | |
| inputs=input_image, | |
| outputs={ | |
| "anomaly_score": anomaly_score, | |
| "heatmap": heatmap, | |
| "diagnosis": diagnosis, | |
| "reconstructed": compare, | |
| "roc_curve": roc_curve, | |
| "pr_curve": pr_curve, | |
| "metrics": metrics_table | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| # demo.launch(server_name="0.0.0.0", server_port=7860) | |
| demo.launch() |