Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline | |
| from PIL import Image | |
| from datasets import load_dataset | |
| from sklearn.metrics import ( | |
| accuracy_score, precision_score, recall_score, f1_score, | |
| confusion_matrix, ConfusionMatrixDisplay | |
| ) | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import requests | |
| from io import BytesIO | |
| MODEL_ID = "Thamer/resnet-fine_tuned" | |
| clf = pipeline("image-classification", model=MODEL_ID) | |
| DISCLAIMER = "β οΈ Educational demo only. Not for clinical/diagnostic use." | |
| LABEL_NAMES = ["Mild_Demented", "Moderate_Demented", "Non_Demented", "Very_Mild_Demented"] | |
| # SilpaCS string labels (no underscores) β model label format | |
| SILPA_STR_MAP = { | |
| "MildDemented": "Mild_Demented", | |
| "ModerateDemented": "Moderate_Demented", | |
| "NonDemented": "Non_Demented", | |
| "VeryMildDemented": "Very_Mild_Demented", | |
| } | |
| # SilpaCS integer labels β model label format | |
| # (same class ordering as Falah: 0=Mild, 1=Moderate, 2=Non, 3=VeryMild) | |
| SILPA_INT_MAP = { | |
| 0: "Mild_Demented", | |
| 1: "Moderate_Demented", | |
| 2: "Non_Demented", | |
| 3: "Very_Mild_Demented", | |
| } | |
| CLASS_NAMES = None | |
| def _get_top_label(preds): | |
| return max(preds, key=lambda x: x["score"])["label"] if preds else None | |
| def _resolve_silpa_label(raw): | |
| """Handle SilpaCS labels whether they come back as int or string.""" | |
| if isinstance(raw, int): | |
| return SILPA_INT_MAP.get(raw, str(raw)) | |
| return SILPA_STR_MAP.get(str(raw), str(raw)) | |
| def _plot_confusion_matrix(y_true, y_pred, class_names=None): | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| if class_names is None: | |
| labels = sorted(set(y_true) | set(y_pred)) | |
| else: | |
| labels = list(class_names) | |
| if len(labels) == 0: | |
| ax.text(0.5, 0.5, "No labeled examples yet.\nUpload an image + select a true label.", | |
| ha="center", va="center") | |
| ax.axis("off") | |
| plt.tight_layout() | |
| return fig | |
| cm = confusion_matrix(y_true, y_pred, labels=labels) | |
| disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels) | |
| disp.plot(ax=ax, cmap="Blues", colorbar=False, values_format="d") | |
| ax.set_title("Confusion Matrix") | |
| plt.xticks(rotation=45, ha="right") | |
| plt.yticks(rotation=0) | |
| plt.tight_layout() | |
| return fig | |
| def predict_and_update(img: Image.Image, true_label: str, y_true_state, y_pred_state): | |
| if img is None: | |
| fig = _plot_confusion_matrix(y_true_state, y_pred_state, CLASS_NAMES) | |
| return {}, DISCLAIMER, fig, y_true_state, y_pred_state | |
| img = img.convert("RGB") | |
| preds = clf(img) | |
| result = {p["label"]: float(p["score"]) for p in sorted(preds, key=lambda x: x["score"], reverse=True)} | |
| pred_label = _get_top_label(preds) | |
| if true_label and true_label != "β" and pred_label is not None: | |
| y_true_state = list(y_true_state) + [true_label] | |
| y_pred_state = list(y_pred_state) + [pred_label] | |
| fig = _plot_confusion_matrix(y_true_state, y_pred_state, CLASS_NAMES) | |
| return result, DISCLAIMER, fig, y_true_state, y_pred_state | |
| def reset_cm(): | |
| fig = _plot_confusion_matrix([], [], CLASS_NAMES) | |
| return fig, [], [] | |
| def load_silpa_safe(): | |
| """ | |
| SilpaCS/Alzheimer has a broken dataset builder. | |
| Fetch the raw auto-converted Parquet file directly via HTTP instead. | |
| """ | |
| url = ( | |
| "https://huggingface.co/datasets/SilpaCS/Alzheimer" | |
| "/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet" | |
| ) | |
| response = requests.get(url, timeout=120) | |
| response.raise_for_status() | |
| df = pd.read_parquet(BytesIO(response.content)) | |
| # Debug: print what the label column actually looks like | |
| print(f"SilpaCS label column dtype: {df['label'].dtype}") | |
| print(f"SilpaCS unique labels (first 10): {df['label'].unique()[:10]}") | |
| return df | |
| def run_full_evaluation(progress=gr.Progress()): | |
| """ | |
| Evaluate on: | |
| - Falah/Alzheimer_MRI test split (1,280 images) β held-out set | |
| - SilpaCS/Alzheimer (6,400 images) β independent source | |
| Total: 7,680 images | |
| """ | |
| progress(0, desc="Loading Falah/Alzheimer_MRI test split...") | |
| falah = load_dataset("Falah/Alzheimer_MRI", split="test") | |
| falah_label_names = falah.features["label"].names | |
| progress(0.05, desc="Fetching SilpaCS/Alzheimer via Parquet...") | |
| try: | |
| silpa_df = load_silpa_safe() | |
| except Exception as e: | |
| silpa_df = None | |
| print(f"Warning: SilpaCS failed to load ({e}), running Falah-only.") | |
| total = len(falah) + (len(silpa_df) if silpa_df is not None else 0) | |
| y_true, y_pred = [], [] | |
| i = 0 | |
| # --- Falah test split --- | |
| for example in falah: | |
| progress(i / total, desc=f"Evaluating image {i+1}/{total}...") | |
| img = example["image"].convert("RGB") | |
| top = _get_top_label(clf(img)) | |
| y_true.append(falah_label_names[example["label"]]) | |
| y_pred.append(top) | |
| i += 1 | |
| # --- SilpaCS (raw Parquet DataFrame) --- | |
| if silpa_df is not None: | |
| for _, row in silpa_df.iterrows(): | |
| progress(i / total, desc=f"Evaluating image {i+1}/{total}...") | |
| try: | |
| img_bytes = row["image"]["bytes"] | |
| img = Image.open(BytesIO(img_bytes)).convert("RGB") | |
| top = _get_top_label(clf(img)) | |
| true = _resolve_silpa_label(row["label"]) | |
| y_true.append(true) | |
| y_pred.append(top) | |
| except Exception as e: | |
| print(f"Skipping row {i}: {e}") | |
| i += 1 | |
| progress(1.0, desc="Done!") | |
| acc = accuracy_score(y_true, y_pred) | |
| prec = precision_score(y_true, y_pred, average="macro", zero_division=0) | |
| rec = recall_score(y_true, y_pred, average="macro", zero_division=0) | |
| f1 = f1_score(y_true, y_pred, average="macro", zero_division=0) | |
| n_falah = len(falah) | |
| n_silpa = len(silpa_df) if silpa_df is not None else 0 | |
| source_note = ( | |
| f"Falah/Alzheimer_MRI test split ({n_falah} images) + SilpaCS/Alzheimer ({n_silpa} images)" | |
| if n_silpa > 0 | |
| else f"Falah/Alzheimer_MRI test split only ({n_falah} images) β SilpaCS failed to load" | |
| ) | |
| metrics_md = f""" | |
| ## Evaluation Results β ResNet-34 | |
| *{source_note}* | |
| | Metric | Score | | |
| |-----------|------------| | |
| | Accuracy | {acc:.2%} | | |
| | Precision | {prec:.2%} | | |
| | Recall | {rec:.2%} | | |
| | F1 | {f1:.2%} | | |
| """ | |
| fig = _plot_confusion_matrix(y_true, y_pred, LABEL_NAMES) | |
| return metrics_md, fig | |
| def true_label_input_component(): | |
| if CLASS_NAMES: | |
| return gr.Dropdown(choices=["β"] + CLASS_NAMES, value="β", label="True Label (for confusion matrix)") | |
| else: | |
| return gr.Textbox(value="β", label="True Label (type label, or β to skip logging)") | |
| with gr.Blocks(title="Alzheimer's MRI Classification (4-class) β Demo") as demo: | |
| gr.Markdown(f"# Alzheimer's MRI Classification (4-class) β Demo\n\n{DISCLAIMER}") | |
| y_true_state = gr.State([]) | |
| y_pred_state = gr.State([]) | |
| # --- Single image prediction tab --- | |
| with gr.Tab("Single Image Prediction"): | |
| with gr.Row(): | |
| img_in = gr.Image(type="pil", label="Upload MRI Image (jpg/png)") | |
| true_in = true_label_input_component() | |
| with gr.Row(): | |
| pred_out = gr.Label(num_top_classes=4, label="Predictions") | |
| cm_plot = gr.Plot(label="Confusion Matrix (this session)") | |
| with gr.Row(): | |
| run_btn = gr.Button("Predict (and log if True Label provided)", variant="primary") | |
| reset_btn = gr.Button("Reset Confusion Matrix") | |
| demo.load(fn=lambda: _plot_confusion_matrix([], [], CLASS_NAMES), inputs=None, outputs=cm_plot) | |
| run_btn.click( | |
| fn=predict_and_update, | |
| inputs=[img_in, true_in, y_true_state, y_pred_state], | |
| outputs=[pred_out, gr.Markdown(), cm_plot, y_true_state, y_pred_state], | |
| ) | |
| reset_btn.click( | |
| fn=reset_cm, | |
| inputs=None, | |
| outputs=[cm_plot, y_true_state, y_pred_state], | |
| ) | |
| # --- Full evaluation tab --- | |
| with gr.Tab("Full Evaluation (7,680 images)"): | |
| gr.Markdown(""" | |
| ### Combined Evaluation β Falah/Alzheimer_MRI test split + SilpaCS/Alzheimer | |
| Evaluates across **7,680 total MRI images** from two independent sources: | |
| - **Falah/Alzheimer_MRI** (1,280 images) β held-out test split of the model's training dataset | |
| - **SilpaCS/Alzheimer** (6,400 images) β fully independent dataset not used during training | |
| β οΈ This will take **several minutes** to complete. | |
| """) | |
| eval_btn = gr.Button("Run Full Evaluation", variant="primary") | |
| metrics_out = gr.Markdown() | |
| eval_cm_plot = gr.Plot(label="Confusion Matrix β Combined Test Set") | |
| eval_btn.click( | |
| fn=run_full_evaluation, | |
| inputs=None, | |
| outputs=[metrics_out, eval_cm_plot], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |