| | |
| | |
| |
|
| | |
| | import os |
| | os.environ["OMP_NUM_THREADS"] = "1" |
| |
|
| | import io, json, glob |
| | import numpy as np |
| | import pandas as pd |
| | from PIL import Image |
| | import gradio as gr |
| | import torch, torch.nn as nn |
| | from torchvision import models, transforms, datasets |
| | import matplotlib |
| | matplotlib.use("Agg") |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from huggingface_hub import hf_hub_download |
| |
|
| | |
| | from cam_utils import grad_cam |
| |
|
| | |
| | MODEL_NAME = os.environ.get("MODEL_NAME", "efficientnet_b0") |
| | NUM_CLASSES = int(os.environ.get("NUM_CLASSES", "2")) |
| | IMAGE_SIZE = int(os.environ.get("IMAGE_SIZE", "224")) |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | CLASS_NAMES = ["Parasitized", "Uninfected"] |
| |
|
| | |
| | HF_REPO_ID = os.environ.get("HF_REPO_ID", "").strip() |
| | HF_WEIGHTS = os.environ.get("HF_WEIGHTS", "best.pt").strip() if HF_REPO_ID else "" |
| | WEIGHTS_PATH = os.environ.get("WEIGHTS_PATH", "checkpoints/best.pt") |
| |
|
| | def resolve_weights() -> str: |
| | if HF_REPO_ID: |
| | try: |
| | path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_WEIGHTS) |
| | return path |
| | except Exception as e: |
| | print(f"[hub] failed to download {HF_REPO_ID}:{HF_WEIGHTS} → {e}") |
| | if os.path.exists(WEIGHTS_PATH): |
| | return WEIGHTS_PATH |
| | candidates = ["checkpoints/best.pt", "checkpoints/last.pt"] + sorted(glob.glob("checkpoints/*.pt")) |
| | for p in candidates: |
| | if os.path.exists(p): |
| | return p |
| | raise FileNotFoundError("No checkpoint found. Upload a .pt file.") |
| |
|
| | |
| | def build_model(name: str, num_classes: int): |
| | name = name.lower() |
| | if name == "efficientnet_b0": |
| | m = models.efficientnet_b0(weights=None) |
| | m.classifier[1] = nn.Linear(m.classifier[1].in_features, num_classes) |
| | return m |
| | elif name == "resnet50": |
| | m = models.resnet50(weights=None) |
| | m.fc = nn.Linear(m.fc.in_features, num_classes) |
| | return m |
| | else: |
| | raise ValueError(f"Unsupported model: {name}") |
| |
|
| | CHECKPOINT_FILE = resolve_weights() |
| | _model = build_model(MODEL_NAME, NUM_CLASSES).to(DEVICE) |
| | _state = torch.load(CHECKPOINT_FILE, map_location=DEVICE) |
| | try: |
| | _model.load_state_dict(_state) |
| | except Exception: |
| | _model.load_state_dict(_state["state_dict"]) |
| | _model.eval() |
| |
|
| | _pre = transforms.Compose([ |
| | transforms.Resize(int(IMAGE_SIZE*1.15)), |
| | transforms.CenterCrop(IMAGE_SIZE), |
| | transforms.ToTensor(), |
| | ]) |
| |
|
| | |
| | def predict_enhanced(image: Image.Image, show_cam: bool): |
| | if image is None: |
| | return "Please upload an image", None, None, None |
| |
|
| | img = image.convert("RGB") |
| | x = _pre(img).unsqueeze(0).to(DEVICE) |
| |
|
| | with torch.no_grad(): |
| | logits = _model(x).cpu().numpy().squeeze() |
| | probs = np.exp(logits - logits.max()) |
| | probs = probs / probs.sum() |
| |
|
| | pred_idx = int(np.argmax(probs)) |
| | pred_class = CLASS_NAMES[pred_idx] |
| | confidence = float(probs[pred_idx]) |
| |
|
| | |
| | if pred_idx == 0: |
| | color = "#FF4444" |
| | status = "MALARIA DETECTED" |
| | emoji = "🦠" |
| | recommendation = """ |
| | ### Clinical Recommendation: |
| | - **Immediate microscopic confirmation required** |
| | - **Consult healthcare provider immediately** |
| | - **Begin rapid diagnostic test (RDT)** |
| | - **Consider antimalarial treatment if confirmed** |
| | """ |
| | else: |
| | color = "#44FF88" |
| | status = "NO MALARIA DETECTED" |
| | emoji = "✅" |
| | recommendation = """ |
| | ### Clinical Recommendation: |
| | - **Negative result - low malaria likelihood** |
| | - **Monitor for symptoms development** |
| | - **Consult healthcare provider if symptoms persist** |
| | - **Consider other differential diagnoses** |
| | """ |
| |
|
| | |
| | result_html = f""" |
| | <div style='padding: 30px; background: linear-gradient(135deg, {color}15 0%, {color}05 100%); |
| | border-left: 6px solid {color}; border-radius: 12px; margin: 20px 0;'> |
| | <div style='display: flex; align-items: center; margin-bottom: 20px;'> |
| | <div style='font-size: 48px; margin-right: 20px;'>{emoji}</div> |
| | <div> |
| | <h1 style='color: {color}; margin: 0; font-size: 32px; font-weight: 700;'>{status}</h1> |
| | <p style='margin: 5px 0 0 0; font-size: 16px; color: #666;'>AI-Powered Analysis Complete</p> |
| | </div> |
| | </div> |
| | |
| | <div style='background: white; padding: 20px; border-radius: 8px; margin-bottom: 15px;'> |
| | <h3 style='margin: 0 0 15px 0; color: #333;'>Diagnostic Results</h3> |
| | <div style='display: grid; grid-template-columns: 1fr 1fr; gap: 15px;'> |
| | <div> |
| | <p style='margin: 0; color: #666; font-size: 14px;'>Prediction</p> |
| | <p style='margin: 5px 0 0 0; font-size: 20px; font-weight: 600; color: {color};'>{pred_class}</p> |
| | </div> |
| | <div> |
| | <p style='margin: 0; color: #666; font-size: 14px;'>Confidence</p> |
| | <p style='margin: 5px 0 0 0; font-size: 20px; font-weight: 600; color: {color};'>{confidence*100:.2f}%</p> |
| | </div> |
| | </div> |
| | </div> |
| | |
| | <div style='background: white; padding: 20px; border-radius: 8px;'> |
| | <h3 style='margin: 0 0 10px 0; color: #333;'>Class Probabilities</h3> |
| | <div style='margin-bottom: 15px;'> |
| | <div style='display: flex; justify-content: space-between; margin-bottom: 5px;'> |
| | <span style='font-weight: 500;'>🦠 Parasitized</span> |
| | <span style='font-weight: 600;'>{probs[0]*100:.2f}%</span> |
| | </div> |
| | <div style='height: 24px; background: #f0f0f0; border-radius: 12px; overflow: hidden;'> |
| | <div style='height: 100%; width: {probs[0]*100}%; background: linear-gradient(90deg, #FF4444, #FF6666);'></div> |
| | </div> |
| | </div> |
| | <div> |
| | <div style='display: flex; justify-content: space-between; margin-bottom: 5px;'> |
| | <span style='font-weight: 500;'>✅ Uninfected</span> |
| | <span style='font-weight: 600;'>{probs[1]*100:.2f}%</span> |
| | </div> |
| | <div style='height: 24px; background: #f0f0f0; border-radius: 12px; overflow: hidden;'> |
| | <div style='height: 100%; width: {probs[1]*100}%; background: linear-gradient(90deg, #44FF88, #66FFAA);'></div> |
| | </div> |
| | </div> |
| | </div> |
| | </div> |
| | |
| | <div style='padding: 20px; background: #FFF8E1; border-left: 4px solid #FFC107; border-radius: 8px; margin-top: 20px;'> |
| | <h3 style='margin: 0 0 10px 0; color: #F57C00;'>⚠️ Medical Disclaimer</h3> |
| | <p style='margin: 0; font-size: 14px; line-height: 1.6;'> |
| | This is a <strong>research tool only</strong> and NOT a medical diagnostic device. |
| | Results must be confirmed by certified laboratory testing and qualified healthcare professionals. |
| | Do not make medical decisions based solely on this AI analysis. |
| | </p> |
| | </div> |
| | """ |
| |
|
| | |
| | overlay = None |
| | cam_img = None |
| | if show_cam: |
| | try: |
| | cam = grad_cam(_model, img, img_size=IMAGE_SIZE, device=DEVICE) |
| | overlay = Image.fromarray((cam["overlay"]*255).astype("uint8")) |
| | cam_img = Image.fromarray((cam["heatmap"]*255).astype("uint8")) |
| | except Exception as e: |
| | print(f"Grad-CAM error: {e}") |
| |
|
| | |
| | fig, ax = plt.subplots(figsize=(6, 4)) |
| | colors_bar = ['#FF4444', '#44FF88'] |
| | bars = ax.barh(CLASS_NAMES, probs*100, color=colors_bar, alpha=0.8) |
| | ax.set_xlabel('Probability (%)', fontsize=12, fontweight='bold') |
| | ax.set_title('Prediction Confidence', fontsize=14, fontweight='bold', pad=20) |
| | ax.set_xlim(0, 100) |
| | ax.grid(axis='x', alpha=0.3) |
| |
|
| | for i, (bar, prob) in enumerate(zip(bars, probs)): |
| | ax.text(prob*100 + 2, i, f'{prob*100:.1f}%', va='center', fontweight='bold') |
| |
|
| | plt.tight_layout() |
| | buf = io.BytesIO() |
| | fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') |
| | buf.seek(0) |
| | prob_chart = Image.open(buf).convert('RGB') |
| | plt.close() |
| |
|
| | return result_html, overlay, prob_chart, recommendation |
| |
|
| | |
| | def export_onnx(precision: str): |
| | m = build_model(MODEL_NAME, NUM_CLASSES).to(DEVICE) |
| | state = torch.load(CHECKPOINT_FILE, map_location=DEVICE) |
| | try: |
| | m.load_state_dict(state) |
| | except Exception: |
| | m.load_state_dict(state["state_dict"]) |
| | m.eval() |
| |
|
| | dummy = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=DEVICE) |
| | if precision == "fp16": |
| | m = m.half() |
| | dummy = dummy.half() |
| |
|
| | dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}} |
| | buf = io.BytesIO() |
| | torch.onnx.export( |
| | m, dummy, buf, |
| | input_names=["input"], output_names=["output"], |
| | dynamic_axes=dynamic_axes, opset_version=17, do_constant_folding=True |
| | ) |
| | fname = f"model_{MODEL_NAME}_{precision}_{IMAGE_SIZE}.onnx" |
| | buf.seek(0) |
| | return fname, buf |
| |
|
| | |
| | def validate(zip_file): |
| | import tempfile, zipfile |
| | if zip_file is None: |
| | return "Please upload a validation dataset ZIP file.", None, None |
| |
|
| | tmp = tempfile.mkdtemp() |
| | with zipfile.ZipFile(zip_file.name, 'r') as zf: |
| | zf.extractall(tmp) |
| |
|
| | ds = datasets.ImageFolder(tmp, transform=_pre) |
| | dl = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=False, num_workers=2) |
| | ys, ps = [], [] |
| |
|
| | with torch.no_grad(): |
| | for xb, yb in dl: |
| | preds = _model(xb.to(DEVICE)).argmax(1).cpu().numpy() |
| | ys.extend(yb.numpy()) |
| | ps.extend(preds) |
| |
|
| | import sklearn.metrics as sk |
| | rep = sk.classification_report(ys, ps, target_names=ds.classes, output_dict=True) |
| | cm = sk.confusion_matrix(ys, ps) |
| |
|
| | |
| | fig, ax = plt.subplots(figsize=(8, 6)) |
| | sns.heatmap(cm, annot=True, fmt="d", cmap="RdYlGn_r", |
| | xticklabels=ds.classes, yticklabels=ds.classes, |
| | ax=ax, cbar_kws={'label': 'Count'}, linewidths=2, linecolor='white') |
| | ax.set_xlabel('Predicted Label', fontsize=12, fontweight='bold') |
| | ax.set_ylabel('True Label', fontsize=12, fontweight='bold') |
| | ax.set_title('Confusion Matrix - Validation Results', fontsize=14, fontweight='bold', pad=20) |
| |
|
| | plt.tight_layout() |
| | buf = io.BytesIO() |
| | fig.savefig(buf, format="png", dpi=160) |
| | buf.seek(0) |
| | cm_img = Image.open(buf).convert("RGB") |
| | plt.close() |
| |
|
| | |
| | acc = rep['accuracy'] |
| | report_md = f""" |
| | ### Validation Results |
| | |
| | **Overall Accuracy:** {acc*100:.2f}% |
| | |
| | #### Per-Class Metrics: |
| | |
| | | Class | Precision | Recall | F1-Score | Support | |
| | |-------|-----------|--------|----------|---------| |
| | | Parasitized | {rep['Parasitized']['precision']:.3f} | {rep['Parasitized']['recall']:.3f} | {rep['Parasitized']['f1-score']:.3f} | {rep['Parasitized']['support']:.0f} | |
| | | Uninfected | {rep['Uninfected']['precision']:.3f} | {rep['Uninfected']['recall']:.3f} | {rep['Uninfected']['f1-score']:.3f} | {rep['Uninfected']['support']:.0f} | |
| | |
| | **Macro Avg:** Precision={rep['macro avg']['precision']:.3f}, Recall={rep['macro avg']['recall']:.3f}, F1={rep['macro avg']['f1-score']:.3f} |
| | """ |
| |
|
| | return report_md, cm_img |
| |
|
| | |
| | METRICS_DEFAULT = "checkpoints/metrics.csv" |
| |
|
| | def _plot_to_pil(df: pd.DataFrame, ycol: str, title: str, ylabel: str, color='#2196F3'): |
| | if ycol not in df.columns: |
| | return None |
| | s = df[["epoch", ycol]].dropna() |
| | if len(s) == 0: |
| | return None |
| |
|
| | fig, ax = plt.subplots(figsize=(8, 5)) |
| | ax.plot(s["epoch"], s[ycol], marker="o", linewidth=2.5, markersize=8, color=color) |
| | ax.fill_between(s["epoch"], s[ycol], alpha=0.3, color=color) |
| | ax.set_title(title, fontsize=16, fontweight='bold', pad=20) |
| | ax.set_xlabel("Epoch", fontsize=12, fontweight='bold') |
| | ax.set_ylabel(ylabel, fontsize=12, fontweight='bold') |
| | ax.grid(True, alpha=0.3, linestyle='--') |
| | ax.spines['top'].set_visible(False) |
| | ax.spines['right'].set_visible(False) |
| |
|
| | plt.tight_layout() |
| | buf = io.BytesIO() |
| | fig.savefig(buf, format="png", dpi=160, bbox_inches="tight", facecolor='white') |
| | buf.seek(0) |
| | img = Image.open(buf).convert("RGB") |
| | plt.close() |
| | return img |
| |
|
| | def load_metrics(path: str): |
| | if not os.path.exists(path): |
| | return "Metrics file not found. Upload your training metrics CSV.", None, None, None, None, None, None |
| |
|
| | try: |
| | df = pd.read_csv(path) |
| | except Exception as e: |
| | return f"Error reading CSV: {e}", None, None, None, None, None, None |
| |
|
| | |
| | col_map = { |
| | 'activation_rate': 'act_rate', |
| | 'energy_savings': 'save_rate' |
| | } |
| | for old_col, new_col in col_map.items(): |
| | if old_col in df.columns and new_col not in df.columns: |
| | df[new_col] = df[old_col] |
| |
|
| | |
| | energy_col = 'save_rate' if 'save_rate' in df.columns else 'energy_savings' |
| |
|
| | |
| | try: |
| | last_5_table = df.tail(5).to_markdown(index=False) |
| | except ImportError: |
| | last_5_table = "```\n" + df.tail(5).to_string(index=False) + "\n```" |
| |
|
| | summary = f""" |
| | ### Training Summary |
| | |
| | **Total Epochs:** {len(df)} |
| | **Best Validation Accuracy:** {df['val_acc'].max()*100:.2f}% (Epoch {df['val_acc'].idxmax() + 1}) |
| | **Final Training Loss:** {df['train_loss'].iloc[-1]:.4f} |
| | **Average Energy Savings:** {df[energy_col].mean()*100:.1f}% |
| | |
| | #### Last 5 Epochs: |
| | {last_5_table} |
| | """ |
| |
|
| | fig_loss = _plot_to_pil(df, "train_loss", "Training Loss Over Time", "Loss", color='#FF5722') |
| | fig_acc = _plot_to_pil(df, "val_acc", "Validation Accuracy Over Time", "Accuracy", color='#4CAF50') |
| | fig_act = _plot_to_pil(df, "act_rate", "Activation Rate (AST)", "Activation Rate", color='#2196F3') |
| | fig_save = _plot_to_pil(df, "save_rate", "Energy Savings (AST)", "Savings Fraction", color='#9C27B0') |
| | fig_thr = _plot_to_pil(df, "threshold", "Activation Threshold (AST)", "Threshold", color='#FF9800') |
| |
|
| | csv_bytes = df.to_csv(index=False).encode("utf-8") |
| | return summary, fig_loss, fig_acc, fig_act, fig_save, fig_thr, ("metrics.csv", csv_bytes) |
| |
|
| | def compare_runs(files): |
| | if not files or len(files) == 0: |
| | return "Upload 2 or more metrics.csv files to compare training runs.", None, None, None, None, None |
| |
|
| | runs = [] |
| | for f in files: |
| | try: |
| | df = pd.read_csv(f.name) |
| | |
| | col_map = {'activation_rate': 'act_rate', 'energy_savings': 'save_rate'} |
| | for old_col, new_col in col_map.items(): |
| | if old_col in df.columns and new_col not in df.columns: |
| | df[new_col] = df[old_col] |
| | runs.append((os.path.basename(f.name), df)) |
| | except Exception as e: |
| | return f"Error reading {f.name}: {e}", None, None, None, None, None |
| |
|
| | def _overlay_plot(runs, ycol, title, ylabel): |
| | fig, ax = plt.subplots(figsize=(10, 6)) |
| | colors = ['#2196F3', '#FF5722', '#4CAF50', '#9C27B0', '#FF9800', '#00BCD4'] |
| | found = False |
| |
|
| | for i, (name, df) in enumerate(runs): |
| | if ycol in df.columns: |
| | s = df[["epoch", ycol]].dropna() |
| | if len(s): |
| | color = colors[i % len(colors)] |
| | ax.plot(s["epoch"], s[ycol], marker="o", linewidth=2, |
| | markersize=6, label=name, color=color, alpha=0.8) |
| | found = True |
| |
|
| | if not found: |
| | plt.close(fig) |
| | return None |
| |
|
| | ax.set_title(title, fontsize=16, fontweight='bold', pad=20) |
| | ax.set_xlabel("Epoch", fontsize=12, fontweight='bold') |
| | ax.set_ylabel(ylabel, fontsize=12, fontweight='bold') |
| | ax.grid(True, alpha=0.3, linestyle='--') |
| | ax.legend(fontsize=10, loc='best') |
| | ax.spines['top'].set_visible(False) |
| | ax.spines['right'].set_visible(False) |
| |
|
| | plt.tight_layout() |
| | buf = io.BytesIO() |
| | fig.savefig(buf, format="png", dpi=160, bbox_inches="tight", facecolor='white') |
| | buf.seek(0) |
| | img = Image.open(buf).convert("RGB") |
| | plt.close() |
| | return img |
| |
|
| | msg = f"Successfully compared {len(runs)} training runs." |
| | p_loss = _overlay_plot(runs, "train_loss", "Training Loss Comparison", "Loss") |
| | p_acc = _overlay_plot(runs, "val_acc", "Validation Accuracy Comparison", "Accuracy") |
| | p_act = _overlay_plot(runs, "act_rate", "Activation Rate Comparison", "Activation Rate") |
| | p_save = _overlay_plot(runs, "save_rate", "Energy Savings Comparison", "Savings Fraction") |
| | p_thr = _overlay_plot(runs, "threshold", "Threshold Comparison", "Threshold") |
| |
|
| | return msg, p_loss, p_acc, p_act, p_save, p_thr |
| |
|
| | |
| | custom_css = """ |
| | #header { |
| | background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
| | padding: 30px; |
| | border-radius: 15px; |
| | margin-bottom: 30px; |
| | text-align: center; |
| | color: white; |
| | } |
| | #header h1 { |
| | margin: 0; |
| | font-size: 42px; |
| | font-weight: 800; |
| | } |
| | #header p { |
| | margin: 10px 0 0 0; |
| | font-size: 18px; |
| | opacity: 0.95; |
| | } |
| | .badge { |
| | display: inline-block; |
| | padding: 8px 16px; |
| | margin: 5px; |
| | background: rgba(255,255,255,0.2); |
| | border-radius: 20px; |
| | font-weight: 600; |
| | } |
| | """ |
| |
|
| | |
| | with gr.Blocks(title="Malaria Detection AI - Advanced Diagnostics", css=custom_css, theme=gr.themes.Soft()) as demo: |
| |
|
| | |
| | gr.HTML(""" |
| | <div id="header"> |
| | <h1>🔬 AI-Powered Malaria Detection System</h1> |
| | <p>Advanced Deep Learning for Rapid Malaria Diagnosis</p> |
| | <div style="margin-top: 15px;"> |
| | <span class="badge">93.94% Accuracy</span> |
| | <span class="badge">88% Energy Savings</span> |
| | <span class="badge">EfficientNet-B0</span> |
| | <span class="badge">Adaptive Sparse Training (Sundew)</span> |
| | </div> |
| | </div> |
| | """) |
| |
|
| | gr.Markdown(f""" |
| | ### System Information |
| | **Model:** `{MODEL_NAME}` | **Weights:** `{os.path.basename(CHECKPOINT_FILE)}` | **Device:** `{DEVICE}` | **Classes:** {NUM_CLASSES} |
| | """) |
| |
|
| | |
| | with gr.Tab("🔍 Diagnosis"): |
| | gr.Markdown(""" |
| | ### Upload Blood Smear Image for Analysis |
| | Upload a microscopy image of a blood cell to detect malaria parasites using AI. |
| | """) |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | img_in = gr.Image(type="pil", label="Upload Blood Cell Image", height=400) |
| | show_cam = gr.Checkbox(value=True, label="Show Grad-CAM Visualization (Explainable AI)") |
| | btn_pred = gr.Button("🔬 Analyze for Malaria", variant="primary", size="lg") |
| |
|
| | with gr.Column(scale=1): |
| | result_out = gr.HTML(label="Diagnostic Results") |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | cam_out = gr.Image(type="pil", label="Grad-CAM Heat Map (Where AI Looks)") |
| | with gr.Column(): |
| | chart_out = gr.Image(type="pil", label="Confidence Distribution") |
| |
|
| | recommendation_out = gr.Markdown(label="Clinical Recommendations") |
| |
|
| | btn_pred.click( |
| | fn=predict_enhanced, |
| | inputs=[img_in, show_cam], |
| | outputs=[result_out, cam_out, chart_out, recommendation_out] |
| | ) |
| |
|
| | |
| | with gr.Tab("✅ Model Validation"): |
| | gr.Markdown(""" |
| | ### Validate Model Performance |
| | Upload a ZIP file containing a validation dataset (with Parasitized/ and Uninfected/ folders). |
| | """) |
| |
|
| | val_zip = gr.File(label="Upload Validation Dataset (.zip)", file_types=[".zip"]) |
| | btn_eval = gr.Button("📊 Run Validation", variant="primary") |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | rep_out = gr.Markdown(label="Classification Report") |
| | with gr.Column(): |
| | cm_img = gr.Image(type="pil", label="Confusion Matrix") |
| |
|
| | btn_eval.click(fn=validate, inputs=[val_zip], outputs=[rep_out, cm_img]) |
| |
|
| | |
| | with gr.Tab("📈 Training Dashboard"): |
| | gr.Markdown(""" |
| | ### Visualize Training Metrics |
| | View training progress, validation accuracy, and energy savings from Adaptive Sparse Training (AST). |
| | |
| | **Metrics are automatically loaded from checkpoints/metrics.csv. Upload a different file if needed.** |
| | """) |
| |
|
| | |
| | initial_summary, initial_loss, initial_acc, initial_act, initial_save, initial_thr, initial_csv = load_metrics(METRICS_DEFAULT) |
| |
|
| | |
| | with gr.Row(): |
| | single_metrics_file = gr.File(label="Upload different metrics.csv (optional)", file_types=[".csv"]) |
| | btn_upload = gr.Button("📊 Load Metrics", variant="secondary") |
| |
|
| | summary_md = gr.Markdown(value=initial_summary) |
| |
|
| | with gr.Row(): |
| | plot_loss = gr.Image(label="Training Loss", value=initial_loss) |
| | plot_acc = gr.Image(label="Validation Accuracy", value=initial_acc) |
| |
|
| | with gr.Row(): |
| | plot_act = gr.Image(label="Activation Rate (AST)", value=initial_act) |
| | plot_save = gr.Image(label="Energy Savings (AST)", value=initial_save) |
| |
|
| | plot_thr = gr.Image(label="Activation Threshold (AST)", value=initial_thr) |
| | dl_btn = gr.DownloadButton(label="⬇️ Download Metrics CSV") |
| |
|
| | |
| | demo.load( |
| | fn=lambda: initial_csv, |
| | inputs=[], |
| | outputs=[dl_btn] |
| | ) |
| |
|
| | |
| | btn_upload.click( |
| | fn=lambda f: load_metrics(f.name if f else ""), |
| | inputs=[single_metrics_file], |
| | outputs=[summary_md, plot_loss, plot_acc, plot_act, plot_save, plot_thr, dl_btn] |
| | ) |
| |
|
| | gr.Markdown("---") |
| | gr.Markdown("### Compare Multiple Training Runs") |
| | mult = gr.Files(label="Upload Multiple metrics.csv Files", file_types=[".csv"]) |
| | cmp_msg = gr.Markdown() |
| |
|
| | with gr.Row(): |
| | p_loss = gr.Image(label="Loss Comparison") |
| | p_acc = gr.Image(label="Accuracy Comparison") |
| |
|
| | with gr.Row(): |
| | p_act = gr.Image(label="Activation Comparison") |
| | p_save = gr.Image(label="Savings Comparison") |
| |
|
| | p_thr = gr.Image(label="Threshold Comparison") |
| |
|
| | mult.upload( |
| | fn=compare_runs, |
| | inputs=[mult], |
| | outputs=[cmp_msg, p_loss, p_acc, p_act, p_save, p_thr] |
| | ) |
| |
|
| | |
| | with gr.Tab("📦 Model Export"): |
| | gr.Markdown(""" |
| | ### Export Model to ONNX Format |
| | Convert the PyTorch model to ONNX format for production deployment and cross-platform compatibility. |
| | """) |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | onnx_precision = gr.Radio( |
| | choices=["fp32", "fp16"], |
| | value="fp32", |
| | label="Precision", |
| | info="FP16 for faster inference, FP32 for maximum accuracy" |
| | ) |
| | btn_onnx = gr.Button("🚀 Export to ONNX", variant="primary") |
| |
|
| | with gr.Column(): |
| | onnx_file = gr.File(label="Download ONNX Model", interactive=False) |
| |
|
| | def onnx_wrap(prec): |
| | fname, fobj = export_onnx(prec) |
| | return (fname, fobj) |
| |
|
| | btn_onnx.click(fn=onnx_wrap, inputs=[onnx_precision], outputs=[onnx_file]) |
| |
|
| | |
| | with gr.Tab("ℹ️ About"): |
| | gr.Markdown(""" |
| | ## About This System |
| | |
| | ### Technology Stack |
| | - **Deep Learning Framework:** PyTorch |
| | - **Model Architecture:** EfficientNet-B0 |
| | - **Training Method:** Adaptive Sparse Training (AST) with Sundew Algorithm |
| | - **Explainable AI:** Grad-CAM (Gradient-weighted Class Activation Mapping) |
| | - **Dataset:** NIH Malaria Cell Images (27,558 samples) |
| | |
| | ### Performance Metrics |
| | - **Validation Accuracy:** 93.94% (final epoch), 94.63% (best epoch) |
| | - **Energy Savings:** 88% reduction in training cost vs. traditional methods |
| | - **Inference Speed:** <1 second per image |
| | - **Model Size:** ~16MB |
| | - **Training:** 30 epochs on NIH Malaria Dataset |
| | |
| | ### Key Features |
| | 1. **Real-time Diagnosis:** Upload blood smear images for instant analysis |
| | 2. **Explainable AI:** Grad-CAM shows exactly where the model detects parasites |
| | 3. **Energy Efficient:** Trained using Adaptive Sparse Training with Sundew algorithm for 88% energy savings |
| | 4. **Clinical Recommendations:** Actionable advice based on predictions |
| | 5. **Model Validation:** Built-in tools for performance evaluation |
| | 6. **ONNX Export:** Deploy anywhere with standard model format |
| | |
| | ### Use Cases |
| | - **Research:** Academic studies on malaria detection |
| | - **Education:** Teaching AI applications in healthcare |
| | - **Triage:** Rapid pre-screening in resource-limited settings |
| | - **Model Comparison:** Benchmark against other approaches |
| | |
| | ### Important Disclaimers |
| | |
| | ⚠️ **This is a research prototype and NOT a medical device.** |
| | |
| | - Results must be confirmed by certified laboratory testing |
| | - Do not use for clinical diagnosis without professional validation |
| | - Always consult qualified healthcare providers |
| | - This tool is for research and educational purposes only |
| | |
| | ### Developer |
| | **Oluwafemi Idiakhoa** |
| | |
| | ### Citation |
| | If you use this system in your research, please cite: |
| | ``` |
| | @software{malaria_ast_detection, |
| | author = {Idiakhoa, Oluwafemi}, |
| | title = {Malaria Detection using Adaptive Sparse Training}, |
| | year = {2025}, |
| | url = {https://github.com/oluwafemidiakhoa/Malaria} |
| | } |
| | ``` |
| | |
| | ### Resources |
| | - [GitHub Repository](https://github.com/oluwafemidiakhoa/Malaria) |
| | - AST Library (PyPI): pypi.org/project/adaptive-sparse-training |
| | - [NIH Malaria Dataset](https://lhncbc.nlm.nih.gov/LHC-research/LHC-projects/image-processing/malaria-datasheet.html) |
| | |
| | --- |
| | |
| | Built with EfficientNet-B0 + Adaptive Sparse Training | Powered by PyTorch & Gradio |
| | """) |
| |
|
| | |
| | gr.Markdown(""" |
| | --- |
| | <div style='text-align: center; padding: 20px; color: #666;'> |
| | <p><strong>Malaria Detection AI</strong> | Advanced Deep Learning for Global Health</p> |
| | <p>Developer: Oluwafemi Idiakhoa | <a href="https://github.com/oluwafemidiakhoa/Malaria">GitHub</a></p> |
| | <p style='font-size: 12px; margin-top: 10px;'>This is a research tool. Not for clinical use.</p> |
| | </div> |
| | """) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch() |
| |
|