pcam_project / app.py
eloise54's picture
update documentation
46f0bee
import torch
import torchvision
import numpy as np
from torch import nn
from torchvision import transforms
from torchvision.datasets import PCAM
import gradio as gr
from PIL import Image
import albumentations as A
from huggingface_hub import hf_hub_download
import os
# ---------------------------------
# 1. Load model
# ---------------------------------
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
model = torch.load("results/pcam/20_06_2025_10_49_24/model_4.pt", map_location="cpu", weights_only=False)
model.eval()
# ---------------------------------
# 2. Define transform and dataset
# ---------------------------------
a_transform = A.Compose([
A.Resize(224, 224),
A.Normalize(normalization="image_per_channel", p=1.0),
A.ToTensorV2()
])
class AlbumentationsToPytorchTransform:
def __init__(self, albumentations_transform):
self.albumentations_transform = albumentations_transform
def __call__(self, img):
img = np.array(img) # Convert PIL to NumPy
transformed = self.albumentations_transform(image=img)
return transformed["image"].to(torch.float32) # Return tensor
transform = AlbumentationsToPytorchTransform(a_transform)
def load_datasets(dataset_choice):
try:
datasets = [PCAM(root="data/", split=dataset_choice, download=True, transform=transform),
PCAM(root="data/", split=dataset_choice, download=True)]
return datasets
except:
# ensure local folder exists
os.makedirs("data/pcam", exist_ok=True)
# set your token (in your Space, add it as a secret named HF_TOKEN)
token = os.getenv("HF_TOKEN")
# list of files in your dataset repository
files = [
"camelyonpatch_level_2_split_train_x.h5",
"camelyonpatch_level_2_split_train_y.h5",
"camelyonpatch_level_2_split_valid_x.h5",
"camelyonpatch_level_2_split_valid_y.h5",
"camelyonpatch_level_2_split_test_x.h5",
"camelyonpatch_level_2_split_test_y.h5",
]
for fname in files:
local_path = hf_hub_download(
repo_id="eloise54/pcam-private",
filename=fname,
repo_type="dataset",
local_dir="data/pcam",
token=token
)
print(f"Downloaded: {local_path}")
datasets = [PCAM(root="data/", split=dataset_choice, download=True, transform=transform),
PCAM(root="data/", split=dataset_choice, download=True)]
return datasets
# ---------------------------------
# 3. Prepare choices for dropdown
# ---------------------------------
dataset_dict = {'train': load_datasets("train"),
'val': load_datasets("val"),
'test': load_datasets("test")}
# ---------------------------------
# 4. Prediction function
# ---------------------------------
def get_sample(index: int, dataset_choice: str):
[t_dataset, dataset] = dataset_dict[dataset_choice]
index = max(0, min(index, len(dataset) - 1)) # clamp index
image_tensor, ground_truth = t_dataset[index]
image_pil, _ = dataset[index] # Untransformed image for display
with torch.no_grad():
output = model(image_tensor.unsqueeze(0)).squeeze()
probability = torch.sigmoid(output)
predicted_label = "Tumor" if probability >= 0.4458489 else "No Tumor"
true_label = "Tumor" if ground_truth == 1 else "No Tumor"
error_label = ""
if predicted_label != true_label:
error_label = "Error !"
return image_pil, predicted_label, probability.numpy(), true_label, index, error_label, index, dataset_choice
# ---------------------------------
# 4. Navigation functions
# ---------------------------------
def next_sample(index: int, dataset_choice: str):
return get_sample(index + 1, dataset_choice)
def prev_sample(index: int, dataset_choice: str):
return get_sample(index - 1, dataset_choice)
# ---------------------------------
# 5. UI elements
# ---------------------------------
dataset_information = """
## πŸ“Š Dataset Overview
https://github.com/basveeling/pcam
The **PatchCamelyon (PCam)** benchmark is a challenging image classification dataset designed for breast cancer detection tasks.
- πŸ“¦ **Total images**: 327,680 color patches
- πŸ–ΌοΈ **Image size**: 96 Γ— 96 pixels
- πŸ§ͺ **Source**: Histopathologic scans of lymph node sections
- 🏷️ **Labels**: Binary β€” A positive (1) label indicates that the center 32x32px region of a patch contains at least one pixel of tumor tissue. Tumor tissue in the outer region of the patch does not influence the label.
```
B. S. Veeling, J. Linmans, J. Winkens, T. Cohen, M. Welling. "Rotation Equivariant CNNs for Digital Pathology". arXiv:1806.03962
```
```
Ehteshami Bejnordi et al. Diagnostic Assessment of Deep Learning Algorithms for Detection of Lymph Node Metastases in Women With Breast Cancer. JAMA: The Journal of the American Medical Association, 318(22), 2199–2210. doi:jama.2017.14585
```
Under CC0 License
"""
with gr.Blocks() as demo:
gr.Markdown("## 🧬 PCAM Tumor Classifier")
gr.Markdown("Use **Next** or **Previous** to browse samples and see model predictions vs ground truth.")
state = gr.State(0) # holds current index
with gr.Row():
dropdown = gr.Dropdown( ["train", "val", "test"], label="Dataset to use from torchvision.datasets.PCAM", value="train")
dataset_choice = gr.Text(label="Using Dataset")
with gr.Row():
prev_btn = gr.Button("⬅️ Prev")
next_btn = gr.Button("Next ➑️")
with gr.Row():
image_output = gr.Image(label="Image")
index = gr.Text(label="Image Number")
with gr.Row():
pred_label = gr.Text(label="Predicted")
true_label = gr.Text(label="Ground Truth")
with gr.Row():
error_label = gr.Text(label="Prediction error")
confidence = gr.Text(label="Probability")
with gr.Row():
gr.Markdown(dataset_information)
# Connect navigation
prev_btn.click(fn=prev_sample, inputs=[state, dropdown], outputs=[image_output, pred_label, confidence, true_label, state, error_label, index, dataset_choice])
next_btn.click(fn=next_sample, inputs=[state, dropdown], outputs=[image_output, pred_label, confidence, true_label, state, error_label, index, dataset_choice])
# Load initial image
demo.load(fn=get_sample, inputs=[state, dropdown], outputs=[image_output, pred_label, confidence, true_label, state, error_label, index, dataset_choice])
# ---------------------------------
# 6. Run
# ---------------------------------
if __name__ == "__main__":
demo.launch()