| import os |
| import os.path |
| import sys |
| import json |
| import torch |
|
|
| |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
| sys.path.append(os.path.abspath("..")) |
| |
|
|
| from dicom_to_nii import convert_ct_dicom_to_nii |
| from nii_to_dicom import convert_nii_to_dicom |
|
|
|
|
| |
| from monai.networks.nets import UNet |
| from monai.networks.layers import Norm |
| from monai.inferers import sliding_window_inference |
| from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch, NibabelReader |
| from monai.utils import first |
| from monai.transforms import ( |
| EnsureChannelFirstd, |
| Compose, |
| CropForegroundd, |
| ScaleIntensityRanged, |
| Invertd, |
| AsDiscreted, |
| ThresholdIntensityd, |
| RemoveSmallObjectsd, |
| KeepLargestConnectedComponentd, |
| Activationsd |
| ) |
| |
| from preprocessing import LoadImaged |
| |
| from postprocessing import SaveImaged, add_contours_exist |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from utils import * |
|
|
|
|
| def predict(tempPath, patient_id, ctSeriesInstanceUID, runInterpreter): |
| |
| |
| if not patient_id or patient_id == "": |
| sys.exit("No Patient dataset loaded: Load the patient dataset in Study Management.") |
| |
| if not ctSeriesInstanceUID or ctSeriesInstanceUID == "": |
| sys.exit("No CT series instance UID to load the CT images. Check for CT data in your study") |
| |
| print("+++ tempath: ", tempPath) |
| print("+++ patient_id: ", patient_id) |
| print("+++ CT SeriesInstanceUID: ", ctSeriesInstanceUID) |
| print("+++ runInterpreter", runInterpreter) |
| |
| |
| dir_base = os.path.join(tempPath, patient_id) |
| createdir(dir_base) |
| dir_ct_dicom = os.path.join(dir_base, 'ct_dicom') |
| createdir(dir_ct_dicom) |
| dir_ct_nii = os.path.join(dir_base, "ct_nii") |
| createdir(dir_ct_nii) |
| dir_prediction_nii = os.path.join(dir_base, 'prediction_nii') |
| createdir(dir_prediction_nii) |
| dir_prediction_dicom = os.path.join(dir_base, 'prediction_dicom') |
| createdir(dir_prediction_dicom) |
| |
| |
| predictedNiiFile = os.path.join(dir_prediction_nii, 'RTStruct.nii.gz') |
| predictedDicomFile = os.path.join(dir_prediction_dicom, 'predicted_rtstruct.dcm') |
|
|
| model_path = r'best_metric_model.pth' |
| if not os.path.exists(model_path): |
| sys.exit("Not found the trained model") |
| |
| |
| print('** Use python interpreter: ', runInterpreter) |
| print('** Patient name: ', patient_id) |
| print('** CT Serial instance UID: ', ctSeriesInstanceUID) |
| |
| downloadSeriesInstanceByModality(ctSeriesInstanceUID, dir_ct_dicom, "CT") |
| print("Loading CT from Orthanc done") |
|
|
| |
| |
| |
| |
| refCT = convert_ct_dicom_to_nii(dir_dicom = dir_ct_dicom, dir_nii = dir_ct_nii, outputname='ct.nii.gz', newvoxelsize = None) |
|
|
| print("Conversion DICOM to nii done") |
|
|
| |
| test_Data = [{'image':os.path.join(dir_ct_nii,'ct.nii.gz')}] |
|
|
| |
| test_pretransforms = Compose( |
| [ |
| LoadImaged(keys=["image"], reader = NibabelReader(), patientname=patient_id), |
| EnsureChannelFirstd(keys=["image"]), |
| ThresholdIntensityd(keys=["image"], threshold=1560, above=False, cval=1560), |
| ThresholdIntensityd(keys=["image"], threshold=-50, above=True, cval=-1000), |
| |
| ScaleIntensityRanged( |
| keys=["image"], a_min=-1000, a_max=1560, |
| b_min=0.0, b_max=1.0, clip=True, |
| ), |
| CropForegroundd(keys=["image"], source_key="image") |
| ] |
| ) |
| test_posttransforms = Compose( |
| [ |
| Activationsd(keys="pred", softmax=True), |
| Invertd( |
| keys="pred", |
| transform=test_pretransforms, |
| orig_keys="image", |
| |
| |
| nearest_interp=False, |
| |
| to_tensor=True, |
| ), |
| AsDiscreted(keys="pred", argmax=True, to_onehot=2, threshold=0.5), |
| KeepLargestConnectedComponentd(keys="pred",is_onehot=True), |
| SaveImaged(keys="pred", output_postfix='', separate_folder=False, output_dir=dir_prediction_nii, resample=False) |
| ] |
| ) |
|
|
| |
| test_ds = CacheDataset(data=test_Data, transform=test_pretransforms) |
| test_loader = DataLoader(test_ds, batch_size=1, shuffle=True, num_workers=1) |
|
|
| |
| |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| model_param = dict( |
| spatial_dims=3, |
| in_channels=1, |
| out_channels=2, |
| channels=(16, 32, 64, 128, 256), |
| strides=(2, 2, 2, 2), |
| num_res_units=2, |
| norm=Norm.BATCH |
| ) |
| model = UNet(**model_param) |
| |
| trained_model_dict = torch.load(model_path, map_location=torch.device('cuda:0' if torch.cuda.is_available() else "cpu")) |
| model.load_state_dict(trained_model_dict) |
|
|
| |
| model = model.to(device) |
| |
|
|
| model.eval() |
| d = first(test_loader) |
| images = d["image"].to(device) |
| d['pred'] = sliding_window_inference(inputs=images, roi_size=(96,96,64),sw_batch_size=1, predictor = model) |
| d['pred'] = [test_posttransforms(i) for i in decollate_batch(d)] |
| |
| |
|
|
| add_contours_exist(preddir = dir_prediction_nii, refCT = refCT) |
| |
| convert_nii_to_dicom(dicomctdir = dir_ct_dicom, predictedNiiFile = predictedNiiFile, predictedDicomFile = predictedDicomFile, predicted_structures=['BODY'], rtstruct_colors=[[255,0,0]], refCT = refCT) |
|
|
| print("Conversion nii to DICOM done") |
|
|
| |
| uploadDicomToOrthanc(predictedDicomFile) |
|
|
| print("Upload predicted result to Orthanc done") |
| |
| print("Body Segmentation prediction done") |
|
|
| ''' |
| Prediction parameters provided by the server. Select the parameters to be used for prediction: |
| [1] tempPath: The path where the predict.py is stored, |
| [2] patientname: python version, |
| [3] ctSeriesInstanceUID: Series instance UID for data set with modality = CT. To predict 'MR' modality data, retrieve the CT UID by the code (see Precision Code) |
| [4] rtStructSeriesInstanceUID: Series instance UID for modality = RTSTURCT |
| [5] regSeriesInstanceUID: Series instance UID for modality = REG, |
| [6] runInterpreter: The python version for the python environment |
| [7] oarList: only for dose predciton. For contour predicion oarList = [] |
| [8] tvList: only for dose prediction. For contour prediction tvList = [] |
| ''' |
|
|
| if __name__ == '__main__': |
| predict(tempPath=sys.argv[1], patient_id=sys.argv[2], ctSeriesInstanceUID=sys.argv[3], runInterpreter=sys.argv[6]) |
|
|
|
|
|
|