Anirudh Balaraman commited on
Commit
8bcc3f2
·
1 Parent(s): 906fcb9

fix inference

Browse files
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  logs/
2
  models/
 
 
 
1
  logs/
2
  models/
3
+ datatemp/
4
+ __pycache__/
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import nrrd
6
+ import matplotlib.pyplot as plt
7
+ import tempfile
8
+ import os
9
+
10
+ # --- 1. IMPORT CUSTOM SCRIPTS ---
11
+ try:
12
+ from model_definition import MyModelClass
13
+ # Your preprocess function should now likely accept a LIST of arrays or a stacked array
14
+ from inference_utils import preprocess_multimodal
15
+ except ImportError:
16
+ st.warning("Could not import custom modules.")
17
+
18
+ # --- 2. CONFIGURATION ---
19
+ MODEL_PATH = 'saved_model.pth'
20
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+
22
+ # --- 3. LOAD MODEL ---
23
+ @st.cache_resource
24
+ def load_trained_model():
25
+ try:
26
+ model = torch.load(MODEL_PATH, map_location=DEVICE)
27
+ model.to(DEVICE)
28
+ model.eval()
29
+ return model
30
+ except Exception as e:
31
+ st.error(f"Error loading model: {e}")
32
+ return None
33
+
34
+ model = load_trained_model()
35
+
36
+ # --- 4. APP INTERFACE ---
37
+ st.title("Multi-Modal Medical Inference")
38
+ st.write("Upload exactly 3 NRRD files (e.g., T1, T2, FLAIR) to generate a prediction.")
39
+
40
+ # Update: accept_multiple_files=True
41
+ uploaded_files = st.file_uploader("Choose 3 NRRD files...", type=["nrrd"], accept_multiple_files=True)
42
+
43
+ # LOGIC: Only proceed if exactly 3 files are present
44
+ if uploaded_files:
45
+ if len(uploaded_files) != 3:
46
+ st.warning(f"Please upload exactly 3 files. You currently have {len(uploaded_files)}.")
47
+ else:
48
+ st.success("3 Files Uploaded. Processing...")
49
+
50
+ # Sort files by name to ensure consistent order (e.g., file_01, file_02, file_03)
51
+ # This is CRITICAL if your model expects channels in a specific order.
52
+ uploaded_files.sort(key=lambda x: x.name)
53
+
54
+ scan_data_list = []
55
+ temp_paths = []
56
+
57
+ try:
58
+ # --- A. Read all 3 files ---
59
+ # We create columns to show previews side-by-side
60
+ cols = st.columns(3)
61
+
62
+ for idx, file in enumerate(uploaded_files):
63
+ # Save to temp
64
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".nrrd") as tmp:
65
+ tmp.write(file.getvalue())
66
+ tmp_path = tmp.name
67
+ temp_paths.append(tmp_path)
68
+
69
+ # Read NRRD
70
+ data, header = nrrd.read(tmp_path)
71
+ scan_data_list.append(data)
72
+
73
+ # Visualize Middle Slice in the respective column
74
+ with cols[idx]:
75
+ st.caption(file.name)
76
+ mid_slice = data.shape[2] // 2 if data.ndim == 3 else 0
77
+
78
+ fig, ax = plt.subplots()
79
+ # Show slice (assuming 3D data: H, W, D)
80
+ ax.imshow(data[:, :, mid_slice], cmap="gray")
81
+ ax.axis("off")
82
+ st.pyplot(fig)
83
+
84
+ # --- B. Combine/Stack Data ---
85
+ if st.button("Run Prediction"):
86
+ st.write("Merging channels and analyzing...")
87
+
88
+ # STACKING LOGIC:
89
+ # We assume the 3 files represent 3 channels.
90
+ # If each data is (H, W, D), result is (3, H, W, D)
91
+ # We stack along a new dimension (axis 0)
92
+ stacked_volume = np.stack(scan_data_list, axis=0)
93
+
94
+ # --- C. Preprocessing ---
95
+ # Pass this (3, ...) array to your pipeline
96
+ input_tensor = preprocess_multimodal(stacked_volume)
97
+
98
+ # Ensure Batch Dimension (1, 3, D, H, W)
99
+ if isinstance(input_tensor, torch.Tensor):
100
+ if input_tensor.ndim == 4: # (3, D, H, W) -> (1, 3, D, H, W)
101
+ input_tensor = input_tensor.unsqueeze(0)
102
+ input_tensor = input_tensor.to(DEVICE)
103
+
104
+ # --- D. Inference ---
105
+ with torch.no_grad():
106
+ output = model(input_tensor)
107
+ probabilities = F.softmax(output, dim=1)
108
+ confidence, predicted_class_idx = torch.max(probabilities, 1)
109
+
110
+ st.success("Done!")
111
+ st.metric("Prediction Class", predicted_class_idx.item())
112
+ st.metric("Confidence", f"{confidence.item()*100:.2f}%")
113
+
114
+ except Exception as e:
115
+ st.error(f"Error during processing: {e}")
116
+
117
+ finally:
118
+ # Cleanup temp files
119
+ for p in temp_paths:
120
+ if os.path.exists(p):
121
+ os.remove(p)
config/config_preprocess.yaml CHANGED
@@ -1,8 +1,8 @@
1
- t2_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/t2
2
- dwi_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/dwi
3
- adc_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/adc
4
- output_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/data_temp/processed
5
- project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/git_updated/Prostate-Foundation/MIL/new_folder
6
 
7
 
8
 
 
1
+ t2_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/t2
2
+ dwi_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/dwi
3
+ adc_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/adc
4
+ output_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate/datatemp/processed
5
+ project_dir: /sc-scratch/sc-scratch-cc06-ag-ki-radiologie/prostate_foundation/WSAttention-Prostate
6
 
7
 
8
 
run_inference.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import time
5
+ import yaml
6
+ import sys
7
+ import gdown
8
+ import numpy as np
9
+ import torch
10
+ import torch.distributed as dist
11
+ import torch.multiprocessing as mp
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from monai.config import KeysCollection
15
+ from monai.metrics import Cumulative, CumulativeAverage
16
+ from monai.networks.nets import milmodel, resnet, MILModel
17
+
18
+ from sklearn.metrics import cohen_kappa_score
19
+ from torch.cuda.amp import GradScaler, autocast
20
+ from torch.utils.data.dataloader import default_collate
21
+ from torchvision.models.resnet import ResNet50_Weights
22
+ import shutil
23
+ from pathlib import Path
24
+ from torch.utils.data.distributed import DistributedSampler
25
+ from torch.utils.tensorboard import SummaryWriter
26
+ from monai.utils import set_determinism
27
+ import matplotlib.pyplot as plt
28
+ import wandb
29
+ import math
30
+ import logging
31
+ from pathlib import Path
32
+
33
+
34
+ from src.model.MIL import MILModel_3D
35
+ from src.model.csPCa_model import csPCa_Model
36
+ from src.data.data_loader import get_dataloader
37
+ from src.utils import save_cspca_checkpoint, get_metrics, setup_logging, save_pirads_checkpoint
38
+ from src.train import train_cspca, train_pirads
39
+ import SimpleITK as sitk
40
+
41
+ import nrrd
42
+
43
+ from tqdm import tqdm
44
+ import pandas as pd
45
+ from picai_prep.preprocessing import PreprocessingSettings, Sample
46
+ import multiprocessing
47
+ import sys
48
+ from src.preprocessing.register_and_crop import register_files
49
+ from src.preprocessing.prostate_mask import get_segmask
50
+ from src.preprocessing.histogram_match import histmatch
51
+ from src.preprocessing.generate_heatmap import get_heatmap
52
+ import logging
53
+ from pathlib import Path
54
+ from src.utils import setup_logging
55
+ from src.utils import validate_steps
56
+ import argparse
57
+ import yaml
58
+ from src.data.data_loader import data_transform, list_data_collate
59
+ from monai.data import Dataset, load_decathlon_datalist, ITKReader, NumpyReader, PersistentDataset
60
+
61
+ def parse_args():
62
+
63
+ parser = argparse.ArgumentParser(description="File preprocessing")
64
+ parser.add_argument("--config", type=str, help="Path to YAML config file")
65
+ parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
66
+ parser.add_argument("--dwi_dir", default=None, help="Path to DWI files")
67
+ parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
68
+ parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
69
+ parser.add_argument("--output_dir", default=None, help="Path to output folder")
70
+ parser.add_argument("--margin", default=0.2, type=float, help="Margin to center crop the images")
71
+ parser.add_argument("--num_classes", default=4, type=int)
72
+ parser.add_argument("--mil_mode", default="att_trans", type=str)
73
+ parser.add_argument("--use_heatmap", default=True, type=bool)
74
+ parser.add_argument("--tile_size", default=64, type=int)
75
+ parser.add_argument("--tile_count", default=24, type=int)
76
+ parser.add_argument("--depth", default=3, type=int)
77
+ parser.add_argument("--project_dir", default=None, help="Project directory")
78
+
79
+ args = parser.parse_args()
80
+ if args.config:
81
+ with open(args.config, 'r') as config_file:
82
+ config = yaml.safe_load(config_file)
83
+ args.__dict__.update(config)
84
+ return args
85
+
86
+ if __name__ == "__main__":
87
+ args = parse_args()
88
+ FUNCTIONS = {
89
+ "register_and_crop": register_files,
90
+ "histogram_match": histmatch,
91
+ "get_segmentation_mask": get_segmask,
92
+ "get_heatmap": get_heatmap,
93
+ }
94
+
95
+ args.logfile = os.path.join(args.output_dir, f"inference.log")
96
+ setup_logging(args.logfile)
97
+ logging.info("Starting preprocessing")
98
+ steps = ["register_and_crop", "get_segmentation_mask", "histogram_match", "get_heatmap"]
99
+ for step in steps:
100
+ func = FUNCTIONS[step]
101
+ args = func(args)
102
+
103
+ logging.info("Preprocessing completed.")
104
+
105
+ args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+
107
+ logging.info("Loading PIRADS model")
108
+ pirads_model = MILModel_3D(
109
+ num_classes=args.num_classes,
110
+ mil_mode=args.mil_mode
111
+ )
112
+ pirads_checkpoint = torch.load(os.path.join(args.project_dir, 'models', 'pirads.pt'), map_location="cpu")
113
+ pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
114
+ pirads_model.to(args.device)
115
+ logging.info("Loading csPCa model")
116
+ cspca_model = csPCa_Model(backbone=pirads_model).to(args.device)
117
+ checkpt = torch.load(os.path.join(args.project_dir, 'models', 'cspca_model.pth'), map_location="cpu")
118
+ cspca_model.load_state_dict(checkpt['state_dict'])
119
+ cspca_model = cspca_model.to(args.device)
120
+
121
+ transform = data_transform(args)
122
+ files = os.listdir(args.t2_dir)
123
+ data_list = []
124
+ for file in files:
125
+ temp = {}
126
+ temp['image'] = os.path.join(args.t2_dir, file)
127
+ temp['dwi'] = os.path.join(args.dwi_dir, file)
128
+ temp['adc'] = os.path.join(args.adc_dir, file)
129
+ temp['heatmap'] = os.path.join(args.heatmapdir, file)
130
+ temp['mask'] = os.path.join(args.seg_dir, file)
131
+ temp['label'] = 0 # dummy label
132
+ data_list.append(temp)
133
+
134
+ dataset = Dataset(data=data_list, transform=transform)
135
+ loader = torch.utils.data.DataLoader(
136
+ dataset,
137
+ batch_size=1,
138
+ shuffle=False,
139
+ num_workers=0,
140
+ pin_memory=True,
141
+ multiprocessing_context= None,
142
+ sampler=None,
143
+ collate_fn=list_data_collate,
144
+ )
145
+
146
+ pirads_list = []
147
+ pirads_model.eval()
148
+ cspca_risk_list = []
149
+ cspca_model.eval()
150
+ with torch.no_grad():
151
+ for idx, batch_data in enumerate(loader):
152
+ data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
153
+ logits = pirads_model(data)
154
+ pirads_score= torch.argmax(logits, dim=1)
155
+ pirads_list.append(pirads_score.item())
156
+
157
+ output = cspca_model(data)
158
+ output = output.squeeze(1)
159
+ cspca_risk_list.append(output.item())
160
+
161
+ for i,j in enumerate(files):
162
+ logging.info(f"File: {j}, PIRADS score: {pirads_list[i]}, csPCa risk score: {cspca_risk_list[i]:.4f}")
src/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/__init__.cpython-39.pyc and b/src/__pycache__/__init__.cpython-39.pyc differ
 
src/__pycache__/utils.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/utils.cpython-39.pyc and b/src/__pycache__/utils.cpython-39.pyc differ
 
src/data/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/src/data/__pycache__/__init__.cpython-39.pyc and b/src/data/__pycache__/__init__.cpython-39.pyc differ
 
src/data/__pycache__/custom_transforms.cpython-39.pyc CHANGED
Binary files a/src/data/__pycache__/custom_transforms.cpython-39.pyc and b/src/data/__pycache__/custom_transforms.cpython-39.pyc differ
 
src/data/__pycache__/data_loader.cpython-39.pyc CHANGED
Binary files a/src/data/__pycache__/data_loader.cpython-39.pyc and b/src/data/__pycache__/data_loader.cpython-39.pyc differ
 
src/model/__pycache__/MIL.cpython-39.pyc CHANGED
Binary files a/src/model/__pycache__/MIL.cpython-39.pyc and b/src/model/__pycache__/MIL.cpython-39.pyc differ
 
src/model/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/src/model/__pycache__/__init__.cpython-39.pyc and b/src/model/__pycache__/__init__.cpython-39.pyc differ
 
src/model/__pycache__/csPCa_model.cpython-39.pyc CHANGED
Binary files a/src/model/__pycache__/csPCa_model.cpython-39.pyc and b/src/model/__pycache__/csPCa_model.cpython-39.pyc differ
 
