Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision | |
| import numpy as np | |
| from torch import nn | |
| from torchvision import transforms | |
| from torchvision.datasets import PCAM | |
| import gradio as gr | |
| from PIL import Image | |
| import albumentations as A | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # --------------------------------- | |
| # 1. Load model | |
| # --------------------------------- | |
| torch.manual_seed(42) | |
| torch.cuda.manual_seed_all(42) | |
| model = torch.load("results/pcam/20_06_2025_10_49_24/model_4.pt", map_location="cpu", weights_only=False) | |
| model.eval() | |
| # --------------------------------- | |
| # 2. Define transform and dataset | |
| # --------------------------------- | |
| a_transform = A.Compose([ | |
| A.Resize(224, 224), | |
| A.Normalize(normalization="image_per_channel", p=1.0), | |
| A.ToTensorV2() | |
| ]) | |
| class AlbumentationsToPytorchTransform: | |
| def __init__(self, albumentations_transform): | |
| self.albumentations_transform = albumentations_transform | |
| def __call__(self, img): | |
| img = np.array(img) # Convert PIL to NumPy | |
| transformed = self.albumentations_transform(image=img) | |
| return transformed["image"].to(torch.float32) # Return tensor | |
| transform = AlbumentationsToPytorchTransform(a_transform) | |
| def load_datasets(dataset_choice): | |
| try: | |
| datasets = [PCAM(root="data/", split=dataset_choice, download=True, transform=transform), | |
| PCAM(root="data/", split=dataset_choice, download=True)] | |
| return datasets | |
| except: | |
| # ensure local folder exists | |
| os.makedirs("data/pcam", exist_ok=True) | |
| # set your token (in your Space, add it as a secret named HF_TOKEN) | |
| token = os.getenv("HF_TOKEN") | |
| # list of files in your dataset repository | |
| files = [ | |
| "camelyonpatch_level_2_split_train_x.h5", | |
| "camelyonpatch_level_2_split_train_y.h5", | |
| "camelyonpatch_level_2_split_valid_x.h5", | |
| "camelyonpatch_level_2_split_valid_y.h5", | |
| "camelyonpatch_level_2_split_test_x.h5", | |
| "camelyonpatch_level_2_split_test_y.h5", | |
| ] | |
| for fname in files: | |
| local_path = hf_hub_download( | |
| repo_id="eloise54/pcam-private", | |
| filename=fname, | |
| repo_type="dataset", | |
| local_dir="data/pcam", | |
| token=token | |
| ) | |
| print(f"Downloaded: {local_path}") | |
| datasets = [PCAM(root="data/", split=dataset_choice, download=True, transform=transform), | |
| PCAM(root="data/", split=dataset_choice, download=True)] | |
| return datasets | |
| # --------------------------------- | |
| # 3. Prepare choices for dropdown | |
| # --------------------------------- | |
| dataset_dict = {'train': load_datasets("train"), | |
| 'val': load_datasets("val"), | |
| 'test': load_datasets("test")} | |
| # --------------------------------- | |
| # 4. Prediction function | |
| # --------------------------------- | |
| def get_sample(index: int, dataset_choice: str): | |
| [t_dataset, dataset] = dataset_dict[dataset_choice] | |
| index = max(0, min(index, len(dataset) - 1)) # clamp index | |
| image_tensor, ground_truth = t_dataset[index] | |
| image_pil, _ = dataset[index] # Untransformed image for display | |
| with torch.no_grad(): | |
| output = model(image_tensor.unsqueeze(0)).squeeze() | |
| probability = torch.sigmoid(output) | |
| predicted_label = "Tumor" if probability >= 0.4458489 else "No Tumor" | |
| true_label = "Tumor" if ground_truth == 1 else "No Tumor" | |
| error_label = "" | |
| if predicted_label != true_label: | |
| error_label = "Error !" | |
| return image_pil, predicted_label, probability.numpy(), true_label, index, error_label, index, dataset_choice | |
| # --------------------------------- | |
| # 4. Navigation functions | |
| # --------------------------------- | |
| def next_sample(index: int, dataset_choice: str): | |
| return get_sample(index + 1, dataset_choice) | |
| def prev_sample(index: int, dataset_choice: str): | |
| return get_sample(index - 1, dataset_choice) | |
| # --------------------------------- | |
| # 5. UI elements | |
| # --------------------------------- | |
| dataset_information = """ | |
| ## π Dataset Overview | |
| https://github.com/basveeling/pcam | |
| The **PatchCamelyon (PCam)** benchmark is a challenging image classification dataset designed for breast cancer detection tasks. | |
| - π¦ **Total images**: 327,680 color patches | |
| - πΌοΈ **Image size**: 96 Γ 96 pixels | |
| - π§ͺ **Source**: Histopathologic scans of lymph node sections | |
| - π·οΈ **Labels**: Binary β A positive (1) label indicates that the center 32x32px region of a patch contains at least one pixel of tumor tissue. Tumor tissue in the outer region of the patch does not influence the label. | |
| ``` | |
| B. S. Veeling, J. Linmans, J. Winkens, T. Cohen, M. Welling. "Rotation Equivariant CNNs for Digital Pathology". arXiv:1806.03962 | |
| ``` | |
| ``` | |
| Ehteshami Bejnordi et al. Diagnostic Assessment of Deep Learning Algorithms for Detection of Lymph Node Metastases in Women With Breast Cancer. JAMA: The Journal of the American Medical Association, 318(22), 2199β2210. doi:jama.2017.14585 | |
| ``` | |
| Under CC0 License | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 𧬠PCAM Tumor Classifier") | |
| gr.Markdown("Use **Next** or **Previous** to browse samples and see model predictions vs ground truth.") | |
| state = gr.State(0) # holds current index | |
| with gr.Row(): | |
| dropdown = gr.Dropdown( ["train", "val", "test"], label="Dataset to use from torchvision.datasets.PCAM", value="train") | |
| dataset_choice = gr.Text(label="Using Dataset") | |
| with gr.Row(): | |
| prev_btn = gr.Button("β¬ οΈ Prev") | |
| next_btn = gr.Button("Next β‘οΈ") | |
| with gr.Row(): | |
| image_output = gr.Image(label="Image") | |
| index = gr.Text(label="Image Number") | |
| with gr.Row(): | |
| pred_label = gr.Text(label="Predicted") | |
| true_label = gr.Text(label="Ground Truth") | |
| with gr.Row(): | |
| error_label = gr.Text(label="Prediction error") | |
| confidence = gr.Text(label="Probability") | |
| with gr.Row(): | |
| gr.Markdown(dataset_information) | |
| # Connect navigation | |
| prev_btn.click(fn=prev_sample, inputs=[state, dropdown], outputs=[image_output, pred_label, confidence, true_label, state, error_label, index, dataset_choice]) | |
| next_btn.click(fn=next_sample, inputs=[state, dropdown], outputs=[image_output, pred_label, confidence, true_label, state, error_label, index, dataset_choice]) | |
| # Load initial image | |
| demo.load(fn=get_sample, inputs=[state, dropdown], outputs=[image_output, pred_label, confidence, true_label, state, error_label, index, dataset_choice]) | |
| # --------------------------------- | |
| # 6. Run | |
| # --------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() |