import torch import torchvision import numpy as np from torch import nn from torchvision import transforms from torchvision.datasets import PCAM import gradio as gr from PIL import Image import albumentations as A from huggingface_hub import hf_hub_download import os # --------------------------------- # 1. Load model # --------------------------------- torch.manual_seed(42) torch.cuda.manual_seed_all(42) model = torch.load("results/pcam/20_06_2025_10_49_24/model_4.pt", map_location="cpu", weights_only=False) model.eval() # --------------------------------- # 2. Define transform and dataset # --------------------------------- a_transform = A.Compose([ A.Resize(224, 224), A.Normalize(normalization="image_per_channel", p=1.0), A.ToTensorV2() ]) class AlbumentationsToPytorchTransform: def __init__(self, albumentations_transform): self.albumentations_transform = albumentations_transform def __call__(self, img): img = np.array(img) # Convert PIL to NumPy transformed = self.albumentations_transform(image=img) return transformed["image"].to(torch.float32) # Return tensor transform = AlbumentationsToPytorchTransform(a_transform) def load_datasets(dataset_choice): try: datasets = [PCAM(root="data/", split=dataset_choice, download=True, transform=transform), PCAM(root="data/", split=dataset_choice, download=True)] return datasets except: # ensure local folder exists os.makedirs("data/pcam", exist_ok=True) # set your token (in your Space, add it as a secret named HF_TOKEN) token = os.getenv("HF_TOKEN") # list of files in your dataset repository files = [ "camelyonpatch_level_2_split_train_x.h5", "camelyonpatch_level_2_split_train_y.h5", "camelyonpatch_level_2_split_valid_x.h5", "camelyonpatch_level_2_split_valid_y.h5", "camelyonpatch_level_2_split_test_x.h5", "camelyonpatch_level_2_split_test_y.h5", ] for fname in files: local_path = hf_hub_download( repo_id="eloise54/pcam-private", filename=fname, repo_type="dataset", local_dir="data/pcam", token=token ) print(f"Downloaded: {local_path}") datasets = [PCAM(root="data/", split=dataset_choice, download=True, transform=transform), PCAM(root="data/", split=dataset_choice, download=True)] return datasets # --------------------------------- # 3. Prepare choices for dropdown # --------------------------------- dataset_dict = {'train': load_datasets("train"), 'val': load_datasets("val"), 'test': load_datasets("test")} # --------------------------------- # 4. Prediction function # --------------------------------- def get_sample(index: int, dataset_choice: str): [t_dataset, dataset] = dataset_dict[dataset_choice] index = max(0, min(index, len(dataset) - 1)) # clamp index image_tensor, ground_truth = t_dataset[index] image_pil, _ = dataset[index] # Untransformed image for display with torch.no_grad(): output = model(image_tensor.unsqueeze(0)).squeeze() probability = torch.sigmoid(output) predicted_label = "Tumor" if probability >= 0.4458489 else "No Tumor" true_label = "Tumor" if ground_truth == 1 else "No Tumor" error_label = "" if predicted_label != true_label: error_label = "Error !" return image_pil, predicted_label, probability.numpy(), true_label, index, error_label, index, dataset_choice # --------------------------------- # 4. Navigation functions # --------------------------------- def next_sample(index: int, dataset_choice: str): return get_sample(index + 1, dataset_choice) def prev_sample(index: int, dataset_choice: str): return get_sample(index - 1, dataset_choice) # --------------------------------- # 5. UI elements # --------------------------------- dataset_information = """ ## 📊 Dataset Overview https://github.com/basveeling/pcam The **PatchCamelyon (PCam)** benchmark is a challenging image classification dataset designed for breast cancer detection tasks. - 📦 **Total images**: 327,680 color patches - 🖼️ **Image size**: 96 × 96 pixels - 🧪 **Source**: Histopathologic scans of lymph node sections - 🏷️ **Labels**: Binary — A positive (1) label indicates that the center 32x32px region of a patch contains at least one pixel of tumor tissue. Tumor tissue in the outer region of the patch does not influence the label. ``` B. S. Veeling, J. Linmans, J. Winkens, T. Cohen, M. Welling. "Rotation Equivariant CNNs for Digital Pathology". arXiv:1806.03962 ``` ``` Ehteshami Bejnordi et al. Diagnostic Assessment of Deep Learning Algorithms for Detection of Lymph Node Metastases in Women With Breast Cancer. JAMA: The Journal of the American Medical Association, 318(22), 2199–2210. doi:jama.2017.14585 ``` Under CC0 License """ with gr.Blocks() as demo: gr.Markdown("## 🧬 PCAM Tumor Classifier") gr.Markdown("Use **Next** or **Previous** to browse samples and see model predictions vs ground truth.") state = gr.State(0) # holds current index with gr.Row(): dropdown = gr.Dropdown( ["train", "val", "test"], label="Dataset to use from torchvision.datasets.PCAM", value="train") dataset_choice = gr.Text(label="Using Dataset") with gr.Row(): prev_btn = gr.Button("⬅️ Prev") next_btn = gr.Button("Next ➡️") with gr.Row(): image_output = gr.Image(label="Image") index = gr.Text(label="Image Number") with gr.Row(): pred_label = gr.Text(label="Predicted") true_label = gr.Text(label="Ground Truth") with gr.Row(): error_label = gr.Text(label="Prediction error") confidence = gr.Text(label="Probability") with gr.Row(): gr.Markdown(dataset_information) # Connect navigation prev_btn.click(fn=prev_sample, inputs=[state, dropdown], outputs=[image_output, pred_label, confidence, true_label, state, error_label, index, dataset_choice]) next_btn.click(fn=next_sample, inputs=[state, dropdown], outputs=[image_output, pred_label, confidence, true_label, state, error_label, index, dataset_choice]) # Load initial image demo.load(fn=get_sample, inputs=[state, dropdown], outputs=[image_output, pred_label, confidence, true_label, state, error_label, index, dataset_choice]) # --------------------------------- # 6. Run # --------------------------------- if __name__ == "__main__": demo.launch()