src/preprocessing/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/src/preprocessing/__pycache__/__init__.cpython-39.pyc and b/src/preprocessing/__pycache__/__init__.cpython-39.pyc differ
 
src/preprocessing/__pycache__/center_crop.cpython-39.pyc CHANGED
Binary files a/src/preprocessing/__pycache__/center_crop.cpython-39.pyc and b/src/preprocessing/__pycache__/center_crop.cpython-39.pyc differ
 
src/preprocessing/__pycache__/generate_heatmap.cpython-39.pyc CHANGED
Binary files a/src/preprocessing/__pycache__/generate_heatmap.cpython-39.pyc and b/src/preprocessing/__pycache__/generate_heatmap.cpython-39.pyc differ
 
src/preprocessing/__pycache__/histogram_match.cpython-39.pyc CHANGED
Binary files a/src/preprocessing/__pycache__/histogram_match.cpython-39.pyc and b/src/preprocessing/__pycache__/histogram_match.cpython-39.pyc differ
 
src/preprocessing/__pycache__/prostate_mask.cpython-39.pyc CHANGED
Binary files a/src/preprocessing/__pycache__/prostate_mask.cpython-39.pyc and b/src/preprocessing/__pycache__/prostate_mask.cpython-39.pyc differ
 
src/preprocessing/__pycache__/register_and_crop.cpython-39.pyc CHANGED
Binary files a/src/preprocessing/__pycache__/register_and_crop.cpython-39.pyc and b/src/preprocessing/__pycache__/register_and_crop.cpython-39.pyc differ
 
src/train/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/src/train/__pycache__/__init__.cpython-39.pyc and b/src/train/__pycache__/__init__.cpython-39.pyc differ
 
src/train/__pycache__/train_cspca.cpython-39.pyc CHANGED
Binary files a/src/train/__pycache__/train_cspca.cpython-39.pyc and b/src/train/__pycache__/train_cspca.cpython-39.pyc differ
 
src/train/__pycache__/train_pirads.cpython-39.pyc CHANGED
Binary files a/src/train/__pycache__/train_pirads.cpython-39.pyc and b/src/train/__pycache__/train_pirads.cpython-39.pyc differ
 
temp.ipynb CHANGED
The diff for this file is too large to render. See raw diff