Spaces:
Sleeping
Sleeping
| # app.py — Space-friendly UI: single-image predict, report, and batch evaluate with uploads | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import os, re, csv, json, time, contextlib, warnings, tempfile, zipfile | |
| from typing import List, Tuple, Optional | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image, ImageOps | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from sklearn.metrics import classification_report, confusion_matrix, accuracy_score | |
| from model import build_model | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # ---------------- Paths & constants ---------------- | |
| ROOT = Path(__file__).resolve().parent | |
| CKPT = ROOT / "ckpt_final320" / "best.pt" | |
| CLASSES_TXT = ROOT / "classes.txt" | |
| REPORT_DIR = ROOT / "reports_final320" | |
| RES = 320 | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| # ---------------- Device / AMP helpers ---------------- | |
| DEVICE = ( | |
| "cuda" if torch.cuda.is_available() | |
| else "mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| def autocast_ctx(): | |
| if DEVICE == "cuda": | |
| return torch.autocast(device_type="cuda", dtype=torch.float16) | |
| if DEVICE == "mps": | |
| return torch.autocast(device_type="mps", dtype=torch.float16) | |
| return contextlib.nullcontext() | |
| torch.set_float32_matmul_precision("high") | |
| # ---------------- Model & transforms ---------------- | |
| def _load_classes(p: Path) -> List[str]: | |
| if not p.exists(): | |
| raise FileNotFoundError(f"classes.txt not found at {p}") | |
| return [ln.strip() for ln in p.read_text().splitlines() if ln.strip()] | |
| def load_model_and_tfms() -> tuple[torch.nn.Module, List[str], transforms.Compose]: | |
| classes = _load_classes(CLASSES_TXT) | |
| if not CKPT.exists(): | |
| raise FileNotFoundError(f"Checkpoint not found at {CKPT}") | |
| model = build_model(len(classes), pretrained=False) | |
| sd = torch.load(CKPT, map_location="cpu") | |
| sd = sd.get("model", sd) # allow either a pure state_dict or {"model": ...} | |
| model.load_state_dict(sd, strict=True) | |
| model.eval() | |
| model.to(DEVICE) | |
| model.to(memory_format=torch.channels_last) | |
| tfm = transforms.Compose([ | |
| transforms.Resize(int(RES * 256 / 224)), | |
| transforms.CenterCrop(RES), | |
| transforms.ToTensor(), | |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), | |
| ]) | |
| return model, classes, tfm | |
| MODEL, CLASSES, TFM = load_model_and_tfms() | |
| # ---------------- Predict (single image) ---------------- | |
| def plot_topk(probs: torch.Tensor, classes: list[str], k: int = 5): | |
| k = max(1, min(k, len(classes))) | |
| vals, idx = torch.topk(probs, k) | |
| vals = vals.detach().cpu().numpy() | |
| labels = [classes[i] for i in idx.tolist()] | |
| fig = plt.figure(figsize=(6, 3.4), dpi=140) | |
| ax = fig.add_subplot(111) | |
| ax.barh(range(k), vals[::-1]) | |
| ax.set_yticks(range(k)); ax.set_yticklabels(labels[::-1], fontsize=9) | |
| ax.set_xlim(0, 1); ax.invert_yaxis() | |
| ax.set_xlabel("Probability"); ax.grid(axis="x", alpha=0.25, linestyle="--") | |
| fig.tight_layout() | |
| return fig | |
| def predict(img: Image.Image, topk: int): | |
| if img is None: | |
| return None, "", None | |
| with torch.inference_mode(), autocast_ctx(): | |
| x = TFM(img.convert("RGB")).unsqueeze(0).to(DEVICE, memory_format=torch.channels_last) | |
| logits = MODEL(x) | |
| prob = F.softmax(logits, dim=1)[0].detach().cpu() | |
| top1_p, top1_i = torch.max(prob, dim=0) | |
| badge = f"**Prediction:** {CLASSES[top1_i.item()]} — **{float(top1_p)*100:.2f}%**" | |
| fig = plot_topk(prob, CLASSES, k=topk) | |
| return img, badge, fig | |
| # ---------------- Report readers (optional on Space) ---------------- | |
| def _find_history_file(): | |
| for p in [REPORT_DIR/"history.json", REPORT_DIR/"history.csv", | |
| ROOT/"history.json", ROOT/"history.csv", | |
| ROOT/"ckpt_final320/history.json", ROOT/"ckpt_final320/history.csv"]: | |
| if p.exists(): return p | |
| return None | |
| def _load_history_from_path(hp: Path | None): | |
| if hp is None or not hp.exists(): return None | |
| try: | |
| if hp.suffix == ".json": | |
| h = json.loads(hp.read_text()) | |
| return {"train_acc": h.get("train_acc") or h.get("acc") or [], | |
| "val_acc": h.get("val_acc") or h.get("val") or [], | |
| "train_loss":h.get("train_loss")or [], | |
| "val_loss": h.get("val_loss") or []} | |
| rows = list(csv.DictReader(hp.read_text().splitlines())) | |
| return {"train_acc":[float(r["train_acc"]) for r in rows if r.get("train_acc")], | |
| "val_acc": [float(r["val_acc"]) for r in rows if r.get("val_acc")], | |
| "train_loss":[float(r["train_loss"]) for r in rows if r.get("train_loss")], | |
| "val_loss": [float(r["val_loss"]) for r in rows if r.get("val_loss")]} | |
| except Exception: | |
| return None | |
| def plot_training_curves(history: dict | None): | |
| if not history or not history.get("train_acc"): return None | |
| ta, va, tl, vl = history["train_acc"], history["val_acc"], history["train_loss"], history["val_loss"] | |
| n = max(len(ta), len(va), len(tl), len(vl)); ep = list(range(n)) | |
| pad = lambda a: a if a and len(a)==n else (a + [a[-1]]*(n-len(a)) if a else [None]*n) | |
| ta, va, tl, vl = map(pad, (ta, va, tl, vl)) | |
| fig = plt.figure(figsize=(10, 3.6), dpi=140) | |
| ax1 = fig.add_subplot(1,2,1); ax1.plot(ep, ta, label="Training Accuracy"); ax1.plot(ep, va, label="Validation Accuracy") | |
| ax1.set_title("Model Accuracy"); ax1.set_xlabel("Epoch"); ax1.set_ylabel("Accuracy"); ax1.grid(alpha=.25, linestyle="--"); ax1.legend(loc="lower right", fontsize=8) | |
| ax2 = fig.add_subplot(1,2,2); ax2.plot(ep, tl, label="Training Loss"); ax2.plot(ep, vl, label="Validation Loss") | |
| ax2.set_title("Model Loss"); ax2.set_xlabel("Epoch"); ax2.set_ylabel("Loss"); ax2.grid(alpha=.25, linestyle="--"); ax2.legend(loc="upper right", fontsize=8) | |
| fig.tight_layout(); return fig | |
| def _parse_report_text(txt: str) -> tuple[Optional[float], Optional[float]]: | |
| m_acc = re.search(r"accuracy\s+([0-9]*\.?[0-9]+)", txt) | |
| m_macro = re.search(r"macro avg\s+([0-9]*\.?[0-9]+)\s+([0-9]*\.?[0-9]+)\s+([0-9]*\.?[0-9]+)", txt) | |
| top1 = float(m_acc.group(1)) if m_acc else None | |
| macro_recall = float(m_macro.group(2)) if m_macro else None | |
| return top1, macro_recall | |
| def load_metrics_and_curves(rpt_upload=None, hist_upload=None): | |
| rpt_txt = None | |
| if rpt_upload is not None: | |
| try: | |
| rpt_txt = Path(rpt_upload.name).read_text() | |
| except Exception: | |
| rpt_txt = None | |
| if rpt_txt is None: | |
| rpt = REPORT_DIR / "classification_report.txt" | |
| if rpt.exists(): | |
| rpt_txt = rpt.read_text() | |
| if rpt_txt: | |
| top1, macro_recall = _parse_report_text(rpt_txt); msg = "" | |
| else: | |
| top1 = macro_recall = None; msg = "Report file not found." | |
| hp = Path(hist_upload.name) if hist_upload is not None else _find_history_file() | |
| hist = _load_history_from_path(hp) | |
| fig = plot_training_curves(hist) if hist else None | |
| top1_md = f"**Top-1 Accuracy (overall):** {top1:.4f}" if top1 is not None else "Top-1 Accuracy: —" | |
| macro_md = f"**Average Accuracy per Class (macro recall):** {macro_recall:.4f}" if macro_recall is not None else "Avg per class: —" | |
| note = msg or ("" if fig else "_No training history found — add `reports_final320/history.json|csv` or upload it above._") | |
| return top1_md, macro_md, fig, note | |
| # ---------------- Confusion matrix plotting ---------------- | |
| def plot_confusion_matrix(cm: np.ndarray, normalize: bool): | |
| if cm is None or cm.size == 0: return None | |
| M = cm.astype(float) | |
| if normalize: | |
| s = M.sum(axis=1, keepdims=True); s[s==0]=1.0; M = M/s | |
| fig = plt.figure(figsize=(6.5, 6), dpi=140) | |
| ax = fig.add_subplot(111) | |
| im = ax.imshow(M, aspect="auto") | |
| ax.set_title("Confusion Matrix" + (" (Normalized)" if normalize else "")) | |
| ax.set_xlabel("Predicted"); ax.set_ylabel("True") | |
| plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) | |
| fig.tight_layout(); return fig | |
| # ---------------- Evaluate helpers ---------------- | |
| def _extract_zip_to_tmp(zip_file) -> Optional[Path]: | |
| if zip_file is None: return None | |
| tmpdir = Path(tempfile.mkdtemp(prefix="imgs_")) | |
| with zipfile.ZipFile(zip_file.name, "r") as zf: | |
| zf.extractall(tmpdir) | |
| return tmpdir | |
| def _as_dir(p: str | Path) -> Path: | |
| """Expand a user-provided folder input to an absolute Path under ROOT.""" | |
| if not p: # empty -> root | |
| return ROOT | |
| pp = Path(p) | |
| if pp.is_absolute(): | |
| return pp | |
| return ROOT / pp | |
| def _resolve_path(name: str, images_dir: Path) -> Optional[Path]: | |
| """Resolve image path using several common patterns.""" | |
| cand = Path(name) | |
| if cand.exists(): return cand | |
| base = Path(name).name | |
| # try relative to provided dir, and common subdirs | |
| for p in [images_dir / name, images_dir / base, images_dir / "Test" / base, images_dir / "Train" / base]: | |
| if p.exists(): return p | |
| # try one level deep | |
| try: | |
| for sub in images_dir.iterdir(): | |
| pp = sub / base | |
| if pp.exists(): return pp | |
| except Exception: | |
| pass | |
| return None | |
| def _read_list(list_path: Path) -> List[Tuple[str, int]]: | |
| pairs: List[Tuple[str, int]] = [] | |
| for ln in list_path.read_text().splitlines(): | |
| ln = ln.strip() | |
| if not ln: continue | |
| parts = ln.split() | |
| if len(parts) < 2: | |
| continue | |
| a, b = parts[0], parts[1] | |
| try: | |
| lab = int(b) | |
| except Exception: | |
| continue | |
| pairs.append((a, lab)) | |
| return pairs | |
| class ListDataset(Dataset): | |
| def __init__(self, records, tfm): | |
| self.records = records | |
| self.tfm = tfm | |
| def __len__(self): return len(self.records) | |
| def __getitem__(self, i): | |
| p, lab = self.records[i] | |
| img = Image.open(p).convert("RGB") | |
| return self.tfm(img), lab, str(p) | |
| # ---------------- Evaluate (fast) ---------------- | |
| def run_eval(list_choice: str, custom_list, classes_file, images_folder: str, images_zip, batch_size: int, | |
| max_items: int, normalize_cm: bool, save_reports: bool, top_err_n: int, | |
| progress=gr.Progress()): | |
| try: | |
| start = time.time() | |
| # Resolve list path + images directory | |
| if list_choice == "test.txt (Test/)": | |
| list_path = ROOT / "test.txt" | |
| images_dir = _as_dir(images_folder) if images_folder else ROOT / "Test" | |
| elif list_choice == "train.txt (Train/)": | |
| list_path = ROOT / "train.txt" | |
| images_dir = _as_dir(images_folder) if images_folder else ROOT / "Train" | |
| else: | |
| if not custom_list: | |
| return "", "", None, [], "_Please provide a custom list file._" | |
| list_path = Path(custom_list.name if hasattr(custom_list, "name") else custom_list) | |
| images_dir = _as_dir(images_folder) | |
| # If a ZIP is provided, extract and use that as the root | |
| tmpdir = _extract_zip_to_tmp(images_zip) | |
| if tmpdir is not None: | |
| images_dir = tmpdir | |
| if not list_path.exists(): | |
| return "", "", None, [], f"_List file not found at {list_path}_" | |
| # Classes (optional override) | |
| if classes_file and hasattr(classes_file, "name"): | |
| classes_path = Path(classes_file.name) | |
| elif isinstance(classes_file, str) and classes_file: | |
| classes_path = Path(classes_file) | |
| else: | |
| classes_path = CLASSES_TXT | |
| if not classes_path.exists(): | |
| return "", "", None, [], f"_Classes file not found: {classes_path}_" | |
| classes = _load_classes(classes_path) | |
| # Pairs -> records with existing files | |
| try: | |
| pairs = _read_list(list_path) | |
| except Exception as e: | |
| return "", "", None, [], f"_Could not read list file: {e}_" | |
| if max_items and max_items > 0: | |
| pairs = pairs[:max_items] | |
| records = [] | |
| missing = 0 | |
| for name, lab in pairs: | |
| p = _resolve_path(name, images_dir) | |
| if p is None: | |
| missing += 1 | |
| continue | |
| if lab < 0 or lab >= len(classes): | |
| # skip labels outside the class list | |
| continue | |
| records.append((p, lab)) | |
| if not records: | |
| return "", "", None, [], "_No valid images found for evaluation._" | |
| # DataLoader — safest config across macOS / HF CPU | |
| loader = DataLoader( | |
| ListDataset(records, TFM), | |
| batch_size=max(1, batch_size), | |
| shuffle=False, | |
| num_workers=0, | |
| pin_memory=False, | |
| persistent_workers=False | |
| ) | |
| # Inference | |
| MODEL.eval() | |
| y_true, y_pred, y_conf, paths = [], [], [], [] | |
| total = len(loader) | |
| with torch.inference_mode(), autocast_ctx(): | |
| for i, (xb, yb, pb) in enumerate(loader): | |
| progress((i+1)/max(1,total), desc=f"Evaluating {i+1}/{total}") | |
| xb = xb.to(DEVICE, memory_format=torch.channels_last) | |
| logits = MODEL(xb) | |
| probs = F.softmax(logits, dim=1) | |
| conf, pred = torch.max(probs, dim=1) | |
| y_pred.extend(pred.cpu().tolist()) | |
| y_true.extend([int(v) for v in yb]) | |
| y_conf.extend(conf.detach().cpu().tolist()) | |
| paths.extend(list(pb)) | |
| # Determine present labels for report/CM to avoid mismatch (e.g., 28 vs 200) | |
| present_labels = sorted(set(y_true) | set(y_pred)) | |
| target_names = [classes[i] if 0 <= i < len(classes) else f"class_{i}" for i in present_labels] | |
| # Metrics | |
| top1 = accuracy_score(y_true, y_pred) | |
| rpt_txt = classification_report( | |
| y_true, y_pred, | |
| labels=present_labels, | |
| target_names=target_names, | |
| digits=4, zero_division=0, | |
| ) | |
| rpt = classification_report( | |
| y_true, y_pred, | |
| labels=present_labels, | |
| output_dict=True, zero_division=0, | |
| ) | |
| macro_recall = float(rpt["macro avg"]["recall"]) | |
| cm = confusion_matrix(y_true, y_pred, labels=present_labels) | |
| # Persist artifacts | |
| if save_reports: | |
| REPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| (REPORT_DIR / "classification_report.txt").write_text(rpt_txt) | |
| np.savetxt(REPORT_DIR / "confusion_matrix.csv", cm, fmt="%d", delimiter=",") | |
| with open(REPORT_DIR / "metrics.json", "w") as f: | |
| json.dump({"top1": float(top1), "avg_per_class": float(macro_recall)}, f) | |
| # Misclassifications: most confident wrong predictions | |
| errors = [] | |
| for p, t, pr, cf in zip(paths, y_true, y_pred, y_conf): | |
| if pr != t: | |
| errors.append((p, t, pr, cf)) | |
| errors.sort(key=lambda x: x[3], reverse=True) | |
| keep = errors[:max(1, min(top_err_n, 24))] | |
| gallery: List[tuple] = [] | |
| for p, t, pr, cf in keep: | |
| try: | |
| im = Image.open(p).convert("RGB") | |
| im = ImageOps.fit(im, (256, 256)) | |
| t_name = classes[t] if 0 <= t < len(classes) else f"class_{t}" | |
| p_name = classes[pr] if 0 <= pr < len(classes) else f"class_{pr}" | |
| caption = f"{p_name} → {t_name} (p={cf:.2f})" | |
| gallery.append((im, caption)) | |
| except Exception: | |
| pass | |
| secs = time.time() - start | |
| header = f"_Evaluated {len(records)}/{len(pairs)} items. Skipped {missing} missing files. Time: {secs:.1f}s on {DEVICE.upper()}._" | |
| # Outputs | |
| top1_md = f"**Top-1 Accuracy:** {top1:.4f}" | |
| macro_md = f"**Average Accuracy per Class (macro recall):** {macro_recall:.4f}" | |
| cm_fig = plot_confusion_matrix(cm, normalize=normalize_cm) | |
| return top1_md, macro_md, cm_fig, gallery, header | |
| except Exception as e: | |
| # Surface any unexpected errors in the UI instead of generic "Error" cards | |
| msg = f"_Evaluation crashed: {type(e).__name__}: {e}_" | |
| return "", "", None, [], msg | |
| # ---------------- UI ---------------- | |
| CSS = """ | |
| .gradio-container { max-width: 980px !important; } | |
| footer { visibility: hidden } | |
| """ | |
| with gr.Blocks(title="Bird Species Classifier — ResNet50", css=CSS, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("### Bird Species Classifier — ResNet50\nA formal interface for inference, reporting, and **fast batch evaluation**.") | |
| with gr.Tabs(): | |
| # --------- Predict ---------- | |
| with gr.Tab("Predict"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| in_img = gr.Image(type="pil", label="Image", height=340) | |
| topk = gr.Slider(1, 10, value=5, step=1, label="Top-K") | |
| btn = gr.Button("Predict", variant="primary") | |
| with gr.Column(scale=1): | |
| out_img = gr.Image(type="pil", label="Preview", height=340) | |
| out_badge = gr.Markdown("") | |
| out_plot = gr.Plot(label="Top-K Probabilities") | |
| btn.click(predict, inputs=[in_img, topk], outputs=[out_img, out_badge, out_plot], show_progress="full") | |
| # --------- Report (read saved or uploaded) ---------- | |
| with gr.Tab("Report"): | |
| gr.Markdown("Load metrics and training curves from **reports_final320/** or upload below.") | |
| with gr.Row(): | |
| rpt_upload = gr.File(label="Upload classification_report.txt (optional)", file_types=[".txt"]) | |
| hist_upload = gr.File(label="Upload history .csv or .json (optional)", file_types=[".csv", ".json"]) | |
| m_btn = gr.Button("Load Metrics", variant="primary") | |
| m_top1 = gr.Markdown("") | |
| m_macro = gr.Markdown("") | |
| m_plot = gr.Plot(label="Training & Validation Curves") | |
| m_note = gr.Markdown("") | |
| m_btn.click(load_metrics_and_curves, inputs=[rpt_upload, hist_upload], outputs=[m_top1, m_macro, m_plot, m_note]) | |
| # --------- Evaluate (Space-friendly) ---------- | |
| with gr.Tab("Evaluate"): | |
| gr.Markdown("Run evaluation from a list file (`filename label`). Upload a **.zip of images** or point to a folder that exists in the Space.") | |
| list_choice = gr.Radio( | |
| ["test.txt (Test/)", "train.txt (Train/)", "Custom"], | |
| value="test.txt (Test/)", | |
| label="List Source" | |
| ) | |
| custom_list = gr.File(file_count="single", label="Custom list file (.txt)", file_types=[".txt"]) | |
| classes_file = gr.File(file_count="single", label="Custom classes.txt (optional)", file_types=[".txt"]) | |
| images_zip = gr.File(file_count="single", label="Optional: images .zip (we will extract server-side)", file_types=[".zip"]) | |
| images_folder = gr.Textbox(value="Test", label="Images folder (leave empty if you upload a .zip)") | |
| def _sync_images_folder(choice: str) -> str: | |
| return "Test" if choice.startswith("test.txt") else ("Train" if choice.startswith("train.txt") else "") | |
| list_choice.change(_sync_images_folder, inputs=[list_choice], outputs=[images_folder]) | |
| with gr.Row(): | |
| batch_size = gr.Slider(1, 128, value=32, step=1, label="Batch size") | |
| max_items = gr.Slider(0, 5000, value=0, step=50, label="Max items (0 = all)") | |
| with gr.Row(): | |
| normalize_cm = gr.Checkbox(value=True, label="Normalize Confusion Matrix") | |
| save_reports = gr.Checkbox(value=True, label="Save reports to reports_final320/") | |
| top_err_n = gr.Slider(4, 24, value=12, step=1, label="Show Top-N Misclassifications") | |
| eval_btn = gr.Button("Run Evaluation", variant="primary") | |
| e_top1 = gr.Markdown("") | |
| e_macro = gr.Markdown("") | |
| e_cm = gr.Plot(label="Confusion Matrix") | |
| e_gallery = gr.Gallery(label="Misclassifications (most confident wrong predictions)", columns=4, height=360) | |
| e_note = gr.Markdown("") | |
| eval_btn.click( | |
| run_eval, | |
| inputs=[list_choice, custom_list, classes_file, images_folder, images_zip, batch_size, max_items, normalize_cm, save_reports, top_err_n], | |
| outputs=[e_top1, e_macro, e_cm, e_gallery, e_note], | |
| show_progress="full", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |