noah34's picture
Update app.py
f19cdfb verified
import gradio as gr
from transformers import pipeline
from PIL import Image
from datasets import load_dataset
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, ConfusionMatrixDisplay
)
import matplotlib.pyplot as plt
import pandas as pd
import requests
from io import BytesIO
MODEL_ID = "Thamer/resnet-fine_tuned"
clf = pipeline("image-classification", model=MODEL_ID)
DISCLAIMER = "⚠️ Educational demo only. Not for clinical/diagnostic use."
LABEL_NAMES = ["Mild_Demented", "Moderate_Demented", "Non_Demented", "Very_Mild_Demented"]
# SilpaCS string labels (no underscores) β†’ model label format
SILPA_STR_MAP = {
"MildDemented": "Mild_Demented",
"ModerateDemented": "Moderate_Demented",
"NonDemented": "Non_Demented",
"VeryMildDemented": "Very_Mild_Demented",
}
# SilpaCS integer labels β†’ model label format
# (same class ordering as Falah: 0=Mild, 1=Moderate, 2=Non, 3=VeryMild)
SILPA_INT_MAP = {
0: "Mild_Demented",
1: "Moderate_Demented",
2: "Non_Demented",
3: "Very_Mild_Demented",
}
CLASS_NAMES = None
def _get_top_label(preds):
return max(preds, key=lambda x: x["score"])["label"] if preds else None
def _resolve_silpa_label(raw):
"""Handle SilpaCS labels whether they come back as int or string."""
if isinstance(raw, int):
return SILPA_INT_MAP.get(raw, str(raw))
return SILPA_STR_MAP.get(str(raw), str(raw))
def _plot_confusion_matrix(y_true, y_pred, class_names=None):
fig, ax = plt.subplots(figsize=(6, 6))
if class_names is None:
labels = sorted(set(y_true) | set(y_pred))
else:
labels = list(class_names)
if len(labels) == 0:
ax.text(0.5, 0.5, "No labeled examples yet.\nUpload an image + select a true label.",
ha="center", va="center")
ax.axis("off")
plt.tight_layout()
return fig
cm = confusion_matrix(y_true, y_pred, labels=labels)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(ax=ax, cmap="Blues", colorbar=False, values_format="d")
ax.set_title("Confusion Matrix")
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
return fig
def predict_and_update(img: Image.Image, true_label: str, y_true_state, y_pred_state):
if img is None:
fig = _plot_confusion_matrix(y_true_state, y_pred_state, CLASS_NAMES)
return {}, DISCLAIMER, fig, y_true_state, y_pred_state
img = img.convert("RGB")
preds = clf(img)
result = {p["label"]: float(p["score"]) for p in sorted(preds, key=lambda x: x["score"], reverse=True)}
pred_label = _get_top_label(preds)
if true_label and true_label != "β€”" and pred_label is not None:
y_true_state = list(y_true_state) + [true_label]
y_pred_state = list(y_pred_state) + [pred_label]
fig = _plot_confusion_matrix(y_true_state, y_pred_state, CLASS_NAMES)
return result, DISCLAIMER, fig, y_true_state, y_pred_state
def reset_cm():
fig = _plot_confusion_matrix([], [], CLASS_NAMES)
return fig, [], []
def load_silpa_safe():
"""
SilpaCS/Alzheimer has a broken dataset builder.
Fetch the raw auto-converted Parquet file directly via HTTP instead.
"""
url = (
"https://huggingface.co/datasets/SilpaCS/Alzheimer"
"/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet"
)
response = requests.get(url, timeout=120)
response.raise_for_status()
df = pd.read_parquet(BytesIO(response.content))
# Debug: print what the label column actually looks like
print(f"SilpaCS label column dtype: {df['label'].dtype}")
print(f"SilpaCS unique labels (first 10): {df['label'].unique()[:10]}")
return df
def run_full_evaluation(progress=gr.Progress()):
"""
Evaluate on:
- Falah/Alzheimer_MRI test split (1,280 images) β€” held-out set
- SilpaCS/Alzheimer (6,400 images) β€” independent source
Total: 7,680 images
"""
progress(0, desc="Loading Falah/Alzheimer_MRI test split...")
falah = load_dataset("Falah/Alzheimer_MRI", split="test")
falah_label_names = falah.features["label"].names
progress(0.05, desc="Fetching SilpaCS/Alzheimer via Parquet...")
try:
silpa_df = load_silpa_safe()
except Exception as e:
silpa_df = None
print(f"Warning: SilpaCS failed to load ({e}), running Falah-only.")
total = len(falah) + (len(silpa_df) if silpa_df is not None else 0)
y_true, y_pred = [], []
i = 0
# --- Falah test split ---
for example in falah:
progress(i / total, desc=f"Evaluating image {i+1}/{total}...")
img = example["image"].convert("RGB")
top = _get_top_label(clf(img))
y_true.append(falah_label_names[example["label"]])
y_pred.append(top)
i += 1
# --- SilpaCS (raw Parquet DataFrame) ---
if silpa_df is not None:
for _, row in silpa_df.iterrows():
progress(i / total, desc=f"Evaluating image {i+1}/{total}...")
try:
img_bytes = row["image"]["bytes"]
img = Image.open(BytesIO(img_bytes)).convert("RGB")
top = _get_top_label(clf(img))
true = _resolve_silpa_label(row["label"])
y_true.append(true)
y_pred.append(top)
except Exception as e:
print(f"Skipping row {i}: {e}")
i += 1
progress(1.0, desc="Done!")
acc = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, average="macro", zero_division=0)
rec = recall_score(y_true, y_pred, average="macro", zero_division=0)
f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
n_falah = len(falah)
n_silpa = len(silpa_df) if silpa_df is not None else 0
source_note = (
f"Falah/Alzheimer_MRI test split ({n_falah} images) + SilpaCS/Alzheimer ({n_silpa} images)"
if n_silpa > 0
else f"Falah/Alzheimer_MRI test split only ({n_falah} images) β€” SilpaCS failed to load"
)
metrics_md = f"""
## Evaluation Results β€” ResNet-34
*{source_note}*
| Metric | Score |
|-----------|------------|
| Accuracy | {acc:.2%} |
| Precision | {prec:.2%} |
| Recall | {rec:.2%} |
| F1 | {f1:.2%} |
"""
fig = _plot_confusion_matrix(y_true, y_pred, LABEL_NAMES)
return metrics_md, fig
def true_label_input_component():
if CLASS_NAMES:
return gr.Dropdown(choices=["β€”"] + CLASS_NAMES, value="β€”", label="True Label (for confusion matrix)")
else:
return gr.Textbox(value="β€”", label="True Label (type label, or β€” to skip logging)")
with gr.Blocks(title="Alzheimer's MRI Classification (4-class) β€” Demo") as demo:
gr.Markdown(f"# Alzheimer's MRI Classification (4-class) β€” Demo\n\n{DISCLAIMER}")
y_true_state = gr.State([])
y_pred_state = gr.State([])
# --- Single image prediction tab ---
with gr.Tab("Single Image Prediction"):
with gr.Row():
img_in = gr.Image(type="pil", label="Upload MRI Image (jpg/png)")
true_in = true_label_input_component()
with gr.Row():
pred_out = gr.Label(num_top_classes=4, label="Predictions")
cm_plot = gr.Plot(label="Confusion Matrix (this session)")
with gr.Row():
run_btn = gr.Button("Predict (and log if True Label provided)", variant="primary")
reset_btn = gr.Button("Reset Confusion Matrix")
demo.load(fn=lambda: _plot_confusion_matrix([], [], CLASS_NAMES), inputs=None, outputs=cm_plot)
run_btn.click(
fn=predict_and_update,
inputs=[img_in, true_in, y_true_state, y_pred_state],
outputs=[pred_out, gr.Markdown(), cm_plot, y_true_state, y_pred_state],
)
reset_btn.click(
fn=reset_cm,
inputs=None,
outputs=[cm_plot, y_true_state, y_pred_state],
)
# --- Full evaluation tab ---
with gr.Tab("Full Evaluation (7,680 images)"):
gr.Markdown("""
### Combined Evaluation β€” Falah/Alzheimer_MRI test split + SilpaCS/Alzheimer
Evaluates across **7,680 total MRI images** from two independent sources:
- **Falah/Alzheimer_MRI** (1,280 images) β€” held-out test split of the model's training dataset
- **SilpaCS/Alzheimer** (6,400 images) β€” fully independent dataset not used during training
⚠️ This will take **several minutes** to complete.
""")
eval_btn = gr.Button("Run Full Evaluation", variant="primary")
metrics_out = gr.Markdown()
eval_cm_plot = gr.Plot(label="Confusion Matrix β€” Combined Test Set")
eval_btn.click(
fn=run_full_evaluation,
inputs=None,
outputs=[metrics_out, eval_cm_plot],
)
if __name__ == "__main__":
demo.launch()