diff --git "a/temp.ipynb" "b/temp.ipynb" deleted file mode 100644--- "a/temp.ipynb" +++ /dev/null @@ -1,939 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "cec738fb", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "If you have questions or suggestions, feel free to open an issue at https://github.com/DIAGNijmegen/picai_prep\n", - "\n" - ] - } - ], - "source": [ - "import argparse\n", - "import os\n", - "import shutil\n", - "import time\n", - "import yaml\n", - "import sys\n", - "import gdown\n", - "import numpy as np\n", - "import torch\n", - "import torch.distributed as dist\n", - "import torch.multiprocessing as mp\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "from monai.config import KeysCollection\n", - "from monai.metrics import Cumulative, CumulativeAverage\n", - "from monai.networks.nets import milmodel, resnet, MILModel\n", - "\n", - "from sklearn.metrics import cohen_kappa_score\n", - "from torch.cuda.amp import GradScaler, autocast\n", - "from torch.utils.data.dataloader import default_collate\n", - "from torchvision.models.resnet import ResNet50_Weights\n", - "import shutil\n", - "from pathlib import Path\n", - "from torch.utils.data.distributed import DistributedSampler\n", - "from torch.utils.tensorboard import SummaryWriter\n", - "from monai.utils import set_determinism\n", - "import matplotlib.pyplot as plt\n", - "import wandb\n", - "import math\n", - "import logging\n", - "from pathlib import Path\n", - "\n", - "\n", - "from src.model.MIL import MILModel_3D\n", - "from src.model.csPCa_model import csPCa_Model\n", - "from src.data.data_loader import get_dataloader\n", - "from src.utils import save_cspca_checkpoint, get_metrics, setup_logging, save_pirads_checkpoint\n", - "from src.train import train_cspca, train_pirads\n", - "import SimpleITK as sitk \n", - "\n", - "import nrrd\n", - "\n", - "from tqdm import tqdm\n", - "import pandas as pd\n", - "from picai_prep.preprocessing import PreprocessingSettings, Sample\n", - "import multiprocessing\n", - "import sys\n", - "from src.preprocessing.register_and_crop import register_files\n", - "from src.preprocessing.prostate_mask import get_segmask\n", - "from src.preprocessing.histogram_match import histmatch\n", - "from src.preprocessing.generate_heatmap import get_heatmap\n", - "import logging\n", - "from pathlib import Path\n", - "from src.utils import setup_logging\n", - "from src.utils import validate_steps\n", - "import argparse\n", - "import yaml \n", - "from src.data.data_loader import data_transform, list_data_collate\n", - "from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "bc433898", - "metadata": {}, - "outputs": [], - "source": [ - "import subprocess\n", - "import sys\n", - "from pathlib import Path\n", - "import torch\n", - "import pytest\n", - "import argparse\n", - "from src.train.train_pirads import get_attention_scores\n" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "f1c90aff", - "metadata": {}, - "outputs": [], - "source": [ - "batch_size = 2\n", - "num_patches = 4\n", - "\n", - "# Sample 0: Target = 3 (Cancer), Sample 1: Target = 0 (PI-RADS 2)\n", - "data = torch.randn(batch_size, num_patches, 1, 8, 8)\n", - "target = torch.tensor([3.0, 0.0])\n", - "\n", - "# Create heatmaps: Sample 0 has one \"hot\" patch\n", - "heatmap = torch.zeros(batch_size, num_patches, 1, 8, 8)\n", - "heatmap[0, 0] = 10.0 # High attention on patch 0 for the first sample\n", - "heatmap[0, 3] = 2.0 \n", - "heatmap[1, 2] = 5.0 # Should be overridden by PI-RADS 2 logic anyway\n" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "80cb444f", - "metadata": {}, - "outputs": [], - "source": [ - "def mock_args():\n", - " # Mocking argparse for the device\n", - " args = argparse.Namespace()\n", - " args.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - " return args" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "6528fd4d", - "metadata": {}, - "outputs": [ - { - "ename": "AssertionError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[41], line 23\u001b[0m\n\u001b[1;32m 21\u001b[0m idx \u001b[38;5;241m=\u001b[39m (shuffled_images[\u001b[38;5;241m0\u001b[39m, :, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m0\u001b[39m] \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m5.0\u001b[39m)\u001b[38;5;241m.\u001b[39mnonzero(as_tuple\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# The attention score at that same index should be the maximum\u001b[39;00m\n\u001b[0;32m---> 23\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m att_labels[\u001b[38;5;241m0\u001b[39m, idx] \u001b[38;5;241m==\u001b[39m att_labels[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mmedian()\n", - "\u001b[0;31mAssertionError\u001b[0m: " - ] - } - ], - "source": [ - "num_patches = 10\n", - "\n", - "# Distinct data per patch: [0, 1, 2, 3...]\n", - "data = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()\n", - "target = torch.tensor([3.0])\n", - "\n", - "# Heatmap matches the data indices so we can track the \"label\"\n", - "heatmap = torch.arange(num_patches).view(1, num_patches, 1, 1, 1).float()\n", - "\n", - "att_labels, shuffled_images = get_attention_scores(data, target, heatmap, mock_args)\n", - "\n", - "\n", - "idx= (shuffled_images[0, :, 0, 0, 0] == 9.0).nonzero(as_tuple=True)[0]\n", - "# The attention score at that same index should be the maximum\n", - "assert att_labels[0, idx] == att_labels[0].max()\n", - "\n", - "idx = (shuffled_images[0, :, 0, 0, 0] == 0.0).nonzero(as_tuple=True)[0]\n", - "# The attention score at that same index should be the maximum\n", - "assert att_labels[0, idx] == att_labels[0].min()\n", - "\n", - "idx = (shuffled_images[0, :, 0, 0, 0] == 5.0).nonzero(as_tuple=True)[0]\n", - "# The attention score at that same index should be the maximum\n", - "assert att_labels[0, idx] == att_labels[0].median()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "90f5acab", - "metadata": {}, - "outputs": [], - "source": [ - "import subprocess\n", - "import sys\n", - "from pathlib import Path\n", - "import torch\n", - "import pytest\n", - "import argparse\n", - "from src.train.train_pirads import get_attention_scores\n", - "import monai\n", - "from monai.transforms import Transform\n", - "from src.data.custom_transforms import NormalizeIntensity_custom" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "e3a2dc6c", - "metadata": {}, - "outputs": [], - "source": [ - "img = torch.zeros((2, 4, 4), dtype=torch.float32)\n", - "mask = torch.zeros((1, 4, 4), dtype=torch.float32)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "98a500df", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[[0., 0., 0., 0.],\n", - " [0., 0., 0., 0.],\n", - " [0., 0., 0., 0.],\n", - " [0., 0., 0., 0.]],\n", - "\n", - " [[0., 0., 0., 0.],\n", - " [0., 0., 0., 0.],\n", - " [0., 0., 0., 0.],\n", - " [0., 0., 0., 0.]]])" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "img" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "c9974f43", - "metadata": {}, - "outputs": [], - "source": [ - "img[0, :, :] = 100.0 # Background\n", - "img[0, 0, 0] = 10.0 # Masked pixel 1\n", - "img[0, 0, 1] = 20.0 # Masked pixel 2\n", - "\n", - "# --- Channel 1 Setup ---\n", - "# Inside mask: Values [2, 4]\n", - "# Outside mask: Value 50\n", - "img[1, :, :] = 50.0 # Background\n", - "img[1, 0, 0] = 2.0 # Masked pixel 1\n", - "img[1, 0, 1] = 4.0 # Masked pixel 2\n", - "\n", - "# --- Mask Setup ---\n", - "# Selects only the top-left two pixels (0,0) and (0,1)\n", - "mask[0, 0, 0] = 1\n", - "mask[0, 0, 1] = 1\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eb910fda", - "metadata": {}, - "outputs": [], - "source": [ - "data = torch.rand(1, 10, 10)\n", - "mask = torch.randint(0, 2, (1, 10, 10)).float()\n", - "normalizer = NormalizeIntensity_custom(nonzero=False, channel_wise=True)\n", - "out = normalizer(data, mask)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "923341a3", - "metadata": {}, - "outputs": [], - "source": [ - "masked = data[mask != 0]\n", - "mean_ = torch.mean(masked.float())\n", - "std_ = torch.std(masked.float(), unbiased=False)\n", - "\n", - "epsilon = 1e-8\n", - "normalized_data = (data - mean_) / (std_ + epsilon)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "e844cde1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([ 1.4106, -0.1975, 0.3907, 1.2870, -0.7974, -1.2061, 0.7028, 1.2778,\n", - " 0.4667, -0.3361, -0.7842, -1.6296, -1.2037, 1.3582, -0.5648, -0.3055,\n", - " -0.3313, 0.0328, -1.0675, 0.6328, -0.2215, -1.3372, 0.5165, 1.9302,\n", - " 0.8875, 0.6793, 0.5553, 0.4335, 0.6390, -1.3707, 1.6053, 1.8626,\n", - " -0.3923, 0.2319, 0.3911, -0.4683, -1.1255, -1.6464, -0.2123, -0.5415,\n", - " 0.1401, -0.2822, 1.5019, -0.5117, -1.6047, -0.2322, -1.3080, 0.0130,\n", - " 1.8028, 0.5602, -1.6317])" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "masked" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "a9a20f58", - "metadata": {}, - "outputs": [], - "source": [ - "torch.testing.assert_close(out, normalized_data)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "cad4a637", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import shutil\n", - "import json\n", - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "bea10ddb", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'image': '10270_1000274.nrrd',\n", - " 'mask': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/prostate_seg_mask/10270_1000274.nrrd',\n", - " 'dwi': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/DWI_hist_matched/10270_1000274.nrrd',\n", - " 'adc': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/ADC_hist_matched/10270_1000274.nrrd',\n", - " 'heatmap': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/heatmap/10270_1000274.nrrd',\n", - " 'label': 0},\n", - " {'image': '11063_1001085.nrrd',\n", - " 'mask': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/prostate_seg_mask/11063_1001085.nrrd',\n", - " 'dwi': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/DWI_hist_matched/11063_1001085.nrrd',\n", - " 'adc': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/ADC_hist_matched/11063_1001085.nrrd',\n", - " 'heatmap': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/heatmap/11063_1001085.nrrd',\n", - " 'label': 1},\n", - " {'image': '11184_1001207.nrrd',\n", - " 'mask': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/prostate_seg_mask/11184_1001207.nrrd',\n", - " 'dwi': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/DWI_hist_matched/11184_1001207.nrrd',\n", - " 'adc': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/ADC_hist_matched/11184_1001207.nrrd',\n", - " 'heatmap': '/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/registered/heatmap/11184_1001207.nrrd',\n", - " 'label': 0}]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "with open('/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/dataset/PICAI_cspca.json', 'r') as f:\n", - " data = json.load(f)\n", - "samples = random.sample(data['test'],3)\n", - "samples" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "aa932307", - "metadata": {}, - "outputs": [ - { - "ename": "IndexError", - "evalue": "list index out of range", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m sam \u001b[38;5;241m=\u001b[39m \u001b[43msamples\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m]\u001b[49m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mimage\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 2\u001b[0m \u001b[38;5;124;03m'''\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124;03mshutil.copy('/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/t2_images/'+sam, 'dataset/samples/sample3/t2.nrrd')\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124;03mshutil.copy('/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/DWI_images/'+sam, 'dataset/samples/sample3/dwi.nrrd')\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124;03mshutil.copy('/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/ADC_images/'+sam, 'dataset/samples/sample3/adc.nrrd')\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124;03m'''\u001b[39;00m\n", - "\u001b[0;31mIndexError\u001b[0m: list index out of range" - ] - } - ], - "source": [ - "sam = samples[3]['image']\n", - "'''\n", - "shutil.copy('/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/t2_images/'+sam, 'dataset/samples/sample3/t2.nrrd')\n", - "shutil.copy('/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/DWI_images/'+sam, 'dataset/samples/sample3/dwi.nrrd')\n", - "shutil.copy('/sc-projects/sc-proj-cc06-ag-ki-radiologie/pirad_model_test_PICAI/ADC_images/'+sam, 'dataset/samples/sample3/adc.nrrd')\n", - "'''" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "c91a5802", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 0%| | 0/1 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# image: HxW (grayscale) or HxWx3 (RGB)\n", - "# mask: HxW (binary or label mask)\n", - "\n", - "plt.figure(figsize=(6, 6))\n", - "\n", - "plt.imshow(adc[:,:,15], cmap=\"gray\")\n", - "plt.imshow(seg[:,:,15], cmap=\"Reds\", alpha=0.4) # overlay mask\n", - "plt.axis(\"off\")\n", - "\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "8b5d382e", - "metadata": {}, - "outputs": [], - "source": [ - "args.num_classes = 4\n", - "args.mil_mode = \"att_trans\"\n", - "args.use_heatmap = True\n", - "args.tile_size = 64\n", - "args.tile_count = 24\n", - "args.depth = 3\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "4cf061ec", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n", - "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", - "You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n" - ] - } - ], - "source": [ - "args.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "\n", - "pirads_model = MILModel_3D(\n", - " num_classes=args.num_classes, \n", - " mil_mode=args.mil_mode \n", - ")\n", - "pirads_checkpoint = torch.load(os.path.join(args.project_dir, 'models', 'pirads.pt'), map_location=\"cpu\")\n", - "pirads_model.load_state_dict(pirads_checkpoint[\"state_dict\"])\n", - "pirads_model.to(args.device)\n", - "\n", - "cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)\n", - "checkpt = torch.load(os.path.join(args.project_dir, 'models', 'cspca_model.pth'), map_location=\"cpu\")\n", - "cspca_model.load_state_dict(checkpt['state_dict'])\n", - "cspca_model = cspca_model.to(args.device)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "fac15515", - "metadata": {}, - "outputs": [], - "source": [ - "transform = data_transform(args)\n", - "files = os.listdir(args.t2_dir)\n", - "data_list = []\n", - "for file in files:\n", - " temp = {}\n", - " temp['image'] = os.path.join(args.t2_dir, file)\n", - " temp['dwi'] = os.path.join(args.dwi_dir, file)\n", - " temp['adc'] = os.path.join(args.adc_dir, file)\n", - " temp['heatmap'] = os.path.join(args.heatmapdir, file)\n", - " temp['mask'] = os.path.join(args.seg_dir, file)\n", - " temp['label'] = 0 # dummy label\n", - " data_list.append(temp)\n", - "\n", - "dataset = Dataset(data=data_list, transform=transform)\n", - "loader = torch.utils.data.DataLoader(\n", - " dataset,\n", - " batch_size=1,\n", - " shuffle=False,\n", - " num_workers=0,\n", - " pin_memory=True,\n", - " multiprocessing_context= None,\n", - " sampler=None,\n", - " collate_fn=list_data_collate,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "eb80047b", - "metadata": {}, - "outputs": [], - "source": [ - "pirads_list = []\n", - "pirads_model.eval()\n", - "cspca_risk_list = []\n", - "cspca_model.eval()\n", - "top5_patches = []\n", - "with torch.no_grad():\n", - " for idx, batch_data in enumerate(loader):\n", - " data = batch_data[\"image\"].as_subclass(torch.Tensor).to(args.device)\n", - " logits = pirads_model(data)\n", - " pirads_score= torch.argmax(logits, dim=1)\n", - " pirads_list.append(pirads_score.item())\n", - "\n", - " output = cspca_model(data)\n", - " output = output.squeeze(1)\n", - " cspca_risk_list.append(output.item())\n", - "\n", - " sh = data.shape\n", - " x = data.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])\n", - " x = cspca_model.backbone.net(x)\n", - " x = x.reshape(sh[0], sh[1], -1)\n", - " x = x.permute(1, 0, 2)\n", - " x = cspca_model.backbone.transformer(x)\n", - " x = x.permute(1, 0, 2)\n", - " a = cspca_model.backbone.attention(x)\n", - " a = torch.softmax(a, dim=1)\n", - " a = a.view(-1)\n", - " top5_values, top5_indices = torch.topk(a, 5)\n", - "\n", - " patches_top_5 = []\n", - " for i in range(5):\n", - " patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()\n", - " patches_top_5.append(patch_temp)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "dbcfc97f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "list" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "type(patches_top_5)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "4edb20e7", - "metadata": {}, - "outputs": [], - "source": [ - "import argparse\n", - "import os\n", - "import shutil\n", - "import time\n", - "import yaml\n", - "import sys\n", - "import gdown\n", - "import numpy as np\n", - "import torch\n", - "import torch.distributed as dist\n", - "import torch.multiprocessing as mp\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "from monai.config import KeysCollection\n", - "from monai.metrics import Cumulative, CumulativeAverage\n", - "from monai.networks.nets import milmodel, resnet, MILModel\n", - "from monai.transforms import (\n", - " Compose,\n", - " GridPatchd,\n", - " LoadImaged,\n", - " MapTransform,\n", - " RandFlipd,\n", - " RandGridPatchd,\n", - " RandRotate90d,\n", - " ScaleIntensityRanged,\n", - " SplitDimd,\n", - " ToTensord,\n", - " ConcatItemsd, \n", - " SelectItemsd,\n", - " EnsureChannelFirstd,\n", - " RepeatChanneld,\n", - " DeleteItemsd,\n", - " EnsureTyped,\n", - " ClipIntensityPercentilesd,\n", - " MaskIntensityd,\n", - " HistogramNormalized,\n", - " RandBiasFieldd,\n", - " RandCropByPosNegLabeld,\n", - " NormalizeIntensityd,\n", - " SqueezeDimd,\n", - " CropForegroundd,\n", - " ScaleIntensityd,\n", - " SpatialPadd,\n", - " CenterSpatialCropd,\n", - " ScaleIntensityd,\n", - " Transposed,\n", - " RandWeightedCropd,\n", - ")\n", - "from sklearn.metrics import cohen_kappa_score\n", - "from torch.cuda.amp import GradScaler, autocast\n", - "from torch.utils.data.dataloader import default_collate\n", - "from torchvision.models.resnet import ResNet50_Weights\n", - "from src.data.custom_transforms import ClipMaskIntensityPercentilesd, NormalizeIntensity_customd\n", - "from torch.utils.data.distributed import DistributedSampler\n", - "from torch.utils.tensorboard import SummaryWriter\n", - "\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import wandb\n", - "import math\n", - "from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset\n", - "\n", - "from src.model.MIL import MILModel_3D\n", - "from src.model.csPCa_model import csPCa_Model\n", - "\n", - "import logging\n", - "from pathlib import Path" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "e42cc132", - "metadata": {}, - "outputs": [], - "source": [ - "transform_image = Compose(\n", - " [\n", - " LoadImaged(keys=[\"image\", \"mask\"], reader=ITKReader(), ensure_channel_first=True, dtype=np.float32),\n", - " ClipMaskIntensityPercentilesd(keys=[\"image\"], lower=0, upper=99.5, mask_key=\"mask\"),\n", - " NormalizeIntensity_customd(keys=[\"image\"], mask_key=\"mask\", channel_wise=True),\n", - " EnsureTyped(keys=[\"label\"], dtype=torch.float32),\n", - " ToTensord(keys=[\"image\", \"label\"]),\n", - " ]\n", - ")\n", - "dataset_image = Dataset(data=data_list, transform=transform_image)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "bcdddd9e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(270, 270, 28)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dataset_image[0]['image'][0].numpy().shape" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "56072a2b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[{'image': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/t2_histmatched/1009449_11049598.nrrd',\n", - " 'dwi': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/DWI_histmatched/1009449_11049598.nrrd',\n", - " 'adc': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/ADC_histmatched/1009449_11049598.nrrd',\n", - " 'heatmap': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/heatmaps/1009449_11049598.nrrd',\n", - " 'mask': '/sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed/prostate_mask/1009449_11049598.nrrd',\n", - " 'label': 0}]" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_list" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "db1163d2", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "foundation", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.21" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}