Spaces:
Runtime error
Runtime error
File size: 7,186 Bytes
8bcc3f2 caf6ee7 8bcc3f2 70d0e22 8bcc3f2 caf6ee7 1baebae caf6ee7 5fa0689 94fe9ea 5fa0689 94fe9ea 5fa0689 94fe9ea 5fa0689 0b653cf 94fe9ea 0b653cf 5fa0689 94fe9ea 5fa0689 8bcc3f2 1baebae 8bcc3f2 1baebae 8bcc3f2 1baebae 8bcc3f2 caf6ee7 8bcc3f2 1baebae 8bcc3f2 70d0e22 8bcc3f2 1baebae 8bcc3f2 1baebae 8bcc3f2 1baebae 8bcc3f2 5fa0689 1baebae 8bcc3f2 5fa0689 8bcc3f2 5fa0689 caf6ee7 1baebae 8bcc3f2 5fa0689 8bcc3f2 a4ef78c 8bcc3f2 1baebae a4ef78c 8bcc3f2 a4ef78c 8bcc3f2 1baebae 8bcc3f2 8c2b158 8bcc3f2 caf6ee7 8bcc3f2 1baebae 8bcc3f2 a4ef78c 8c2b158 caf6ee7 8c2b158 a4ef78c caf6ee7 8c2b158 a4ef78c 1baebae 8c2b158 1baebae 8c2b158 1baebae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | import argparse
import json
import logging
import os
from pathlib import Path
import torch
import yaml
from monai.data import Dataset
from src.data.data_loader import data_transform, list_data_collate
from src.model.cspca_model import CSPCAModel
from src.model.mil import MILModel3D
from src.preprocessing.generate_heatmap import get_heatmap
from src.preprocessing.histogram_match import histmatch
from src.preprocessing.prostate_mask import get_segmask
from src.preprocessing.register_and_crop import register_files
from src.utils import get_parent_image, get_patch_coordinate, setup_logging
import streamlit as st
@st.cache_resource
def load_pirads_model(num_classes, mil_mode, project_dir, device):
model = MILModel3D(num_classes=num_classes, mil_mode=mil_mode)
checkpoint = torch.load(
os.path.join(project_dir, "models", "pirads.pt"), map_location="cpu"
)
model.load_state_dict(checkpoint["state_dict"])
model.to(device)
model.eval()
return model
@st.cache_resource
def load_cspca_model(_pirads_model, project_dir, device):
model = CSPCAModel(backbone=_pirads_model).to(device)
checkpt = torch.load(
os.path.join(project_dir, "models", "cspca_model.pth"), map_location="cpu"
)
model.load_state_dict(checkpt["state_dict"])
model = model.to(device)
model.eval()
return model
def parse_args():
parser = argparse.ArgumentParser(description="File preprocessing")
parser.add_argument("--config", type=str, help="Path to YAML config file")
parser.add_argument("--t2_dir", default=None, help="Path to T2W files")
parser.add_argument("--dwi_dir", default=None, help="Path to DWI files")
parser.add_argument("--adc_dir", default=None, help="Path to ADC files")
parser.add_argument("--seg_dir", default=None, help="Path to segmentation masks")
parser.add_argument("--output_dir", default=None, help="Path to output folder")
parser.add_argument(
"--margin", default=0.2, type=float, help="Margin to center crop the images"
)
parser.add_argument("--num_classes", default=4, type=int)
parser.add_argument("--mil_mode", default="att_trans", type=str)
parser.add_argument("--use_heatmap", default=True, type=bool)
parser.add_argument("--tile_size", default=64, type=int)
parser.add_argument("--tile_count", default=24, type=int)
parser.add_argument("--depth", default=3, type=int)
parser.add_argument("--project_dir", default=None, help="Project directory")
args = parser.parse_args()
if args.config:
with open(args.config) as config_file:
config = yaml.safe_load(config_file)
args.__dict__.update(config)
return args
if __name__ == "__main__":
args = parse_args()
if args.project_dir is None:
args.project_dir = Path(__file__).resolve().parent # Set project directory
FUNCTIONS = {
"register_and_crop": register_files,
"histogram_match": histmatch,
"get_segmentation_mask": get_segmask,
"get_heatmap": get_heatmap,
}
args.logfile = os.path.join(args.output_dir, "inference.log")
setup_logging(args.logfile)
logging.info("Starting preprocessing")
steps = ["register_and_crop", "get_segmentation_mask", "histogram_match", "get_heatmap"]
for step in steps:
func = FUNCTIONS[step]
args = func(args)
logging.info("Preprocessing completed.")
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("Loading PIRADS model")
pirads_model = load_pirads_model(args.num_classes, args.mil_mode, args.project_dir, args.device)
'''
pirads_checkpoint = torch.load(
os.path.join(args.project_dir, "models", "pirads.pt"), map_location="cpu"
)
pirads_model.load_state_dict(pirads_checkpoint["state_dict"])
pirads_model.to(args.device)
'''
logging.info("Loading csPCa model")
cspca_model = load_cspca_model(pirads_model, args.project_dir, args.device)
'''
cspca_model = CSPCAModel(backbone=pirads_model).to(args.device)
checkpt = torch.load(
os.path.join(args.project_dir, "models", "cspca_model.pth"), map_location="cpu"
)
cspca_model.load_state_dict(checkpt["state_dict"])
cspca_model = cspca_model.to(args.device)
'''
transform = data_transform(args)
files = os.listdir(args.t2_dir)
args.data_list = []
for file in files:
temp = {}
temp["image"] = os.path.join(args.t2_dir, file)
temp["dwi"] = os.path.join(args.dwi_dir, file)
temp["adc"] = os.path.join(args.adc_dir, file)
temp["heatmap"] = os.path.join(args.heatmapdir, file)
temp["mask"] = os.path.join(args.seg_dir, file)
temp["label"] = 0 # dummy label
args.data_list.append(temp)
dataset = Dataset(data=args.data_list, transform=transform)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=0,
pin_memory=True,
multiprocessing_context=None,
sampler=None,
collate_fn=list_data_collate,
)
pirads_list = []
pirads_model.eval()
cspca_risk_list = []
cspca_model.eval()
patches_top_5_list = []
with torch.no_grad():
for _, batch_data in enumerate(loader):
data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
logits = pirads_model(data)
pirads_score = torch.argmax(logits, dim=1)
pirads_list.append(pirads_score.item())
output = cspca_model(data)
output = output.squeeze(1)
cspca_risk_list.append(output.item())
sh = data.shape
x = data.reshape(sh[0] * sh[1], sh[2], sh[3], sh[4], sh[5])
x = cspca_model.backbone.net(x)
x = x.reshape(sh[0], sh[1], -1)
x = x.permute(1, 0, 2)
x = cspca_model.backbone.transformer(x)
x = x.permute(1, 0, 2)
a = cspca_model.backbone.attention(x)
a = torch.softmax(a, dim=1)
a = a.view(-1)
top5_values, top5_indices = torch.topk(a, 5)
patches_top_5 = []
for i in range(5):
patch_temp = data[0, top5_indices.cpu().numpy()[i]][0].cpu().numpy()
patches_top_5.append(patch_temp)
patches_top_5_list.append(patches_top_5)
coords_list = []
for j, i in enumerate(args.data_list):
parent_image = get_parent_image([i], args)
coords = get_patch_coordinate(patches_top_5_list[j], parent_image)
coords_list.append(coords)
output_dict = {}
for i, j in enumerate(files):
logging.info(
f"File: {j}, PIRADS score: {pirads_list[i] + 2.0}, csPCa risk score: {cspca_risk_list[i]:.4f}"
)
output_dict[j] = {
"Predicted PIRAD Score": pirads_list[i] + 2.0,
"csPCa risk": cspca_risk_list[i],
"Top left coordinate of top 5 patches(x,y,z)": coords_list[i],
}
with open(os.path.join(args.output_dir, "results.json"), "w") as f:
json.dump(output_dict, f, indent=4)
|