diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..39a1c60da9ad95424b031242144c2627f57cd4b2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +model/dynamic-network-architectures-main/imgs/Logos/HI_Logo.png filter=lfs diff=lfs merge=lfs -text diff --git a/CVPR25_TextSegFMData_with_class.json b/CVPR25_TextSegFMData_with_class.json new file mode 100644 index 0000000000000000000000000000000000000000..2599d6792bdadae68d5599009e264035b54ff657 --- /dev/null +++ b/CVPR25_TextSegFMData_with_class.json @@ -0,0 +1,4044 @@ +{ + "CT_Abdomen1K": { + "1": [ + "Liver", + "Liver in abdominal CT", + "CT imaging of the liver in the abdomen", + "Abdominal CT showing liver structures", + "Liver detected in abdominal CT scans", + "Visualization of the liver in abdominal CT imaging", + "Abdominal CT revealing liver anatomy", + "CT scan of the liver in the abdominal region", + "Presence of liver tissue in abdominal CT scans" + ], + "2": [ + "Right kidney", + "Right kidney in abdominal CT", + "CT imaging of the right kidney in the abdomen", + "Abdominal CT showing right kidney structures", + "Right kidney detected in abdominal CT scans", + "Visualization of the right kidney in abdominal CT imaging", + "CT scan of the right kidney in the abdominal region", + "Presence of right kidney tissue in abdominal CT images", + "Abdominal CT revealing right kidney anatomy" + ], + "3": [ + "Spleen", + "Spleen in abdominal CT", + "CT imaging of the spleen in the abdomen", + "Abdominal CT showing spleen structures", + "Spleen observed in abdominal CT scans", + "Visualization of the spleen in abdominal CT imaging", + "CT scan of the spleen in the abdominal region", + "Abdominal CT revealing spleen anatomy", + "Spleen segmentation in abdominal CT imaging" + ], + "4": [ + "Pancreas", + "Pancreas in abdominal CT", + "CT imaging of the pancreas in the abdomen", + "Abdominal CT showing pancreatic structures", + "Pancreas detected in abdominal CT scans", + "Visualization of the pancreas in abdominal CT imaging", + "CT scan of the pancreas in the abdominal region", + "Presence of pancreatic tissue in abdominal CT images", + "Abdominal CT revealing pancreas anatomy" + ], + "5": [ + "Left kidney", + "Left kidney in abdominal CT", + "CT imaging of the left kidney in the abdomen", + "Abdominal CT showing left kidney structures", + "Left kidney detected in abdominal CT scans", + "Visualization of the left kidney in abdominal CT imaging", + "CT scan of the left kidney in the abdominal region", + "Presence of left kidney tissue in abdominal CT images", + "Abdominal CT revealing left kidney anatomy" + ], + "instance_label": 0 + }, + "CT_AdrenalTumor": { + "1": [ + "Adrenocortical carcinoma", + "Adrenocortical carcinoma in abdominal CT imaging", + "Abdominal CT depiction of adrenocortical carcinoma", + "CT imaging of adrenocortical carcinoma in the abdomen", + "Adrenocortical carcinoma as seen in abdominal CT scans", + "CT visualization of adrenocortical carcinoma in the abdomen", + "Adrenocortical carcinoma manifestation in abdominal CT imaging", + "Abdominal CT representation of adrenocortical carcinoma", + "Adrenocortical carcinoma observed in abdominal CT imaging", + "CT-based detection of adrenocortical carcinoma in the abdomen", + "Adrenocortical carcinoma features in abdominal CT studies" + ], + "instance_label": 1 + }, + "CT_AMOS": { + "1": [ + "Spleen", + "Abdominal CT revealing spleen structures", + "Spleen detected in abdominal CT scans", + "CT imaging of the spleen within the abdomen", + "Spleen anatomy visualized in abdominal CT", + "Presence of spleen tissue observed in abdominal CT imaging", + "CT scan showing spleen in the abdominal cavity", + "Abdominal CT assessment of the spleen", + "Spleen observed in CT imaging of the abdomen" + ], + "2": [ + "Right kidney", + "Right kidney in abdominal CT", + "CT imaging of the right kidney in the abdomen", + "Abdominal CT showing the right kidney", + "Right kidney observed in abdominal CT scans", + "CT scan of the right kidney in the abdominal region", + "Presence of the right kidney detected in abdominal CT images", + "Abdominal CT revealing the right kidney anatomy" + ], + "3": [ + "Left kidney", + "Left kidney in abdominal CT", + "CT imaging of the left kidney in the abdomen", + "Abdominal CT showing the left kidney", + "Left kidney observed in abdominal CT scans", + "Visualization of the left kidney in abdominal CT imaging", + "CT scan of the left kidney in the abdominal region", + "Presence of the left kidney detected in abdominal CT images", + "Abdominal CT revealing the left kidney anatomy" + ], + "4": [ + "Gallbladder", + "Gallbladder in abdominal CT", + "CT imaging of the gallbladder in the abdomen", + "Abdominal CT showing gallbladder structures", + "Gallbladder observed in abdominal CT scans", + "Visualization of the gallbladder in abdominal CT imaging", + "CT scan of the gallbladder in the abdominal region", + "Presence of the gallbladder detected in abdominal CT images", + "Abdominal CT revealing gallbladder anatomy" + ], + "5": [ + "Esophagus", + "Esophagus in abdominal CT", + "CT imaging of the esophagus in the abdomen", + "Abdominal CT showing esophagus structures", + "Esophagus observed in abdominal CT scans", + "Visualization of the esophagus in abdominal CT imaging", + "CT scan of the esophagus in the abdominal region", + "Presence of the esophagus detected in abdominal CT images", + "Abdominal CT revealing esophagus anatomy" + ], + "6": [ + "Liver", + "Liver in abdominal CT", + "CT imaging of the liver in the abdomen", + "Abdominal CT showing liver structures", + "Liver detected in abdominal CT scans", + "Visualization of the liver in abdominal CT imaging", + "CT scan of the liver in the abdominal region", + "Presence of liver tissue in abdominal CT images", + "Abdominal CT revealing liver anatomy" + ], + "7": [ + "Stomach", + "Stomach in abdominal CT", + "CT imaging of the stomach in the abdomen", + "Abdominal CT showing stomach structures", + "Stomach observed in abdominal CT scans", + "Visualization of the stomach in abdominal CT imaging", + "CT scan of the stomach in the abdominal region", + "Presence of the stomach detected in abdominal CT images", + "Abdominal CT revealing stomach anatomy" + ], + "8": [ + "Aorta", + "Aorta in abdominal CT", + "CT imaging of the aorta in the abdomen", + "Abdominal CT showing aortic structures", + "Aorta observed in abdominal CT scans", + "Visualization of the aorta in abdominal CT imaging", + "CT scan of the aorta in the abdominal region", + "presence of the aorta detected in abdominal CT images", + "Abdominal CT revealing aorta anatomy" + ], + "9": [ + "Inferior vena cava", + "Inferior vena cava in abdominal CT", + "CT imaging of the inferior vena cava in the abdomen", + "Abdominal CT showing inferior vena cava structures", + "Inferior vena cava observed in abdominal CT scan", + "Visualization of the inferior vena cava in abdominal CT imaging", + "CT scan of the inferior vena cava in the abdominal region", + "Presence of the inferior vena cava detected in abdominal CT images", + "Abdominal CT revealing inferior vena cava anatomy" + ], + "10": [ + "Pancreas", + "Pancreas in abdominal CT", + "CT imaging of the pancreas in the abdomen", + "Abdominal CT showing pancreatic structures", + "Pancreas detected in abdominal CT scans", + "Visualization of the pancreas in abdominal CT imaging", + "CT scan of the pancreas in the abdominal region", + "Presence of pancreatic tissue in abdominal CT images", + "Abdominal CT revealing pancreas anatomy" + ], + "11": [ + "Right adrenal gland", + "Right adrenal gland in abdominal CT", + "CT imaging of the right adrenal gland in the abdomen", + "Abdominal CT showing right adrenal gland structures", + "Right adrenal gland observed in abdominal CT scans", + "Visualization of the right adrenal gland in abdominal CT imaging", + "CT scan of the right adrenal gland in the abdominal region", + "Presence of the right adrenal gland detected in abdominal CT images", + "Abdominal CT revealing right adrenal gland anatomy" + ], + "12": [ + "Left adrenal gland", + "Left adrenal gland in abdominal CT", + "CT imaging of the left adrenal gland in the abdomen", + "Abdominal CT showing left adrenal gland structures", + "Left adrenal gland observed in abdominal CT scans", + "Visualization of the left adrenal gland in abdominal CT imaging", + "CT scan of the left adrenal gland in the abdominal region", + "Presence of the left adrenal gland detected in abdominal CT images", + "Abdominal CT revealing left adrenal gland anatomy" + ], + "13": [ + "Duodenum", + "Duodenum in abdominal CT", + "CT imaging of the duodenum in the abdomen", + "Abdominal CT showing duodenal structures", + "Duodenum observed in abdominal CT scans", + "Visualization of the duodenum in abdominal CT imaging", + "CT scan of the duodenum in the abdominal region", + "Presence of the duodenum detected in abdominal CT images", + "Abdominal CT revealing duodenum anatomy" + ], + "14": [ + "Bladder", + "Bladder in abdominal CT", + "CT imaging of the bladder in the abdomen", + "Abdominal CT showing bladder structures", + "Bladder observed in abdominal CT scans", + "Visualization of the bladder in abdominal CT imaging", + "CT scan of the bladder in the abdominal region", + "Presence of the bladder detected in abdominal CT images", + "Abdominal CT revealing bladder anatomy" + ], + "15": [ + "Prostate/uterus", + "Prostate/uterus in abdominal CT", + "CT imaging of the prostate/uterus in the abdomen", + "Abdominal CT showing prostate/uterus structures", + "Prostate/uterus observed in abdominal CT scans", + "Visualization of the prostate/uterus in abdominal CT imaging", + "CT scan of the prostate/uterus in the abdominal region", + "Presence of the prostate/uterus detected in abdominal CT images", + "Abdominal CT revealing prostate/uterus anatomy" + ], + "instance_label": 0 + }, + "CT_Aorta": { + "1": [ + "Aortic vessel trees", + "Aortic vessel trees in whole-body CTA", + "CTA imaging of aortic vessel trees throughout the body", + "Whole-body CTA showing aortic vessel tree structures", + "Aortic vessel trees observed in whole-body CTA scans", + "Visualization of aortic vessel trees in CTA of the entire body", + "CTA scan revealing aortic vessel trees in the whole body", + "Aortic vessel tree assessment using whole-body CTA", + "Presence of aortic vessel trees detected in whole-body CTA imaging", + "Aortic vessel tree anatomy shown in whole-body CTA" + ], + "instance_label": 0 + }, + "CT_ColonTumor": { + "1": [ + "Colon cancer primaries", + "Colon cancer primaries in abdominal CT", + "CT imaging of colon cancer primaries in the abdomen", + "Abdominal CT showing colon cancer primary tumors", + "Colon cancer primaries observed in abdominal CT scans", + "Visualization of colon cancer primaries in abdominal CT imaging", + "CT scan of colon cancer primaries in the abdominal region", + "Presence of colon cancer primaries detected in abdominal CT images", + "Colon cancer primary tumors assessed using abdominal CT", + "Abdominal CT revealing colon cancer primaries" + ], + "instance_label": 1 + }, + "CT_COVID19-Infection": { + "1": [ + "COVID-19 infection", + "COVID-19 infection detected in chest CT scans", + "Thoracic CT imaging revealing COVID-19 involvement", + "Pulmonary manifestations of COVID-19 in CT imaging", + "COVID-19-associated parenchymal abnormalities in chest CT", + "COVID-19-related pathology characterized in chest tomographic scans", + "Radiologic evidence of COVID-19 infection in pulmonary CT studies", + "Diffuse COVID-19 lung involvement documented in CT imaging", + "Tomographic evidence of COVID-19 infection in the chest", + "Lung involvement by COVID-19 on computed tomography" + ], + "instance_label": 1 + }, + "CT_HaN-Seg": { + "1": [ + "Left carotid artery", + "Left carotid artery in head and neck CT", + "CT imaging of the left carotid artery in the head and neck region", + "Head and neck CT demonstrating the left carotid artery", + "Visualization of the left carotid artery in head and neck CT imaging", + "CT scan of the head and neck showing the left carotid artery", + "Left carotid artery segmentation in head and neck CT images", + "Assessment of the left carotid artery using head and neck CT" + ], + "10": [ + "Cricopharyngeal inlet", + "Cricopharyngeal inlet in head and neck CT", + "CT imaging of the cricopharyngeal inlet in the head and neck region", + "Head and neck CT showing the cricopharyngeal inlet", + "Cricopharyngeal inlet observed in head and neck CT scans", + "Visualization of the cricopharyngeal inlet in head and neck CT imaging", + "CT scan of the cricopharyngeal inlet in the head and neck", + "Presence of the cricopharyngeal inlet detected in head and neck CT images", + "Cricopharyngeal inlet segmentation in head and neck CT imaging" + ], + "11": [ + "Cervical esophagus", + "Cervical esophagus in Head and Neck CT imaging", + "Cervical esophagus delineated in Head and Neck CT scans", + "Cervical esophagus segmentation target in Head and Neck CT studies", + "Cervical esophagus anatomical structure visualized via Head and Neck CT", + "Cervical esophagus depicted on Head and Neck CT imaging", + "Cervical esophagus localization within Head and Neck CT examinations", + "Cervical esophagus imaged in Head and Neck CT scans", + "Cervical esophagus boundaries assessed in Head and Neck CT" + ], + "12": [ + "Left anterior segment of the eyeball", + "Left anterior segment of the eyeball in head and neck CT", + "CT imaging of the left anterior segment of the eyeball in the head and neck region", + "Head and neck CT showing left anterior eyeball segment structures", + "Visualization of the left anterior segment of the eyeball in head and neck CT scans", + "Left anterior eyeball segment observed in head and neck CT imaging", + "CT scan of the head and neck depicting the left anterior segment of the eyeball", + "Left anterior eyeball segment segmentation in head and neck CT", + "Left anterior eyeball segment delineation in head and neck CT imaging", + "CT-based localization of the left anterior segment of the eyeball in the head and neck" + ], + "13": [ + "Right anterior segment of the eyeball", + "Right anterior segment of the eyeball in head and neck CT", + "CT imaging of the right anterior segment of the eyeball in the head and neck region", + "Head and neck CT showing right anterior eyeball segment structures", + "Visualization of the right anterior segment of the eyeball in head and neck CT scans", + "Right anterior eyeball segment observed in head and neck CT imaging", + "CT scan of the head and neck highlighting the right anterior segment of the eyeball", + "Right anterior eyeball segment segmentation in head and neck CT", + "Right anterior eyeball segment delineation in head and neck CT imaging", + "CT-based identification of the right anterior segment of the eyeball in the head and neck" + ], + "14": [ + "Left posterior segment of the eyeball", + "Left posterior segment of the eyeball in head and neck CT", + "CT imaging of the left posterior segment of the eyeball in the head and neck region", + "Head and neck CT showing left posterior eyeball segment structures", + "Visualization of the left posterior segment of the eyeball in head and neck CT scans", + "Left posterior eyeball segment observed in head and neck CT imaging", + "CT scan of the head and neck depicting the left posterior segment of the eyeball", + "Left posterior eyeball segment segmentation in head and neck CT", + "Left posterior eyeball segment delineation in head and neck CT imaging", + "CT-based localization of the left posterior segment of the eyeball in the head and neck" + ], + "15": [ + "Right posterior segment of the eyeball", + "Right posterior segment of the eyeball in head and neck CT", + "CT imaging of the right posterior segment of the eyeball in the head and neck region", + "Head and neck CT showing right posterior eyeball segment structures", + "Visualization of the right posterior segment of the eyeball in head and neck CT scans", + "Right posterior eyeball segment observed in head and neck CT imaging", + "CT scan of the head and neck highlighting the right posterior segment of the eyeball", + "Right posterior eyeball segment segmentation in head and neck CT", + "Right posterior eyeball segment delineation in head and neck CT imaging", + "CT-based identification of the right posterior segment of the eyeball in the head and neck" + ], + "16": [ + "Left lacrimal gland", + "Left lacrimal gland in Head and Neck CT imaging", + "Left lacrimal gland anatomical segmentation in Head and Neck CT scans", + "Left lacrimal gland localization in Head and Neck CT studies", + "Left lacrimal gland boundaries in Head and Neck CT examinations", + "Left lacrimal gland visualization in Head and Neck CT datasets", + "Head and neck CT showing the left lacrimal gland", + "Left lacrimal gland observed in head and neck CT scans", + "Visualization of the left lacrimal gland in head and neck CT imaging", + "CT scan of the left lacrimal gland in the head and neck", + "Presence of the left lacrimal gland detected in head and neck CT images" + ], + "17": [ + "Right lacrimal gland", + "Right lacrimal gland in Head and Neck CT imaging", + "Right lacrimal gland anatomical segmentation in Head and Neck CT scans", + "Right lacrimal gland localization in Head and Neck CT studies", + "Right lacrimal gland boundaries in Head and Neck CT examinations", + "Right lacrimal gland visualization in Head and Neck CT datasets", + "Head and neck CT showing the right lacrimal gland", + "Right lacrimal gland observed in head and neck CT scans", + "Visualization of the right lacrimal gland in head and neck CT imaging", + "CT scan of the right lacrimal gland in the head and neck", + "Presence of the right lacrimal gland detected in head and neck CT images" + ], + "18": [ + "Left submandibular gland", + "Left submandibular gland in head and neck CT", + "CT imaging of the left submandibular gland in the head and neck region", + "Head and neck CT showing left submandibular gland structures", + "Visualization of the left submandibular gland in head and neck CT scans", + "Left submandibular gland observed in head and neck CT imaging", + "CT scan of the head and neck depicting the left submandibular gland", + "Left submandibular gland segmentation in head and neck CT", + "Left submandibular gland delineation in head and neck CT imaging", + "CT-based localization of the left submandibular gland in the head and neck" + ], + "19": [ + "Right submandibular gland", + "Right submandibular gland in head and neck CT", + "CT imaging of the right submandibular gland in the head and neck region", + "Head and neck CT showing right submandibular gland structures", + "Visualization of the right submandibular gland in head and neck CT scans", + "Right submandibular gland observed in head and neck CT imaging", + "CT scan of the head and neck highlighting the right submandibular gland", + "Right submandibular gland segmentation in head and neck CT", + "Right submandibular gland delineation in head and neck CT imaging", + "CT-based identification of the right submandibular gland in the head and neck" + ], + "2": [ + "Right carotid artery", + "Right carotid artery in Head and Neck CT imaging", + "Right carotid artery visualized in Head and Neck CT scans", + "Right carotid artery delineation via Head and Neck CT", + "Right carotid artery as a segmentation target in Head and Neck CT studies", + "Right carotid artery anatomic structure in Head and Neck CT", + "Right carotid artery depicted on Head and Neck CT imaging", + "Right carotid artery imaged within the Head and Neck via CT", + "Right carotid artery localization in Head and Neck CT examinations", + "Right carotid artery observed in Head and Neck CT scans", + "Right carotid artery visualization in Head and Neck CT" + ], + "20": [ + "Thyroid", + "Thyroid in head and neck CT", + "CT imaging of the thyroid in the head and neck region", + "Head and neck CT showing thyroid structures", + "Visualization of the thyroid in head and neck CT scans", + "Thyroid observed in head and neck CT imaging", + "CT scan of the head and neck depicting the thyroid", + "Thyroid segmentation in head and neck CT", + "Thyroid delineation in head and neck CT imaging", + "CT-based localization of the thyroid in the head and neck" + ], + "21": [ + "Larynx-glottis", + "Larynx-glottis in Head and Neck CT imaging", + "Larynx-glottis delineation in Head and Neck CT scans", + "Larynx-glottis segmentation target in Head and Neck CT studies", + "Larynx-glottis anatomical boundaries in Head and Neck CT", + "Larynx-glottis spatial localization in Head and Neck CT", + "Larynx-glottis visualized in Head and Neck CT datasets", + "CT imaging of the larynx-glottis in the head and neck region", + "Head and neck CT demonstrating the larynx-glottis" + ], + "22": [ + "Larynx-supraglottic", + "Larynx-supraglottic in Head and Neck CT imaging", + "Larynx-supraglottic anatomical demarcation in Head and Neck CT scans", + "Larynx-supraglottic segmentation in Head and Neck CT", + "Larynx-supraglottic spatial localization in Head and Neck CT", + "Larynx-supraglottic morphological delineation in Head and Neck CT", + "Larynx-supraglottic depicted in Head and Neck CT datasets", + "CT imaging of the larynx-supraglottic in the head and neck region", + "Head and neck CT showing the larynx-supraglottic structures" + ], + "23": [ + "Lips", + "Lips in Head and Neck CT imaging", + "Lips anatomical delineation in Head and Neck CT scans", + "Lips segmentation in Head and Neck CT", + "Lips visualization in Head and Neck CT", + "CT imaging of the lips in the head and neck region", + "Head and neck CT showing lip structures", + "Visualization of the lips in head and neck CT scans", + "Lips observed in head and neck CT imaging", + "CT scan of the head and neck showing the lips" + ], + "24": [ + "Optic chiasm", + "Optic chiasm in Head and Neck CT imaging", + "Optic chiasm anatomical demarcation in Head and Neck CT scans", + "Optic chiasm segmentation in Head and Neck CT", + "Optic chiasm morphological boundaries in Head and Neck CT imaging", + "Optic chiasm depicted in Head and Neck CT datasets", + "Optic chiasm structural delineation in Head and Neck CT", + "CT imaging of the optic chiasm in the head and neck region", + "Head and neck CT showing optic chiasm structures" + ], + "25": [ + "Left optic nerve", + "Left optic nerve in Head and Neck CT imaging", + "Left optic nerve anatomical course in Head and Neck CT scans", + "Left optic nerve segmentation in Head and Neck CT", + "Left optic nerve spatial localization in Head and Neck CT", + "Left optic nerve morphological boundaries in Head and Neck CT imaging", + "Left optic nerve structural delineation in Head and Neck CT", + "Left optic nerve depicted in Head and Neck CT datasets", + "Left optic nerve visualization in Head and Neck CT" + ], + "26": [ + "Right optic nerve", + "Right optic nerve in Head and Neck CT imaging", + "Right optic nerve anatomical course in Head and Neck CT scans", + "Right optic nerve segmentation focus in Head and Neck CT studies", + "Right optic nerve spatial localization in Head and Neck CT", + "Right optic nerve volumetric assessment via Head and Neck CT", + "Right optic nerve morphological boundaries in Head and Neck CT imaging", + "Right optic nerve structural delineation in Head and Neck CT", + "Right optic nerve depicted in Head and Neck CT datasets", + "Right optic nerve visualization in Head and Neck CT", + "Right optic nerve 3D reconstruction from Head and Neck CT" + ], + "27": [ + "Left parotid gland", + "Left parotid gland in Head and Neck CT", + "Left parotid gland segmentation in CT", + "Left parotid gland identification in CT imaging", + "Left parotid gland boundaries in CT", + "Left parotid gland localization in Head and Neck CT", + "Left parotid gland demarcation in CT", + "Left parotid gland structure in CT", + "Left parotid gland visualization in Head and Neck CT" + ], + "28": [ + "Right parotid gland", + "Right parotid gland in Head and Neck CT", + "Right parotid gland segmentation in CT", + "Right parotid gland identification in CT imaging", + "Right parotid gland boundaries in CT", + "Right parotid gland localization in Head and Neck CT", + "Right parotid gland demarcation in CT", + "Right parotid gland structure in CT", + "Right parotid gland visualization in Head and Neck CT" + ], + "29": [ + "Pituitary gland", + "Pituitary gland in head and neck CT", + "CT imaging of the pituitary gland in the head and neck region", + "Head and neck CT showing pituitary gland structures", + "Visualization of the pituitary gland in head and neck CT scans", + "Pituitary gland observed in head and neck CT imaging", + "CT scan of the head and neck showing the pituitary gland", + "Pituitary gland segmentation in head and neck CT" + ], + "3": [ + "Arytenoids delineation", + "Arytenoids delineation in head and neck CT imaging", + "Segmentation of arytenoids on head and neck computed tomography scans", + "Detection of arytenoids within head and neck CT acquisitions", + "Arytenoids localization in head and neck CT studies", + "Visualization of arytenoids via head and neck CT protocols", + "Boundary demarcation of arytenoids on head and neck CT examinations", + "Characterization of arytenoids using head and neck CT sequences", + "Identification of arytenoids in head and neck CT datasets" + ], + "30": [ + "Spinal cord", + "Spinal cord in head and neck CT", + "CT imaging of the spinal cord in the head and neck region", + "Head and neck CT showing spinal cord structures", + "Visualization of the spinal cord in head and neck CT scans", + "Spinal cord observed in head and neck CT imaging", + "CT scan of the head and neck showing the spinal cord", + "Spinal cord segmentation in head and neck CT" + ], + "5": [ + "Brainstem", + "Brainstem in head and neck CT", + "CT imaging of the brainstem in the head and neck region", + "Head and neck CT scan visualizing the brainstem", + "Brainstem segmentation in head and neck CT images", + "Structural analysis of the brainstem in head and neck CT", + "Brainstem morphology observed in head and neck CT imaging" + ], + "6": [ + "Buccal mucosa", + "Buccal mucosa in head and neck CT", + "CT imaging of the buccal mucosa in the head and neck region", + "Head and neck CT showing buccal mucosa structures", + "Visualization of the buccal mucosa in head and neck CT scans", + "Buccal mucosa observed in head and neck CT imaging", + "CT scan of the head and neck highlighting the buccal mucosa", + "Buccal mucosa segmentation in head and neck CT", + "Buccal mucosa delineation in head and neck CT imaging", + "CT-based identification of the buccal mucosa in the head and neck" + ], + "7": [ + "Oral cavity delineation", + "Oral cavity delineation in Head and Neck CT", + "Anatomical boundaries of the oral cavity in CT imaging", + "segmentation of the oral cavity in Head and Neck CT", + "CT-based demarcation of the oral cavity", + "Oral cavity segmentation in CT imaging", + "Oral cavity identification in Head and Neck CT", + "Oral cavity boundaries in CT", + "Oral cavity localization in CT", + "Oral cavity demarcation in Head and Neck CT", + "Oral cavity visualization in CT", + "Oral cavity structure in Head and Neck CT" + ], + "8": [ + "Left cochlea", + "Left cochlea in head and neck CT", + "CT imaging of the left cochlea in the head and neck region", + "Head and neck CT showing the left cochlea", + "Left cochlea observed in head and neck CT scans", + "Visualization of the left cochlea in head and neck CT imaging", + "CT scan of the left cochlea in the head and neck", + "Presence of the left cochlea detected in head and neck CT images", + "Left cochlea segmentation in head and neck CT imaging" + ], + "9": [ + "Right cochlea", + "Right cochlea in head and neck CT", + "CT imaging of the right cochlea in the head and neck region", + "Head and neck CT showing the right cochlea", + "Right cochlea observed in head and neck CT scans", + "Visualization of the right cochlea in head and neck CT imaging", + "CT scan of the right cochlea in the head and neck", + "Presence of the right cochlea detected in head and neck CT images", + "Right cochlea segmentation in head and neck CT imaging" + ], + "instance_label": 0 + }, + "CT_KidneyTumor": { + "1": [ + "Kidney lesions", + "Kidney lesions detected in abdominal CT imaging", + "Renal lesions identified via computed tomography of the abdomen", + "Focal kidney abnormalities visualized in cross-sectional abdominal scans", + "Lesions within the renal parenchyma characterized by abdominal CT", + "Kidney lesions mapped through tomographic imaging of the abdomen", + "Renal structural abnormalities documented in CT studies of the abdominal region", + "Kidney lesions demonstrating radiographic features on abdominal computed tomography", + "Intra-abdominal renal lesions assessed with contrast-enhanced CT protocols", + "Kidney lesions observed in multiphase abdominal CT scans", + "Pathological renal foci localized in abdominal tomographic imaging" + ], + "instance_label": 1 + }, + "CT_LiverTumor": { + "1": [ + "Liver tumors", + "Liver tumors detected in abdominal CT imaging", + "Hepatic neoplasms identified via computed tomography of the abdomen", + "Focal liver lesions visualized in cross-sectional abdominal scans", + "Liver tumors characterized by contrast-enhanced CT of the abdomen", + "Hepatic masses mapped through tomographic imaging of the abdominal cavity", + "Liver lesions documented in CT studies of the abdominal region", + "Liver tumors demonstrating radiographic features on abdominal computed tomography", + "Intra-abdominal hepatic neoplasms assessed with multiphase CT protocols", + "Liver lesions observed in volumetric abdominal CT acquisitions", + "Pathological hepatic foci localized in abdominal tomographic imaging" + ], + "instance_label": 1 + }, + "CT_LungLesion": { + "1": [ + "Lung lesions", + "Lung lesions identified in abdominal CT imaging", + "Lung lesions detected via computed tomography of the abdomen", + "Pulmonary lesions visualized in cross-sectional abdominal scans", + "Lung lesions characterized within the abdominal region through tomographic imaging", + "Lung lesions documented in CT studies of the abdomen", + "Lung lesions localized in abdominal computed tomography", + "Lung lesion observed in multiphase abdominal CT scans", + "Lung lesion demonstrating radiographic features on abdominal CT imaging", + "Pulmonary lesion mapped through contrast-enhanced abdominal CT", + "Lung lesion assessed in tomographic scans of the abdominal cavity" + ], + "instance_label": 1 + }, + "CT_Lungs": { + "1": [ + "Left lung", + "Left lung segmented in chest CT imaging", + "Left pulmonary structure identified via thoracic computed tomography", + "Left lung visualized in cross-sectional chest scans", + "Left-sided pulmonary anatomy characterized by chest CT", + "Left lung mapped through tomographic imaging of the thoracic cavity", + "Left lung documented in CT studies of the chest", + "Left lung demonstrating anatomical boundaries on thoracic computed tomography", + "Left pulmonary field assessed with high-resolution CT protocols", + "Left lung morphology observed in volumetric chest CT acquisitions", + "Left-sided intrathoracic lung structure localized in thoracic tomographic imaging" + ], + "2": [ + "Right lung", + "Right lung segmented in chest CT imaging", + "Right pulmonary structure identified via thoracic computed tomography", + "Right lung visualized in cross-sectional chest scans", + "Right-sided pulmonary anatomy characterized by chest CT", + "Right lung mapped through tomographic imaging of the thoracic cavity", + "Right lung documented in CT studies of the chest", + "Right lung demonstrating anatomical boundaries on thoracic computed tomography", + "Right pulmonary field assessed with high-resolution CT protocols", + "Right lung morphology observed in volumetric chest CT acquisitions", + "Right-sided intrathoracic lung structure localized in thoracic tomographic imaging" + ], + "instance_label": 0 + }, + "CT_LymphNode": { + "1": [ + "Lymph node", + "Lymph node detection in chest and abdomen CT", + "Lymph node identification within thoracoabdominal CT imaging", + "Lymph node localization in cross-sectional chest and abdominal scans", + "Mediastinal lymph node visualization in thoracic CT", + "Abdominal lymph node delineation in abdominal computed tomography", + "Pathological lymph node assessment in thoracoabdominal computed tomography", + "Mediastinal lymph node characterization in chest CT", + "Abdominal lymph node evaluation in abdominal cross-sectional imaging", + "Lymph node mapping in chest and abdomen CT" + ], + "instance_label": 1 + }, + "CT_PancreasTumor": { + "1": [ + "Pancreas tumors", + "Pancreas tumors detected in abdominal CT imaging", + "Pancreatic neoplasms identified via computed tomography of the abdomen", + "Focal pancreatic lesions visualized in cross-sectional abdominal scans", + "Pancreas tumors characterized by contrast-enhanced CT of the abdomen", + "Pancreatic masses mapped through tomographic imaging of the abdominal cavity", + "Pancreatic lesions documented in CT studies of the abdominal region", + "Pancreas tumors demonstrating radiographic features on abdominal computed tomography", + "Intra-abdominal pancreatic neoplasms assessed with multiphase CT protocols", + "Pancreatic lesions observed in volumetric abdominal CT acquisitions", + "Pathological pancreatic foci localized in abdominal tomographic imaging" + ], + "instance_label": 1 + }, + "CT_SegRap_HeadNeckTumor": { + "1": [ + "Head-neck cancer", + "Head-neck cancer in head and neck CT", + "CT imaging of head-neck cancer in the head and neck region", + "Head and neck CT showing head-neck cancer structures", + "Visualization of head-neck cancer in head and neck CT scans", + "Head-neck cancer observed in head and neck CT imaging", + "CT scan of the head and neck showing head-neck cancer", + "Head-neck cancer segmentation in head and neck CT", + "Abnormalities related to head-neck cancer detected by head and neck CT" + ], + "instance_label": 1 + }, + "CT_ThoracicOrgans-TCIA-LCTSC": { + "1": [ + "Esophagus", + "Esophagus segmented in chest CT imaging", + "Esophageal structure identified via thoracic computed tomography", + "Esophagus visualized in cross-sectional chest scans", + "Esophageal lumen characterized by contrast-enhanced CT of the chest", + "Esophagus mapped through tomographic imaging of the thoracic cavity", + "Esophageal anatomy documented in CT studies of the chest", + "Intrathoracic esophageal structure assessed with high-resolution CT protocols", + "Esophagus observed in volumetric chest CT acquisitions", + "Mediastinal esophageal segment localized in thoracic tomographic imaging" + ], + "2": [ + "Heart", + "Heart segmented in chest CT imaging", + "Cardiac structure identified via thoracic computed tomography", + "Heart visualized in cross-sectional chest scans", + "Myocardial anatomy characterized by contrast-enhanced CT of the chest", + "Heart mapped through tomographic imaging of the thoracic cavity", + "Cardiac anatomy documented in CT studies of the chest", + "Heart demonstrating anatomical boundaries on thoracic computed tomography", + "Intrathoracic cardiac structure assessed with high-resolution CT protocols", + "Heart observed in volumetric chest CT acquisitions", + "Mediastinal cardiac component localized in thoracic tomographic imaging" + ], + "3": [ + "Left lung", + "Left lung segmented in chest CT imaging", + "Left pulmonary structure identified via thoracic computed tomography", + "Left lung visualized in cross-sectional chest scans", + "Left-sided pulmonary anatomy characterized by chest CT", + "Left lung mapped through tomographic imaging of the thoracic cavity", + "Left lung documented in CT studies of the chest", + "Left lung demonstrating anatomical boundaries on thoracic computed tomography", + "Left pulmonary field assessed with high-resolution CT protocols", + "Left lung morphology observed in volumetric chest CT acquisitions", + "Left-sided intrathoracic lung structure localized in thoracic tomographic imaging" + ], + "4": [ + "Right lung", + "Right lung segmented in chest CT imaging", + "Right pulmonary structure identified via thoracic computed tomography", + "Right lung visualized in cross-sectional chest scans", + "Right-sided pulmonary anatomy characterized by chest CT", + "Right lung mapped through tomographic imaging of the thoracic cavity", + "Right lung documented in CT studies of the chest", + "Right lung demonstrating anatomical boundaries on thoracic computed tomography", + "Right pulmonary field assessed with high-resolution CT protocols", + "Right lung morphology observed in volumetric chest CT acquisitions", + "Right-sided intrathoracic lung structure localized in thoracic tomographic imaging" + ], + "5": [ + "Spinal cord", + "Spinal cord segmented in chest CT imaging", + "Spinal cord structure identified via thoracic computed tomography", + "Spinal cord visualized in cross-sectional chest scans", + "Thoracic spinal cord characterized by contrast-enhanced CT of the chest", + "Spinal cord mapped through tomographic imaging of the thoracic cavity", + "Spinal cord anatomy documented in CT studies of the chest region", + "Spinal cord demonstrating anatomical margins on thoracic computed tomography", + "Intrathoracic spinal cord assessed with high-resolution CT protocols", + "Spinal cord observed in volumetric chest CT acquisitions" + ], + "instance_label": 0 + }, + "CT_TotalSeg-vertebrae": { + "1": [ + "Sacrum", + "Sacrum in whole body CT", + "CT imaging of the sacrum in the whole body region", + "Whole body CT showing sacrum structures", + "Visualization of the sacrum in whole body CT scans", + "Sacrum observed in whole body CT imaging", + "CT scan of the whole body depicting the sacrum", + "Sacrum segmentation in whole body CT", + "Sacrum delineation in whole body CT imaging", + "CT-based localization of the sacrum in the whole body" + ], + "2": [ + "Vertebrae S1", + "Vertebrae S1 in whole body CT", + "CT imaging of vertebrae S1 in the whole body region", + "Whole body CT showing vertebrae S1 structures", + "Visualization of vertebrae S1 in whole body CT scans", + "Vertebrae S1 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae S1", + "Vertebrae S1 segmentation in whole body CT", + "Vertebrae S1 delineation in whole body CT imaging", + "CT-based identification of vertebrae S1 in the whole body" + ], + "3": [ + "Vertebrae L5", + "Vertebrae L5 in whole body CT", + "CT imaging of vertebrae L5 in the whole body region", + "Whole body CT showing vertebrae L5 structures", + "Visualization of vertebrae L5 in whole body CT scans", + "Vertebrae L5 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae L5", + "Vertebrae L5 segmentation in whole body CT", + "Vertebrae L5 delineation in whole body CT imaging", + "CT-based localization of vertebrae L5 in the whole body" + ], + "4": [ + "Vertebrae L4", + "Vertebrae L4 in whole body CT", + "CT imaging of vertebrae L4 in the whole body region", + "Whole body CT showing vertebrae L4 structures", + "Visualization of vertebrae L4 in whole body CT scans", + "Vertebrae L4 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae L4", + "Vertebrae L4 segmentation in whole body CT", + "Vertebrae L4 delineation in whole body CT imaging", + "CT-based identification of vertebrae L4 in the whole body" + ], + "5": [ + "Vertebrae L3", + "Vertebrae L3 in whole body CT", + "CT imaging of vertebrae L3 in the whole body region", + "Whole body CT showing vertebrae L3 structures", + "Visualization of vertebrae L3 in whole body CT scans", + "Vertebrae L3 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae L3", + "Vertebrae L3 segmentation in whole body CT", + "Vertebrae L3 delineation in whole body CT imaging", + "CT-based localization of vertebrae L3 in the whole body" + ], + "6": [ + "Vertebrae L2", + "Vertebrae L2 in whole body CT", + "CT imaging of vertebrae L2 in the whole body region", + "Whole body CT showing vertebrae L2 structures", + "Visualization of vertebrae L2 in whole body CT scans", + "Vertebrae L2 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae L2", + "Vertebrae L2 segmentation in whole body CT", + "Vertebrae L2 delineation in whole body CT imaging", + "CT-based identification of vertebrae L2 in the whole body" + ], + "7": [ + "Vertebrae L1", + "Vertebrae L1 in whole body CT", + "CT imaging of vertebrae L1 in the whole body region", + "Whole body CT showing vertebrae L1 structures", + "Visualization of vertebrae L1 in whole body CT scans", + "Vertebrae L1 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae L1", + "Vertebrae L1 segmentation in whole body CT", + "Vertebrae L1 delineation in whole body CT imaging", + "CT-based localization of vertebrae L1 in the whole body" + ], + "8": [ + "Vertebrae T12", + "Vertebrae T12 in whole body CT", + "CT imaging of vertebrae T12 in the whole body region", + "Whole body CT showing vertebrae T12 structures", + "Visualization of vertebrae T12 in whole body CT scans", + "Vertebrae T12 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae T12", + "Vertebrae T12 segmentation in whole body CT", + "Vertebrae T12 delineation in whole body CT imaging", + "CT-based identification of vertebrae T12 in the whole body" + ], + "9": [ + "Vertebrae T11", + "Vertebrae T11 in whole body CT", + "CT imaging of vertebrae T11 in the whole body region", + "Whole body CT showing vertebrae T11 structures", + "Visualization of vertebrae T11 in whole body CT scans", + "Vertebrae T11 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae T11", + "Vertebrae T11 segmentation in whole body CT", + "Vertebrae T11 delineation in whole body CT imaging", + "CT-based localization of vertebrae T11 in the whole body" + ], + "10": [ + "Vertebrae T10", + "Vertebrae T10 in whole body CT", + "CT imaging of vertebrae T10 in the whole body region", + "Whole body CT showing vertebrae T10 structures", + "Visualization of vertebrae T10 in whole body CT scans", + "Vertebrae T10 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae T10", + "Vertebrae T10 segmentation in whole body CT", + "Vertebrae T10 delineation in whole body CT imaging", + "CT-based identification of vertebrae T10 in the whole body" + ], + "11": [ + "Vertebrae T9", + "Vertebrae T9 in whole body CT", + "CT imaging of vertebrae T9 in the whole body region", + "Whole body CT showing vertebrae T9 structures", + "Visualization of vertebrae T9 in whole body CT scans", + "Vertebrae T9 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae T9", + "Vertebrae T9 segmentation in whole body CT", + "Vertebrae T9 delineation in whole body CT imaging", + "CT-based localization of vertebrae T9 in the whole body" + ], + "12": [ + "Vertebrae T8", + "Vertebrae T8 in whole body CT", + "CT imaging of vertebrae T8 in the whole body region", + "Whole body CT showing vertebrae T8 structures", + "Visualization of vertebrae T8 in whole body CT scans", + "Vertebrae T8 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae T8", + "Vertebrae T8 segmentation in whole body CT", + "Vertebrae T8 delineation in whole body CT imaging", + "CT-based identification of vertebrae T8 in the whole body" + ], + "13": [ + "Vertebrae T7", + "Vertebrae T7 in whole body CT", + "CT imaging of vertebrae T7 in the whole body region", + "Whole body CT showing vertebrae T7 structures", + "Visualization of vertebrae T7 in whole body CT scans", + "Vertebrae T7 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae T7", + "Vertebrae T7 segmentation in whole body CT", + "Vertebrae T7 delineation in whole body CT imaging", + "CT-based localization of vertebrae T7 in the whole body" + ], + "14": [ + "Vertebrae T6", + "Vertebrae T6 in whole body CT", + "CT imaging of vertebrae T6 in the whole body region", + "Whole body CT showing vertebrae T6 structures", + "Visualization of vertebrae T6 in whole body CT scans", + "Vertebrae T6 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae T6", + "Vertebrae T6 segmentation in whole body CT", + "Vertebrae T6 delineation in whole body CT imaging", + "CT-based identification of vertebrae T6 in the whole body" + ], + "15": [ + "Vertebrae T5", + "Vertebrae T5 in whole body CT", + "CT imaging of vertebrae T5 in the whole body region", + "Whole body CT showing vertebrae T5 structures", + "Visualization of vertebrae T5 in whole body CT scans", + "Vertebrae T5 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae T5", + "Vertebrae T5 segmentation in whole body CT", + "Vertebrae T5 delineation in whole body CT imaging", + "CT-based localization of vertebrae T5 in the whole body" + ], + "16": [ + "Vertebrae T4", + "Vertebrae T4 in whole body CT", + "CT imaging of vertebrae T4 in the whole body region", + "Whole body CT showing vertebrae T4 structures", + "Visualization of vertebrae T4 in whole body CT scans", + "Vertebrae T4 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae T4", + "Vertebrae T4 segmentation in whole body CT", + "Vertebrae T4 delineation in whole body CT imaging", + "CT-based identification of vertebrae T4 in the whole body" + ], + "17": [ + "Vertebrae T3", + "Vertebrae T3 in whole body CT", + "CT imaging of vertebrae T3 in the whole body region", + "Whole body CT showing vertebrae T3 structures", + "Visualization of vertebrae T3 in whole body CT scans", + "Vertebrae T3 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae T3", + "Vertebrae T3 segmentation in whole body CT", + "Vertebrae T3 delineation in whole body CT imaging", + "CT-based localization of vertebrae T3 in the whole body" + ], + "18": [ + "Vertebrae T2", + "Vertebrae T2 in whole body CT", + "CT imaging of vertebrae T2 in the whole body region", + "Whole body CT showing vertebrae T2 structures", + "Visualization of vertebrae T2 in whole body CT scans", + "Vertebrae T2 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae T2", + "Vertebrae T2 segmentation in whole body CT", + "Vertebrae T2 delineation in whole body CT imaging", + "CT-based identification of vertebrae T2 in the whole body" + ], + "19": [ + "Vertebrae T1", + "Vertebrae T1 in whole body CT", + "CT imaging of vertebrae T1 in the whole body region", + "Whole body CT showing vertebrae T1 structures", + "Visualization of vertebrae T1 in whole body CT scans", + "Vertebrae T1 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae T1", + "Vertebrae T1 segmentation in whole body CT", + "Vertebrae T1 delineation in whole body CT imaging", + "CT-based localization of vertebrae T1 in the whole body" + ], + "20": [ + "Vertebrae C7", + "Vertebrae C7 in whole body CT", + "CT imaging of vertebrae C7 in the whole body region", + "Whole body CT showing vertebrae C7 structures", + "Visualization of vertebrae C7 in whole body CT scans", + "Vertebrae C7 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae C7", + "Vertebrae C7 segmentation in whole body CT", + "Vertebrae C7 delineation in whole body CT imaging", + "CT-based identification of vertebrae C7 in the whole body" + ], + "21": [ + "Vertebrae C6", + "Vertebrae C6 in whole body CT", + "CT imaging of vertebrae C6 in the whole body region", + "Whole body CT showing vertebrae C6 structures", + "Visualization of vertebrae C6 in whole body CT scans", + "Vertebrae C6 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae C6", + "Vertebrae C6 segmentation in whole body CT", + "Vertebrae C6 delineation in whole body CT imaging", + "CT-based localization of vertebrae C6 in the whole body" + ], + "22": [ + "Vertebrae C5", + "Vertebrae C5 in whole body CT", + "CT imaging of vertebrae C5 in the whole body region", + "Whole body CT showing vertebrae C5 structures", + "Visualization of vertebrae C5 in whole body CT scans", + "Vertebrae C5 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae C5", + "Vertebrae C5 segmentation in whole body CT", + "Vertebrae C5 delineation in whole body CT imaging", + "CT-based identification of vertebrae C5 in the whole body" + ], + "23": [ + "Vertebrae C4", + "Vertebrae C4 in whole body CT", + "CT imaging of vertebrae C4 in the whole body region", + "Whole body CT showing vertebrae C4 structures", + "Visualization of vertebrae C4 in whole body CT scans", + "Vertebrae C4 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae C4", + "Vertebrae C4 segmentation in whole body CT", + "Vertebrae C4 delineation in whole body CT imaging", + "CT-based localization of vertebrae C4 in the whole body" + ], + "24": [ + "Vertebrae C3", + "Vertebrae C3 in whole body CT", + "CT imaging of vertebrae C3 in the whole body region", + "Whole body CT showing vertebrae C3 structures", + "Visualization of vertebrae C3 in whole body CT scans", + "Vertebrae C3 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae C3", + "Vertebrae C3 segmentation in whole body CT", + "Vertebrae C3 delineation in whole body CT imaging", + "CT-based identification of vertebrae C3 in the whole body" + ], + "25": [ + "Vertebrae C2", + "Vertebrae C2 in whole body CT", + "CT imaging of vertebrae C2 in the whole body region", + "Whole body CT showing vertebrae C2 structures", + "Visualization of vertebrae C2 in whole body CT scans", + "Vertebrae C2 observed in whole body CT imaging", + "CT scan of the whole body depicting vertebrae C2", + "Vertebrae C2 segmentation in whole body CT", + "Vertebrae C2 delineation in whole body CT imaging", + "CT-based localization of vertebrae C2 in the whole body" + ], + "26": [ + "Vertebrae C1", + "Vertebrae C1 in whole body CT", + "CT imaging of vertebrae C1 in the whole body region", + "Whole body CT showing vertebrae C1 structures", + "Visualization of vertebrae C1 in whole body CT scans", + "Vertebrae C1 observed in whole body CT imaging", + "CT scan of the whole body highlighting vertebrae C1", + "Vertebrae C1 segmentation in whole body CT", + "Vertebrae C1 delineation in whole body CT imaging", + "CT-based identification of vertebrae C1 in the whole body" + ], + "instance_label": 0 + }, + "CT_TotalSeg_organs": { + "1": [ + "Spleen", + "Spleen in whole body CT", + "CT imaging of the spleen in the whole body region", + "Whole body CT showing spleen structures", + "Visualization of the spleen in whole body CT scans", + "Spleen observed in whole body CT imaging", + "CT scan of the whole body depicting the spleen", + "Spleen segmentation in whole body CT", + "Spleen delineation in whole body CT imaging", + "CT-based localization of the spleen in the whole body" + ], + "2": [ + "Right kidney", + "Right kidney in whole body CT", + "CT imaging of the right kidney in the whole body region", + "Whole body CT showing right kidney structures", + "Visualization of the right kidney in whole body CT scans", + "Right kidney observed in whole body CT imaging", + "CT scan of the whole body highlighting the right kidney", + "Right kidney segmentation in whole body CT", + "Right kidney delineation in whole body CT imaging", + "CT-based identification of the right kidney in the whole body" + ], + "3": [ + "Left kidney", + "Left kidney in whole body CT", + "CT imaging of the left kidney in the whole body region", + "Whole body CT showing left kidney structures", + "Visualization of the left kidney in whole body CT scans", + "Left kidney observed in whole body CT imaging", + "CT scan of the whole body depicting the left kidney", + "Left kidney segmentation in whole body CT", + "Left kidney delineation in whole body CT imaging", + "CT-based localization of the left kidney in the whole body" + ], + "4": [ + "Gallbladder", + "Gallbladder in whole body CT", + "CT imaging of the gallbladder in the whole body region", + "Whole body CT showing gallbladder structures", + "Visualization of the gallbladder in whole body CT scans", + "Gallbladder observed in whole body CT imaging", + "CT scan of the whole body highlighting the gallbladder", + "Gallbladder segmentation in whole body CT", + "Gallbladder delineation in whole body CT imaging", + "CT-based identification of the gallbladder in the whole body" + ], + "5": [ + "Liver", + "Liver in whole body CT", + "CT imaging of the liver in the whole body region", + "Whole body CT showing liver structures", + "Liver observed in whole body CT imaging", + "CT scan of the whole body depicting the liver", + "Liver segmentation in whole body CT", + "Liver delineation in whole body CT imaging", + "CT-based localization of the liver in the whole body" + ], + "6": [ + "Stomach", + "Stomach in whole body CT", + "CT imaging of the stomach in the whole body region", + "Whole body CT showing stomach structures", + "Visualization of the stomach in whole body CT scans", + "Stomach observed in whole body CT imaging", + "CT scan of the whole body highlighting the stomach", + "Stomach segmentation in whole body CT", + "Stomach delineation in whole body CT imaging", + "CT-based identification of the stomach in the whole body" + ], + "7": [ + "Pancreas", + "Pancreas in whole body CT", + "CT imaging of the pancreas in the whole body region", + "Whole body CT showing pancreas structures", + "Visualization of the pancreas in whole body CT scans", + "Pancreas observed in whole body CT imaging", + "CT scan of the whole body depicting the pancreas", + "Pancreas segmentation in whole body CT", + "Pancreas delineation in whole body CT imaging", + "CT-based localization of the pancreas in the whole body" + ], + "8": [ + "Right adrenal gland", + "Right adrenal gland in whole body CT", + "CT imaging of the right adrenal gland in the whole body region", + "Whole body CT showing right adrenal gland structures", + "Visualization of the right adrenal gland in whole body CT scans", + "Right adrenal gland observed in whole body CT imaging", + "CT scan of the whole body highlighting the right adrenal gland", + "Right adrenal gland segmentation in whole body CT", + "Right adrenal gland delineation in whole body CT imaging", + "CT-based identification of the right adrenal gland in the whole body" + ], + "9": [ + "Left adrenal gland", + "Left adrenal gland in whole body CT", + "CT imaging of the left adrenal gland in the whole body region", + "Whole body CT showing left adrenal gland structures", + "Visualization of the left adrenal gland in whole body CT scans", + "Left adrenal gland observed in whole body CT imaging", + "CT scan of the whole body depicting the left adrenal gland", + "Left adrenal gland segmentation in whole body CT", + "Left adrenal gland delineation in whole body CT imaging", + "CT-based localization of the left adrenal gland in the whole body" + ], + "10": [ + "Lung upper lobe left", + "Lung upper lobe left in whole body CT", + "CT imaging of the left upper lung lobe in the whole body region", + "Whole body CT showing left upper lung lobe structures", + "Visualization of the left upper lung lobe in whole body CT scans", + "Left upper lung lobe observed in whole body CT imaging", + "CT scan of the whole body highlighting the left upper lung lobe", + "Left upper lung lobe segmentation in whole body CT", + "Left upper lung lobe delineation in whole body CT imaging", + "CT-based identification of the left upper lung lobe in the whole body" + ], + "11": [ + "Lung lower lobe left", + "Lung lower lobe left in whole body CT", + "CT imaging of the left lower lung lobe in the whole body region", + "Whole body CT showing left lower lung lobe structures", + "Visualization of the left lower lung lobe in whole body CT scans", + "Left lower lung lobe observed in whole body CT imaging", + "CT scan of the whole body depicting the left lower lung lobe", + "Left lower lung lobe segmentation in whole body CT", + "Left lower lung lobe delineation in whole body CT imaging", + "CT-based localization of the left lower lung lobe in the whole body" + ], + "12": [ + "Lung upper lobe right", + "Lung upper lobe right in whole body CT", + "CT imaging of the right upper lung lobe in the whole body region", + "Whole body CT showing right upper lung lobe structures", + "Visualization of the right upper lung lobe in whole body CT scans", + "Right upper lung lobe observed in whole body CT imaging", + "CT scan of the whole body highlighting the right upper lung lobe", + "Right upper lung lobe segmentation in whole body CT", + "Right upper lung lobe delineation in whole body CT imaging", + "CT-based identification of the right upper lung lobe in the whole body" + ], + "13": [ + "Lung middle lobe right", + "Lung middle lobe right in whole body CT", + "CT imaging of the right middle lung lobe in the whole body region", + "Whole body CT showing right middle lung lobe structures", + "Visualization of the right middle lung lobe in whole body CT scans", + "Right middle lung lobe observed in whole body CT imaging", + "CT scan of the whole body depicting the right middle lung lobe", + "Right middle lung lobe segmentation in whole body CT", + "Right middle lung lobe delineation in whole body CT imaging", + "CT-based localization of the right middle lung lobe in the whole body" + ], + "14": [ + "Lung lower lobe right", + "Lung lower lobe right in whole body CT", + "CT imaging of the right lower lung lobe in the whole body region", + "Whole body CT showing right lower lung lobe structures", + "Visualization of the right lower lung lobe in whole body CT scans", + "Right lower lung lobe observed in whole body CT imaging", + "CT scan of the whole body highlighting the right lower lung lobe", + "Right lower lung lobe segmentation in whole body CT", + "Right lower lung lobe delineation in whole body CT imaging", + "CT-based identification of the right lower lung lobe in the whole body" + ], + "15": [ + "Esophagus", + "Esophagus in whole body CT", + "CT imaging of the esophagus in the whole body region", + "Whole body CT showing esophageal structures", + "Visualization of the esophagus in whole body CT scans", + "Esophagus observed in whole body CT imaging", + "CT scan of the whole body depicting the esophagus", + "Esophagus segmentation in whole body CT", + "Esophagus delineation in whole body CT imaging", + "CT-based localization of the esophagus in the whole body" + ], + "16": [ + "Trachea", + "Trachea in whole body CT", + "CT imaging of the trachea in the whole body region", + "Whole body CT showing tracheal structures", + "Visualization of the trachea in whole body CT scans", + "Trachea observed in whole body CT imaging", + "CT scan of the whole body highlighting the trachea", + "Trachea segmentation in whole body CT", + "Trachea delineation in whole body CT imaging", + "CT-based identification of the trachea in the whole body" + ], + "17": [ + "Thyroid gland", + "Thyroid gland in whole body CT", + "CT imaging of the thyroid gland in the whole body region", + "Whole body CT showing thyroid gland structures", + "Visualization of the thyroid gland in whole body CT scans", + "Thyroid gland observed in whole body CT imaging", + "CT scan of the whole body depicting the thyroid gland", + "Thyroid gland segmentation in whole body CT", + "Thyroid gland delineation in whole body CT imaging", + "CT-based localization of the thyroid gland in the whole body" + ], + "18": [ + "Small bowel", + "Small bowel in whole body CT", + "CT imaging of the small bowel in the whole body region", + "Whole body CT showing small bowel structures", + "Visualization of the small bowel in whole body CT scans", + "Small bowel observed in whole body CT imaging", + "CT scan of the whole body highlighting the small bowel", + "Small bowel segmentation in whole body CT", + "Small bowel delineation in whole body CT imaging", + "CT-based identification of the small bowel in the whole body" + ], + "19": [ + "Duodenum", + "Duodenum in whole body CT", + "CT imaging of the duodenum in the whole body region", + "Whole body CT showing duodenal structures", + "Visualization of the duodenum in whole body CT scans", + "Duodenum observed in whole body CT imaging", + "CT scan of the whole body depicting the duodenum", + "Duodenum segmentation in whole body CT", + "Duodenum delineation in whole body CT imaging", + "CT-based localization of the duodenum in the whole body" + ], + "20": [ + "Colon", + "Colon in whole body CT", + "CT imaging of the colon in the whole body region", + "Whole body CT showing colonic structures", + "Visualization of the colon in whole body CT scans", + "Colon observed in whole body CT imaging", + "CT scan of the whole body highlighting the colon", + "Colon segmentation in whole body CT", + "Colon delineation in whole body CT imaging", + "CT-based identification of the colon in the whole body" + ], + "21": [ + "Urinary_bladder", + "Urinary_bladder in whole body CT", + "CT imaging of the urinary bladder in the whole body region", + "Whole body CT showing urinary bladder structures", + "Visualization of the urinary bladder in whole body CT scans", + "Urinary bladder observed in whole body CT imaging", + "CT scan of the whole body depicting the urinary bladder", + "Urinary bladder segmentation in whole body CT", + "Urinary bladder delineation in whole body CT imaging", + "CT-based localization of the urinary bladder in the whole body" + ], + "22": [ + "Prostate", + "Prostate in whole body CT", + "CT imaging of the prostate in the whole body region", + "Whole body CT showing prostatic structures", + "Visualization of the prostate in whole body CT scans", + "Prostate observed in whole body CT imaging", + "CT scan of the whole body highlighting the prostate", + "Prostate segmentation in whole body CT", + "Prostate delineation in whole body CT imaging", + "CT-based identification of the prostate in the whole body" + ], + "23": [ + "Left kidney cyst", + "Left kidney cyst in whole body CT", + "CT imaging of the left kidney cyst in the whole body region", + "Whole body CT showing left kidney cyst structures", + "Visualization of the left kidney cyst in whole body CT scans", + "Left kidney cyst observed in whole body CT imaging", + "CT scan of the whole body depicting the left kidney cyst", + "Left kidney cyst segmentation in whole body CT", + "Left kidney cyst delineation in whole body CT imaging", + "CT-based localization of the left kidney cyst in the whole body" + ], + "24": [ + "Right kidney cyst", + "Right kidney cyst in whole body CT", + "CT imaging of the right kidney cyst in the whole body region", + "Whole body CT showing right kidney cyst structures", + "Visualization of the right kidney cyst in whole body CT scans", + "Right kidney cyst observed in whole body CT imaging", + "CT scan of the whole body highlighting the right kidney cyst", + "Right kidney cyst segmentation in whole body CT", + "Right kidney cyst delineation in whole body CT imaging", + "CT-based identification of the right kidney cyst in the whole body" + ], + "instance_label": 0 + }, + "CT_TotalSeg_muscles": { + "1": [ + "Left humerus", + "Left humerus in whole body CT", + "Left humeral bone in whole body CT imaging", + "CT visualization of the left humeral shaft", + "Whole body CT showing left proximal humerus", + "Left humeral diaphysis in CT", + "CT-based segmentation of the left humerus", + "CT scan depicting left humeral anatomy", + "Left humeral head localization in CT" + ], + "2": [ + "Right humerus", + "Right humerus in whole body CT", + "Right humeral bone in whole body CT imaging", + "CT visualization of the right humeral shaft", + "Whole body CT showing right proximal humerus", + "Right humeral diaphysis in CT", + "CT-based segmentation of the right humerus", + "CT scan depicting right humeral anatomy", + "Right humeral head localization in CT" + ], + "3": [ + "Left scapula", + "Left scapula in whole body CT", + "Left shoulder blade in CT imaging", + "CT visualization of the left scapular body", + "CT-based segmentation of the left scapula", + "CT scan depicting left scapular anatomy" + ], + "4": [ + "Right scapula", + "Right scapula in whole body CT", + "Right shoulder blade in CT imaging", + "CT visualization of the right scapular body", + "CT-based segmentation of the right scapula", + "CT scan depicting right scapular anatomy" + ], + "5": [ + "Left clavicula", + "Left clavicula in whole body CT", + "CT imaging of the left clavicula in the whole body region", + "Whole body CT showing left clavicular structures", + "Visualization of the left clavicula in whole body CT scans", + "Left clavicula observed in whole body CT imaging", + "CT scan of the whole body depicting the left clavicula", + "Left clavicula segmentation in whole body CT", + "Left clavicula delineation in whole body CT imaging", + "CT-based localization of the left clavicula in the whole body", + "Left clavicular anatomy in whole body CT" + ], + "6": [ + "Right clavicula", + "Right clavicula in whole body CT", + "CT imaging of the right clavicula in the whole body region", + "Whole body CT showing right clavicular structures", + "Visualization of the right clavicula in whole body CT scans", + "Right clavicula observed in whole body CT imaging", + "CT scan of the whole body highlighting the right clavicula", + "Right clavicula segmentation in whole body CT", + "Right clavicula delineation in whole body CT imaging", + "CT-based identification of the right clavicula in the whole body", + "Right clavicular anatomy in whole body CT" + ], + "7": [ + "Left femur", + "Left femur in whole body CT", + "CT imaging of the left femur in the whole body region", + "Whole body CT showing left femoral structures", + "Visualization of the left femur in whole body CT scans", + "Left femur observed in whole body CT imaging", + "CT scan of the whole body depicting the left femur", + "Left femur segmentation in whole body CT", + "Left femur delineation in whole body CT imaging", + "CT-based localization of the left femur in the whole body", + "Left femoral anatomy in whole body CT" + ], + "8": [ + "Right femur", + "Right femur in whole body CT", + "CT imaging of the right femur in the whole body region", + "Whole body CT showing right femoral structures", + "Visualization of the right femur in whole body CT scans", + "Right femur observed in whole body CT imaging", + "CT scan of the whole body highlighting the right femur", + "Right femur segmentation in whole body CT", + "Right femur delineation in whole body CT imaging", + "CT-based identification of the right femur in the whole body", + "Right femoral anatomy in whole body CT" + ], + "9": [ + "Left hip", + "Left hip in whole body CT", + "CT imaging of the left hip in the whole body region", + "Whole body CT showing left hip structures", + "Visualization of the left hip in whole body CT scans", + "Left hip observed in whole body CT imaging", + "CT scan of the whole body depicting the left hip", + "Left hip segmentation in whole body CT", + "Left hip delineation in whole body CT imaging", + "CT-based localization of the left hip in the whole body", + "Left hip anatomy in whole body CT" + ], + "10": [ + "Right hip", + "Right hip in whole body CT", + "CT imaging of the right hip in the whole body region", + "Whole body CT showing right hip structures", + "Visualization of the right hip in whole body CT scans", + "Right hip observed in whole body CT imaging", + "CT scan of the whole body highlighting the right hip", + "Right hip segmentation in whole body CT", + "Right hip delineation in whole body CT imaging", + "CT-based identification of the right hip in the whole body", + "Right hip anatomy in whole body CT" + ], + "11": [ + "Spinal cord", + "Spinal cord in whole body CT", + "CT imaging of the spinal cord in the whole body region", + "Whole body CT showing spinal cord structures", + "Visualization of the spinal cord in whole body CT scans", + "Spinal cord observed in whole body CT imaging", + "CT scan of the whole body depicting the spinal cord", + "Spinal cord segmentation in whole body CT", + "Spinal cord delineation in whole body CT imaging", + "CT-based localization of the spinal cord in the whole body", + "Spinal cord anatomy in whole body CT" + ], + "12": [ + "Left gluteus Maximus", + "Left gluteus Maximus in whole body CT", + "CT imaging of the left gluteus Maximus in the whole body region", + "Whole body CT showing left gluteus Maximus structures", + "Visualization of the left gluteus Maximus in whole body CT scans", + "Left gluteus Maximus observed in whole body CT imaging", + "CT scan of the whole body depicting the left gluteus Maximus", + "Left gluteus Maximus segmentation in whole body CT", + "Left gluteus Maximus delineation in whole body CT imaging", + "CT-based localization of the left gluteus Maximus in the whole body", + "Left gluteus Maximus anatomy in whole body CT" + ], + "13": [ + "Right gluteus maximus", + "Right gluteus maximus in whole body CT", + "CT imaging of the right gluteus maximus in the whole body region", + "Whole body CT showing right gluteus maximus structures", + "Visualization of the right gluteus maximus in whole body CT scans", + "Right gluteus maximus observed in whole body CT imaging", + "CT scan of the whole body highlighting the right gluteus maximus", + "Right gluteus maximus segmentation in whole body CT", + "Right gluteus maximus delineation in whole body CT imaging", + "CT-based identification of the right gluteus maximus in the whole body", + "Right gluteus maximus anatomy in whole body CT" + ], + "14": [ + "Left gluteus medius", + "Left gluteus medius in whole body CT", + "CT imaging of the left gluteus medius in the whole body region", + "Whole body CT showing left gluteus medius structures", + "Visualization of the left gluteus medius in whole body CT scans", + "Left gluteus medius observed in whole body CT imaging", + "CT scan of the whole body depicting the left gluteus medius", + "Left gluteus medius segmentation in whole body CT", + "Left gluteus medius delineation in whole body CT imaging", + "CT-based localization of the left gluteus medius in the whole body", + "Left gluteus medius anatomy in whole body CT" + ], + "15": [ + "Right gluteus medius", + "Right gluteus medius in whole body CT", + "CT imaging of the right gluteus medius in the whole body region", + "Whole body CT showing right gluteus medius structures", + "Visualization of the right gluteus medius in whole body CT scans", + "Right gluteus medius observed in whole body CT imaging", + "CT scan of the whole body highlighting the right gluteus medius", + "Right gluteus medius segmentation in whole body CT", + "Right gluteus medius delineation in whole body CT imaging", + "CT-based identification of the right gluteus medius in the whole body", + "Right gluteus medius anatomy in whole body CT" + ], + "16": [ + "Left gluteus minimus", + "Left gluteus minimus in whole body CT", + "CT imaging of the left gluteus minimus in the whole body region", + "Whole body CT showing left gluteus minimus structures", + "Visualization of the left gluteus minimus in whole body CT scans", + "Left gluteus minimus observed in whole body CT imaging", + "CT scan of the whole body depicting the left gluteus minimus", + "Left gluteus minimus segmentation in whole body CT", + "Left gluteus minimus delineation in whole body CT imaging", + "CT-based localization of the left gluteus minimus in the whole body", + "Left gluteus minimus anatomy in whole body CT" + ], + "17": [ + "Right gluteus minimus", + "Right gluteus minimus in whole body CT", + "CT imaging of the right gluteus minimus in the whole body region", + "Whole body CT showing right gluteus minimus structures", + "Visualization of the right gluteus minimus in whole body CT scans", + "Right gluteus minimus observed in whole body CT imaging", + "CT scan of the whole body highlighting the right gluteus minimus", + "Right gluteus minimus segmentation in whole body CT", + "Right gluteus minimus delineation in whole body CT imaging", + "CT-based identification of the right gluteus minimus in the whole body", + "Right gluteus minimus anatomy in whole body CT" + ], + "18": [ + "Left autochthon", + "Left autochthon in whole body CT", + "CT imaging of the left autochthon in the whole body region", + "Whole body CT showing left autochthon structures", + "Visualization of the left autochthon in whole body CT scans", + "Left autochthon observed in whole body CT imaging", + "CT scan of the whole body depicting the left autochthon", + "Left autochthon segmentation in whole body CT", + "Left autochthon delineation in whole body CT imaging", + "CT-based localization of the left autochthon in the whole body", + "Left autochthon anatomy in whole body CT" + ], + "19": [ + "Right autochthon", + "Right autochthon in whole body CT", + "CT imaging of the right autochthon in the whole body region", + "Whole body CT showing right autochthon structures", + "Visualization of the right autochthon in whole body CT scans", + "Right autochthon observed in whole body CT imaging", + "CT scan of the whole body highlighting the right autochthon", + "Right autochthon segmentation in whole body CT", + "Right autochthon delineation in whole body CT imaging", + "CT-based identification of the right autochthon in the whole body", + "Right autochthon anatomy in whole body CT" + ], + "20": [ + "Left iliopsoas", + "Left iliopsoas in whole body CT", + "CT imaging of the left iliopsoas in the whole body region", + "Whole body CT showing left iliopsoas structures", + "Visualization of the left iliopsoas in whole body CT scans", + "Left iliopsoas observed in whole body CT imaging", + "CT scan of the whole body depicting the left iliopsoas", + "Left iliopsoas segmentation in whole body CT", + "Left iliopsoas delineation in whole body CT imaging", + "CT-based localization of the left iliopsoas in the whole body", + "Left iliopsoas anatomy in whole body CT" + ], + "21": [ + "Right iliopsoas", + "Right iliopsoas in whole body CT", + "CT imaging of the right iliopsoas in the whole body region", + "Whole body CT showing right iliopsoas structures", + "Visualization of the right iliopsoas in whole body CT scans", + "Right iliopsoas observed in whole body CT imaging", + "CT scan of the whole body highlighting the right iliopsoas", + "Right iliopsoas segmentation in whole body CT", + "Right iliopsoas delineation in whole body CT imaging", + "CT-based identification of the right iliopsoas in the whole body", + "Right iliopsoas anatomy in whole body CT" + ], + "22": [ + "Brain", + "Brain in whole body CT", + "CT imaging of the brain in the whole body region", + "Whole body CT showing brain structures", + "Visualization of the brain in whole body CT scans", + "Brain observed in whole body CT imaging", + "CT scan of the whole body depicting the brain", + "Brain segmentation in whole body CT", + "Brain delineation in whole body CT imaging", + "CT-based localization of the brain in the whole body", + "Brain anatomy in whole body CT" + ], + "23": [ + "Skull", + "Skull in whole body CT", + "CT imaging of the skull in the whole body region", + "Whole body CT showing skull structures", + "Visualization of the skull in whole body CT scans", + "Skull observed in whole body CT imaging", + "CT scan of the whole body highlighting the skull", + "Skull segmentation in whole body CT", + "Skull delineation in whole body CT imaging", + "CT-based identification of the skull in the whole body", + "Skull anatomy in whole body CT" + ], + "instance_label": 0 + }, + "CT_TotalSeg_cardiac": { + "1": [ + "Heart", + "Heart in whole body CT", + "CT imaging of the heart in the whole body region", + "Whole body CT showing cardiac structures", + "Visualization of the heart in whole body CT scans", + "Heart observed in whole body CT imaging", + "CT scan of the whole body depicting the heart", + "Heart segmentation in whole body CT", + "Heart delineation in whole body CT imaging", + "CT-based localization of the heart in the whole body", + "Cardiac anatomy in whole body CT" + ], + "2": [ + "Aorta", + "Aorta in whole body CT", + "CT imaging of the aorta in the whole body region", + "Whole body CT showing aortic structures", + "Visualization of the aorta in whole body CT scans", + "Aorta observed in whole body CT imaging", + "CT scan of the whole body highlighting the aorta", + "Aorta segmentation in whole body CT", + "Aorta delineation in whole body CT imaging", + "CT-based identification of the aorta in the whole body" + ], + "3": [ + "Pulmonary vein", + "Pulmonary vein in whole body CT", + "CT imaging of the pulmonary vein in the whole body region", + "Whole body CT showing pulmonary venous structures", + "Visualization of the pulmonary vein in whole body CT scans", + "Pulmonary vein observed in whole body CT imaging", + "CT scan of the whole body depicting the pulmonary vein", + "Pulmonary vein segmentation in whole body CT", + "Pulmonary vein delineation in whole body CT imaging", + "CT-based localization of the pulmonary vein in the whole body", + "Pulmonary venous anatomy in whole body CT" + ], + "4": [ + "Brachiocephalic trunk", + "Brachiocephalic trunk in whole body CT", + "CT imaging of the brachiocephalic trunk in the whole body region", + "Whole body CT showing brachiocephalic trunk structures", + "Visualization of the brachiocephalic trunk in whole body CT scans", + "Brachiocephalic trunk observed in whole body CT imaging", + "CT scan of the whole body highlighting the brachiocephalic trunk", + "Brachiocephalic trunk segmentation in whole body CT", + "Brachiocephalic trunk delineation in whole body CT imaging", + "CT-based identification of the brachiocephalic trunk in the whole body", + "Brachiocephalic trunk anatomy in whole body CT" + ], + "5": [ + "Right subclavian artery", + "Right subclavian artery in whole body CT", + "CT imaging of the right subclavian artery in the whole body region", + "Whole body CT showing right subclavian arterial structures", + "Visualization of the right subclavian artery in whole body CT scans", + "Right subclavian artery observed in whole body CT imaging", + "CT scan of the whole body depicting the right subclavian artery", + "Right subclavian artery segmentation in whole body CT", + "Right subclavian artery delineation in whole body CT imaging", + "CT-based localization of the right subclavian artery in the whole body", + "Right subclavian artery anatomy in whole body CT" + ], + "6": [ + "Left subclavian artery", + "Left subclavian artery in whole body CT", + "CT imaging of the left subclavian artery in the whole body region", + "Whole body CT showing left subclavian arterial structures", + "Visualization of the left subclavian artery in whole body CT scans", + "Left subclavian artery observed in whole body CT imaging", + "CT scan of the whole body highlighting the left subclavian artery", + "Left subclavian artery segmentation in whole body CT", + "Left subclavian artery delineation in whole body CT imaging", + "CT-based identification of the left subclavian artery in the whole body", + "Left subclavian artery anatomy in whole body CT" + ], + "7": [ + "Right common carotid artery", + "Right common carotid artery in whole body CT", + "CT imaging of the right common carotid artery in the whole body region", + "Whole body CT showing right common carotid arterial structures", + "Visualization of the right common carotid artery in whole body CT scans", + "Right common carotid artery observed in whole body CT imaging", + "CT scan of the whole body depicting the right common carotid artery", + "Right common carotid artery segmentation in whole body CT", + "Right common carotid artery delineation in whole body CT imaging", + "CT-based localization of the right common carotid artery in the whole body", + "Right common carotid artery anatomy in whole body CT" + ], + "8": [ + "Left common carotid artery", + "Left common carotid artery in whole body CT", + "CT imaging of the left common carotid artery in the whole body region", + "Whole body CT showing left common carotid arterial structures", + "Visualization of the left common carotid artery in whole body CT scans", + "Left common carotid artery observed in whole body CT imaging", + "CT scan of the whole body highlighting the left common carotid artery", + "Left common carotid artery segmentation in whole body CT", + "Left common carotid artery delineation in whole body CT imaging", + "CT-based identification of the left common carotid artery in the whole body", + "Left common carotid artery anatomy in whole body CT" + ], + "9": [ + "Left brachiocephalic vein", + "Left brachiocephalic vein in whole body CT", + "CT imaging of the left brachiocephalic vein in the whole body region", + "Whole body CT showing left brachiocephalic venous structures", + "Visualization of the left brachiocephalic vein in whole body CT scans", + "Left brachiocephalic vein observed in whole body CT imaging", + "CT scan of the whole body depicting the left brachiocephalic vein", + "Left brachiocephalic vein segmentation in whole body CT", + "Left brachiocephalic vein delineation in whole body CT imaging", + "CT-based localization of the left brachiocephalic vein in the whole body", + "Left brachiocephalic venous anatomy in whole body CT" + ], + "10": [ + "Right brachiocephalic vein", + "Right brachiocephalic vein in whole body CT", + "CT imaging of the right brachiocephalic vein in the whole body region", + "Whole body CT showing right brachiocephalic venous structures", + "Visualization of the right brachiocephalic vein in whole body CT scans", + "Right brachiocephalic vein observed in whole body CT imaging", + "CT scan of the whole body highlighting the right brachiocephalic vein", + "Right brachiocephalic vein segmentation in whole body CT", + "Right brachiocephalic vein delineation in whole body CT imaging", + "CT-based identification of the right brachiocephalic vein in the whole body", + "Right brachiocephalic venous anatomy in whole body CT" + ], + "11": [ + "Left atrial appendage", + "Left atrial appendage in whole body CT", + "CT imaging of the left atrial appendage in the whole body region", + "Whole body CT showing left atrial appendage structures", + "Visualization of the left atrial appendage in whole body CT scans", + "Left atrial appendage observed in whole body CT imaging", + "CT scan of the whole body depicting the left atrial appendage", + "Left atrial appendage segmentation in whole body CT", + "Left atrial appendage delineation in whole body CT imaging", + "CT-based localization of the left atrial appendage in the whole body", + "Left atrial appendage anatomy in whole body CT" + ], + "12": [ + "Superior vena cava", + "Superior vena cava in whole body CT", + "CT imaging of the superior vena cava in the whole body region", + "Whole body CT showing superior vena caval structures", + "Visualization of the superior vena cava in whole body CT scans", + "Superior vena cava observed in whole body CT imaging", + "CT scan of the whole body highlighting the superior vena cava", + "Superior vena cava segmentation in whole body CT", + "Superior vena cava delineation in whole body CT imaging", + "CT-based identification of the superior vena cava in the whole body", + "Superior vena caval anatomy in whole body CT" + ], + "13": [ + "Inferior vena cava", + "Inferior vena cava in whole body CT", + "CT imaging of the inferior vena cava in the whole body region", + "Whole body CT showing inferior vena caval structures", + "Visualization of the inferior vena cava in whole body CT scans", + "Inferior vena cava observed in whole body CT imaging", + "CT scan of the whole body depicting the inferior vena cava", + "Inferior vena cava segmentation in whole body CT", + "Inferior vena cava delineation in whole body CT imaging", + "CT-based localization of the inferior vena cava in the whole body", + "Inferior vena caval anatomy in whole body CT" + ], + "14": [ + "Portal vein and splenic vein", + "Portal vein and splenic vein in whole body CT", + "CT imaging of the portal vein and splenic vein in the whole body region", + "Whole body CT showing portal and splenic venous structures", + "Visualization of the portal vein and splenic vein in whole body CT scans", + "Portal vein and splenic vein observed in whole body CT imaging", + "CT scan of the whole body highlighting the portal vein and splenic vein", + "Portal vein and splenic vein segmentation in whole body CT", + "Portal vein and splenic vein delineation in whole body CT imaging", + "CT-based identification of the portal vein and splenic vein in the whole body", + "Portal and splenic venous anatomy in whole body CT" + ], + "15": [ + "Left iliac artery", + "Left iliac artery in whole body CT", + "CT imaging of the left iliac artery in the whole body region", + "Whole body CT showing left iliac arterial structures", + "Visualization of the left iliac artery in whole body CT scans", + "Left iliac artery observed in whole body CT imaging", + "CT scan of the whole body depicting the left iliac artery", + "Left iliac artery segmentation in whole body CT", + "Left iliac artery delineation in whole body CT imaging", + "CT-based localization of the left iliac artery in the whole body", + "Left iliac arterial anatomy in whole body CT" + ], + "16": [ + "Right iliac artery", + "Right iliac artery in whole body CT", + "CT imaging of the right iliac artery in the whole body region", + "Whole body CT showing right iliac arterial structures", + "Visualization of the right iliac artery in whole body CT scans", + "Right iliac artery observed in whole body CT imaging", + "CT scan of the whole body highlighting the right iliac artery", + "Right iliac artery segmentation in whole body CT", + "Right iliac artery delineation in whole body CT imaging", + "CT-based identification of the right iliac artery in the whole body", + "Right iliac arterial anatomy in whole body CT" + ], + "17": [ + "Left iliac vena", + "Left iliac vena in whole body CT", + "CT imaging of the left iliac vena in the whole body region", + "Whole body CT showing left iliac venous structures", + "Visualization of the left iliac vena in whole body CT scans", + "Left iliac vena observed in whole body CT imaging", + "CT scan of the whole body depicting the left iliac vena", + "Left iliac vena segmentation in whole body CT", + "Left iliac vena delineation in whole body CT imaging", + "CT-based localization of the left iliac vena in the whole body", + "Left iliac venous anatomy in whole body CT" + ], + "18": [ + "Right iliac vena", + "Right iliac vena in whole body CT", + "CT imaging of the right iliac vena in the whole body region", + "Whole body CT showing right iliac venous structures", + "Visualization of the right iliac vena in whole body CT scans", + "Right iliac vena observed in whole body CT imaging", + "CT scan of the whole body highlighting the right iliac vena", + "Right iliac vena segmentation in whole body CT", + "Right iliac vena delineation in whole body CT imaging", + "CT-based identification of the right iliac vena in the whole body", + "Right iliac venous anatomy in whole body CT" + ], + "instance_label": 0 + }, + "CT_AbdomenAtlas": { + "1": [ + "Aorta", + "Aorta in abdominal CT", + "CT imaging of the aorta in the abdomen", + "Abdominal CT showing aortic structures", + "Aorta observed in abdominal CT scans", + "Visualization of the aorta in abdominal CT imaging", + "CT scan of the aorta in the abdominal region", + "presence of the aorta detected in abdominal CT images", + "Abdominal CT revealing aorta anatomy" + ], + "2": [ + "Gallbladder", + "Gallbladder in abdominal CT", + "CT imaging of the gallbladder in the abdomen", + "Abdominal CT showing gallbladder structures", + "Gallbladder observed in abdominal CT scans", + "Visualization of the gallbladder in abdominal CT imaging", + "CT scan of the gallbladder in the abdominal region", + "Presence of the gallbladder detected in abdominal CT images", + "Abdominal CT revealing gallbladder anatomy" + ], + "3": [ + "Left kidney", + "Left kidney in abdominal CT", + "CT imaging of the left kidney in the abdomen", + "Abdominal CT showing the left kidney", + "Left kidney observed in abdominal CT scans", + "Visualization of the left kidney in abdominal CT imaging", + "CT scan of the left kidney in the abdominal region", + "Presence of the left kidney detected in abdominal CT images", + "Abdominal CT revealing the left kidney anatomy" + ], + "4": [ + "Right kidney", + "Right kidney in abdominal CT", + "CT imaging of the right kidney in the abdomen", + "Abdominal CT showing the right kidney", + "Right kidney observed in abdominal CT scans", + "CT scan of the right kidney in the abdominal region", + "Presence of the right kidney detected in abdominal CT images", + "Abdominal CT revealing the right kidney anatomy" + ], + "5": [ + "Liver", + "Liver in abdominal CT", + "CT imaging of the liver in the abdomen", + "Abdominal CT showing liver structures", + "Liver detected in abdominal CT scans", + "Visualization of the liver in abdominal CT imaging", + "CT scan of the liver in the abdominal region", + "Presence of liver tissue in abdominal CT images", + "Abdominal CT revealing liver anatomy" + ], + "6": [ + "Pancreas", + "Pancreas in abdominal CT", + "CT imaging of the pancreas in the abdomen", + "Abdominal CT showing pancreatic structures", + "Pancreas detected in abdominal CT scans", + "Visualization of the pancreas in abdominal CT imaging", + "CT scan of the pancreas in the abdominal region", + "Presence of pancreatic tissue in abdominal CT images", + "Abdominal CT revealing pancreas anatomy" + ], + "7": [ + "Postcava", + "Postcava in Abdomen CT", + "CT imaging of the postcava in the abdominal region", + "Visualization of the postcava in abdominal CT scans", + "Postcava observed in abdominal CT imaging", + "CT scan of the abdomen depicting the postcava", + "Postcava segmentation in Abdomen CT", + "Postcava delineation in abdominal CT imaging", + "CT-based localization of the postcava in the abdomen" + ], + "8": [ + "Spleen", + "Abdominal CT revealing spleen structures", + "Spleen detected in abdominal CT scans", + "CT imaging of the spleen within the abdomen", + "Spleen anatomy visualized in abdominal CT", + "Presence of spleen tissue observed in abdominal CT imaging", + "CT scan showing spleen in the abdominal cavity", + "Abdominal CT assessment of the spleen", + "Spleen observed in CT imaging of the abdomen" + ], + "9": [ + "Stomach", + "Stomach in abdominal CT", + "CT imaging of the stomach in the abdomen", + "Abdominal CT showing stomach structures", + "Stomach observed in abdominal CT scans", + "Visualization of the stomach in abdominal CT imaging", + "CT scan of the stomach in the abdominal region", + "Presence of the stomach detected in abdominal CT images", + "Abdominal CT revealing stomach anatomy" + ], + "10": [ + "Left adrenal gland", + "Left adrenal gland in abdominal CT", + "CT imaging of the left adrenal gland in the abdomen", + "Abdominal CT showing left adrenal gland structures", + "Left adrenal gland observed in abdominal CT scans", + "Visualization of the left adrenal gland in abdominal CT imaging", + "CT scan of the left adrenal gland in the abdominal region", + "Presence of the left adrenal gland detected in abdominal CT images", + "Abdominal CT revealing left adrenal gland anatomy" + ], + "11": [ + "Right adrenal gland", + "Right adrenal gland in abdominal CT", + "CT imaging of the right adrenal gland in the abdomen", + "Abdominal CT showing right adrenal gland structures", + "Right adrenal gland observed in abdominal CT scans", + "Visualization of the right adrenal gland in abdominal CT imaging", + "CT scan of the right adrenal gland in the abdominal region", + "Presence of the right adrenal gland detected in abdominal CT images", + "Abdominal CT revealing right adrenal gland anatomy" + ], + "12": [ + "Bladder", + "Bladder in abdominal CT", + "CT imaging of the bladder in the abdomen", + "Abdominal CT showing bladder structures", + "Bladder observed in abdominal CT scans", + "Visualization of the bladder in abdominal CT imaging", + "CT scan of the bladder in the abdominal region", + "Presence of the bladder detected in abdominal CT images", + "Abdominal CT revealing bladder anatomy" + ], + "13": [ + "Esophagus", + "Esophagus in abdominal CT", + "CT imaging of the esophagus in the abdomen", + "Abdominal CT showing esophagus structures", + "Esophagus observed in abdominal CT scans", + "Visualization of the esophagus in abdominal CT imaging", + "CT scan of the esophagus in the abdominal region", + "Presence of the esophagus detected in abdominal CT images", + "Abdominal CT revealing esophagus anatomy" + ], + "instance_label": 0 + }, + "CT_WholeBodyTumor": { + "1": [ + "Lesion", + "Lesion in whole body CT", + "CT imaging of the lesion in the whole body region", + "Visualization of the lesion in whole body CT scans", + "Lesion observed in whole body CT imaging", + "Lesion segmentation in whole body CT", + "Lesion delineation in whole body CT imaging", + "CT-based localization of the lesion in the whole body", + "Detection of lesion via whole body CT", + "Lesion identification in whole body CT imaging" + ], + "instance_label": 1 + }, + "CT_AirwayTree": { + "2": [ + "Airway Tree", + "Airway tree in chest CT", + "CT imaging of the airway tree in the chest region", + "Visualization of the airway tree in chest CT scans", + "Airway tree segmentation in chest CT imaging", + "Delineation of bronchial branches within the chest CT" + ], + "instance_label": 0 + }, + "CT_TopCoW24": { + "1": [ + "Circle of Willis", + "Circle of Willis in Brain CTA", + "CTA imaging of Circle of Willis in Brain", + "Visualization of Circle of Willis in Brain Computed Tomography Angiography", + "Circle of Willis observed in Brain CTA", + "Segmentation of Circle of Willis in Brain Computed Tomography Angiography", + "Delineation of Circle of Willis vasculature in Brain CTA imaging", + "Identification of Circle of Willis in Brain Computed Tomography Angiography", + "Localization of Circle of Willis in Brain CTA", + "Characterization of Circle of Willis vascular architecture in Brain CTA" + ], + "instance_label": 1 + }, + "CT_Aeropath": { + "1": [ + "Airways", + "Airway segmented in chest computed tomography", + "Airway visualized in chest CT datasets", + "Airway identified in thoracic CT scans", + "Airway resolved in chest using computed tomography", + "Airway observed in thoracic cavity under CT", + "Airway localized in thoracic region with computed tomography", + "Airway appearing in chest CT volumetric scans" + ], + "instance_label": 0 + }, + "MR_AMOS": { + "1": [ + "Spleen", + "Splenic structure identified via abdominal magnetic resonance", + "Spleen visualized in multiplanar abdominal MRI scans", + "Splenic parenchyma characterized by abdominal MRI", + "Spleen mapped through volumetric MR imaging of the abdomen", + "Splenic morphology documented in abdominal MR studies", + "Spleen demonstrating signal intensity on abdominal magnetic resonance", + "Intra-abdominal splenic architecture assessed with contrast-enhanced MRI", + "Splenic volume quantified in abdominal MR volumetry" + ], + "2": [ + "Right kidney", + "Right renal structure identified via abdominal magnetic resonance", + "Right kidney visualized in abdominal MRI scans", + "Right kidney mapped through MR volumetry of the abdomen", + "Right renal morphology documented in abdominal MR protocols", + "Right renal parenchyma assessed with functional MR sequences", + "Right nephric architecture localized in abdominal MR tomography" + ], + "3": [ + "Left renal structure identified via abdominal magnetic resonance", + "Left renal structure identified via abdominal magnetic resonance", + "Left kidney visualized in axial abdominal MRI scans", + "Left renal anatomy characterized by abdominal MRI", + "Left renal morphology documented in abdominal MR studies", + "Left renal parenchyma assessed with MR spectroscopy", + "Left nephric architecture localized in abdominal MR volumetry" + ], + "4": [ + "Gallbladder", + "Gallbladder identified via abdominal magnetic resonance", + "Gallbladder anatomy characterized by abdominal MRI", + "Gallbladder mapped through volumetric MR imaging of the abdomen", + "Gallbladder morphology documented in abdominal MR studies", + "Intra-abdominal gallbladder assessed with MR", + "Gallbladder content observed in abdominal MRI" + ], + "5": [ + "Esophagus", + "Esophagus segmented in abdominal MRI", + "Esophageal structure identified via abdominal magnetic resonance imaging", + "Esophagus visualized in cross-sectional abdominal MR scans", + "Esophageal anatomy characterized by abdominal MR protocols", + "Esophagus mapped through magnetic resonance imaging of the abdomen", + "Esophageal morphology documented in abdominal MR studies", + "Esophageal architecture localized in abdominal MR tomographic imaging", + "Esophageal tissue delineated in multiparametric abdominal MR sequences" + ], + "6": [ + "Liver", + "Hepatic structure identified via abdominal magnetic resonance imaging", + "Liver visualized in cross-sectional abdominal MR scans", + "Hepatic parenchyma characterized by abdominal MR protocols", + "Liver mapped through magnetic resonance imaging of the abdomen", + "Hepatic morphology documented in abdominal MR studies", + "Liver demonstrating anatomical lobar boundaries on abdominal magnetic resonance", + "Liver architecture localized in abdominal MR tomographic imaging", + "Hepatic tissue delineated in abdominal MR sequences" + ], + "7": [ + "Stomach", + "Stomach identified via abdominal magnetic resonance imaging", + "Stomach visualized in cross-sectional abdominal MR scans", + "Stomach anatomy characterized by abdominal MR protocols", + "Stomach mapped through magnetic resonance imaging of the abdomen", + "Stomach morphology documented in abdominal MR studies", + "Stomach demonstrating anatomical curvature on abdominal magnetic resonance", + "Gastric architecture localized in abdominal MR tomographic imaging", + "Stomach tissue delineated in multiparametric abdominal MR sequences" + ], + "8": [ + "Aorta", + "Aortic structure identified via abdominal magnetic resonance imaging", + "Aorta visualized in cross-sectional abdominal MR scans", + "Aortic lumen characterized by abdominal MR protocols", + "Aorta mapped through magnetic resonance imaging of the abdomen", + "Aortic morphology documented in abdominal MR studies", + "Aortic architecture localized in abdominal MR tomographic imaging", + "Aortic tissue delineated in multiparametric abdominal MR sequences" + ], + "9": [ + "Inferior vena cava", + "IVC structure identified via abdominal magnetic resonance imaging", + "Inferior vena cava visualized in cross-sectional abdominal MR scans", + "IVC lumen characterized by abdominal MR protocols", + "Inferior vena cava mapped through magnetic resonance imaging of the abdomen", + "IVC morphology documented in abdominal MR studies", + "Inferior vena cava demonstrating anatomical course on abdominal magnetic resonance", + "IVC architecture localized in abdominal MR tomographic imaging", + "IVC tissue delineated in multiparametric abdominal MR sequences" + ], + "10": [ + "Pancreas", + "Pancreatic structure identified via abdominal magnetic resonance imaging", + "Pancreas visualized in cross-sectional abdominal MR scans", + "Pancreatic parenchyma characterized by abdominal MR protocols", + "Pancreas mapped through magnetic resonance imaging of the abdomen", + "Pancreatic morphology documented in abdominal MR studies", + "Pancreas demonstrating anatomical lobulation on abdominal magnetic resonance", + "Pancreatic architecture localized in abdominal MR tomographic imaging", + "Pancreatic tissue delineated in multiparametric abdominal MR sequences" + ], + "11": [ + "Right adrenal gland", + "Right adrenal structure identified via abdominal magnetic resonance imaging", + "Right adrenal gland visualized in cross-sectional abdominal MR scans", + "Right adrenal parenchyma characterized by abdominal MR protocols", + "Right adrenal gland mapped through magnetic resonance imaging of the abdomen", + "Right adrenal morphology documented in abdominal MR studies", + "Right adrenal gland demonstrating anatomical margins on abdominal magnetic resonance", + "Right adrenal architecture localized in abdominal MR tomographic imaging", + "Right adrenal tissue delineated in multiparametric abdominal MR sequences" + ], + "12": [ + "Left adrenal gland", + "Left adrenal structure identified via abdominal magnetic resonance imaging", + "Left adrenal gland visualized in cross-sectional abdominal MR scans", + "Left adrenal parenchyma characterized by abdominal MR protocols", + "Left adrenal gland mapped through magnetic resonance imaging of the abdomen", + "Left adrenal morphology documented in abdominal MR studies", + "Left adrenal gland demonstrating anatomical margins on abdominal magnetic resonance", + "Left adrenal architecture localized in abdominal MR tomographic imaging", + "Left adrenal tissue delineated in multiparametric abdominal MR sequences" + ], + "13": [ + "Duodenum", + "Duodenal structure identified via abdominal magnetic resonance imaging", + "Duodenum visualized in cross-sectional abdominal MR scans", + "Duodenal anatomy characterized by abdominal MR protocols", + "Duodenum mapped through magnetic resonance imaging of the abdomen", + "Duodenal morphology documented in abdominal MR studies", + "Duodenum demonstrating anatomical folds on abdominal magnetic resonance", + "Duodenal architecture localized in abdominal MR tomographic imaging", + "Duodenal tissue delineated in multiparametric abdominal MR sequences" + ], + "14": [ + "Bladder", + "Vesical structure identified via abdominal magnetic resonance imaging", + "Bladder visualized in cross-sectional abdominal MR scans", + "Bladder anatomy characterized by abdominal MR protocols", + "Bladder mapped through magnetic resonance imaging of the abdomen", + "Vesical morphology documented in abdominal MR studies", + "Bladder demonstrating anatomical capacity on abdominal magnetic resonance", + "Vesical architecture localized in abdominal MR tomographic imaging", + "Bladder tissue delineated in multiparametric abdominal MR sequences" + ], + "15": [ + "Prostate/uterus", + "Prostatic/uterine structure identified via abdominal magnetic resonance imaging", + "Prostate/uterus visualized in cross-sectional abdominal MR scans", + "Prostatic/uterine anatomy characterized by abdominal MR protocols", + "Prostate/uterus mapped through magnetic resonance imaging of the abdomen", + "Prostatic/uterine morphology documented in abdominal MR studies", + "Prostate/uterus demonstrating anatomical zonation/endometrial layers on abdominal magnetic resonance", + "Prostatic/uterine architecture localized in abdominal MR tomographic imaging", + "Prostate/uterus tissue delineated in multiparametric abdominal MR sequences" + ], + "instance_label": 0 + }, + "MR_CervicalCancer": { + "1": [ + "Cervical cancer tumor", + "Cervical cancer tumor segmented in pelvic MRI", + "Cervical malignancy identified via pelvic magnetic resonance imaging", + "Neoplastic cervical lesion visualized in cross-sectional pelvic MR scans", + "Cervical cancer tumor characterized by pelvic MR protocols", + "Cervical malignancy mapped through magnetic resonance imaging of the pelvis", + "Cervical cancer tumor documented in pelvic MR studies", + "Cervical malignancy demonstrating imaging features on pelvic magnetic resonance", + "Cervical cancer tumor assessed with multiparametric pelvic MR sequences", + "Neoplastic cervical lesion localized in pelvic tomographic imaging" + ], + "instance_label": 1 + }, + "MR_CHAOS-T1": { + "1": [ + "Liver", + "Hepatic structure identified via abdominal magnetic resonance T1 imaging", + "Liver visualized in cross-sectional abdominal MR T1 scans", + "Hepatic parenchyma characterized by abdominal MR T1 protocols", + "Liver mapped through magnetic resonance T1 imaging of the abdomen", + "Hepatic morphology documented in abdominal MR T1 studies", + "Liver demonstrating anatomical lobar boundaries on abdominal magnetic resonance T1", + "Liver architecture localized in abdominal MR T1 tomographic imaging", + "Hepatic tissue delineated in abdominal MR T1 sequences" + ], + "2": [ + "Right kidney", + "Right renal structure identified via abdominal magnetic resonance T1 imaging", + "Right kidney visualized in abdominal MR T1 scans", + "Right kidney mapped through MR T1 volumetry of the abdomen", + "Right renal morphology documented in abdominal MR T1 protocols", + "Right renal parenchyma assessed with functional MR T1 sequences", + "Right nephric architecture localized in abdominal MR T1 tomography" + ], + "3": [ + "Left kidney", + "Left renal structure identified via abdominal magnetic resonance T1 imaging", + "Left kidney visualized in axial abdominal MR T1 scans", + "Left renal anatomy characterized by abdominal MR T1 imaging", + "Left renal morphology documented in abdominal MR T1 studies", + "Left renal parenchyma assessed with MR T1 spectroscopy", + "Left nephric architecture localized in abdominal MR T1 volumetry" + ], + "4": [ + "Spleen", + "Splenic structure identified via abdominal magnetic resonance T1 imaging", + "Spleen visualized in multiplanar abdominal MR T1 scans", + "Splenic parenchyma characterized by abdominal MR T1 imaging", + "Spleen mapped through volumetric MR T1 imaging of the abdomen", + "Splenic morphology documented in abdominal MR T1 studies", + "Spleen demonstrating signal intensity on abdominal magnetic resonance T1", + "Intra-abdominal splenic architecture assessed with contrast-enhanced MR T1", + "Splenic volume quantified in abdominal MR T1 volumetry" + ], + "instance_label": 0 + }, + "MR_CHAOS-T2": { + "1": [ + "Liver", + "Hepatic structure identified via abdominal magnetic resonance imaging", + "Liver visualized in cross-sectional abdominal MR T2 scans", + "Hepatic parenchyma characterized by abdominal MR T2 protocols", + "Liver mapped through magnetic resonance imaging of the abdomen", + "Hepatic morphology documented in abdominal MR T2 studies", + "Liver demonstrating anatomical lobar boundaries on abdominal magnetic resonance", + "Liver architecture localized in abdominal MR T2 tomographic imaging", + "Hepatic tissue delineated in abdominal MR T2 sequences" + ], + "2": [ + "Right kidney", + "Right renal structure identified via abdominal magnetic resonance", + "Right kidney visualized in abdominal MRI T2 scans", + "Right kidney mapped through MR T2 volumetry of the abdomen", + "Right renal morphology documented in abdominal MR T2 protocols", + "Right renal parenchyma assessed with functional MR T2 sequences", + "Right nephric architecture localized in abdominal MR T2 tomography" + ], + "3": [ + "Left kidney", + "Left renal structure identified via abdominal magnetic resonance", + "Left kidney visualized in axial abdominal MRI T2 scans", + "Left renal anatomy characterized by abdominal MRI T2", + "Left renal morphology documented in abdominal MR T2 studies", + "Left renal parenchyma assessed with MR T2 spectroscopy", + "Left nephric architecture localized in abdominal MR T2 volumetry" + ], + "4": [ + "Spleen", + "Splenic structure identified via abdominal magnetic resonance", + "Spleen visualized in multiplanar abdominal MRI T2 scans", + "Splenic parenchyma characterized by abdominal MRI T2", + "Spleen mapped through volumetric MR T2 imaging of the abdomen", + "Splenic morphology documented in abdominal MR T2 studies", + "Spleen demonstrating signal intensity on abdominal magnetic resonance", + "Intra-abdominal splenic architecture assessed with contrast-enhanced MRI T2", + "Splenic volume quantified in abdominal MR T2 volumetry" + ], + "instance_label": 0 + }, + "MR_Heart_ACDC": { + "1": [ + "Right ventricle cavity", + "Right ventricle cavity segmentation in chest MR", + "Right ventricle cavity delineation in chest magnetic resonance", + "Right ventricle cavity identification in chest MR imaging", + "Right ventricle cavity visualization in chest MR scans", + "Right ventricle cavity mapping in chest MR acquisitions", + "Right ventricle cavity detection in chest magnetic resonance", + "Right ventricle cavity outlining in chest MR imaging" + ], + "2": [ + "Myocardium", + "Myocardium segmentation in chest MR", + "Myocardium delineation in chest magnetic resonance", + "Myocardium identification in chest MR imaging", + "Myocardium visualization in chest MR scans", + "Myocardium mapping in chest MR acquisitions", + "Myocardium detection in chest magnetic resonance", + "Myocardium outlining in chest MR imaging" + ], + "3": [ + "Left ventricle cavity", + "Left ventricle cavity segmentation in chest MR", + "Left ventricle cavity anatomical boundaries in chest MR imaging", + "Left ventricle cavity structural delineation in chest MR scans", + "Left ventricle cavity endocardial border identification in chest MR acquisitions", + "Left ventricle cavity outlining in chest MR imaging", + "Left ventricle cavity detection in chest MR imaging" + ], + "instance_label": 0 + }, + "MR_HNTS-MRG_HeadTumor": { + "1": [ + "GTVp and GTVn tumor", + "GTVp and GTVn tumor demarcation in head MR imaging", + "GTVp and GTVn tumor segmentation within cranial MR scans", + "Volumetric analysis of GTVp and GTVn tumors in head magnetic resonance", + "Spatial localization of GTVp and GTVn tumors on head MR acquisitions", + "GTVp and GTVn tumor volume definition via head MR protocols", + "Anatomical delineation of combined GTVp and GTVn tumors in head MRI", + "GTVp and GTVn tumor boundary identification in head MR studies", + "GTVp and GTVn in cranial magnetic resonance imaging", + "GTVp and GTVn tumor extent mapping using head MR sequences" + ], + "instance_label": 1 + }, + "MR_ISLES_DWI": { + "1": [ + "Stroke lesion", + "Acute ischemic focus identified via diffusion-weighted brain MRI", + "Stroke lesion visualized in cerebral DWI sequences", + "Diffusion-restricted lesion characterized by brain MR DWI protocols", + "Stroke lesion mapped through diffusion-weighted imaging of the brain parenchyma", + "Acute infarcted tissue documented in brain DWI-MR studies", + "Stroke lesion demonstrating hyperintensity on brain diffusion-weighted MRI", + "Ischemic brain lesion assessed with ADC map correlation in MR DWI", + "Stroke lesion localized in axial brain DWI acquisitions", + "Diffusion abnormality quantified in volumetric brain MR DWI scans" + ], + "instance_label": 1 + }, + "MR_ISLES_ADC": { + "1": [ + "Stroke lesion", + "Stroke lesion segmented in brain MR ADC imaging", + "Acute ischemic lesion identified via ADC map analysis in brain MRI", + "Stroke lesion visualized in cerebral ADC sequences", + "ADC hypointense lesion characterized in brain MR ADC protocols", + "Stroke lesion mapped through apparent diffusion coefficient imaging of the brain parenchyma", + "Ischemic lesion documented in brain ADC-MR studies", + "Stroke lesion demonstrating ADC signal reduction on brain MR ADC", + "Hypointense stroke lesion assessed with ADC parametric mapping in brain MRI", + "Stroke lesion localized in axial brain ADC acquisitions", + "Diffusion coefficient abnormalities quantified in volumetric brain MR ADC scans" + ], + "instance_label": 1 + }, + "MR_LeftAtrium": { + "1": [ + "Left atrium", + "Left atrium delineation in thoracic MR imaging", + "Left atrial anatomical boundaries on chest MR scans", + "Morphological segmentation of the left atrium in thoracic cavity MRI", + "Structural evaluation of the left atrium via chest magnetic resonance", + "Left atrial chamber dimensions quantified in chest MR studies", + "Left atrial endocardial borders identified in chest MR acquisitions", + "Left atrial chamber visualization in chest MR studies", + "Left atrial endocardial contours in chest MR acquisitions" + ], + "instance_label": 0 + }, + "MR_ProstateADC": { + "1": [ + "Transition zone", + "Transition zone delineation in pelvic MR ADC imaging", + "Segmentation of the transition zone on ADC-weighted pelvic MRI", + "Zonal characterization of the transition zone using pelvis ADC MR sequences", + "Transition zone mapping in ADC-based pelvic magnetic resonance", + "Anatomical localization of the transition zone in ADC-weighted pelvis MRI", + "Transition zone identification in ADC sequences of pelvic magnetic resonance", + "Transition zone analysis in ADC-weighted pelvis MR scans", + "Transition zone detection in ADC-based pelvic MRI acquisitions" + ], + "instance_label": 0 + }, + "MR_ProstateT2": { + "1": [ + "Prostate", + "Prostate delineation in pelvic MR T2 imaging", + "Segmentation of the prostate on T2-weighted pelvic MRI", + "Zonal characterization of the prostate using pelvis T2 MRI protocols", + "Prostate mapping in T2-based pelvic magnetic resonance", + "Localization of the prostate in T2 pelvic MRI", + "Visualization of the prostate via pelvic T2 magnetic resonance", + "Prostate detection in T2-based pelvic MRI acquisitions", + "Prostate identification within T2 MR imaging of the pelvis", + "Prostate boundary definition in pelvic T2-weighted MRI sequences" + ], + "instance_label": 0 + }, + "MR_QIN-PROSTATE-Lesion": { + "1": [ + "Prostate lesion", + "Prostate lesion delineation in pelvic MR imaging", + "Segmentation of prostate lesions on pelvic MRI scans", + "Detection of prostate lesions within pelvic magnetic resonance", + "Prostate lesion localization in pelvic MR acquisitions", + "Visualization of prostate lesions via pelvic MRI protocols", + "Boundary demarcation of prostate lesions on pelvic MR studies", + "Mapping of prostate lesions in pelvic magnetic resonance imaging", + "Characterization of prostate lesions using pelvic MRI sequences", + "Identification of prostate lesions in pelvic MR datasets" + ], + "instance_label": 1 + }, + "MR_TotalSeg": { + "1": [ + "Spleen", + "Spleen in whole body MR", + "MR imaging of the spleen in the whole body region", + "Whole body MR showing spleen structures", + "Visualization of the spleen in whole body MR scans", + "Spleen observed in whole body MR imaging", + "MR scan of the whole body depicting the spleen", + "Spleen segmentation in whole body MR", + "Spleen delineation in whole body MR imaging", + "MR-based localization of the spleen in the whole body", + "Spleen anatomy in whole body MR" + ], + "2": [ + "Right kidney", + "Right kidney in whole body MR", + "MR imaging of the right kidney in the whole body region", + "Whole body MR showing right kidney structures", + "Visualization of the right kidney in whole body MR scans", + "Right kidney observed in whole body MR imaging", + "MR scan of the whole body highlighting the right kidney", + "Right kidney segmentation in whole body MR", + "Right kidney delineation in whole body MR imaging", + "MR-based identification of the right kidney in the whole body", + "Right kidney anatomy in whole body MR" + ], + "3": [ + "Left kidney", + "Left kidney in whole body MR", + "MR imaging of the left kidney in the whole body region", + "Whole body MR showing left kidney structures", + "Visualization of the left kidney in whole body MR scans", + "Left kidney observed in whole body MR imaging", + "MR scan of the whole body depicting the left kidney", + "Left kidney segmentation in whole body MR", + "Left kidney delineation in whole body MR imaging", + "MR-based localization of the left kidney in the whole body", + "Left kidney anatomy in whole body MR" + ], + "4": [ + "Gallbladder", + "Gallbladder in whole body MR", + "MR imaging of the gallbladder in the whole body region", + "Whole body MR showing gallbladder structures", + "Visualization of the gallbladder in whole body MR scans", + "Gallbladder observed in whole body MR imaging", + "MR scan of the whole body highlighting the gallbladder", + "Gallbladder segmentation in whole body MR", + "Gallbladder delineation in whole body MR imaging", + "MR-based identification of the gallbladder in the whole body", + "Gallbladder anatomy in whole body MR" + ], + "5": [ + "Liver", + "Liver in whole body MR", + "MR imaging of the liver in the whole body region", + "Whole body MR showing hepatic structures", + "Visualization of the liver in whole body MR scans", + "Liver observed in whole body MR imaging", + "MR scan of the whole body depicting the liver", + "Liver segmentation in whole body MR", + "Liver delineation in whole body MR imaging", + "MR-based localization of the liver in the whole body", + "Hepatic anatomy in whole body MR" + ], + "6": [ + "Stomach", + "Stomach in whole body MR", + "MR imaging of the stomach in the whole body region", + "Whole body MR showing gastric structures", + "Visualization of the stomach in whole body MR scans", + "Stomach observed in whole body MR imaging", + "MR scan of the whole body highlighting the stomach", + "Stomach segmentation in whole body MR", + "Stomach delineation in whole body MR imaging", + "MR-based identification of the stomach in the whole body", + "Gastric anatomy in whole body MR" + ], + "7": [ + "Pancreas", + "Pancreas in whole body MR", + "MR imaging of the pancreas in the whole body region", + "Whole body MR showing pancreatic structures", + "Visualization of the pancreas in whole body MR scans", + "Pancreas observed in whole body MR imaging", + "MR scan of the whole body depicting the pancreas", + "Pancreas segmentation in whole body MR", + "Pancreas delineation in whole body MR imaging", + "MR-based localization of the pancreas in the whole body", + "Pancreatic anatomy in whole body MR" + ], + "8": [ + "Right adrenal gland", + "Right adrenal gland in whole body MR", + "MR imaging of the right adrenal gland in the whole body region", + "Whole body MR showing right adrenal structures", + "Visualization of the right adrenal gland in whole body MR scans", + "Right adrenal gland observed in whole body MR imaging", + "MR scan of the whole body highlighting the right adrenal gland", + "Right adrenal gland segmentation in whole body MR", + "Right adrenal gland delineation in whole body MR imaging", + "MR-based identification of the right adrenal gland in the whole body", + "Right adrenal anatomy in whole body MR" + ], + "9": [ + "Left adrenal gland", + "Left adrenal gland in whole body MR", + "MR imaging of the left adrenal gland in the whole body region", + "Whole body MR showing left adrenal structures", + "Visualization of the left adrenal gland in whole body MR scans", + "Left adrenal gland observed in whole body MR imaging", + "MR scan of the whole body depicting the left adrenal gland", + "Left adrenal gland segmentation in whole body MR", + "Left adrenal gland delineation in whole body MR imaging", + "MR-based localization of the left adrenal gland in the whole body", + "Left adrenal anatomy in whole body MR" + ], + "10": [ + "Left lung", + "Left lung in whole body MR", + "MR imaging of the left lung in the whole body region", + "Whole body MR showing left pulmonary structures", + "Visualization of the left lung in whole body MR scans", + "Left lung observed in whole body MR imaging", + "MR scan of the whole body highlighting the left lung", + "Left lung segmentation in whole body MR", + "Left lung delineation in whole body MR imaging", + "MR-based identification of the left lung in the whole body", + "Left pulmonary anatomy in whole body MR" + ], + "11": [ + "Right lung", + "Right lung in whole body MR", + "MR imaging of the right lung in the whole body region", + "Whole body MR showing right pulmonary structures", + "Visualization of the right lung in whole body MR scans", + "Right lung observed in whole body MR imaging", + "MR scan of the whole body depicting the right lung", + "Right lung segmentation in whole body MR", + "Right lung delineation in whole body MR imaging", + "MR-based localization of the right lung in the whole body", + "Right pulmonary anatomy in whole body MR" + ], + "12": [ + "Esophagus", + "Esophagus in whole body MR", + "MR imaging of the esophagus in the whole body region", + "Whole body MR showing esophageal structures", + "Visualization of the esophagus in whole body MR scans", + "Esophagus observed in whole body MR imaging", + "MR scan of the whole body highlighting the esophagus", + "Esophagus segmentation in whole body MR", + "Esophagus delineation in whole body MR imaging", + "MR-based identification of the esophagus in the whole body", + "Esophageal anatomy in whole body MR" + ], + "13": [ + "Small bowel", + "Small bowel in whole body MR", + "MR imaging of the small bowel in the whole body region", + "Whole body MR showing small bowel structures", + "Visualization of the small bowel in whole body MR scans", + "Small bowel observed in whole body MR imaging", + "MR scan of the whole body depicting the small bowel", + "Small bowel segmentation in whole body MR", + "Small bowel delineation in whole body MR imaging", + "MR-based localization of the small bowel in the whole body", + "Small bowel anatomy in whole body MR" + ], + "14": [ + "Duodenum", + "Duodenum in whole body MR", + "MR imaging of the duodenum in the whole body region", + "Whole body MR showing duodenal structures", + "Visualization of the duodenum in whole body MR scans", + "Duodenum observed in whole body MR imaging", + "MR scan of the whole body highlighting the duodenum", + "Duodenum segmentation in whole body MR", + "Duodenum delineation in whole body MR imaging", + "MR-based identification of the duodenum in the whole body", + "Duodenal anatomy in whole body MR" + ], + "15": [ + "Colon", + "Colon in whole body MR", + "MR imaging of the colon in the whole body region", + "Whole body MR showing colonic structures", + "Visualization of the colon in whole body MR scans", + "Colon observed in whole body MR imaging", + "MR scan of the whole body depicting the colon", + "Colon segmentation in whole body MR", + "Colon delineation in whole body MR imaging", + "MR-based localization of the colon in the whole body", + "Colonic anatomy in whole body MR" + ], + "16": [ + "Urinary bladder", + "Urinary bladder in whole body MR", + "MR imaging of the urinary bladder in the whole body region", + "Whole body MR showing bladder structures", + "Visualization of the urinary bladder in whole body MR scans", + "Urinary bladder observed in whole body MR imaging", + "MR scan of the whole body highlighting the urinary bladder", + "Urinary bladder segmentation in whole body MR", + "Urinary bladder delineation in whole body MR imaging", + "MR-based identification of the urinary bladder in the whole body", + "Bladder anatomy in whole body MR" + ], + "17": [ + "Prostate", + "Prostate in whole body MR", + "MR imaging of the prostate in the whole body region", + "Whole body MR showing prostatic structures", + "Visualization of the prostate in whole body MR scans", + "Prostate observed in whole body MR imaging", + "MR scan of the whole body depicting the prostate", + "Prostate segmentation in whole body MR", + "Prostate delineation in whole body MR imaging", + "MR-based localization of the prostate in the whole body", + "Prostatic anatomy in whole body MR" + ], + "18": [ + "Sacrum", + "Sacrum in whole body MR", + "MR imaging of the sacrum in the whole body region", + "Whole body MR showing sacral structures", + "Visualization of the sacrum in whole body MR scans", + "Sacrum observed in whole body MR imaging", + "MR scan of the whole body highlighting the sacrum", + "Sacrum segmentation in whole body MR", + "Sacrum delineation in whole body MR imaging", + "MR-based identification of the sacrum in the whole body", + "Sacral anatomy in whole body MR" + ], + "19": [ + "Vertebrae", + "Vertebrae in whole body MR", + "MR imaging of the vertebrae in the whole body region", + "Whole body MR showing vertebral structures", + "Visualization of the vertebrae in whole body MR scans", + "Vertebrae observed in whole body MR imaging", + "MR scan of the whole body depicting the vertebrae", + "Vertebrae segmentation in whole body MR", + "Vertebrae delineation in whole body MR imaging", + "MR-based localization of the vertebrae in the whole body", + "Vertebral anatomy in whole body MR" + ], + "20": [ + "Intervertebral discs", + "Intervertebral discs in whole body MR", + "MR imaging of the intervertebral discs in the whole body region", + "Whole body MR showing disc structures", + "Visualization of the intervertebral discs in whole body MR scans", + "Intervertebral discs observed in whole body MR imaging", + "MR scan of the whole body highlighting the intervertebral discs", + "Intervertebral discs segmentation in whole body MR", + "Intervertebral discs delineation in whole body MR imaging", + "MR-based identification of the intervertebral discs in the whole body", + "Disc anatomy in whole body MR" + ], + "21": [ + "Spinal cord", + "Spinal cord in whole body MR", + "MR imaging of the spinal cord in the whole body region", + "Whole body MR showing spinal cord structures", + "Visualization of the spinal cord in whole body MR scans", + "Spinal cord observed in whole body MR imaging", + "MR scan of the whole body depicting the spinal cord", + "Spinal cord segmentation in whole body MR", + "Spinal cord delineation in whole body MR imaging", + "MR-based localization of the spinal cord in the whole body", + "Spinal cord anatomy in whole body MR" + ], + "22": [ + "Heart", + "Heart in whole body MR", + "MR imaging of the heart in the whole body region", + "Whole body MR showing cardiac structures", + "Visualization of the heart in whole body MR scans", + "Heart observed in whole body MR imaging", + "MR scan of the whole body highlighting the heart", + "Heart segmentation in whole body MR", + "Heart delineation in whole body MR imaging", + "MR-based identification of the heart in the whole body" + ], + "23": [ + "Aorta", + "Aorta in whole body MR", + "MR imaging of the aorta in the whole body region", + "Whole body MR showing aortic structures", + "Visualization of the aorta in whole body MR scans", + "Aorta observed in whole body MR imaging", + "MR scan of the whole body depicting the aorta", + "Aorta segmentation in whole body MR", + "Aorta delineation in whole body MR imaging", + "MR-based localization of the aorta in the whole body", + "Aortic anatomy in whole body MR" + ], + "24": [ + "Inferior vena cava", + "Inferior vena cava in whole body MR", + "MR imaging of the inferior vena cava in the whole body region", + "Whole body MR showing IVC structures", + "Visualization of the inferior vena cava in whole body MR scans", + "Inferior vena cava observed in whole body MR imaging", + "MR scan of the whole body highlighting the inferior vena cava", + "Inferior vena cava segmentation in whole body MR", + "Inferior vena cava delineation in whole body MR imaging", + "MR-based identification of the inferior vena cava in the whole body", + "IVC anatomy in whole body MR" + ], + "25": [ + "Portal vein and splenic vein", + "Portal vein and splenic vein in whole body MR", + "MR imaging of the portal vein and splenic vein in the whole body region", + "Whole body MR showing portal and splenic venous structures", + "Visualization of the portal vein and splenic vein in whole body MR scans", + "Portal vein and splenic vein observed in whole body MR imaging", + "MR scan of the whole body depicting the portal vein and splenic vein", + "Portal vein and splenic vein segmentation in whole body MR", + "Portal vein and splenic vein delineation in whole body MR imaging", + "MR-based localization of the portal vein and splenic vein in the whole body", + "Portal and splenic venous anatomy in whole body MR" + ], + "26": [ + "Left iliac artery", + "Left iliac artery in whole body MR", + "MR imaging of the left iliac artery in the whole body region", + "Whole body MR showing left iliac arterial structures", + "Visualization of the left iliac artery in whole body MR scans", + "Left iliac artery observed in whole body MR imaging", + "MR scan of the whole body highlighting the left iliac artery", + "Left iliac artery segmentation in whole body MR", + "Left iliac artery delineation in whole body MR imaging", + "MR-based identification of the left iliac artery in the whole body", + "Left iliac arterial anatomy in whole body MR" + ], + "27": [ + "Right iliac artery", + "Right iliac artery in whole body MR", + "MR imaging of the right iliac artery in the whole body region", + "Whole body MR showing right iliac arterial structures", + "Visualization of the right iliac artery in whole body MR scans", + "Right iliac artery observed in whole body MR imaging", + "MR scan of the whole body depicting the right iliac artery", + "Right iliac artery segmentation in whole body MR", + "Right iliac artery delineation in whole body MR imaging", + "MR-based localization of the right iliac artery in the whole body", + "Right iliac arterial anatomy in whole body MR" + ], + "28": [ + "Left iliac vena", + "Left iliac vena in whole body MR", + "MR imaging of the left iliac vena in the whole body region", + "Whole body MR showing left iliac venous structures", + "Visualization of the left iliac vena in whole body MR scans", + "Left iliac vena observed in whole body MR imaging", + "MR scan of the whole body highlighting the left iliac vena", + "Left iliac vena segmentation in whole body MR", + "Left iliac vena delineation in whole body MR imaging", + "MR-based identification of the left iliac vena in the whole body", + "Left iliac venous anatomy in whole body MR" + ], + "29": [ + "Right iliac vena", + "Right iliac vena in whole body MR", + "MR imaging of the right iliac vena in the whole body region", + "Whole body MR showing right iliac venous structures", + "Visualization of the right iliac vena in whole body MR scans", + "Right iliac vena observed in whole body MR imaging", + "MR scan of the whole body depicting the right iliac vena", + "Right iliac vena segmentation in whole body MR", + "Right iliac vena delineation in whole body MR imaging", + "MR-based localization of the right iliac vena in the whole body", + "Right iliac venous anatomy in whole body MR" + ], + "30": [ + "Left humerus", + "Left humerus in whole body MR", + "MR imaging of the left humerus in the whole body region", + "Whole body MR showing left humeral structures", + "Visualization of the left humerus in whole body MR scans", + "Left humerus observed in whole body MR imaging", + "MR scan of the whole body highlighting the left humerus", + "Left humerus segmentation in whole body MR", + "Left humerus delineation in whole body MR imaging", + "MR-based identification of the left humerus in the whole body", + "Left humeral anatomy in whole body MR" + ], + "31": [ + "Right humerus", + "Right humerus in whole body MR", + "MR imaging of the right humerus in the whole body region", + "Whole body MR showing right humeral structures", + "Visualization of the right humerus in whole body MR scans", + "Right humerus observed in whole body MR imaging", + "MR scan of the whole body depicting the right humerus", + "Right humerus segmentation in whole body MR", + "Right humerus delineation in whole body MR imaging", + "MR-based localization of the right humerus in the whole body", + "Right humeral anatomy in whole body MR" + ], + "34": [ + "Left femur", + "Left femur in whole body MR", + "MR imaging of the left femur in the whole body region", + "Whole body MR showing left femoral structures", + "Visualization of the left femur in whole body MR scans", + "Left femur observed in whole body MR imaging", + "MR scan of the whole body highlighting the left femur", + "Left femur segmentation in whole body MR", + "Left femur delineation in whole body MR imaging", + "MR-based identification of the left femur in the whole body", + "Left femoral anatomy in whole body MR" + ], + "35": [ + "Right femur", + "Right femur in whole body MR", + "MR imaging of the right femur in the whole body region", + "Whole body MR showing right femoral structures", + "Visualization of the right femur in whole body MR scans", + "Right femur observed in whole body MR imaging", + "MR scan of the whole body depicting the right femur", + "Right femur segmentation in whole body MR", + "Right femur delineation in whole body MR imaging", + "MR-based localization of the right femur in the whole body", + "Right femoral anatomy in whole body MR" + ], + "36": [ + "Left hip", + "Left hip in whole body MR", + "MR imaging of the left hip in the whole body region", + "Whole body MR showing left hip structures", + "Visualization of the left hip in whole body MR scans", + "Left hip observed in whole body MR imaging", + "MR scan of the whole body depicting the left hip", + "Left hip segmentation in whole body MR", + "Left hip delineation in whole body MR imaging", + "MR-based localization of the left hip in the whole body", + "Left hip anatomy in whole body MR" + ], + "37": [ + "Right hip", + "Right hip in whole body MR", + "MR imaging of the right hip in the whole body region", + "Whole body MR showing right hip structures", + "Visualization of the right hip in whole body MR scans", + "Right hip observed in whole body MR imaging", + "MR scan of the whole body highlighting the right hip", + "Right hip segmentation in whole body MR", + "Right hip delineation in whole body MR imaging", + "MR-based identification of the right hip in the whole body", + "Right hip anatomy in whole body MR" + ], + "38": [ + "Left gluteus maximus", + "Left gluteus maximus in whole body MR", + "MR imaging of the left gluteus maximus in the whole body region", + "Whole body MR showing left gluteus maximus structures", + "Visualization of the left gluteus maximus in whole body MR scans", + "Left gluteus maximus observed in whole body MR imaging", + "MR scan of the whole body depicting the left gluteus maximus", + "Left gluteus maximus segmentation in whole body MR", + "Left gluteus maximus delineation in whole body MR imaging", + "MR-based localization of the left gluteus maximus in the whole body", + "Left gluteus maximus anatomy in whole body MR" + ], + "39": [ + "Right gluteus maximus", + "Right gluteus maximus in whole body MR", + "MR imaging of the right gluteus maximus in the whole body region", + "Whole body MR showing right gluteus maximus structures", + "Visualization of the right gluteus maximus in whole body MR scans", + "Right gluteus maximus observed in whole body MR imaging", + "MR scan of the whole body highlighting the right gluteus maximus", + "Right gluteus maximus segmentation in whole body MR", + "Right gluteus maximus delineation in whole body MR imaging", + "MR-based identification of the right gluteus maximus in the whole body", + "Right gluteus maximus anatomy in whole body MR" + ], + "40": [ + "Left gluteus medius", + "Left gluteus medius in whole body MR", + "MR imaging of the left gluteus medius in the whole body region", + "Whole body MR showing left gluteus medius structures", + "Visualization of the left gluteus medius in whole body MR scans", + "Left gluteus medius observed in whole body MR imaging", + "MR scan of the whole body depicting the left gluteus medius", + "Left gluteus medius segmentation in whole body MR", + "Left gluteus medius delineation in whole body MR imaging", + "MR-based localization of the left gluteus medius in the whole body", + "Left gluteus medius anatomy in whole body MR" + ], + "41": [ + "Right gluteus medius", + "Right gluteus medius in whole body MR", + "MR imaging of the right gluteus medius in the whole body region", + "Whole body MR showing right gluteus medius structures", + "Visualization of the right gluteus medius in whole body MR scans", + "Right gluteus medius observed in whole body MR imaging", + "MR scan of the whole body highlighting the right gluteus medius", + "Right gluteus medius segmentation in whole body MR", + "Right gluteus medius delineation in whole body MR imaging", + "MR-based identification of the right gluteus medius in the whole body", + "Right gluteus medius anatomy in whole body MR" + ], + "42": [ + "Left gluteus minimus", + "Left gluteus minimus in whole body MR", + "MR imaging of the left gluteus minimus in the whole body region", + "Whole body MR showing left gluteus minimus structures", + "Visualization of the left gluteus minimus in whole body MR scans", + "Left gluteus minimus observed in whole body MR imaging", + "MR scan of the whole body depicting the left gluteus minimus", + "Left gluteus minimus segmentation in whole body MR", + "Left gluteus minimus delineation in whole body MR imaging", + "MR-based localization of the left gluteus minimus in the whole body", + "Left gluteus minimus anatomy in whole body MR" + ], + "43": [ + "Right gluteus minimus", + "Right gluteus minimus in whole body MR", + "MR imaging of the right gluteus minimus in the whole body region", + "Whole body MR showing right gluteus minimus structures", + "Visualization of the right gluteus minimus in whole body MR scans", + "Right gluteus minimus observed in whole body MR imaging", + "MR scan of the whole body highlighting the right gluteus minimus", + "Right gluteus minimus segmentation in whole body MR", + "Right gluteus minimus delineation in whole body MR imaging", + "MR-based identification of the right gluteus minimus in the whole body", + "Right gluteus minimus anatomy in whole body MR" + ], + "44": [ + "Left autochthon", + "Left autochthon in whole body MR", + "MR imaging of the left autochthon in the whole body region", + "Whole body MR showing left autochthon structures", + "Visualization of the left autochthon in whole body MR scans", + "Left autochthon observed in whole body MR imaging", + "MR scan of the whole body depicting the left autochthon", + "Left autochthon segmentation in whole body MR", + "Left autochthon delineation in whole body MR imaging", + "MR-based localization of the left autochthon in the whole body", + "Left autochthon anatomy in whole body MR" + ], + "45": [ + "Right autochthon", + "Right autochthon in whole body MR", + "MR imaging of the right autochthon in the whole body region", + "Whole body MR showing right autochthon structures", + "Visualization of the right autochthon in whole body MR scans", + "Right autochthon observed in whole body MR imaging", + "MR scan of the whole body highlighting the right autochthon", + "Right autochthon segmentation in whole body MR", + "Right autochthon delineation in whole body MR imaging", + "MR-based identification of the right autochthon in the whole body", + "Right autochthon anatomy in whole body MR" + ], + "46": [ + "Left iliopsoas", + "Left iliopsoas in whole body MR", + "MR imaging of the left iliopsoas in the whole body region", + "Whole body MR showing left iliopsoas structures", + "Visualization of the left iliopsoas in whole body MR scans", + "Left iliopsoas observed in whole body MR imaging", + "MR scan of the whole body depicting the left iliopsoas", + "Left iliopsoas segmentation in whole body MR", + "Left iliopsoas delineation in whole body MR imaging", + "MR-based localization of the left iliopsoas in the whole body", + "Left iliopsoas anatomy in whole body MR" + ], + "47": [ + "Right iliopsoas", + "Right iliopsoas in whole body MR", + "MR imaging of the right iliopsoas in the whole body region", + "Whole body MR showing right iliopsoas structures", + "Visualization of the right iliopsoas in whole body MR scans", + "Right iliopsoas observed in whole body MR imaging", + "MR scan of the whole body highlighting the right iliopsoas", + "Right iliopsoas segmentation in whole body MR", + "Right iliopsoas delineation in whole body MR imaging", + "MR-based identification of the right iliopsoas in the whole body", + "Right iliopsoas anatomy in whole body MR" + ], + "56": [ + "Brain", + "Brain in whole body MR", + "MR imaging of the brain in the whole body region", + "Whole body MR showing brain structures", + "Visualization of the brain in whole body MR scans", + "Brain observed in whole body MR imaging", + "MR scan of the whole body depicting the brain", + "Brain segmentation in whole body MR", + "Brain delineation in whole body MR imaging", + "MR-based localization of the brain in the whole body", + "Brain anatomy in whole body MR" + ], + "instance_label": 0 + }, + "MR_WMH_FLAIR": { + "1": [ + "White matter hyperintensities", + "White matter hyperintensities delineation in brain MR FLAIR imaging", + "Segmentation of white matter hyperintensities on FLAIR brain MRI", + "Detection of white matter hyperintensities within cerebral FLAIR magnetic resonance", + "White matter hyperintensities localization in brain FLAIR MR acquisitions", + "Visualization of white matter hyperintensities via FLAIR MRI protocols in the brain", + "Boundary demarcation of white matter hyperintensities on brain FLAIR MR studies", + "Mapping of white matter hyperintensities in cerebral FLAIR MRI", + "Characterization of white matter hyperintensities using brain FLAIR MR sequences", + "Identification of white matter hyperintensities in brain MR FLAIR datasets" + ], + "instance_label": 1 + }, + "MR_WMH_T1": { + "1": [ + "White matter hyperintensities", + "White matter hyperintensities delineation in brain MR T1 imaging", + "Segmentation of white matter hyperintensities on T1-weighted brain MRI", + "Detection of white matter hyperintensities within cerebral T1 magnetic resonance", + "White matter hyperintensities localization in brain T1 MR acquisitions", + "Visualization of white matter hyperintensities via T1 MRI protocols in the brain", + "Boundary demarcation of white matter hyperintensities on brain T1 MR studies", + "Mapping of white matter hyperintensities in cerebral T1-weighted MRI", + "Characterization of white matter hyperintensities using brain T1 MR sequences", + "Identification of white matter hyperintensities in brain MR T1 datasets" + ], + "instance_label": 1 + }, + "MR_Spider_Vertebrae": { + "1": [ + "Vertebrae", + "Vertebrae in Spine MR", + "MR imaging of the vertebrae in the spinal region", + "Spine MR showing vertebral structures", + "Visualization of the vertebrae in spinal MR scans", + "Vertebrae observed in spinal MR imaging", + "MR scan of the spine depicting the vertebrae", + "Vertebrae segmentation in Spine MR", + "Vertebrae delineation in spinal MR imaging", + "MR-based localization of the vertebrae in the spinal column", + "Vertebrae morphology in spinal MR" + ], + "instance_label": 1 + }, + "MR_Spider_IVD": { + "1": [ + "Intervertebral discs", + "Intervertebral discs in Spine MR", + "MR imaging of the intervertebral discs in the spinal region", + "Spine MR showing IVD structures", + "Visualization of the intervertebral discs in spinal MR scans", + "Intervertebral discs observed in spinal MR imaging", + "MR scan of the spine depicting the intervertebral discs", + "Intervertebral disc segmentation in Spine MR", + "IVD delineation in spinal MR imaging", + "MR-based localization of intervertebral discs in the spinal column" + ], + "instance_label": 1 + }, + "MR_Spider_Spine": { + "1": [ + "Spinal canal", + "Spinal canal in Spine MR", + "MR imaging of the spinal canal in the spinal region", + "Spine MR showing spinal canal structures", + "Visualization of the spinal canal in spinal MR scans", + "Spinal canal observed in spinal MR imaging", + "MR scan of the spine depicting the spinal canal", + "Spinal canal segmentation in Spine MR", + "Canal delineation in spinal MR imaging", + "MR-based localization of the spinal canal in the spinal column" + ], + "instance_label": 1 + }, + "MR_T1c_crossMoDA_Tumor_Cochlea": { + "1": [ + "Intra-meatal region of vestibular schwannoma", + "Intra-meatal region of vestibular schwannoma in brain MR", + "MR imaging of intra-meatal region of vestibular schwannoma in brain", + "Visualization of intra-meatal region of vestibular schwannoma in brain MR", + "Intra-meatal region of vestibular schwannoma observed in brain MR", + "Segmentation of intra-meatal region of vestibular schwannoma in brain MR", + "Delineation of intra-meatal region of vestibular schwannoma in brain MR imaging", + "Identification of intra-meatal region of vestibular schwannoma in brain MR", + "Localization of intra-meatal region of vestibular schwannoma in brain MR", + "Characterization of intra-meatal region of vestibular schwannoma in brain MR" + ], + "2": [ + "Extra-meatal region of vestibular schwannoma", + "Extra-meatal region of vestibular schwannoma in brain MR", + "MR imaging of extra-meatal region of vestibular schwannoma in brain", + "Visualization of extra-meatal region of vestibular schwannoma in brain MR", + "Extra-meatal region of vestibular schwannoma observed in brain MR", + "Segmentation of extra-meatal region of vestibular schwannoma in brain MR", + "Delineation of extra-meatal region of vestibular schwannoma in brain MR imaging", + "Identification of extra-meatal region of vestibular schwannoma in brain MR", + "Localization of extra-meatal region of vestibular schwannoma in brain MR", + "Characterization of extra-meatal region of vestibular schwannoma in brain MR" + ], + "4": [ + "Right cochlea", + "Right cochlea in brain MR", + "MR imaging of right cochlea in brain", + "Visualization of right cochlea in brain MR", + "Right cochlea observed in brain MR", + "Segmentation of right cochlea in brain MR", + "Delineation of right cochlea in brain MR imaging", + "Identification of right cochlea in brain MR", + "Localization of right cochlea in brain MR" + ], + "5": [ + "Left cochlea", + "Left cochlea in brain MR", + "MR imaging of left cochlea in brain", + "Visualization of left cochlea in brain MR", + "Left cochlea observed in brain MR", + "Segmentation of left cochlea in brain MR", + "Delineation of left cochlea in brain MR imaging", + "Identification of left cochlea in brain MR", + "Localization of left cochlea in brain MR" + ], + "instance_label": 0 + }, + "MR_BraTS-T1n": { + "1": [ + "Non-enhancing tumor core", + "Non-enhancing tumor core in head MR naive T1", + "MR naive T1 imaging of non-enhancing tumor core in head", + "Visualization of non-enhancing tumor core in head MR naive T1", + "Non-enhancing tumor core observed in head MR naive T1", + "Segmentation of non-enhancing tumor core in head MR naive T1", + "Delineation of non-enhancing tumor core in head MR naive T1 imaging", + "Identification of non-enhancing tumor core in head MR naive T1", + "Localization of non-enhancing tumor core in head MR naive T1", + "Characterization of non-enhancing tumor core in head MR naive T1" + ], + "2": [ + "Surrounding non-enhancing FLAIR hyperintensity", + "Surrounding non-enhancing FLAIR hyperintensity in head MR naive T1", + "MR naive T1 imaging of surrounding non-enhancing FLAIR hyperintensity in head", + "Visualization of surrounding non-enhancing FLAIR hyperintensity in head MR naive T1", + "Surrounding non-enhancing FLAIR hyperintensity observed in head MR naive T1", + "Segmentation of surrounding non-enhancing FLAIR hyperintensity in head MR naive T1", + "Delineation of surrounding non-enhancing FLAIR hyperintensity in head MR naive T1 imaging", + "Identification of surrounding non-enhancing FLAIR hyperintensity in head MR naive T1", + "Localization of surrounding non-enhancing FLAIR hyperintensity in head MR naive T1", + "Characterization of surrounding non-enhancing FLAIR hyperintensity in head MR naive T1" + ], + "3": [ + "Enhancing tissue", + "Enhancing tissue in head MR naive T1", + "MR naive T1 imaging of enhancing tissue in head", + "Visualization of enhancing tissue in head MR naive T1", + "Enhancing tissue observed in head MR naive T1", + "Segmentation of enhancing tissue in head MR naive T1", + "Delineation of enhancing tissue in head MR naive T1 imaging", + "Identification of enhancing tissue in head MR naive T1", + "Localization of enhancing tissue in head MR naive T1", + "Characterization of enhancing tissue in head MR naive T1" + ], + "4": [ + "Resection cavity", + "Resection cavity in head MR naive T1", + "MR naive T1 imaging of resection cavity in head", + "Visualization of resection cavity in head MR naive T1", + "Resection cavity observed in head MR naive T1", + "Segmentation of resection cavity in head MR naive T1", + "Delineation of resection cavity in head MR naive T1 imaging", + "Identification of resection cavity in head MR naive T1", + "Localization of resection cavity in head MR naive T1" + ], + "instance_label": 0 + }, + "MR_BraTS-T1c": { + "1": [ + "Non-enhancing tumor core", + "Non-enhancing tumor core in head post-contrast T1-weighted MR", + "Post-contrast T1-weighted MR imaging of non-enhancing tumor core in head", + "Visualization of non-enhancing tumor core in head post-contrast T1-weighted MR", + "Non-enhancing tumor core observed in head post-contrast T1-weighted MR", + "Segmentation of non-enhancing tumor core in head post-contrast T1-weighted MR", + "Delineation of non-enhancing tumor core in head post-contrast T1-weighted MR imaging", + "Identification of non-enhancing tumor core in head post-contrast T1-weighted MR", + "Localization of non-enhancing tumor core in head post-contrast T1-weighted MR", + "Characterization of non-enhancing tumor core in head post-contrast T1-weighted MR" + ], + "2": [ + "Surrounding non-enhancing FLAIR hyperintensity", + "Surrounding non-enhancing FLAIR hyperintensity in head post-contrast T1-weighted MR", + "Post-contrast T1-weighted MR imaging of surrounding non-enhancing FLAIR hyperintensity in head", + "Visualization of surrounding non-enhancing FLAIR hyperintensity in head post-contrast T1-weighted MR", + "Surrounding non-enhancing FLAIR hyperintensity observed in head post-contrast T1-weighted MR", + "Segmentation of surrounding non-enhancing FLAIR hyperintensity in head post-contrast T1-weighted MR", + "Delineation of surrounding non-enhancing FLAIR hyperintensity in head post-contrast T1-weighted MR imaging", + "Identification of surrounding non-enhancing FLAIR hyperintensity in head post-contrast T1-weighted MR", + "Localization of surrounding non-enhancing FLAIR hyperintensity in head post-contrast T1-weighted MR", + "Characterization of surrounding non-enhancing FLAIR hyperintensity in head post-contrast T1-weighted MR" + ], + "3": [ + "Enhancing tissue", + "Enhancing tissue in head post-contrast T1-weighted MR", + "Post-contrast T1-weighted MR imaging of enhancing tissue in head", + "Visualization of enhancing tissue in head post-contrast T1-weighted MR", + "Enhancing tissue observed in head post-contrast T1-weighted MR", + "Segmentation of enhancing tissue in head post-contrast T1-weighted MR", + "Delineation of enhancing tissue in head post-contrast T1-weighted MR imaging", + "Identification of enhancing tissue in head post-contrast T1-weighted MR", + "Localization of enhancing tissue in head post-contrast T1-weighted MR", + "Characterization of enhancing tissue in head post-contrast T1-weighted MR" + ], + "4": [ + "Resection cavity", + "Resection cavity in head post-contrast T1-weighted MR", + "Post-contrast T1-weighted MR imaging of resection cavity in head", + "Visualization of resection cavity in head post-contrast T1-weighted MR", + "Resection cavity observed in head post-contrast T1-weighted MR", + "Segmentation of resection cavity in head post-contrast T1-weighted MR", + "Delineation of resection cavity in head post-contrast T1-weighted MR imaging", + "Identification of resection cavity in head post-contrast T1-weighted MR", + "Localization of resection cavity in head post-contrast T1-weighted MR" + ], + "instance_label": 0 + }, + "MR_BraTS-T2f": { + "1": [ + "Non-enhancing tumor core", + "Non-enhancing tumor core in head MR T2 FLAIR", + "MR T2 FLAIR imaging of non-enhancing tumor core in head", + "Visualization of non-enhancing tumor core in head MR T2 FLAIR", + "Non-enhancing tumor core observed in head MR T2 FLAIR", + "Segmentation of non-enhancing tumor core in head MR T2 FLAIR", + "Delineation of non-enhancing tumor core in head MR T2 FLAIR imaging", + "Identification of non-enhancing tumor core in head MR T2 FLAIR", + "Localization of non-enhancing tumor core in head MR T2 FLAIR", + "Characterization of non-enhancing tumor core in head MR T2 FLAIR" + ], + "2": [ + "Surrounding non-enhancing FLAIR hyperintensity", + "Surrounding non-enhancing FLAIR hyperintensity in head MR T2 FLAIR", + "MR T2 FLAIR imaging of surrounding non-enhancing FLAIR hyperintensity in head", + "Visualization of surrounding non-enhancing FLAIR hyperintensity in head MR T2 FLAIR", + "Surrounding non-enhancing FLAIR hyperintensity observed in head MR T2 FLAIR", + "Segmentation of surrounding non-enhancing FLAIR hyperintensity in head MR T2 FLAIR", + "Delineation of surrounding non-enhancing FLAIR hyperintensity in head MR T2 FLAIR imaging", + "Identification of surrounding non-enhancing FLAIR hyperintensity in head MR T2 FLAIR", + "Localization of surrounding non-enhancing FLAIR hyperintensity in head MR T2 FLAIR", + "Characterization of surrounding non-enhancing FLAIR hyperintensity in head MR T2 FLAIR" + ], + "3": [ + "Enhancing tissue", + "Enhancing tissue in head MR T2 FLAIR", + "MR T2 FLAIR imaging of enhancing tissue in head", + "Visualization of enhancing tissue in head MR T2 FLAIR", + "Enhancing tissue observed in head MR T2 FLAIR", + "Segmentation of enhancing tissue in head MR T2 FLAIR", + "Delineation of enhancing tissue in head MR T2 FLAIR imaging", + "Identification of enhancing tissue in head MR T2 FLAIR", + "Localization of enhancing tissue in head MR T2 FLAIR", + "Characterization of enhancing tissue in head MR T2 FLAIR" + ], + "4": [ + "Resection cavity", + "Resection cavity in head MR T2 FLAIR", + "MR T2 FLAIR imaging of resection cavity in head", + "Visualization of resection cavity in head MR T2 FLAIR", + "Resection cavity observed in head MR T2 FLAIR", + "Segmentation of resection cavity in head MR T2 FLAIR", + "Delineation of resection cavity in head MR T2 FLAIR imaging", + "Identification of resection cavity in head MR T2 FLAIR", + "Localization of resection cavity in head MR T2 FLAIR" + ], + "instance_label": 0 + }, + "MR_BraTS-T2w": { + "1": [ + "Non-enhancing tumor core", + "Non-enhancing tumor core in head T2 weighted MR", + "T2 weighted MR imaging of non-enhancing tumor core in head", + "Visualization of non-enhancing tumor core in head T2 weighted MR", + "Non-enhancing tumor core observed in head T2 weighted MR", + "Segmentation of non-enhancing tumor core in head T2 weighted MR", + "Delineation of non-enhancing tumor core in head T2 weighted MR imaging", + "Identification of non-enhancing tumor core in head T2 weighted MR", + "Localization of non-enhancing tumor core in head T2 weighted MR", + "Characterization of non-enhancing tumor core in head T2 weighted MR" + ], + "2": [ + "Surrounding non-enhancing FLAIR hyperintensity", + "Surrounding non-enhancing FLAIR hyperintensity in head T2 weighted MR", + "T2 weighted MR imaging of surrounding non-enhancing FLAIR hyperintensity in head", + "Visualization of surrounding non-enhancing FLAIR hyperintensity in head T2 weighted MR", + "Surrounding non-enhancing FLAIR hyperintensity observed in head T2 weighted MR", + "Segmentation of surrounding non-enhancing FLAIR hyperintensity in head T2 weighted MR", + "Delineation of surrounding non-enhancing FLAIR hyperintensity in head T2 weighted MR imaging", + "Identification of surrounding non-enhancing FLAIR hyperintensity in head T2 weighted MR", + "Localization of surrounding non-enhancing FLAIR hyperintensity in head T2 weighted MR", + "Characterization of surrounding non-enhancing FLAIR hyperintensity in head T2 weighted MR" + ], + "3": [ + "Enhancing tissue", + "Enhancing tissue in head T2 weighted MR", + "T2 weighted MR imaging of enhancing tissue in head", + "Visualization of enhancing tissue in head T2 weighted MR", + "Enhancing tissue observed in head T2 weighted MR", + "Segmentation of enhancing tissue in head T2 weighted MR", + "Delineation of enhancing tissue in head T2 weighted MR imaging", + "Identification of enhancing tissue in head T2 weighted MR", + "Localization of enhancing tissue in head T2 weighted MR", + "Characterization of enhancing tissue in head T2 weighted MR" + ], + "4": [ + "Resection cavity", + "Resection cavity in head T2 weighted MR", + "T2 weighted MR imaging of resection cavity in head", + "Visualization of resection cavity in head T2 weighted MR", + "Resection cavity observed in head T2 weighted MR", + "Segmentation of resection cavity in head T2 weighted MR", + "Delineation of resection cavity in head T2 weighted MR imaging", + "Identification of resection cavity in head T2 weighted MR", + "Localization of resection cavity in head T2 weighted MR" + ], + "instance_label": 0 + }, + "MR_TopCoW24": { + "1": [ + "Circle of Willis", + "Circle of Willis in Brain Magnetic Resonance Angiography", + "Magnetic Resonance Angiography imaging of Circle of Willis in Brain", + "Visualization of Circle of Willis in Brain Magnetic Resonance Angiography", + "Circle of Willis observed in Brain Magnetic Resonance Angiography", + "Segmentation of Circle of Willis in Brain Magnetic Resonance Angiography", + "Delineation of Circle of Willis in Brain Magnetic Resonance Angiography imaging", + "Identification of Circle of Willis vascular architecture in Brain Magnetic Resonance Angiography", + "Localization of Circle of Willis branches in Brain Magnetic Resonance Angiography", + "Characterization of Circle of Willis flow dynamics in Brain Magnetic Resonance Angiography" + ], + "instance_label": 1 + }, + "MR_SegThy": { + "1": [ + "Thyroid", + "Thyroid in Neck MR", + "MR imaging of thyroid in Neck", + "Visualization of thyroid in Neck MR", + "Thyroid observed in Neck MR", + "Segmentation of thyroid in Neck MR", + "Delineation of thyroid in Neck MR imaging", + "Identification of thyroid in Neck MR", + "Localization of thyroid in Neck MR", + "Characterization of thyroid in Neck MR" + ], + "2": [ + "Left carotid artery", + "Left carotid artery in Neck MR", + "MR imaging of left carotid artery in Neck", + "Visualization of left carotid artery in Neck MR", + "Left carotid artery observed in Neck MR", + "Segmentation of left carotid artery in Neck MR", + "Delineation of left carotid artery in Neck MR imaging", + "Identification of left carotid artery in Neck MR", + "Localization of left carotid artery in Neck MR" + ], + "3": [ + "Left jugular vein", + "Left jugular vein in Neck MR", + "MR imaging of left jugular vein in Neck", + "Visualization of left jugular vein in Neck MR", + "Left jugular vein observed in Neck MR", + "Segmentation of left jugular vein in Neck MR", + "Delineation of left jugular vein in Neck MR imaging", + "Identification of left jugular vein in Neck MR", + "Localization of left jugular vein in Neck MR" + ], + "4": [ + "Right carotid artery", + "Right carotid artery in Neck MR", + "MR imaging of right carotid artery in Neck", + "Visualization of right carotid artery in Neck MR", + "Right carotid artery observed in Neck MR", + "Segmentation of right carotid artery in Neck MR", + "Delineation of right carotid artery in Neck MR imaging", + "Identification of right carotid artery in Neck MR", + "Localization of right carotid artery in Neck MR", + "Vascular architecture of right carotid artery in Neck MR" + ], + "5": [ + "Right jugular vein", + "Right jugular vein in Neck MR", + "MR imaging of right jugular vein in Neck", + "Visualization of right jugular vein in Neck MR", + "Right jugular vein observed in Neck MR", + "Segmentation of right jugular vein in Neck MR", + "Delineation of right jugular vein in Neck MR imaging", + "Identification of right jugular vein in Neck MR", + "Localization of right jugular vein in Neck MR", + "Venous structure of right jugular vein in Neck MR" + ], + "instance_label": 0 + }, + "MR_HVSMR": { + "1": [ + "Left Ventricle", + "Left ventricle in Heart MR", + "MR imaging of left ventricle in Heart", + "Visualization of left ventricle in Heart MR", + "Left ventricle observed in Heart MR", + "Segmentation of left ventricle in Heart MR", + "Delineation of left ventricle in Heart MR imaging", + "Identification of left ventricle in Heart MR", + "Localization of left ventricle in Heart MR" + ], + "2": [ + "Right Ventricle", + "Right ventricle in Heart MR", + "MR imaging of right ventricle in Heart", + "Visualization of right ventricle in Heart MR", + "Right ventricle observed in Heart MR", + "Segmentation of right ventricle in Heart MR", + "Delineation of right ventricle in Heart MR imaging", + "Identification of right ventricle in Heart MR", + "Localization of right ventricle in Heart MR" + ], + "3": [ + "Left Atrium", + "Left atrium in Heart MR", + "MR imaging of left atrium in Heart", + "Visualization of left atrium in Heart MR", + "Left atrium observed in Heart MR", + "Segmentation of left atrium in Heart MR", + "Delineation of left atrium in Heart MR imaging", + "Identification of left atrium in Heart MR", + "Localization of left atrium in Heart MR" + ], + "4": [ + "Right Atrium", + "Right atrium in Heart MR", + "MR imaging of right atrium in Heart", + "Visualization of right atrium in Heart MR", + "Right atrium observed in Heart MR", + "Segmentation of right atrium in Heart MR", + "Delineation of right atrium in Heart MR imaging", + "Identification of right atrium in Heart MR", + "Localization of right atrium in Heart MR", + "Right atrial chamber in Heart MR" + ], + "5": [ + "Aorta", + "Aorta in Heart MR", + "MR imaging of aorta in Heart", + "Visualization of aorta in Heart MR", + "Aorta observed in Heart MR", + "Segmentation of aorta in Heart MR", + "Delineation of aorta in Heart MR imaging", + "Identification of aortic root in Heart MR", + "Localization of aorta in Heart MR" + ], + "6": [ + "Pulmonary Artery", + "Pulmonary artery in Heart MR", + "MR imaging of pulmonary artery in Heart", + "Visualization of pulmonary artery in Heart MR", + "Pulmonary artery observed in Heart MR", + "Segmentation of pulmonary artery in Heart MR", + "Delineation of pulmonary artery in Heart MR imaging", + "Identification of pulmonary artery in Heart MR", + "Localization of pulmonary artery in Heart MR" + ], + "7": [ + "Superior vena cava", + "Superior vena cava in Heart MR", + "MR imaging of superior vena cava in Heart", + "Visualization of superior vena cava in Heart MR", + "Superior vena cava observed in Heart MR", + "Segmentation of superior vena cava in Heart MR", + "Delineation of superior vena cava in Heart MR imaging", + "Identification of superior vena cava in Heart MR", + "Localization of superior vena cava in Heart MR" + ], + "8": [ + "Inferior vena cava", + "Inferior vena cava in Heart MR", + "MR imaging of inferior vena cava in Heart", + "Visualization of inferior vena cava in Heart MR", + "Inferior vena cava observed in Heart MR", + "Segmentation of inferior vena cava in Heart MR", + "Delineation of inferior vena cava in Heart MR imaging", + "Identification of inferior vena cava in Heart MR", + "Localization of inferior vena cava in Heart MR", + "Inferior vena cava hepatic segment in Heart MR" + ], + "instance_label": 0 + }, + "MR_IBD": { + "1": [ + "Stomach", + "Stomach in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of stomach in abdomen", + "Visualization of stomach in abdomen Magnetic Resonance Enterography", + "Stomach observed in abdomen Magnetic Resonance Enterography", + "Segmentation of stomach in abdomen Magnetic Resonance Enterography", + "Delineation of stomach in abdomen Magnetic Resonance Enterography imaging", + "Identification of stomach in abdomen Magnetic Resonance Enterography", + "Localization of stomach in abdomen Magnetic Resonance Enterography" + ], + "2": [ + "Duodenum", + "Duodenum in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of duodenum in abdomen", + "Visualization of duodenum in abdomen Magnetic Resonance Enterography", + "Duodenum observed in abdomen Magnetic Resonance Enterography", + "Segmentation of duodenum in abdomen Magnetic Resonance Enterography", + "Delineation of duodenal folds in abdomen Magnetic Resonance Enterography imaging", + "Identification of duodenum in abdomen Magnetic Resonance Enterography" + ], + "3": [ + "Small Intestine", + "Small intestine in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of small intestine in abdomen", + "Visualization of small intestine in abdomen Magnetic Resonance Enterography", + "Small intestine observed in abdomen Magnetic Resonance Enterography", + "Segmentation of small intestine in abdomen Magnetic Resonance Enterography", + "Delineation of small bowel loops in abdomen Magnetic Resonance Enterography imaging", + "Identification of small intestine in abdomen Magnetic Resonance Enterography" + ], + "4": [ + "Appendix", + "Appendix in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of appendix in abdomen", + "Visualization of appendix in abdomen Magnetic Resonance Enterography", + "appendix observed in abdomen Magnetic Resonance Enterography", + "Segmentation of appendix in abdomen Magnetic Resonance Enterography", + "Delineation of appendiceal lumen in abdomen Magnetic Resonance Enterography imaging", + "Identification of appendix in abdomen Magnetic Resonance Enterography", + "Localization of the appendix in abdomen Magnetic Resonance Enterography" + ], + "5": [ + "Cecum", + "Cecum in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of cecum in abdomen", + "Visualization of cecum in abdomen Magnetic Resonance Enterography", + "cecum observed in abdomen Magnetic Resonance Enterography", + "Segmentation of cecum in abdomen Magnetic Resonance Enterography", + "Delineation of cecal folds in abdomen Magnetic Resonance Enterography imaging", + "Identification of cecum in abdomen Magnetic Resonance Enterography" + ], + "6": [ + "Ascending colon", + "Ascending colon in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of ascending colon in abdomen", + "Visualization of ascending colon in abdomen Magnetic Resonance Enterography", + "ascending colon observed in abdomen Magnetic Resonance Enterography", + "Segmentation of ascending colon in abdomen Magnetic Resonance Enterography", + "Delineation of ascending colon haustra in abdomen Magnetic Resonance Enterography imaging", + "Identification of ascending colon in abdomen Magnetic Resonance Enterography" + ], + "7": [ + "Transverse colon", + "Transverse colon in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of transverse colon in abdomen", + "Visualization of transverse colon in abdomen Magnetic Resonance Enterography", + "transverse colon observed in abdomen Magnetic Resonance Enterography", + "Segmentation of transverse colon in abdomen MRE", + "Delineation of transverse colon taenia coli in abdomen Magnetic Resonance Enterography imaging", + "Identification of transverse colon in abdomen Magnetic Resonance Enterography", + "Localization of transverse colon in abdomen Magnetic Resonance Enterography", + "Transverse colon in abdomen Magnetic Resonance Enterography" + ], + "8": [ + "Descending colon", + "Descending colon in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of descending colon in abdomen", + "Visualization of descending colon in abdomen Magnetic Resonance Enterography", + "descending colon observed in abdomen Magnetic Resonance Enterography", + "Segmentation of descending colon in abdomen Magnetic Resonance Enterography", + "Delineation of descending colon in abdomen Magnetic Resonance Enterography imaging", + "Identification of descending colon in abdomen Magnetic Resonance Enterography", + "Localization of descending colon in abdomen Magnetic Resonance Enterography" + ], + "9": [ + "Sigmoid colon", + "Sigmoid colon in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of sigmoid colon in abdomen", + "Visualization of sigmoid colon in abdomen Magnetic Resonance Enterography", + "sigmoid colon observed in abdomen Magnetic Resonance Enterography", + "Segmentation of sigmoid colon in abdomen Magnetic Resonance Enterography", + "Delineation of sigmoid colon tortuosity in abdomen Magnetic Resonance Enterography imaging", + "Identification of sigmoid colon in abdomen Magnetic Resonance Enterography" + ], + "10": [ + "Rectum", + "Rectum in abdomen Magnetic Resonance Enterography", + "Magnetic Resonance Enterography imaging of rectum in abdomen", + "Visualization of rectum in abdomen Magnetic Resonance Enterography", + "Rectum observed in abdomen Magnetic Resonance Enterography", + "Segmentation of rectum in abdomen Magnetic Resonance Enterography", + "Delineation of rectum in abdomen Magnetic Resonance Enterography imaging", + "Identification of rectum in abdomen Magnetic Resonance Enterography" + ], + "instance_label": 0 + }, + "MR_RESECT": { + "1": [ + "Brain tumor", + "Brain tumor segmented in brain MR imaging", + "Brain tumor detected within cerebral tissue via MRI", + "Brain tumor visualized in intracranial magnetic resonance imaging", + "Brain tumor identified in brain MRI", + "Brain tumor resolved in cerebral MR scans", + "Brain tumor localized in brain using MR", + "Brain tumor appearing in brain magnetic resonance imaging", + "Brain tumor delineated in cerebral tissue with MR imaging", + "Brain tumor observed in brain via MRI" + ], + "instance_label": 1 + }, + "US_Low-limb-Leg": { + "1": [ + "Soleus", + "Soleus in low-limb leg Ultrasound", + "Ultrasound imaging of the Soleus in the low-limb leg region", + "Low-limb leg Ultrasound showing Soleus structures", + "Visualization of the Soleus in low-limb leg Ultrasound scans", + "Soleus observed in low-limb leg Ultrasound imaging", + "Ultrasound scan of the low-limb leg depicting the Soleus", + "Soleus segmentation in low-limb leg Ultrasound", + "Soleus delineation in low-limb leg Ultrasound imaging", + "Ultrasound-based localization of the Soleus in the low-limb leg" + ], + "2": [ + "Gastrocnemius Medialis", + "Gastrocnemius Medialis in low-limb leg Ultrasound", + "Ultrasound imaging of the Gastrocnemius Medialis in the low-limb leg region", + "Low-limb leg Ultrasound showing Gastrocnemius Medialis structures", + "Visualization of the Gastrocnemius Medialis in low-limb leg Ultrasound scans", + "Gastrocnemius Medialis observed in low-limb leg Ultrasound imaging", + "Ultrasound scan of the low-limb leg depicting the Gastrocnemius Medialis", + "Gastrocnemius Medialis segmentation in low-limb leg Ultrasound", + "Gastrocnemius Medialis delineation in low-limb leg Ultrasound imaging", + "Ultrasound-based localization of the Gastrocnemius Medialis in the low-limb leg" + ], + "3": [ + "Gastrocnemius Lateralis", + "Gastrocnemius Lateralis in low-limb leg Ultrasound", + "Ultrasound imaging of the Gastrocnemius Lateralis in the low-limb leg region", + "Low-limb leg Ultrasound showing Gastrocnemius Lateralis structures", + "Visualization of the Gastrocnemius Lateralis in low-limb leg Ultrasound scans", + "Gastrocnemius Lateralis observed in low-limb leg Ultrasound imaging", + "Ultrasound scan of the low-limb leg highlighting the Gastrocnemius Lateralis", + "Gastrocnemius Lateralis segmentation in low-limb leg Ultrasound", + "Gastrocnemius Lateralis delineation in low-limb leg Ultrasound imaging", + "Ultrasound-based identification of the Gastrocnemius Lateralis in the low-limb leg" + ], + "instance_label": 0 + }, + "US_Cardiac": { + "1": [ + "Left Ventricle", + "Left ventricle in heart Ultrasound", + "Echocardiography imaging of the left ventricle in the cardiac region", + "Visualization of the left ventricle in cardiac Ultrasound scans", + "Left ventricle observed in heart echocardiography", + "Echocardiographic segmentation of the left ventricle", + "Cardiac Ultrasound depicting left ventricular chamber" + ], + "2": [ + "Myocardium", + "Myocardium in heart Ultrasound", + "Echocardiography imaging of the myocardial tissue in the cardiac region", + "Visualization of the myocardium in cardiac Ultrasound scans", + "Myocardium observed in heart echocardiography", + "Echocardiographic delineation of myocardium" + ], + "3": [ + "Left Atrium", + "Left atrium in heart Ultrasound", + "Echocardiography imaging of the left atrium in the cardiac region", + "Visualization of the left atrium in cardiac Ultrasound scans", + "Left atrial chamber observed in heart echocardiography", + "Echocardiographic segmentation of the left atrial cavity", + "Ultrasound depicting left atrium" + ], + "instance_label": 0 + }, + "US_SegThy": { + "1": [ + "Thyroid", + "Thyroid in Neck Ultrasound", + "Ultrasound imaging of thyroid in Neck", + "Visualization of thyroid in Neck Ultrasound", + "thyroid observed in Neck Ultrasound", + "Segmentation of thyroid in Neck Ultrasound", + "Delineation of thyroid in Neck Ultrasound imaging", + "Identification of thyroid in Neck Ultrasound", + "Localization of thyroid in Neck Ultrasound", + "Morphology of thyroid in Neck Ultrasound" + ], + "2": [ + "Carotid artery", + "Carotid artery in Neck Ultrasound", + "Ultrasound imaging of carotid artery in Neck", + "Visualization of carotid artery in Neck Ultrasound", + "Carotid artery observed in Neck Ultrasound", + "Segmentation of carotid artery in Neck Ultrasound", + "Delineation of carotid artery in Neck Ultrasound imaging", + "Identification of carotid artery in Neck Ultrasound", + "Localization of carotid artery in Neck Ultrasound", + "Vascular architecture of carotid artery in Neck Ultrasound" + ], + "3": [ + "Jugular vein", + "Jugular vein in Neck Ultrasound", + "Ultrasound imaging of jugular vein in Neck", + "Visualization of jugular vein in Neck Ultrasound", + "Jugular vein observed in Neck Ultrasound", + "Segmentation of jugular vein in Neck Ultrasound", + "Delineation of jugular vein in Neck Ultrasound imaging", + "Identification of jugular vein in Neck Ultrasound", + "Localization of jugular vein in Neck Ultrasound", + "Venous structure of jugular vein in Neck Ultrasound" + ], + "instance_label": 0 + }, + "US_RESECT": { + "1": [ + "Brain tumor", + "Brain tumor detected in brain via transcranial ultrasound", + "Brain tumor visualized in cerebral tissue using ultrasound imaging", + "Brain tumor identified in intracranial region through ultrasound", + "Brain tumor localized in brain under ultrasound examination", + "Brain tumor imaged in brain through transcranial ultrasonic imaging", + "Brain tumor observed in cerebral ultrasound", + "Brain tumor appearing in intracranial ultrasound scans", + "Brain tumor delineated in cerebral tissue with transcranial ultrasound", + "Brain tumor captured in brain via ultrasound" + ], + "instance_label": 1 + }, + "autoPET": { + "1": [ + "Lesion", + "Lesion in whole body PET", + "PET imaging of the lesion in the whole body region", + "Visualization of the lesion in whole body PET scans", + "Lesion observed in whole body PET imaging", + "Lesion segmentation in whole body PET", + "Lesion delineation in whole body PET imaging", + "PET-based localization of the lesion in the whole body", + "Detection of lesion metabolic activity via whole body PET", + "Quantitative analysis of lesion FDG uptake in whole body PET" + ], + "instance_label": 1 + }, + "Microscopy_urocell_Endolysosomes": { + "1": [ + "Endolysosomes", + "Endolysosomes in Urinary bladders Electron Microscopy", + "Electron Microscopy imaging of endolysosomes in urinary bladder tissue", + "Visualization of endolysosomal structures in Urinary bladders via Electron Microscopy", + "Endolysosomes observed in Urinary bladder ultrastructural analysis", + "Electron Microscopy-based segmentation of endolysosomes in Urinary bladders", + "Delineation of endolysosomal in Urinary bladders using Electron Microscopy" + ], + "instance_label": 1 + }, + "Microscopy_urocell_Mitochondria": { + "1": [ + "Mitochondria", + "Mitochondria in Urinary bladders Electron microscopy", + "Electron microscopy imaging of Mitochondria in Urinary bladders", + "Visualization of Mitochondria in Urinary bladders Electron microscopy", + "Mitochondria observed in Urinary bladders Electron microscopy", + "Mitochondria segmentation in Urinary bladders Electron microscopy", + "Mitochondria delineation in Urinary bladders Electron microscopy imaging", + "Mitochondria identification in Urinary bladders Electron microscopy", + "Mitochondria localization in Urinary bladders Electron microscopy" + ], + "instance_label": 1 + }, + "Microscopy_cremi": { + "1": [ + "Synaptic clefts", + "Synaptic clefts in Drosophila brain Electron microscopy", + "Electron microscopy imaging of Synaptic clefts in Drosophila brain", + "Visualization of Synaptic clefts in Drosophila brain Electron microscopy", + "Synaptic clefts observed in Drosophila brain Electron microscopy", + "Synaptic clefts segmentation in Drosophila brain Electron microscopy", + "Synaptic clefts delineation in Drosophila brain Electron microscopy imaging", + "Synaptic clefts identification in Drosophila brain Electron microscopy", + "Synaptic clefts localization in Drosophila brain Electron microscopy", + "Synaptic clefts ultrastructure in Drosophila brain Electron microscopy", + "Synaptic clefts architecture in Drosophila brain Electron microscopy" + ], + "instance_label": 1 + }, + "Microscopy_nucmm": { + "1": [ + "Nuclei", + "Nuclei in zebrafish brain Electron microscopy", + "Electron microscopy imaging of Nuclei in zebrafish brain", + "Visualization of Nuclei in zebrafish brain Electron microscopy", + "Nuclei observed in zebrafish brain Electron microscopy", + "Nuclei segmentation in zebrafish brain Electron microscopy", + "Nuclei delineation in zebrafish brain Electron microscopy imaging", + "Nuclei identification in zebrafish brain Electron microscopy", + "Nuclei localization in zebrafish brain Electron microscopy" + ], + "instance_label": 1 + }, + "Microscopy_SELMA3D_neural_activity_marker": { + "1": [ + "Neural activity", + "Neural activity in brain Light-sheet Microscopy", + "Light-sheet Microscopy imaging of Neural activity in brain", + "Visualization of Neural activity in brain Light-sheet Microscopy", + "Neural activity observed in brain Light-sheet Microscopy", + "Neural activity delineation in brain Light-sheet Microscopy imaging", + "Neural activity identification in brain Light-sheet Microscopy", + "Neural activity localization in brain Light-sheet Microscopy" + ], + "instance_label": 1 + }, + "Microscopy_SELMA3D_ADplaques": { + "1": [ + "Alzheimer's disease plaque", + "Alzheimer's disease plaque in brain Light-sheet Microscopy", + "Light-sheet Microscopy imaging of Alzheimer's disease plaque in brain", + "Visualization of Alzheimer's disease plaque in brain Light-sheet Microscopy", + "Alzheimer's disease plaque observed in brain Light-sheet Microscopy", + "Alzheimer's disease plaque delineation in brain Light-sheet Microscopy imaging", + "Alzheimer's disease plaque identification in brain Light-sheet Microscopy", + "Alzheimer's disease plaque localization in brain Light-sheet Microscopy" + ], + "instance_label": 1 + }, + "Microscopy_SELMA3D_nuceus": { + "1": [ + "Nucleus", + "Nucleus in brain Light-sheet Microscopy", + "Light-sheet Microscopy imaging of Nucleus in brain", + "Visualization of Nucleus in brain Light-sheet Microscopy", + "Nucleus observed in brain Light-sheet Microscopy", + "Nucleus segmentation in brain Light-sheet Microscopy", + "Nucleus delineation in brain Light-sheet Microscopy imaging", + "Nucleus identification in brain Light-sheet Microscopy", + "Nucleus localization in brain Light-sheet Microscopy" + ], + "instance_label": 1 + }, + "Microscopy_SELMA3D_vessel": { + "1": [ + "Vessel", + "Vessel in brain Light-sheet Microscopy", + "Light-sheet Microscopy imaging of Vessel in brain", + "Visualization of Vessel in brain Light-sheet Microscopy", + "Vessel observed in brain Light-sheet Microscopy", + "Vessel segmentation in brain Light-sheet Microscopy", + "Vessel delineation in brain Light-sheet Microscopy imaging", + "Vessel identification in brain Light-sheet Microscopy", + "Vessel localization in brain Light-sheet Microscopy" + ], + "instance_label": 1 + }, + "Microscopy_MUCIC-HL60": { + "1": [ + "Cell nuclei", + "Microscopy imaging of Cell nuclei in HL60 cell line", + "Visualization of Cell nuclei in HL60 cell line microscopic images", + "Cell nuclei observed in HL60 cell line microscopic images", + "Segmentation of Cell nuclei in HL60 cell line microscopic images", + "Delineation of Cell nuclei in HL60 cell line Microscopy imaging", + "Identification of Cell nuclei in HL60 cell line microscopic images", + "Cell nuclei observed in HL60 cell line microscopy", + "Microscopic imaging of nuclei in the HL60 cell system", + "Cell nuclei imaged in HL60 cell line through microscopy", + "Cell nuclei identified in HL60 cells during microscopic examination" + ], + "instance_label": 1 + }, + "Microscopy_AxonEM": { + "1": [ + "Axon", + "Axon observed in cortical tissue via electron microscopy", + "Axon visualized in cortex using EM imaging", + "Axon detected within cortical samples under electron microscopy", + "Axon imaged in cortex through ultrastructural EM analysis", + "Axon identified in cortical tissue with electron microscopy", + "Axon present in cortex under electron microscopic observation" + ], + "instance_label": 1 + }, + "Microscopy_NIS3D": { + "1": [ + "Nuclei", + "Nuclei imaged in embryonic cells via light sheet microscopy", + "Nuclei visualized in embryonic tissue using microscopy", + "Nuclei detected in embryonic specimens with light sheet imaging", + "Nuclei observed in embryonic cells under LSM (Light Sheet Microscopy)", + "Nuclei appearing in embryonic cells during light sheet volumetric imaging", + "Nuclei captured in embryonic specimens using light sheet microscopy", + "Nuclei identified in embryonic tissue through microscopy" + ], + "instance_label": 1 + } +} \ No newline at end of file diff --git a/config_CT.json b/config_CT.json new file mode 100644 index 0000000000000000000000000000000000000000..f8b61bf5248aae5898667e96da0c9de8a76d40e6 --- /dev/null +++ b/config_CT.json @@ -0,0 +1,93 @@ +{ + "texts_soft_tissue": [ + "Aorta in whole body CT", + "gallbladder in whole body CT", + "left kidney in whole body CT", + "right kidney in whole body CT", + "liver in whole body CT", + "Pancreas in whole body CT", + "Spleen in whole body CT", + "stomach in whole body CT", + "Left adrenal gland in whole body CT", + "right adrenal gland in whole body CT", + "Bladder in whole body CT", + "Esophagus in whole body CT", + "Heart in whole body CT", + "Pulmonary vein in whole body CT", + "Brachiocephalic trunk in whole body CT", + "Right subclavian artery in whole body CT", + "Left subclavian artery in whole body CT", + "Right common carotid artery in whole body CT", + "Left common carotid artery in whole body CT", + "Left brachiocephalic vein in whole body CT", + "Right brachiocephalic vein in whole body CT", + "Left atrial appendage in whole body CT", + "Superior vena cava in whole body CT", + "Inferior vena cava in whole body CT", + "Portal vein and splenic vein in whole body CT", + "Left iliac artery in whole body CT", + "Right iliac artery in whole body CT", + "Left iliac vena in whole body CT", + "Right iliac vena in whole body CT", + "Spinal cord in whole body CT", + "Left gluteus Maximus in whole body CT", + "Right gluteus Maximus in whole body CT", + "Left gluteus Medius in whole body CT", + "Right gluteus Medius in whole body CT", + "Left gluteus Minimus in whole body CT", + "Right gluteus Minimus in whole body CT", + "Left autochthon in whole body CT", + "Right autochthon in whole body CT", + "Left iliopsoas in whole body CT", + "Right iliopsoas in whole body CT" + ], + "texts_bone": [ + "Vertebrae C7 in whole body CT", + "Vertebrae C6 in whole body CT", + "Vertebrae C5 in whole body CT", + "Vertebrae C4 in whole body CT", + "Vertebrae C3 in whole body CT", + "Vertebrae C2 in whole body CT", + "Vertebrae C1 in whole body CT", + "Vertebrae T12 in whole body CT", + "Vertebrae T11 in whole body CT", + "Vertebrae T10 in whole body CT", + "Vertebrae T9 in whole body CT", + "Vertebrae T8 in whole body CT", + "Vertebrae T7 in whole body CT", + "Vertebrae T6 in whole body CT", + "Vertebrae T5 in whole body CT", + "Vertebrae T4 in whole body CT", + "Vertebrae T3 in whole body CT", + "Vertebrae T2 in whole body CT", + "Vertebrae T1 in whole body CT", + "Left humerus in whole body CT", + "Right humerus in whole body CT", + "Left clavicula in whole body CT", + "Right clavicula in whole body CT", + "Left femur in whole body CT", + "Right femur in whole body CT", + "Left hip in whole body CT", + "Right hip in whole body CT" + ], + "texts_lung": [ + "Left lung in whole body CT", + "Right lung in whole body CT" + ], + "window_settings": { + "soft_tissue": { + "window_level": 40, + "window_width": 400 + }, + "bone": { + "window_level": 500, + "window_width": 1500 + }, + "lung": { + "window_level": -600, + "window_width": 1500 + } + }, + "modality": "CT", + "instance_label": 0 +} diff --git a/config_nonCT.json b/config_nonCT.json new file mode 100644 index 0000000000000000000000000000000000000000..54f5eca3963dd9283221f704dd01ece60fb2cc18 --- /dev/null +++ b/config_nonCT.json @@ -0,0 +1,13 @@ +{ + "texts": [ + "Spleen in MRI" + ], + "normalization_settings": { + "percentile_lower": 0.5, + "percentile_upper": 99.5, + "preserve_zero": true + }, + "modality": "MRI", + "instance_label": 0 +} + diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data/default_resampling.py b/data/default_resampling.py new file mode 100644 index 0000000000000000000000000000000000000000..057c11f8b1b24dd874c7ab4a39ff5bfe92fcd5a5 --- /dev/null +++ b/data/default_resampling.py @@ -0,0 +1,208 @@ +from collections import OrderedDict +from copy import deepcopy +from typing import Union, Tuple, List + +import numpy as np +import pandas as pd +import sklearn +import torch +from batchgenerators.augmentations.utils import resize_segmentation +from scipy.ndimage import map_coordinates +from skimage.transform import resize + +ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low + # resolution axis must be 3x as large as the next largest spacing) + +def get_do_separate_z(spacing: Union[Tuple[float, ...], List[float], np.ndarray], anisotropy_threshold=ANISO_THRESHOLD): + do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold + return do_separate_z + + +def get_lowres_axis(new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]): + axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic + return axis + + +def compute_new_shape(old_shape: Union[Tuple[int, ...], List[int], np.ndarray], + old_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> np.ndarray: + assert len(old_spacing) == len(old_shape) + assert len(old_shape) == len(new_spacing) + new_shape = np.array([int(round(i / j * k)) for i, j, k in zip(old_spacing, new_spacing, old_shape)]) + return new_shape + + + +def determine_do_sep_z_and_axis( + force_separate_z: bool, + current_spacing, + new_spacing, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD) -> Tuple[bool, Union[int, None]]: + if force_separate_z is not None: + do_separate_z = force_separate_z + if force_separate_z: + axis = get_lowres_axis(current_spacing) + else: + axis = None + else: + if get_do_separate_z(current_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(current_spacing) + elif get_do_separate_z(new_spacing, separate_z_anisotropy_threshold): + do_separate_z = True + axis = get_lowres_axis(new_spacing) + else: + do_separate_z = False + axis = None + + if axis is not None: + if len(axis) == 3: + do_separate_z = False + axis = None + elif len(axis) == 2: + # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample + # separately in the out of plane axis + do_separate_z = False + axis = None + else: + axis = axis[0] + return do_separate_z, axis + + +def resample_data_or_seg_to_spacing(data: np.ndarray, + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, + order: int = 3, order_z: int = 0, + force_separate_z: Union[bool, None] = False, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): + do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, + separate_z_anisotropy_threshold) + + if data is not None: + assert data.ndim == 4, "data must be c x y z" + + shape = np.array(data.shape) + new_shape = compute_new_shape(shape[1:], current_spacing, new_spacing) + + data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z) + return data_reshaped + + +def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray], + new_shape: Union[Tuple[int, ...], List[int], np.ndarray], + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, + order: int = 3, order_z: int = 0, + force_separate_z: Union[bool, None] = False, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD): + """ + needed for segmentation export. Stupid, I know + """ + if isinstance(data, torch.Tensor): + data = data.numpy() + + do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, + separate_z_anisotropy_threshold) + + if data is not None: + assert data.ndim == 4, "data must be c x y z" + + data_reshaped = resample_data_or_seg(data, new_shape, is_seg, axis, order, do_separate_z, order_z=order_z) + return data_reshaped + + +def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, axis: Union[None, int] = None, order: int = 3, + do_separate_z: bool = False, order_z: int = 0, dtype_out = None): + """ + separate_z=True will resample with order 0 along z + :param data: + :param new_shape: + :param is_seg: + :param axis: + :param order: + :param do_separate_z: + :param order_z: only applies if do_separate_z is True + :return: + """ + assert data.ndim == 4, "data must be (c, x, y, z)" + assert len(new_shape) == data.ndim - 1 + + if is_seg: + resize_fn = resize_segmentation + kwargs = OrderedDict() + else: + resize_fn = resize + kwargs = {'mode': 'edge', 'anti_aliasing': False} + shape = np.array(data[0].shape) + new_shape = np.array(new_shape) + if dtype_out is None: + dtype_out = data.dtype + reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out) + if np.any(shape != new_shape): + data = data.astype(float, copy=False) + if do_separate_z: + # print("separate z, order in z is", order_z, "order inplane is", order) + assert axis is not None, 'If do_separate_z, we need to know what axis is anisotropic' + if axis == 0: + new_shape_2d = new_shape[1:] + elif axis == 1: + new_shape_2d = new_shape[[0, 2]] + else: + new_shape_2d = new_shape[:-1] + + for c in range(data.shape[0]): + tmp = deepcopy(new_shape) + tmp[axis] = shape[axis] + reshaped_here = np.zeros(tmp) + for slice_id in range(shape[axis]): + if axis == 0: + reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs) + elif axis == 1: + reshaped_here[:, slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs) + else: + reshaped_here[:, :, slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs) + if shape[axis] != new_shape[axis]: + + # The following few lines are blatantly copied and modified from sklearn's resize() + rows, cols, dim = new_shape[0], new_shape[1], new_shape[2] + orig_rows, orig_cols, orig_dim = reshaped_here.shape + + # align_corners=False + row_scale = float(orig_rows) / rows + col_scale = float(orig_cols) / cols + dim_scale = float(orig_dim) / dim + + map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim] + map_rows = row_scale * (map_rows + 0.5) - 0.5 + map_cols = col_scale * (map_cols + 0.5) - 0.5 + map_dims = dim_scale * (map_dims + 0.5) - 0.5 + + coord_map = np.array([map_rows, map_cols, map_dims]) + if not is_seg or order_z == 0: + reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[None] + else: + unique_labels = np.sort(pd.unique(reshaped_here.ravel())) # np.unique(reshaped_data) + for i, cl in enumerate(unique_labels): + reshaped_final[c][np.round( + map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z, + mode='nearest')) > 0.5] = cl + else: + reshaped_final[c] = reshaped_here + else: + # print("no separate z, order", order) + for c in range(data.shape[0]): + reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs) + return reshaped_final + else: + # print("no resampling necessary") + return data + + +if __name__ == '__main__': + input_array = np.random.random((1, 42, 231, 142)) + output_shape = (52, 256, 256) + out = resample_data_or_seg(input_array, output_shape, is_seg=False, axis=3, order=1, order_z=0, do_separate_z=True) + print(out.shape, input_array.shape) \ No newline at end of file diff --git a/data/resample_torch.py b/data/resample_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..a7fd804c2e31fded4ffb5e05c8ff6ae2953c8ec8 --- /dev/null +++ b/data/resample_torch.py @@ -0,0 +1,162 @@ +from copy import deepcopy +from typing import Union, Tuple, List + +import numpy as np +import torch +from einops import rearrange +from torch.nn import functional as F + +from data.default_resampling import determine_do_sep_z_and_axis + +ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low + # resolution axis must be 3x as large as the next largest spacing) + +def resample_torch_simple( + data: Union[torch.Tensor, np.ndarray], + new_shape: Union[Tuple[int, ...], List[int], np.ndarray], + is_seg: bool = False, + num_threads: int = 4, + device: torch.device = torch.device('cpu'), + memefficient_seg_resampling: bool = False, + mode='linear' +): + if mode == 'linear': + if data.ndim == 4: + torch_mode = 'trilinear' + elif data.ndim == 3: + torch_mode = 'bilinear' + else: + raise RuntimeError + else: + torch_mode = mode + + if isinstance(new_shape, np.ndarray): + new_shape = [int(i) for i in new_shape] + + if all([i == j for i, j in zip(new_shape, data.shape[1:])]): + return data + else: + n_threads = torch.get_num_threads() + torch.set_num_threads(num_threads) + new_shape = tuple(new_shape) + with torch.no_grad(): + + input_was_numpy = isinstance(data, np.ndarray) + if input_was_numpy: + data = torch.from_numpy(data).to(device) + else: + orig_device = deepcopy(data.device) + data = data.to(device) + + if is_seg: + unique_values = torch.unique(data) + result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16 + result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device) + if not memefficient_seg_resampling: + # believe it or not, the implementation below is 3x as fast (at least on Liver CT and on CPU) + # Why? Because argmax is slow. The implementation below immediately sets most locations and only lets the + # uncertain ones be determined by argmax + + # unique_values = torch.unique(data) + # result = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16) + # for i, u in enumerate(unique_values): + # result[i] = F.interpolate((data[None] == u).float() * 1000, new_shape, mode='trilinear', antialias=False)[0] + # result = unique_values[result.argmax(0)] + + result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16, + device=device) + scale_factor = 1000 + done_mask = torch.zeros_like(result, dtype=torch.bool, device=device) + for i, u in enumerate(unique_values): + result_tmp[i] = \ + F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode, + antialias=False)[0] + mask = result_tmp[i] > (0.7 * scale_factor) + result[mask] = u.item() + done_mask |= mask + if not torch.all(done_mask): + # print('resolving argmax', torch.sum(~done_mask), "voxels to go") + result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype) + else: + for i, u in enumerate(unique_values): + if u == 0: + pass + result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[ + 0] > 0.5] = u + else: + result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0] + if input_was_numpy: + result = result.cpu().numpy() + else: + result = result.to(orig_device) + torch.set_num_threads(n_threads) + return result + + +def resample_torch_fornnunet( + data: Union[torch.Tensor, np.ndarray], + new_shape: Union[Tuple[int, ...], List[int], np.ndarray], + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, + num_threads: int = 4, + device: torch.device = torch.device('cpu'), + memefficient_seg_resampling: bool = False, + force_separate_z: Union[bool, None] = None, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD, + mode='linear', + aniso_axis_mode='nearest-exact' +): + """ + data must be c, x, y, z + """ + assert data.ndim == 4, "data must be c, x, y, z" + new_shape = [int(i) for i in new_shape] + orig_shape = data.shape + + do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, + separate_z_anisotropy_threshold) + # print('shape', data.shape, 'current_spacing', current_spacing, 'new_spacing', new_spacing, 'do_separate_z', do_separate_z, 'axis', axis) + + if do_separate_z: + was_numpy = isinstance(data, np.ndarray) + if was_numpy: + data = torch.from_numpy(data) + + if isinstance(axis, list): + assert len(axis) == 1 + axis = axis[0] + else: + pass + + tmp = "xyz" + axis_letter = tmp[axis] + others_int = [i for i in range(3) if i != axis] + others = [tmp[i] for i in others_int] + + # reshape by overloading c channel + data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}") + + # reshape in-plane + tmp_new_shape = [new_shape[i] for i in others_int] + data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device, + memefficient_seg_resampling=memefficient_seg_resampling, mode=mode) + data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z", + **{ + axis_letter: orig_shape[axis + 1], + others[0]: tmp_new_shape[0], + others[1]: tmp_new_shape[1] + } + ) + # reshape out of plane w/ nearest + data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device, + memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode) + if was_numpy: + data = data.numpy() + return data + else: + return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling) + + +if __name__ == '__main__': + torch.set_num_threads(16) diff --git a/data/resampling_test.py b/data/resampling_test.py new file mode 100644 index 0000000000000000000000000000000000000000..26c688318e0726cf0ab8520fd5a52c2e6f417e58 --- /dev/null +++ b/data/resampling_test.py @@ -0,0 +1,593 @@ +from typing import Union, Tuple, List +import numpy as np +import torch +import torch.nn.functional as F +from einops import rearrange +import time +from copy import deepcopy +from default_resampling import determine_do_sep_z_and_axis +import psutil +import nibabel as nib +import os +from pathlib import Path + +ANISO_THRESHOLD = 3 + +def compute_new_shape(current_shape: Union[Tuple[int, ...], List[int], np.ndarray], + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + target_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> List[int]: + """Compute new shape based on spacing ratios.""" + current_shape = np.array(current_shape) + current_spacing = np.array(current_spacing) + target_spacing = np.array(target_spacing) + return [int(round(s * (cs / ts))) for s, cs, ts in zip(current_shape, current_spacing, target_spacing)] + +def optimized_3d_resample( + data: Union[torch.Tensor, np.ndarray], + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + target_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, + device: torch.device = torch.device('cpu'), + num_threads: int = 8, + chunk_size: int = 64, + force_separate_z: Union[bool, None] = None, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD, + preserve_range: bool = True +) -> Union[torch.Tensor, np.ndarray]: + """ + Optimized 3D image resampling with adaptive interpolation and chunked processing. + + Args: + data: Input 3D volume [C, D, H, W] or [D, H, W] + current_spacing: Current voxel spacing (z, y, x) + target_spacing: Target voxel spacing (z, y, x) + is_seg: Whether the input is a segmentation mask + device: Torch device for computation + num_threads: Number of threads for CPU operations + chunk_size: Size of chunks for large volume processing + force_separate_z: Force separate z resampling + separate_z_anisotropy_threshold: Threshold for anisotropic resampling + preserve_range: Preserve original value range for non-segmentation data + + Returns: + Resampled 3D volume + """ + print(f"\nStarting optimized_3d_resample with input shape: {data.shape}, is_seg: {is_seg}") + input_was_numpy = isinstance(data, np.ndarray) + if input_was_numpy: + data = torch.from_numpy(data).to(device) + else: + data = data.to(device) + print(f"Input converted to tensor on {device}, shape: {data.shape}") + + if data.ndim == 3: + data = data.unsqueeze(0) + assert data.ndim == 4, "Data must be 3D or 4D (C, D, H, W)" + + new_shape = compute_new_shape(data.shape[1:], current_spacing, target_spacing) + print(f"Computed new shape: {new_shape} from current_spacing: {current_spacing}, target_spacing: {target_spacing}") + + if all(i == j for i, j in zip(new_shape, data.shape[1:])): + print("No resampling needed, shapes identical.") + return data.cpu().numpy() if input_was_numpy else data + + mode = 'nearest' if is_seg else 'trilinear' + aniso_axis_mode = 'nearest-exact' if is_seg else 'linear' + print(f"Interpolation mode: {mode}, Anisotropic axis mode: {aniso_axis_mode}") + + do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, + target_spacing, separate_z_anisotropy_threshold) + print(f"Do separate Z: {do_separate_z}, Axis: {axis}") + + if preserve_range and not is_seg: + v_min, v_max = data.min(), data.max() + print(f"Preserving range for non-segmentation data: min={v_min.item():.4f}, max={v_max.item():.4f}") + + torch.set_num_threads(num_threads) + print(f"Set number of threads to {num_threads}") + + start_time = time.time() + if do_separate_z: + tmp = "xyz" + axis_letter = tmp[axis] + others_int = [i for i in range(3) if i != axis] + others = [tmp[i] for i in others_int] + print(f"Separate Z resampling along axis {axis_letter}, others: {others}") + + tmp_new_shape = [new_shape[i] for i in others_int] + print(f"First pass: Resampling to shape {tmp_new_shape} for axes {others}") + data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}") + print(f"Rearranged data shape: {data.shape}") + data = _chunked_resample(data, tmp_new_shape, mode, chunk_size, device, is_seg) + print(f"After first pass resampling, shape: {data.shape}") + + data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z", + **{axis_letter: data.shape[1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]}) + print(f"Rearranged back to shape: {data.shape}") + data = _chunked_resample(data, new_shape, aniso_axis_mode, chunk_size, device, is_seg) + print(f"After second pass resampling, final shape: {data.shape}") + else: + print(f"Direct resampling to shape: {new_shape}") + data = _chunked_resample(data, new_shape, mode, chunk_size, device, is_seg) + print(f"After direct resampling, final shape: {data.shape}") + resample_time = time.time() - start_time + print(f"Resampling completed in {resample_time:.3f}s") + + if is_seg: + unique_values = torch.unique(data) + result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16 + data = data.round().to(result_dtype) + print(f"Segmentation data rounded and converted to {result_dtype}, unique values: {unique_values.tolist()}") + + if preserve_range and not is_seg: + data = torch.clamp(data, v_min, v_max) + print(f"Clamped data to original range: min={v_min.item():.4f}, max={v_max.item():.4f}") + + output = data.cpu().numpy() if input_was_numpy else data + print(f"Output shape: {output.shape}, type: {type(output)}") + return output + +def _chunked_resample( + volume: torch.Tensor, + target_shape: Tuple[int, ...], + mode: str, + chunk_size: int, + device: torch.device, + is_seg: bool +) -> torch.Tensor: + """Chunked resampling for large volumes with adaptive chunk sizing.""" + print(f"\nStarting _chunked_resample with input shape: {volume.shape}, target shape: {target_shape}") + C, D, H, W = volume.shape + tD, tH, tW = target_shape + + # Adaptive chunk size based on available memory + if device.type == 'cpu': + available_memory = psutil.virtual_memory().available / 1024**2 # in MB + else: + total_memory = torch.cuda.get_device_properties(device).total_memory / 1024**2 # in MB + allocated_memory = torch.cuda.memory_allocated(device) / 1024**2 + available_memory = total_memory - allocated_memory + + mem_per_voxel = volume.element_size() * volume.nelement() / volume.numel() + target_voxel_count = C * tD * tH * tW + chunk_mem_ratio = 0.5 if device.type == 'cpu' else 0.3 + adaptive_chunk_size = max( + 32, + min(chunk_size, int((available_memory * chunk_mem_ratio / mem_per_voxel / C) ** (1/3))) + ) + + # Early return for small volumes + if D * H * W <= 128**3: + with torch.cuda.amp.autocast(enabled=not is_seg): + start_time = time.time() + # Cast to float for interpolation if is_seg and mode is nearest + input_tensor = volume.float() if is_seg and mode == 'nearest' else volume + result = F.interpolate( + input_tensor.unsqueeze(0), + size=target_shape, + mode=mode, + align_corners=False if mode != 'nearest' else None + ).squeeze(0) + # Convert back to original dtype for segmentation + if is_seg: + result = result.round().to(volume.dtype) + # print(f"Direct interpolation completed in {time.time() - start_time:.3f}s, output shape: {result.shape}") + return result + + result = torch.zeros((C, tD, tH, tW), device=device, dtype=volume.dtype) + + out_chunk_size = max(1, int(adaptive_chunk_size * min(tD/D, tH/H, tW/W))) + + for c in range(C): + for z in range(0, tD, out_chunk_size): + z_end = min(z + out_chunk_size, tD) + for y in range(0, tH, out_chunk_size): + y_end = min(y + out_chunk_size, tH) + for x in range(0, tW, out_chunk_size): + x_end = min(x + out_chunk_size, tW) + + in_z = max(0, int(z * D / tD) - 1) + in_z_end = min(D, int(z_end * D / tD) + 2) + in_y = max(0, int(y * H / tH) - 1) + in_y_end = min(H, int(y_end * H / tH) + 2) + in_x = max(0, int(x * W / tW) - 1) + in_x_end = min(W, int(x_end * W / tW) + 2) + + chunk = volume[c:c+1, in_z:in_z_end, in_y:in_y_end, in_x:in_x_end] + chunk_target = (z_end - z, y_end - y, x_end - x) + + with torch.cuda.amp.autocast(enabled=not is_seg): + start_time = time.time() + # Cast to float for interpolation if is_seg and mode is nearest + input_chunk = chunk.float() if is_seg and mode == 'nearest' else chunk + resampled_chunk = F.interpolate( + input_chunk.unsqueeze(0), + size=chunk_target, + mode=mode, + align_corners=False if mode != 'nearest' else None + ).squeeze(0) + # Convert back to original dtype for segmentation + if is_seg: + resampled_chunk = resampled_chunk.round().to(volume.dtype) + # print(f"Chunk interpolation completed in {time.time() - start_time:.3f}s, shape: {resampled_chunk.shape}") + + result[c, z:z_end, y:y_end, x:x_end] = resampled_chunk + del chunk, resampled_chunk + if device.type == 'cuda': + torch.cuda.empty_cache() + + return result + +def resample_torch_simple( + data: Union[torch.Tensor, np.ndarray], + new_shape: Union[Tuple[int, ...], List[int], np.ndarray], + is_seg: bool = False, + num_threads: int = 4, + device: torch.device = torch.device('cpu'), + memefficient_seg_resampling: bool = False, + mode: str = 'linear' +) -> Union[torch.Tensor, np.ndarray]: + if mode == 'linear': + torch_mode = 'trilinear' if data.ndim == 4 else 'bilinear' + else: + torch_mode = mode + + if isinstance(new_shape, np.ndarray): + new_shape = [int(i) for i in new_shape] + + if all([i == j for i, j in zip(new_shape, data.shape[1:])]): + return data + + n_threads = torch.get_num_threads() + torch.set_num_threads(num_threads) + new_shape = tuple(new_shape) + with torch.no_grad(): + input_was_numpy = isinstance(data, np.ndarray) + if input_was_numpy: + data = torch.from_numpy(data).to(device) + else: + orig_device = deepcopy(data.device) + data = data.to(device) + + if is_seg: + unique_values = torch.unique(data) + result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16 + result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device) + if not memefficient_seg_resampling: + result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16, + device=device) + scale_factor = 1000 + done_mask = torch.zeros_like(result, dtype=torch.bool, device=device) + for i, u in enumerate(unique_values): + result_tmp[i] = F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode, + antialias=False)[0] + mask = result_tmp[i] > (0.7 * scale_factor) + result[mask] = u.item() + done_mask |= mask + if not torch.all(done_mask): + result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype) + else: + for i, u in enumerate(unique_values): + if u == 0: + continue + result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[0] > 0.5] = u + else: + result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0] + + if input_was_numpy: + result = result.cpu().numpy() + else: + result = result.to(orig_device) + + torch.set_num_threads(n_threads) + return result + +def resample_torch_fornnunet( + data: Union[torch.Tensor, np.ndarray], + new_shape: Union[Tuple[int, ...], List[int], np.ndarray], + current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], + is_seg: bool = False, + num_threads: int = 4, + device: torch.device = torch.device('cpu'), + memefficient_seg_resampling: bool = False, + force_separate_z: Union[bool, None] = None, + separate_z_anisotropy_threshold: float = ANISO_THRESHOLD, + mode: str = 'linear', + aniso_axis_mode: str = 'nearest-exact' +) -> Union[torch.Tensor, np.ndarray]: + assert data.ndim == 4, "data must be c, x, y, z" + new_shape = [int(i) for i in new_shape] + orig_shape = data.shape + + do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, + separate_z_anisotropy_threshold) + + if do_separate_z: + was_numpy = isinstance(data, np.ndarray) + if was_numpy: + data = torch.from_numpy(data) + + if isinstance(axis, list): + axis = axis[0] + + tmp = "xyz" + axis_letter = tmp[axis] + others_int = [i for i in range(3) if i != axis] + others = [tmp[i] for i in others_int] + + data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}") + tmp_new_shape = [new_shape[i] for i in others_int] + data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device, + memefficient_seg_resampling=memefficient_seg_resampling, mode=mode) + data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z", + **{axis_letter: orig_shape[axis + 1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]}) + data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device, + memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode) + if was_numpy: + data = data.numpy() + return data + else: + return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling) + +def dice_score(pred: np.ndarray, true: np.ndarray) -> float: + """Compute Dice score for segmentation masks.""" + pred = pred.flatten() + true = true.flatten() + intersection = np.sum(pred * true) + return (2. * intersection) / (np.sum(pred) + np.sum(true) + 1e-8) + +# Placeholder for compute_new_shape if not provided +def compute_new_shape(original_shape, current_spacing, target_spacing): + """ + Compute the new shape based on the spacing ratio. + original_shape: (z, y, x) + current_spacing: (z, y, x) + target_spacing: (z, y, x) + """ + zoom_factors = [c / t for c, t in zip(current_spacing, target_spacing)] + new_shape = [int(round(s * z)) for s, z in zip(original_shape, zoom_factors)] + return tuple(new_shape) + +# Function to save as NIfTI +def save_nii(array, spacing, output_path, is_seg=False): + """ + Save numpy array as NIfTI file with specified spacing. + is_seg: If True, convert to int32 for segmentation masks. + """ + # Convert torch tensor to numpy if necessary + if isinstance(array, torch.Tensor): + array = array.cpu().numpy() + + # Convert data type for NIfTI compatibility + if is_seg: + array = array.astype(np.int32) # Convert segmentation to int32 + else: + array = array.astype(np.float32) # Ensure image is float32 + + # Transpose to (X, Y, Z, C) for NIfTI + if array.ndim == 4: + array = array.transpose(2, 3, 1, 0) # From (C, Z, Y, X) to (X, Y, Z, C) + else: + array = array.transpose(2, 3, 1) # From (Z, Y, X) to (X, Y, Z) + + # Create NIfTI image with affine based on spacing + affine = np.diag(list(spacing) + [1.0]) + nii_img = nib.Nifti1Image(array, affine=affine) + nib.save(nii_img, output_path) + print(f"Saved: {output_path}") + +# Main resampling function +def main(): + torch.set_num_threads(4) + device = torch.device('cuda') #torch.device('cpu') # Force CPU as per provided code + print(f"\nRunning tests on device: {device}") + + # Define paths + npz_file_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/inputs/Microscopy_cremi_000_sc.npz" + gt_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/gts/Microscopy_cremi_000_sc.npz" + output_dir = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/workspace_teamx/outputs_test_resample" + + # Ensure output directory exists + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Load input data + data = np.load(npz_file_path, allow_pickle=True) + img_array = data['imgs'] # Shape: (C, Z, Y, X) or (Z, Y, X) + img_spacing = data['spacing'] # (z, y, x) + img_spacing = [1.0, 1.0, 1.0] # Override as per provided code + gt_data = np.load(gt_path, allow_pickle=True) + gt_array = gt_data['gts'] # Shape: (C, Z, Y, X) or (Z, Y, X) + + # Convert data types to PyTorch-compatible types + img_array = img_array.astype(np.float32) # Convert image to float32 + gt_array = gt_array.astype(np.int32) # Convert segmentation mask to int32 + + # Ensure img_array and gt_array have channel dimension + if img_array.ndim == 3: + img_array = img_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X) + if gt_array.ndim == 3: + gt_array = gt_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X) + + # Define target spacings to test + target_spacings = [ + (1.2, 1.2, 1.2), + (1.5, 1.5, 1.5), + (2.0, 2.0, 2.0), + ] + + # Original shape and spacing + original_shape = img_array.shape[1:] # (Z, Y, X) + current_spacing = img_spacing + print(f"\nOriginal image shape: {original_shape}, Current spacing (z,y,x): {current_spacing}") + + for target_spacing in target_spacings: + print(f"\n=== Resampling to Target Spacing: {target_spacing} ===") + + # Compute new shape + new_shape = compute_new_shape(original_shape, current_spacing, target_spacing) + print(f"Computed target shape: {new_shape}") + + # === Image Resampling === + print("\nResampling image...") + + # Ground truth resampling + print("Computing ground truth with resample_torch_simple...") + start_time = time.time() + if device.type == 'cuda': + torch.cuda.synchronize() # Ensure GPU operations are complete + gt_img = resample_torch_simple( + img_array, + new_shape=new_shape, + is_seg=False, + num_threads=4, + device=device + ) + if device.type == 'cuda': + torch.cuda.synchronize() # Ensure GPU operations are complete + gt_time = time.time() - start_time + output_path = os.path.join(output_dir, f"img_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") + print(f"Ground truth image shape: {gt_img.shape}, Time: {gt_time:.3f}s") + save_nii(gt_img, target_spacing, output_path, is_seg=False) + + # Optimized resampling + print("Running optimized_3d_resample...") + start_time = time.time() + if device.type == 'cuda': + torch.cuda.synchronize() + mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 + resampled_img_opt = optimized_3d_resample( + img_array, + current_spacing, + target_spacing, + is_seg=False, + device=device, + num_threads=4, + chunk_size=64 + ) + if device.type == 'cuda': + torch.cuda.synchronize() + + opt_time = time.time() - start_time + mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 + opt_mae = np.mean(np.abs(resampled_img_opt - gt_img)) + output_path = os.path.join(output_dir, f"img_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") + print(f"Optimized image shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, " + f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {opt_mae:.6f}") + save_nii(resampled_img_opt, target_spacing, output_path, is_seg=False) + + # Original resampling + print("Running resample_torch_fornnunet...") + start_time = time.time() + if device.type == 'cuda': + torch.cuda.synchronize() + mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 + resampled_img_orig = resample_torch_fornnunet( + img_array, + new_shape, + current_spacing, + target_spacing, + is_seg=False, + num_threads=4, + device=device + ) + if device.type == 'cuda': + torch.cuda.synchronize() + orig_time = time.time() - start_time + mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 + orig_mae = np.mean(np.abs(resampled_img_orig - gt_img)) + output_path = os.path.join(output_dir, f"img_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") + print(f"Original image shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, " + f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {orig_mae:.6f}") + save_nii(resampled_img_orig, target_spacing, output_path, is_seg=False) + + # === Segmentation Mask Resampling === + print("\nResampling segmentation mask...") + + # Ground truth resampling + print("Computing ground truth with resample_torch_simple...") + start_time = time.time() + if device.type == 'cuda': + torch.cuda.synchronize() + gt_seg = resample_torch_simple( + gt_array, + new_shape=new_shape, + is_seg=True, + num_threads=4, + device=device + ) + if device.type == 'cuda': + torch.cuda.synchronize() + gt_seg_time = time.time() - start_time + output_path = os.path.join(output_dir, f"seg_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") + print(f"Ground truth segmentation shape: {gt_seg.shape}, Time: {gt_seg_time:.3f}s") + save_nii(gt_seg, target_spacing, output_path, is_seg=True) + + # Optimized resampling + print("Running optimized_3d_resample for segmentation...") + start_time = time.time() + if device.type == 'cuda': + torch.cuda.synchronize() + mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 + resampled_seg_opt = optimized_3d_resample( + gt_array, + current_spacing, + target_spacing, + is_seg=True, + device=device, + num_threads=4, + chunk_size=64 + ) + if device.type == 'cuda': + torch.cuda.synchronize() + + opt_seg_time = time.time() - start_time + mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 + opt_dice = dice_score(resampled_seg_opt, gt_seg) + output_path = os.path.join(output_dir, f"seg_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") + print(f"Optimized segmentation shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, " + f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {opt_dice:.6f}") + save_nii(resampled_seg_opt, target_spacing, output_path, is_seg=True) + + # Original resampling + print("Running resample_torch_fornnunet for segmentation...") + start_time = time.time() + if device.type == 'cuda': + torch.cuda.synchronize() + mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 + resampled_seg_orig = resample_torch_fornnunet( + gt_array, + new_shape, + current_spacing, + target_spacing, + is_seg=True, + num_threads=4, + device=device + ) + if device.type == 'cuda': + torch.cuda.synchronize() + + orig_seg_time = time.time() - start_time + mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 + orig_dice = dice_score(resampled_seg_orig, gt_seg) + output_path = os.path.join(output_dir, f"seg_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") + print(f"Original segmentation shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, " + f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {orig_dice:.6f}") + save_nii(resampled_seg_orig, target_spacing, output_path, is_seg=True) + + # Summary + print(f"\n=== Summary for Target Spacing: {target_spacing} ===") + print("Image Resampling Metrics:") + print(f"Optimized - Shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, MAE: {opt_mae:.6f}") + print(f"Original - Shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, MAE: {orig_mae:.6f}") + print(f"Time Improvement: {(orig_time - opt_time) / orig_time * 100:.2f}%") + print(f"MAE Improvement: {(orig_mae - opt_mae) / orig_mae * 100:.2f}%") + print("Segmentation Mask Resampling Metrics:") + print(f"Optimized - Shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, Dice: {opt_dice:.6f}") + print(f"Original - Shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, Dice: {orig_dice:.6f}") + print(f"Time Improvement: {(orig_seg_time - opt_seg_time) / orig_seg_time * 100:.2f}%") + print(f"Dice Improvement: {(opt_dice - orig_dice) / orig_dice * 100:.2f}%") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..6a5d83afe54509a810470ff61e1794f051e3f2cb --- /dev/null +++ b/environment.yml @@ -0,0 +1,211 @@ +name: medals_local_test +channels: + - pytorch + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - aom=3.12.1=h7934f7d_0 + - blas=1.0=mkl + - brotlicffi=1.2.0.0=py310h7354ed3_0 + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2025.12.2=h06a4308_0 + - cairo=1.18.4=h44eff21_0 + - certifi=2025.11.12=py310h06a4308_0 + - cffi=2.0.0=py310h4eded50_1 + - charset-normalizer=3.4.4=py310h06a4308_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.9.19=0 + - cuda-runtime=12.1.0=0 + - cuda-version=12.9=3 + - dav1d=1.2.1=h5eee18b_0 + - expat=2.7.3=h7354ed3_4 + - ffmpeg=6.1.1=hecf7045_5 + - filelock=3.20.0=py310h06a4308_0 + - fontconfig=2.15.0=h2c49b7f_0 + - freetype=2.13.3=h4a9f257_0 + - fribidi=1.0.10=h7b6447c_0 + - giflib=5.2.2=h5eee18b_0 + - gmp=6.3.0=h6a678d5_0 + - gmpy2=2.2.2=py310ha78e65c_0 + - graphite2=1.3.14=h295c915_1 + - harfbuzz=10.2.0=hdfddeaa_1 + - icu=73.1=h6a678d5_0 + - idna=3.11=py310h06a4308_0 + - intel-openmp=2025.0.0=h06a4308_1171 + - jinja2=3.1.6=py310h06a4308_0 + - jpeg=9f=h5ce9db8_0 + - lame=3.100=h7b6447c_0 + - lcms2=2.17=heab6991_0 + - ld_impl_linux-64=2.44=h153f514_2 + - leptonica=1.82.0=hfdeec58_3 + - lerc=4.0.0=h6a678d5_0 + - libarchive=3.8.2=h3ec8f01_0 + - libavif=1.3.0=h3539ee5_0 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.14.1.1=4 + - libcurand=10.3.10.19=0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdeflate=1.22=h5eee18b_0 + - libexpat=2.7.3=h7354ed3_4 + - libffi=3.4.4=h6a678d5_1 + - libgcc=15.2.0=h69a1729_7 + - libgcc-ng=15.2.0=h166f726_7 + - libglib=2.84.4=h77a78f3_0 + - libgomp=15.2.0=h4751f2c_7 + - libhwloc=2.12.1=default_hf1bbc79_1000 + - libiconv=1.16=h5eee18b_3 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libnpp=12.0.2.50=0 + - libnsl=2.0.0=h5eee18b_0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libogg=1.3.5=h27cfd23_1 + - libopenjpeg=2.5.4=hee96239_1 + - libopus=1.3.1=h5eee18b_1 + - libpng=1.6.50=h2ed474d_0 + - libstdcxx=15.2.0=h39759b7_7 + - libstdcxx-ng=15.2.0=hc03a8fd_7 + - libtheora=1.2.0=h32ad74f_1 + - libtiff=4.7.1=h029b1ac_0 + - libuuid=1.41.5=h5eee18b_0 + - libvorbis=1.3.7=h7b6447c_0 + - libvpx=1.15.2=h4cb591d_0 + - libwebp=1.6.0=h089d785_0 + - libwebp-base=1.6.0=hb7bb969_0 + - libxcb=1.17.0=h9b100fa_0 + - libxml2=2.13.9=h2c43086_0 + - libzlib=1.3.1=hb25bd0a_0 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_1 + - markupsafe=3.0.2=py310h5eee18b_0 + - mkl=2025.0.0=hacee8c2_941 + - mkl-service=2.5.2=py310hacdc0fc_0 + - mkl_fft=2.1.1=py310h8fe796d_0 + - mkl_random=1.3.0=py310h505adc9_0 + - mpc=1.3.1=h5eee18b_0 + - mpfr=4.2.1=h5eee18b_0 + - mpmath=1.3.0=py310h06a4308_0 + - ncurses=6.5=h7934f7d_0 + - networkx=3.4.2=py310h06a4308_0 + - ocl-icd=2.3.3=h47b2149_0 + - opencl-headers=2025.07.22=hfb20e49_0 + - openh264=2.6.0=he621ea3_0 + - openjpeg=2.5.4=h4e0627c_1 + - openssl=3.0.18=hd6dcaed_0 + - pcre2=10.46=hf426167_0 + - pillow=12.0.0=py310h3b88751_1 + - pip=25.3=pyhc872135_0 + - pixman=0.46.4=h7934f7d_0 + - pthread-stubs=0.3=h0ce48e5_1 + - pycparser=2.23=py310h06a4308_0 + - pysocks=1.7.1=py310h06a4308_1 + - python=3.10.19=h6fa692b_0 + - pytorch-cuda=12.1=ha16c6d3_6 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0.3=py310h591646f_0 + - readline=8.3=hc2a1206_0 + - requests=2.32.5=py310h06a4308_1 + - setuptools=80.9.0=py310h06a4308_0 + - sqlite=3.51.0=h2a70700_0 + - sympy=1.14.0=py310h06a4308_1 + - tbb=2022.3.0=h698db13_0 + - tbb-devel=2022.3.0=h698db13_0 + - tesseract=5.2.0=hb0d2e87_3 + - tk=8.6.15=h54e0aa7_0 + - typing_extensions=4.15.0=py310h06a4308_0 + - urllib3=2.6.1=py310h06a4308_0 + - wheel=0.45.1=py310h06a4308_0 + - xorg-libx11=1.8.12=h9b100fa_1 + - xorg-libxau=1.0.12=h9b100fa_0 + - xorg-libxdmcp=1.1.5=h9b100fa_0 + - xorg-libxext=1.3.6=h9b100fa_0 + - xorg-libxrender=0.9.12=h9b100fa_0 + - xorg-xorgproto=2024.1=h5eee18b_1 + - xz=5.6.4=h5eee18b_1 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.3.1=hb25bd0a_0 + - zstd=1.5.7=h11fc155_0 + - pip: + - acvl-utils==0.2.5 + - argparse==1.4.0 + - batchgenerators==0.25.1 + - blosc2==3.12.2 + - connected-components-3d==3.26.1 + - contourpy==1.3.2 + - cycler==0.12.1 + - dicom2nifti==2.6.2 + - dynamic-network-architectures==0.2 + - einops==0.8.1 + - fonttools==4.61.1 + - fsspec==2025.12.0 + - future==1.0.0 + - hf-xet==1.2.0 + - huggingface-hub==0.36.0 + - imagecodecs==2025.3.30 + - imageio==2.37.2 + - importlib-resources==6.5.2 + - joblib==1.5.3 + - kiwisolver==1.4.9 + - lazy-loader==0.4 + - linecache2==1.0.0 + - matplotlib==3.10.8 + - monai==1.4.0 + - msgpack==1.1.2 + - ndindex==1.10.1 + - nibabel==5.3.2 + - nnunetv2==2.4.1 + - numexpr==2.14.1 + - numpy==1.26.4 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==8.9.2.26 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.19.3 + - nvidia-nvjitlink-cu12==12.9.86 + - nvidia-nvtx-cu12==12.1.105 + - packaging==25.0 + - pandas==2.3.3 + - platformdirs==4.5.1 + - positional-encodings==6.0.3 + - py-cpuinfo==9.0.0 + - pydicom==3.0.1 + - pyparsing==3.2.5 + - python-dateutil==2.9.0.post0 + - python-gdcm==3.2.2 + - python-graphviz==0.21 + - pytz==2025.2 + - regex==2025.11.3 + - safetensors==0.7.0 + - scikit-image==0.25.2 + - scikit-learn==1.7.2 + - scipy==1.15.3 + - seaborn==0.13.2 + - simpleitk==2.5.3 + - six==1.17.0 + - threadpoolctl==3.6.0 + - tifffile==2025.5.10 + - tokenizers==0.21.4 + - torch==2.2.0+cu121 + - torchaudio==2.2.0+cu121 + - torchvision==0.17.0+cu121 + - tqdm==4.67.1 + - traceback2==1.4.0 + - transformers==4.51.3 + - triton==2.2.0 + - tzdata==2025.3 + - unittest2==1.1.0 + - yacs==0.1.8 +prefix: /yinghepool/shipengcheng/.conda/envs/medals_local_test diff --git a/evaluate/SurfaceDice.py b/evaluate/SurfaceDice.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc48e269223d1c1352e9ac31d0445788aefb018 --- /dev/null +++ b/evaluate/SurfaceDice.py @@ -0,0 +1,492 @@ +import numpy as np +import scipy.ndimage + +# neighbour_code_to_normals is a lookup table. +# For every binary neighbour code +# (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes) +# it contains the surface normals of the triangles (called "surfel" for +# "surface element" in the following). The length of the normal +# vector encodes the surfel area. +# +# created by compute_surface_area_lookup_table.ipynb using the +# marching_cube algorithm, see e.g. https://en.wikipedia.org/wiki/Marching_cubes +# credit to: http://medicaldecathlon.com/files/Surface_distance_based_measures.ipynb +neighbour_code_to_normals = [ + [[0,0,0]], + [[0.125,0.125,0.125]], + [[-0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[0.125,-0.125,0.125]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25]], + [[0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[-0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,0.125,0.125]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25]], + [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]], + [[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[0.5,0.0,0.0],[0.5,0.0,0.0]], + [[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25]], + [[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]], + [[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]], + [[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]], + [[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]], + [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]], + [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]], + [[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]], + [[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]], + [[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]], + [[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]], + [[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]], + [[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0]], + [[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]], + [[0.0,0.5,0.0],[0.0,-0.5,0.0]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]], + [[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]], + [[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]], + [[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]], + [[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]], + [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]], + [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]], + [[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25]], + [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]], + [[-0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]], + [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]], + [[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]], + [[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]], + [[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]], + [[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]], + [[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[-0.0,0.0,0.5],[0.0,0.0,0.5]], + [[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]], + [[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]], + [[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]], + [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]], + [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[-0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]], + [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,-0.125]], + [[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]], + [[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]], + [[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]], + [[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]], + [[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]], + [[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,0.125,0.125]], + [[0.125,0.125,0.125]], + [[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[0.25,0.25,-0.0],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.125,0.125,0.125]], + [[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]], + [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125],[0.125,0.125,0.125]], + [[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[0.125,0.125,0.125],[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.0,0.0,0.5]], + [[-0.375,-0.375,0.375],[0.25,-0.25,0.0],[0.0,0.25,0.25],[-0.125,-0.125,0.125]], + [[0.0,-0.5,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[0.375,-0.375,0.375],[0.0,0.25,0.25],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]], + [[-0.25,0.25,0.25],[-0.125,0.125,0.125],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.5,0.0,-0.0],[0.25,-0.25,-0.25],[0.125,-0.125,-0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.125,0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125],[0.125,0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[0.0,0.25,0.25],[0.0,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[0.0,0.25,0.25],[0.0,0.25,0.25]], + [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[0.125,0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.25,0.0,0.25],[-0.375,-0.375,0.375],[-0.25,0.25,0.0],[-0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25],[0.25,0.0,0.25],[0.25,0.0,0.25]], + [[-0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[-0.25,0.0,0.25],[0.25,0.0,-0.25],[-0.125,0.125,0.125]], + [[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]], + [[0.125,0.125,0.125],[0.125,0.125,0.125],[0.25,0.25,0.25],[0.0,0.0,0.5]], + [[-0.0,0.0,0.5],[0.0,0.0,0.5]], + [[0.0,0.0,-0.5],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[-0.25,-0.0,-0.25],[-0.375,0.375,0.375],[-0.25,-0.25,0.0],[-0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[-0.0,0.0,0.5],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.25,0.0,0.25],[0.25,0.0,-0.25]], + [[0.5,0.0,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]], + [[-0.25,0.0,-0.25],[0.375,-0.375,-0.375],[0.0,0.25,-0.25],[-0.125,0.125,0.125]], + [[-0.25,0.25,-0.25],[-0.25,0.25,-0.25],[-0.125,0.125,-0.125],[-0.125,0.125,-0.125]], + [[-0.0,0.5,0.0],[-0.25,0.25,-0.25],[0.125,-0.125,0.125]], + [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.25,0.0,0.25],[-0.25,0.0,0.25]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[0.375,-0.375,0.375],[0.0,-0.25,-0.25],[-0.125,0.125,-0.125],[0.25,0.25,0.0]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.0,0.0,0.5],[0.25,-0.25,0.25],[0.125,-0.125,0.125]], + [[0.0,-0.25,0.25],[0.0,-0.25,0.25]], + [[-0.125,-0.125,0.125],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]], + [[0.125,0.125,0.125],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125]], + [[-0.375,0.375,-0.375],[-0.25,-0.25,0.0],[-0.125,0.125,-0.125],[-0.25,0.0,0.25]], + [[0.0,0.5,0.0],[0.25,0.25,-0.25],[-0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.125,0.125,0.125],[0.25,-0.25,0.0],[-0.25,0.25,0.0]], + [[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]], + [[0.125,0.125,0.125],[0.0,-0.5,0.0],[-0.25,-0.25,-0.25],[-0.125,-0.125,-0.125]], + [[-0.375,-0.375,-0.375],[-0.25,0.0,0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0],[0.125,-0.125,0.125]], + [[0.0,0.5,0.0],[0.0,-0.5,0.0]], + [[0.0,0.5,0.0],[0.125,-0.125,0.125],[-0.25,0.25,-0.25]], + [[0.0,0.5,0.0],[-0.25,0.25,0.25],[0.125,-0.125,-0.125]], + [[0.25,-0.25,0.0],[-0.25,0.25,0.0]], + [[-0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.0,0.25,-0.25],[0.375,-0.375,-0.375],[-0.125,0.125,0.125],[0.25,0.25,0.0]], + [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125],[0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.25,0.25,-0.25],[0.25,0.25,-0.25],[0.125,0.125,-0.125],[-0.125,-0.125,0.125]], + [[-0.0,0.0,0.5],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125],[-0.125,0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[-0.375,-0.375,0.375],[-0.0,0.25,0.25],[0.125,0.125,-0.125],[-0.25,-0.0,-0.25]], + [[0.0,-0.25,0.25],[0.0,0.25,-0.25],[0.125,-0.125,0.125]], + [[0.125,-0.125,0.125],[-0.25,-0.0,-0.25],[0.25,0.0,0.25]], + [[0.125,-0.125,0.125],[0.125,-0.125,0.125]], + [[0.0,-0.5,0.0],[0.125,0.125,-0.125],[0.25,0.25,-0.25]], + [[0.0,-0.25,0.25],[0.0,0.25,-0.25]], + [[0.125,0.125,0.125],[0.125,-0.125,0.125]], + [[0.125,-0.125,0.125]], + [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25]], + [[-0.5,0.0,0.0],[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.125,0.125,0.125]], + [[0.375,0.375,0.375],[0.0,0.25,-0.25],[-0.125,-0.125,-0.125],[-0.25,0.25,0.0]], + [[0.125,-0.125,-0.125],[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.125,0.125,0.125],[0.375,0.375,0.375],[0.0,-0.25,0.25],[-0.25,0.0,0.25]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25],[-0.125,0.125,0.125]], + [[-0.125,0.125,0.125],[0.125,-0.125,-0.125]], + [[-0.125,-0.125,-0.125],[-0.25,-0.25,-0.25],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,0.0,-0.5],[0.25,0.25,0.25],[-0.125,-0.125,-0.125]], + [[0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.5,0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[-0.125,-0.125,0.125],[0.125,-0.125,-0.125]], + [[0.0,-0.25,-0.25],[0.0,0.25,0.25]], + [[0.125,-0.125,-0.125]], + [[0.5,0.0,0.0],[0.5,0.0,0.0]], + [[-0.5,0.0,0.0],[-0.25,0.25,0.25],[-0.125,0.125,0.125]], + [[0.5,0.0,0.0],[0.25,-0.25,0.25],[-0.125,0.125,-0.125]], + [[0.25,-0.25,0.0],[0.25,-0.25,0.0]], + [[0.5,0.0,0.0],[-0.25,-0.25,0.25],[-0.125,-0.125,0.125]], + [[-0.25,0.0,0.25],[-0.25,0.0,0.25]], + [[0.125,0.125,0.125],[-0.125,0.125,0.125]], + [[-0.125,0.125,0.125]], + [[0.5,0.0,-0.0],[0.25,0.25,0.25],[0.125,0.125,0.125]], + [[0.125,-0.125,0.125],[-0.125,-0.125,0.125]], + [[-0.25,-0.0,-0.25],[0.25,0.0,0.25]], + [[0.125,-0.125,0.125]], + [[-0.25,-0.25,0.0],[0.25,0.25,-0.0]], + [[-0.125,-0.125,0.125]], + [[0.125,0.125,0.125]], + [[0,0,0]]] + + +def compute_surface_distances(mask_gt, mask_pred, spacing_mm): + """Compute closest distances from all surface points to the other surface. + + Finds all surface elements "surfels" in the ground truth mask `mask_gt` and + the predicted mask `mask_pred`, computes their area in mm^2 and the distance + to the closest point on the other surface. It returns two sorted lists of + distances together with the corresponding surfel areas. If one of the masks + is empty, the corresponding lists are empty and all distances in the other + list are `inf` + + Args: + mask_gt: 3-dim Numpy array of type bool. The ground truth mask. + mask_pred: 3-dim Numpy array of type bool. The predicted mask. + spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2 + direction + + Returns: + A dict with + "distances_gt_to_pred": 1-dim numpy array of type float. The distances in mm + from all ground truth surface elements to the predicted surface, + sorted from smallest to largest + "distances_pred_to_gt": 1-dim numpy array of type float. The distances in mm + from all predicted surface elements to the ground truth surface, + sorted from smallest to largest + "surfel_areas_gt": 1-dim numpy array of type float. The area in mm^2 of + the ground truth surface elements in the same order as + distances_gt_to_pred + "surfel_areas_pred": 1-dim numpy array of type float. The area in mm^2 of + the predicted surface elements in the same order as + distances_pred_to_gt + + """ + + # compute the area for all 256 possible surface elements + # (given a 2x2x2 neighbourhood) according to the spacing_mm + neighbour_code_to_surface_area = np.zeros([256]) + for code in range(256): + normals = np.array(neighbour_code_to_normals[code]) + sum_area = 0 + for normal_idx in range(normals.shape[0]): + # normal vector + n = np.zeros([3]) + n[0] = normals[normal_idx,0] * spacing_mm[1] * spacing_mm[2] + n[1] = normals[normal_idx,1] * spacing_mm[0] * spacing_mm[2] + n[2] = normals[normal_idx,2] * spacing_mm[0] * spacing_mm[1] + area = np.linalg.norm(n) + sum_area += area + neighbour_code_to_surface_area[code] = sum_area + + # compute the bounding box of the masks to trim + # the volume to the smallest possible processing subvolume + mask_all = mask_gt | mask_pred + bbox_min = np.zeros(3, np.int64) + bbox_max = np.zeros(3, np.int64) + + # max projection to the x0-axis + proj_0 = np.max(np.max(mask_all, axis=2), axis=1) + idx_nonzero_0 = np.nonzero(proj_0)[0] + if len(idx_nonzero_0) == 0: + return {"distances_gt_to_pred": np.array([]), + "distances_pred_to_gt": np.array([]), + "surfel_areas_gt": np.array([]), + "surfel_areas_pred": np.array([])} + + bbox_min[0] = np.min(idx_nonzero_0) + bbox_max[0] = np.max(idx_nonzero_0) + + # max projection to the x1-axis + proj_1 = np.max(np.max(mask_all, axis=2), axis=0) + idx_nonzero_1 = np.nonzero(proj_1)[0] + bbox_min[1] = np.min(idx_nonzero_1) + bbox_max[1] = np.max(idx_nonzero_1) + + # max projection to the x2-axis + proj_2 = np.max(np.max(mask_all, axis=1), axis=0) + idx_nonzero_2 = np.nonzero(proj_2)[0] + bbox_min[2] = np.min(idx_nonzero_2) + bbox_max[2] = np.max(idx_nonzero_2) + +# print("bounding box min = {}".format(bbox_min)) +# print("bounding box max = {}".format(bbox_max)) + + # crop the processing subvolume. + # we need to zeropad the cropped region with 1 voxel at the lower, + # the right and the back side. This is required to obtain the "full" + # convolution result with the 2x2x2 kernel + cropmask_gt = np.zeros((bbox_max - bbox_min)+2, np.uint8) + cropmask_pred = np.zeros((bbox_max - bbox_min)+2, np.uint8) + + cropmask_gt[0:-1, 0:-1, 0:-1] = mask_gt[bbox_min[0]:bbox_max[0]+1, + bbox_min[1]:bbox_max[1]+1, + bbox_min[2]:bbox_max[2]+1] + + cropmask_pred[0:-1, 0:-1, 0:-1] = mask_pred[bbox_min[0]:bbox_max[0]+1, + bbox_min[1]:bbox_max[1]+1, + bbox_min[2]:bbox_max[2]+1] + + # compute the neighbour code (local binary pattern) for each voxel + # the resultsing arrays are spacially shifted by minus half a voxel in each axis. + # i.e. the points are located at the corners of the original voxels + kernel = np.array([[[128,64], + [32,16]], + [[8,4], + [2,1]]]) + neighbour_code_map_gt = scipy.ndimage.filters.correlate(cropmask_gt.astype(np.uint8), kernel, mode="constant", cval=0) + neighbour_code_map_pred = scipy.ndimage.filters.correlate(cropmask_pred.astype(np.uint8), kernel, mode="constant", cval=0) + + # create masks with the surface voxels + borders_gt = ((neighbour_code_map_gt != 0) & (neighbour_code_map_gt != 255)) + borders_pred = ((neighbour_code_map_pred != 0) & (neighbour_code_map_pred != 255)) + + # compute the distance transform (closest distance of each voxel to the surface voxels) + if borders_gt.any(): + distmap_gt = scipy.ndimage.morphology.distance_transform_edt(~borders_gt, sampling=spacing_mm) + else: + distmap_gt = np.Inf * np.ones(borders_gt.shape) + + if borders_pred.any(): + distmap_pred = scipy.ndimage.morphology.distance_transform_edt(~borders_pred, sampling=spacing_mm) + else: + distmap_pred = np.Inf * np.ones(borders_pred.shape) + + # compute the area of each surface element + surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt] + surface_area_map_pred = neighbour_code_to_surface_area[neighbour_code_map_pred] + + # create a list of all surface elements with distance and area + distances_gt_to_pred = distmap_pred[borders_gt] + distances_pred_to_gt = distmap_gt[borders_pred] + surfel_areas_gt = surface_area_map_gt[borders_gt] + surfel_areas_pred = surface_area_map_pred[borders_pred] + + # sort them by distance + if distances_gt_to_pred.shape != (0,): + sorted_surfels_gt = np.array(sorted(zip(distances_gt_to_pred, surfel_areas_gt))) + distances_gt_to_pred = sorted_surfels_gt[:,0] + surfel_areas_gt = sorted_surfels_gt[:,1] + + if distances_pred_to_gt.shape != (0,): + sorted_surfels_pred = np.array(sorted(zip(distances_pred_to_gt, surfel_areas_pred))) + distances_pred_to_gt = sorted_surfels_pred[:,0] + surfel_areas_pred = sorted_surfels_pred[:,1] + + + return {"distances_gt_to_pred": distances_gt_to_pred, + "distances_pred_to_gt": distances_pred_to_gt, + "surfel_areas_gt": surfel_areas_gt, + "surfel_areas_pred": surfel_areas_pred} + + +def compute_average_surface_distance(surface_distances): + distances_gt_to_pred = surface_distances["distances_gt_to_pred"] + distances_pred_to_gt = surface_distances["distances_pred_to_gt"] + surfel_areas_gt = surface_distances["surfel_areas_gt"] + surfel_areas_pred = surface_distances["surfel_areas_pred"] + average_distance_gt_to_pred = np.sum( distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt) + average_distance_pred_to_gt = np.sum( distances_pred_to_gt * surfel_areas_pred) / np.sum(surfel_areas_pred) + return (average_distance_gt_to_pred, average_distance_pred_to_gt) + +def compute_robust_hausdorff(surface_distances, percent): + distances_gt_to_pred = surface_distances["distances_gt_to_pred"] + distances_pred_to_gt = surface_distances["distances_pred_to_gt"] + surfel_areas_gt = surface_distances["surfel_areas_gt"] + surfel_areas_pred = surface_distances["surfel_areas_pred"] + if len(distances_gt_to_pred) > 0: + surfel_areas_cum_gt = np.cumsum(surfel_areas_gt) / np.sum(surfel_areas_gt) + idx = np.searchsorted(surfel_areas_cum_gt, percent/100.0) + perc_distance_gt_to_pred = distances_gt_to_pred[min(idx, len(distances_gt_to_pred)-1)] + else: + perc_distance_gt_to_pred = np.Inf + + if len(distances_pred_to_gt) > 0: + surfel_areas_cum_pred = np.cumsum(surfel_areas_pred) / np.sum(surfel_areas_pred) + idx = np.searchsorted(surfel_areas_cum_pred, percent/100.0) + perc_distance_pred_to_gt = distances_pred_to_gt[min(idx, len(distances_pred_to_gt)-1)] + else: + perc_distance_pred_to_gt = np.Inf + + return max( perc_distance_gt_to_pred, perc_distance_pred_to_gt) + +def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm): + distances_gt_to_pred = surface_distances["distances_gt_to_pred"] + distances_pred_to_gt = surface_distances["distances_pred_to_gt"] + surfel_areas_gt = surface_distances["surfel_areas_gt"] + surfel_areas_pred = surface_distances["surfel_areas_pred"] + rel_overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) / np.sum(surfel_areas_gt) + rel_overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) / np.sum(surfel_areas_pred) + return (rel_overlap_gt, rel_overlap_pred) + +def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm): + distances_gt_to_pred = surface_distances["distances_gt_to_pred"] + distances_pred_to_gt = surface_distances["distances_pred_to_gt"] + surfel_areas_gt = surface_distances["surfel_areas_gt"] + surfel_areas_pred = surface_distances["surfel_areas_pred"] + overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) + overlap_pred = np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) + surface_dice = (overlap_gt + overlap_pred) / ( + np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred)) + return surface_dice + + +def compute_dice_coefficient(mask_gt, mask_pred): + """Compute soerensen-dice coefficient. + + compute the soerensen-dice coefficient between the ground truth mask `mask_gt` + and the predicted mask `mask_pred`. + + Args: + mask_gt: 3-dim Numpy array of type bool. The ground truth mask. + mask_pred: 3-dim Numpy array of type bool. The predicted mask. + + Returns: + the dice coeffcient as float. If both masks are empty, the result is NaN + """ + volume_sum = mask_gt.sum() + mask_pred.sum() + if volume_sum == 0: + return np.NaN + volume_intersect = (mask_gt & mask_pred).sum() + return 2*volume_intersect / volume_sum + \ No newline at end of file diff --git a/evaluate/__init__.py b/evaluate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/evaluate/evaluator.py b/evaluate/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e609628c6f0529ac009081b92e260b04343231cf --- /dev/null +++ b/evaluate/evaluator.py @@ -0,0 +1,379 @@ +import os +import time + +import torch +from torch.cuda.amp import autocast as autocast +from tqdm import tqdm +from einops import rearrange, repeat, reduce +import numpy as np +import pandas as pd +from pathlib import Path +import nibabel as nib +import shutil +import pickle +from scipy.ndimage import gaussian_filter +import torch.distributed as dist + +from evaluate.metric import calculate_metric_percase +from evaluate.merge_after_evaluate import merge +from train.dist import is_master + +def compute_gaussian(tile_size, sigma_scale: float = 1. / 8, value_scaling_factor: float = 10, dtype=np.float16): + tmp = np.zeros(tile_size) + center_coords = [i // 2 for i in tile_size] + sigmas = [i * sigma_scale for i in tile_size] + tmp[tuple(center_coords)] = 1 + gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) + + # gaussian_importance_map = torch.from_numpy(gaussian_importance_map) + + gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * value_scaling_factor + gaussian_importance_map = gaussian_importance_map.astype(dtype) + + # gaussian_importance_map cannot be 0, otherwise we may end up with nans! + gaussian_importance_map[gaussian_importance_map == 0] = np.min( + gaussian_importance_map[gaussian_importance_map != 0]) + + return gaussian_importance_map + +def evaluate(model, + text_encoder, + device, + testset, + testloader, + dice_score, + nsd_score, + csv_path, + resume, + save_interval, + visualization): + + # if to store pred、gt、img (as nii.gz + if visualization: + nib_dir = csv_path.replace('.csv', '') + + # collate in master process + if is_master(): + # datasets --> labels --> metrics + datasets_labels_metrics = {} # {'COVID19':{'covid19_infection':{'dice':[0.8, 0.9, ...], ...} ...}, ...} + + # datasets --> samples --> labels --> metrics + samples_labels_metrics = {} # {'COVID19':{'0.npy':{'covid19_infection':{'dice':0.8, ...} ...}, ...} 记录每个dataset里的sample(行) + + # datsets --> labels + datasets_labels_sets = {} # {'COVID19':set('covid19_infection', ...), ...} 记录每个dataset里的label种类(列) + + # accumulate scores of each sample in each process + results_of_samples = [] # each element : [dataset_name, modality, sample_id, scores_of_labels(dict), label_names] + + # load results from an interrupted eval (only in master process) + if resume and is_master(): + root_dir = os.path.dirname(csv_path) + prefix = os.path.basename(csv_path).replace('.csv', '_tmp_rank') # xxx/test/step_xxx.csv --> step_xxx_tmp_rank + pkl_to_del = [] + for f in os.listdir(root_dir): + if prefix in f: + # load list of results + pkl_path = f'{root_dir}/{f}' + with open(pkl_path, 'rb') as f: + results_of_samples += pickle.load(f) + print(f'Load results from {pkl_path}') + pkl_to_del.append(pkl_path) + + # there may be duplication? We leave the deduplication to the final merge + # merge all the loaded samples, del the tmp pickle files in previous evaluation task + for pkl_path in pkl_to_del: + os.remove(pkl_path) + print(f'Del {pkl_path}') + merge_pkl = csv_path.replace('.csv', f'_tmp_rank0.pkl') + with open(merge_pkl, 'wb') as f: + pickle.dump(results_of_samples, f) + print(f'Load results of {len(results_of_samples)} samples, Merge into {merge_pkl}') + + model.eval() + text_encoder.eval() + + with torch.no_grad(): + + data_time = 0 + pred_time = 0 + metric_time = 0 + + avg_patch_batch_num = 0 + avg_query_batch_num = 0 + + # in ddp, only master process display the progress bar + if is_master(): + testloader = tqdm(testloader, disable=False) + else: + testloader = tqdm(testloader, disable=True) + + # gaussian kernel to accumulate predcition + gaussian = torch.tensor(compute_gaussian((288, 288, 96))).to(device) # hwd + + end_time = time.time() + for sample in testloader: # in evaluation/inference, a "batch" in loader is a volume + # data loading + dataset_name = sample['dataset_name'] + sample_id = sample['sample_id'] + batched_patches = sample['batched_patches'] + batched_y1y2_x1x2_z1z2 = sample['batched_y1y2_x1x2_z1z2'] + labels = sample['labels'] + gt_segmentation = sample['gt_segmentation'].numpy() # n h w d + modality = sample['modality'] + image_path = sample['image_path'] + + n,h,w,d = gt_segmentation.shape + prediction = torch.zeros((n, h, w, d)) + accumulation = torch.zeros((n, h, w, d)) + + data_time += (time.time()-end_time) + end_time = time.time() + + with autocast(): + + queries = text_encoder(labels, modality) + + # for each batch of patches, query with all labels + for patches, y1y2_x1x2_z1z2_ls in zip(batched_patches, batched_y1y2_x1x2_z1z2): # [b, c, h, w, d] + patches = patches.to(device=device) + prediction_patch = model(queries=queries, image_input=patches, train_mode=False) + prediction_patch = torch.sigmoid(prediction_patch) # bnhwd + prediction_patch = prediction_patch.detach() # .cpu().numpy() + + # fill in + for b in range(len(y1y2_x1x2_z1z2_ls)): + y1, y2, x1, x2, z1, z2 = y1y2_x1x2_z1z2_ls[b] + + # gaussian accumulation + tmp = prediction_patch[b, :, :y2-y1, :x2-x1, :z2-z1] * gaussian[:y2-y1, :x2-x1, :z2-z1] # on gpu + prediction[:, y1:y2, x1:x2, z1:z2] += tmp.cpu() + accumulation[:, y1:y2, x1:x2, z1:z2] += gaussian[:y2-y1, :x2-x1, :z2-z1].cpu() + + pred_time += (time.time()-end_time) + end_time = time.time() + + # avg + prediction = prediction / accumulation + prediction = torch.where(prediction>0.5, 1.0, 0.0) + prediction = prediction.numpy() + + # cal metrics : [{'dice':x, ...}, ...] + scores = [] + for j in range(len(labels)): + scores.append(calculate_metric_percase(prediction[j, :, :, :], gt_segmentation[j, :, :, :], dice_score, nsd_score)) # {'dice':0.9, 'nsd':0.8} 每个label一个dict + + # visualization + if visualization: + Path(f'{nib_dir}/{dataset_name}').mkdir(exist_ok=True, parents=True) + # 将image、gt和prediction保存下来 + results = np.zeros((h, w, d)) # hwd + for j, label in enumerate(labels): + results += prediction[j, :, :, :] * (j+1) # 0 --> 1 (skip background) + Path(f'{nib_dir}/{dataset_name}/seg_{sample_id}').mkdir(exist_ok=True, parents=True) + # 每个label单独一个nii.gz + segobj = nib.nifti2.Nifti1Image(prediction[j, :, :, :], np.eye(4)) + nib.save(segobj, f'{nib_dir}/{dataset_name}/seg_{sample_id}/{label}.nii.gz') + segobj = nib.nifti2.Nifti1Image(results, np.eye(4)) + nib.save(segobj, f'{nib_dir}/{dataset_name}/seg_{sample_id}.nii.gz') + + image = testset.load_image(image_path) + image = np.squeeze(image) + imgobj = nib.nifti2.Nifti1Image(image, np.eye(4)) + nib.save(imgobj, f'{nib_dir}/{dataset_name}/img_{sample_id}.nii.gz') + + gt = np.zeros((h, w, d)) # hwd + for j, label in enumerate(labels): + gt += gt_segmentation[j, :, :, :] * (j+1) # 0 --> 1 (skip background) + Path(f'{nib_dir}/{dataset_name}/gt_{sample_id}').mkdir(exist_ok=True, parents=True) + # 每个label单独一个nii.gz + segobj = nib.nifti2.Nifti1Image(gt_segmentation[j, :, :, :], np.eye(4)) + nib.save(segobj, f'{nib_dir}/{dataset_name}/gt_{sample_id}/{label}.nii.gz') + gtobj = nib.nifti2.Nifti1Image(gt, np.eye(4)) + nib.save(gtobj, f'{nib_dir}/{dataset_name}/gt_{sample_id}.nii.gz') + + metric_time += (time.time()-end_time) + end_time = time.time() + + # accumulate + results_of_samples.append([dataset_name, modality, sample_id, scores, labels]) + + # save in each process regularly in case of interruption + if len(results_of_samples) % save_interval == 0: + with open(csv_path.replace('.csv', f'_tmp_rank{dist.get_rank()}.pkl'), 'wb') as f: + pickle.dump(results_of_samples, f) + + """ + # gather results from all device to rank-0 (solution 1) + gather_results = [None for i in range(dist.get_world_size())] + dist.gather_object( + results_of_samples, + gather_results if dist.get_rank() == 0 else None, + dst = 0 + ) + + if int(dist.get_rank()) == 0: + results_of_samples = [tmp for ls in results_of_samples for tmp in ls] + """ + + avg_patch_batch_num /= len(testloader) + avg_query_batch_num /= len(testloader) + data_time /= len(testloader) + pred_time /= len(testloader) + metric_time /= len(testloader) + print(f'On Rank {dist.get_rank()}, each sample has {avg_patch_batch_num} batch of patches and {avg_query_batch_num} batch of queries, Data Time: {data_time}, Pred Time: {pred_time}, Dice Time: {metric_time}') + + torch.cuda.empty_cache() + + # save in each process (to a fnl pickle, also denoting this process ends) + with open(csv_path.replace('.csv', f'_fnl_rank{dist.get_rank()}.pkl'), 'wb') as f: + pickle.dump(results_of_samples, f) + + # gather and record in rank 0 (solution 2) + if is_master(): + + # detect the finish of each process + while True: + all_process_finished = True + for rank_id in range(torch.distributed.get_world_size()): + if not os.path.exists(csv_path.replace('.csv', f'_fnl_rank{rank_id}.pkl')): # xxx_tmp_rankx.pkl + all_process_finished = False + break + if all_process_finished: + break + else: + time.sleep(10) + + # read results of each process (samples may be duplicated due to the even distribution of ddp, check) + results_of_samples = [] + for rank_id in range(torch.distributed.get_world_size()): + fnl_results_file = csv_path.replace('.csv', f'_fnl_rank{rank_id}.pkl') + tmp_results_file = csv_path.replace('.csv', f'_tmp_rank{rank_id}.pkl') + with open(fnl_results_file, 'rb') as f: + results_of_samples += pickle.load(f) + os.remove(fnl_results_file) + if os.path.exists(tmp_results_file): + os.remove(tmp_results_file) + + # check duplication + unique_set = set() + deduplicated_results_of_samples = [] + for dataset_name, modality, sample_id, scores, labels in results_of_samples: + if f'{dataset_name}/{sample_id}' not in unique_set: + unique_set.add(f'{dataset_name}/{sample_id}') + deduplicated_results_of_samples.append([dataset_name, modality, sample_id, scores, labels]) + results_of_samples = deduplicated_results_of_samples + + # save for tmp + with open(csv_path.replace('.csv', '.pkl'), 'wb') as f: + pickle.dump(results_of_samples, f) + + # collate results + for dataset_name, modality, sample_id, scores, labels in results_of_samples: # [[dataset_name, modality, sample_id, scores_of_labels(dict), label_names], ...] + dataset_name = f'{dataset_name}({modality})' + + if dataset_name not in datasets_labels_metrics: + datasets_labels_metrics[dataset_name] = {} # {'COVID19(CT)':{}} + if dataset_name not in datasets_labels_sets: + datasets_labels_sets[dataset_name] = set() # {'COVID19(CT)':set()} + if dataset_name not in samples_labels_metrics: + samples_labels_metrics[dataset_name] = {} + samples_labels_metrics[dataset_name][sample_id] = {} # {'COVID19(CT)':{'0':{}}} + + for metric_dict, label in zip(scores, labels): + # accumulate metrics (for per dataset per class + # {'COVID19(CT)':{'covid19_infection':{'dice':[0.8, 0.9, ...], 'nsd':[0.8, 0.9, ...], ...} ...}, ...} + if label not in datasets_labels_metrics[dataset_name]: + datasets_labels_metrics[dataset_name][label] = {k:[v] for k,v in metric_dict.items()} + else: + for k,v in metric_dict.items(): + datasets_labels_metrics[dataset_name][label][k].append(v) + + # statistic labels + # {'COVID19(CT)':set('covid19_infection', ...)} + if label not in datasets_labels_sets[dataset_name]: + datasets_labels_sets[dataset_name].add(label) + + # record metrics (for per dataset per sample per class + # {'COVID19':{'0.npy':{'covid19_infection':{'dice':0.8, 'nsd':0.9, ...} ...}, ...} + samples_labels_metrics[dataset_name][sample_id][label] = {k:v for k,v in metric_dict.items()} + + # average and log (列为metrics,例如dice,nsd...) + # create a df like: + # { + # 'TotalSegmentator': [0.xx, 0.xx, ...] # 在T之前,这是一列 + # 'TotalSegmentator, Lung': [0.68, 0.72, ...] + # } + # by defult, print the dice (1st metric) of each dataset + info = 'Metrics of Each Dataset:\n' + avg_df = {} + for dataset in datasets_labels_metrics.keys(): + avg_df[dataset] = {k:[] for k in metric_dict.keys()} # 'TotalSegmentator(CT)': {'dice':[0.8, ...] 'nsd':[0.5, ...], ...} + for label in datasets_labels_metrics[dataset].keys(): + avg_df[f'{dataset}, {label}'] = [] + for metric in datasets_labels_metrics[dataset][label].keys(): + label_metric = np.average(datasets_labels_metrics[dataset][label][metric]) + avg_df[f'{dataset}, {label}'].append(label_metric) # 'TotalSegmentator, Lung': [0.68, 0.72, ...] list of num_metrics + avg_df[dataset][metric].append(label_metric) + avg_df[dataset] = {k:np.average(v) for k,v in avg_df[dataset].items()} # 'TotalSegmentator': {'dice':[0.8, ...] 'nsd':[0.5, ...], ...} --> 'TotalSegmentator': {'dice':0.x, 'nsd':0.x, ...} + info += f'{dataset} | ' + for k ,v in avg_df[dataset].items(): + info += f'{v}({k}) | ' + info += '\n' + avg_df[dataset] = list(avg_df[dataset].values()) + avg_df = pd.DataFrame(avg_df).T + avg_df.columns = list(metric_dict.keys()) # ['dice', 'nsd'] + avg_df.to_csv(csv_path) + print(info) + + # detailed log (nsd和dice,列为class label + # multi-sheet, two for each dataset + df_list = [['summary', avg_df]] + for dataset, label_set in datasets_labels_sets.items(): + metric_df ={} + if dice_score: + metric_df['dice'] = {} + if nsd_score: + metric_df['nsd'] = {} + + # create dfs like: + # { + # '0.npy': [0.xx, 0.xx, ...] + # ...... + # } + + # {'COVID19':{'0.npy':{'covid19_infection':{'dice':0.8, ...} ...}, ...} + for image_id, label_dict in samples_labels_metrics[dataset].items(): + for metric in metric_df: + tmp = [] # one dice for each label in this dataset + for label in label_set: + score = label_dict[label][metric] if label in label_dict else -1 + tmp.append(score) + metric_df[metric][image_id] = tmp + + for metric, metric_df in metric_df.items(): + metric_df = pd.DataFrame(metric_df).T + metric_df.columns = list(label_set) + df_list.append([dataset+f'({metric})', metric_df]) + + xlsx_path = csv_path.replace('.csv', '.xlsx') + with pd.ExcelWriter(xlsx_path) as writer: + for name, df in df_list: + # 将每个 DataFrame 写入一个 sheet(sheet name must be < 31) + if len(name) > 31: + name = name[len(name)-31:] + df.to_excel(writer, sheet_name=name, index=True) + + # avg_dice_over_merged_labels, avg_nsd_over_merged_labels = merge(region_split_json, label_statistic_json, xlsx_path, xlsx_path) + + os.remove(csv_path.replace('.csv', '.pkl')) + + else: + + pass + + # avg_dice_over_merged_labels = avg_nsd_over_merged_labels = 0 + + return # avg_dice_over_merged_labels, avg_nsd_over_merged_labels + + \ No newline at end of file diff --git a/evaluate/merge_after_evaluate.py b/evaluate/merge_after_evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..1173fd8303e4cebd66dca18916f8ccb0ade2c331 --- /dev/null +++ b/evaluate/merge_after_evaluate.py @@ -0,0 +1,198 @@ +import json + +import pandas as pd +import openpyxl + +def merge(mod_label_json, mod_label_statistic, xlsx2load, xlsx2save): + mod_lab2dice = {} + + # Load the first sheet of the Excel file + excel_file_path = xlsx2load + df = pd.read_excel(excel_file_path, sheet_name=0) + has_nsd = True if len(df.columns) > 2 else False + + # 将Dataset Merged 写入新的工作表 + workbook = openpyxl.load_workbook(xlsx2load) + new_sheet = workbook.create_sheet(title='Dataset Merge', index=1) + new_sheet.cell(row=1, column=1, value='Dataset') + new_sheet.cell(row=1, column=2, value='Dice') + new_sheet.cell(row=1, column=3, value='NSD') + row = 2 + for i in range(0, len(df)): + if ',' not in df.iloc[i, 0]: + new_sheet.cell(row=row, column=1, value=df.iloc[i, 0]) + new_sheet.cell(row=row, column=2, value=df.iloc[i, 1]) + if has_nsd: + new_sheet.cell(row=row, column=3, value=df.iloc[i, 2]) + row += 1 + + # with pd.ExcelWriter(xlsx2save, engine='openpyxl', mode='a', if_sheet_exists='new') as writer: + # filtered_df.to_excel(writer, sheet_name='Dataset Merge', index=False) + + # 选取前两列 + dataset_label_ls = df.iloc[:, 0] + dice_ls = df.iloc[:, 1] + nsd_ls = df.iloc[:, 2] if has_nsd else [0] * len(df) + + for dataset_modality_label, dice, nsd in zip(dataset_label_ls, dice_ls, nsd_ls): # MSD_Pancreas(ct), pancreas 0.89 + if ', ' not in dataset_modality_label: + continue + dataset_modality, label = dataset_modality_label.split(', ') + label = label.lower() # pancreas + # label = merge_label(label) + modality = dataset_modality.split('(')[-1].split(')')[0] # ct + + # unique id : modality_label + mod_lab = f'{modality}_{label}' + + # accumulate : dice and where the dice comes from (dataset, label, modality) + if mod_lab not in mod_lab2dice: + mod_lab2dice[mod_lab] = {'dice':[], 'nsd':[], 'merge':[]} + mod_lab2dice[mod_lab]['dice'].append(dice) + mod_lab2dice[mod_lab]['nsd'].append(nsd) + mod_lab2dice[mod_lab]['merge'].append(dataset_modality_label) + + # retrieval regions + with open(mod_label_json, 'r') as f: + dict = json.load(f) + region2label = dict['region_based'] + for region, label_ls in region2label.items(): + region2label[region] = [mod_lab.split('_')[-1] for mod_lab in label_ls] # 去除modality + region2label['abnormal'] = [mod_lab.split('_')[-1] for mod_lab in dict['abnormal']] + + region_dice_ls = {k:[] for k in region2label.keys()} # {'brain':[0.9, ...], ...} + region_nsd_ls = {k:[] for k in region2label.keys()} # {'brain':[0.9, ...], ...} + region_merge_ls = {k:[] for k in region2label.keys()} # {'brain':['frontal lobe', ...], ...} + + mod_lab_ls = [] + dice_ls = [] + nsd_ls = [] + merge_ls = [] + region_ls = [] + for mod_lab, dict in mod_lab2dice.items(): + label = mod_lab.split('_')[-1] + mod_lab_ls.append(mod_lab) + dice_ls.append(sum(dict['dice'])/len(dict['dice'])) + nsd_ls.append(sum(dict['nsd'])/len(dict['nsd'])) + merge_ls.append(' / '.join(dict['merge'])) + + # find region + if label in region2label['abnormal']: + region_dice_ls['abnormal'].append(dice_ls[-1]) + region_nsd_ls['abnormal'].append(nsd_ls[-1]) + region_merge_ls['abnormal'].append(mod_lab) + region_ls.append('abnormal') + else: + found = False + for region, labels_in_region in region2label.items(): + if label in labels_in_region: + region_dice_ls[region].append(dice_ls[-1]) + region_nsd_ls[region].append(nsd_ls[-1]) + region_merge_ls[region].append(mod_lab) + region_ls.append(region) + found = True + break + if not found: + print(label) + region_ls.append('unknown') + + df = pd.DataFrame({ + 'Modality_Label': mod_lab_ls, + 'Dice': dice_ls, + 'NSD': nsd_ls, + 'Merge': merge_ls, + 'Region': region_ls + }) + + #book = openpyxl.load_workbook(xlsx2save) + #writer = pd.ExcelWriter(xlsx2save, engine='openpyxl') + #writer.book = book + + # with pd.ExcelWriter(xlsx2save, engine='openpyxl', mode='a', if_sheet_exists='new') as writer: + # df.to_excel(writer, sheet_name='Label Merge', index=False) + + # 写上anno num和repeat ratio + with open(mod_label_statistic, 'r') as f: + statistic_dict = json.load(f) + + # 将Label Merged DataFrame写入新的工作表 + new_sheet = workbook.create_sheet(title='Label Merge', index=1) + new_sheet.cell(row=1, column=1, value='Modality_Label') + new_sheet.cell(row=1, column=2, value='Dice') + new_sheet.cell(row=1, column=3, value='NSD') + new_sheet.cell(row=1, column=4, value='Merge') + new_sheet.cell(row=1, column=5, value='Region') + new_sheet.cell(row=1, column=6, value='Total_Num') + new_sheet.cell(row=1, column=7, value='Aug_Ratio') + row = 2 + for mod_lab, dice, nsd, merge, region in zip(mod_lab_ls, dice_ls, nsd_ls, merge_ls, region_ls): + if mod_lab in statistic_dict: + _, total_num, aug_ratio = statistic_dict[mod_lab] + else: + total_num = aug_ratio = 0 + new_sheet.cell(row=row, column=1, value=mod_lab) + new_sheet.cell(row=row, column=2, value=dice) + new_sheet.cell(row=row, column=3, value=nsd) + new_sheet.cell(row=row, column=4, value=merge) + new_sheet.cell(row=row, column=5, value=region) + new_sheet.cell(row=row, column=6, value=total_num) + new_sheet.cell(row=row, column=7, value=aug_ratio) + row += 1 + new_sheet.cell(row=row, column=2, value=sum(dice_ls)/len(dice_ls)) # avg over all labels + new_sheet.cell(row=row, column=3, value=sum(nsd_ls)/len(nsd_ls)) + + # 将Region Merged 写入新的工作表 + new_sheet = workbook.create_sheet(title='Region Merge', index=1) + new_sheet.cell(row=1, column=1, value='Region') + new_sheet.cell(row=1, column=2, value='Dice') + new_sheet.cell(row=1, column=3, value='NSD') + new_sheet.cell(row=1, column=4, value='Merge') + row = 2 + for key in region_dice_ls.keys(): + if len(region_dice_ls[key]) == 0: + dice = nsd = 0 + merge = None + else: + dice = sum(region_dice_ls[key])/len(region_dice_ls[key]) + nsd = sum(region_nsd_ls[key])/len(region_nsd_ls[key]) + merge = ','.join(region_merge_ls[key]) + class_name = f'{key}({len(region_dice_ls[key])})' + new_sheet.cell(row=row, column=1, value=class_name) + new_sheet.cell(row=row, column=2, value=dice) + new_sheet.cell(row=row, column=3, value=nsd) + new_sheet.cell(row=row, column=4, value=merge) + row += 1 + + workbook.save(xlsx2save) + + # 返回所有 label 的 avg + avg_dice_over_merged_labels = sum(dice_ls) / len(dice_ls) + avg_nsd_over_merged_labels = sum(nsd_ls) / len(nsd_ls) + + return avg_dice_over_merged_labels, avg_nsd_over_merged_labels + +if __name__ == '__main__': + import argparse + + def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + parser = argparse.ArgumentParser() + parser.add_argument('--xlsx2load', type=str) + parser.add_argument('--xlsx2save', type=str) + parser.add_argument('--mod_lab_json', type=str, default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab(72).json') + parser.add_argument('--mod_label_statistic', type=str, default='/mnt/petrelfs/share_data/wuchaoyi/SAM/processed_files_v4/mod_lab_accum_statis(49).json') + + config = parser.parse_args() + + if not config.xlsx2save: + config.xlsx2save = config.xlsx2load + + merge(config.mod_lab_json, config.mod_label_statistic, config.xlsx2load, config.xlsx2save) \ No newline at end of file diff --git a/evaluate/metric.py b/evaluate/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..49610fdc523d64e2aa8f54dbcee87f31bf6cdf20 --- /dev/null +++ b/evaluate/metric.py @@ -0,0 +1,46 @@ +import torch +import numpy as np +import time +from medpy import metric +from .SurfaceDice import compute_surface_distances, compute_surface_dice_at_tolerance + +def calculate_metric_percase(pred, gt, dice=True, nsd=True): + pred = pred.astype(bool) + gt = gt.astype(bool) + + metrics = {} + + if np.sum(gt) == 0.0: + if np.sum(pred) == 0.0: + if dice: + metrics['dice'] = 1.0 + if nsd: + metrics['nsd'] = 1.0 + else: + if dice: + metrics['dice'] = 0.0 + if nsd: + metrics['nsd'] = 0.0 + return metrics + + if dice: + dice_score = metric.binary.dc(pred, gt) + metrics['dice'] = dice_score + + if nsd: + surface_distances = compute_surface_distances(gt, pred, [1, 1, 3]) + nsd_score = compute_surface_dice_at_tolerance(surface_distances, 1) + metrics['nsd'] = nsd_score + + return metrics + +if __name__ == '__main__': + pred = torch.zeros((3, 256, 256, 16)).numpy() + pred[:, 0:128, 0:128, :] = 1.0 + gt = torch.zeros((3, 256, 256, 16)).numpy() + gt[:, 0:64, 0:64, :] = 1.0 + dice = calculate_metric_percase(pred, gt)['dice'] + print(dice) + + + \ No newline at end of file diff --git a/evaluate/params.py b/evaluate/params.py new file mode 100644 index 0000000000000000000000000000000000000000..e77a45843e2313251ac3a1d2e0cca8461f68ff50 --- /dev/null +++ b/evaluate/params.py @@ -0,0 +1,153 @@ +import argparse + +def str2bool(v): + return v.lower() in ('true', 't') + +def parse_args(): + parser = argparse.ArgumentParser() + + # Exp Controller + + parser.add_argument( + "--rcd_dir", + type=str, + help="save the evaluation results (in a directory)", + ) + parser.add_argument( + "--rcd_file", + type=str, + help="save the evaluation results (in a csv/xlsx file)", + ) + parser.add_argument( + "--visualization", + type=str2bool, + default=False, + help="save the visualization for each case (img, gt, pred)", + ) + parser.add_argument( + "--checkpoint", + type=str, + help="Checkpoint path", + ) + parser.add_argument( + "--partial_load", + type=str2bool, + default=True, + help="Allow to load partial paramters from checkpoint", + ) + parser.add_argument( + "--gpu", + type=str, + default=None, + ) + parser.add_argument( + "--resume", + type=str2bool, + default=True, + help="Inherit medial results from an interrupted evaluation (no harm even if you evaluate from scratch)", + ) + parser.add_argument( + "--save_interval", + type=int, + default=100 + ) + + # Metrics + + parser.add_argument( + "--dice", + type=str2bool, + default=True, + ) + parser.add_argument( + "--nsd", + type=str2bool, + default=True, + ) + + # Med SAM Dataset + + parser.add_argument( + "--datasets_jsonl", + type=str, + ) + parser.add_argument( + "--text_prompts_json", + type=str, + help='This is needed for CVPR25 challenge, where multiple prompts (synonyms) are required.' + ) + + # Sampler and Loader + + parser.add_argument( + "--online_crop", + type=str2bool, + default='False', + help='load pre-cropped image patches directly, or crop online', + ) + parser.add_argument( + "--crop_size", + type=int, + nargs='+', + default=[288, 288, 96], + ) + parser.add_argument( + "--max_queries", + type=int, + default=256, + ) + parser.add_argument( + "--batchsize_3d", + type=int, + default=2, + ) + parser.add_argument( + "--pin_memory", + type=str2bool, + default=False, + help='load data to gpu to accelerate' + ) + parser.add_argument( + "--num_workers", + type=int, + default=4 + ) + + # Knowledge Encoder + parser.add_argument( + "--text_encoder_partial_load", + type=str2bool, + default=True, + help="Allow to load partial paramters from checkpoint", + ) + parser.add_argument( + "--text_encoder_checkpoint", + type=str, + ) + parser.add_argument( + "--text_encoder", + type=str, + ) + + # MaskFormer + + parser.add_argument( + "--vision_backbone", + type=str, + help='UNET or UNET-H' + ) + parser.add_argument( + "--patch_size", + type=int, + nargs='+', + default=[32, 32, 32], + help='patch size on h w and d' + ) + parser.add_argument( + "--deep_supervision", + type=str2bool, + default=False, + ) + + args = parser.parse_args() + return args diff --git a/inference_medals_nifti.py b/inference_medals_nifti.py new file mode 100644 index 0000000000000000000000000000000000000000..3f55c57a2f199cfb98c4d8884e74857de806df21 --- /dev/null +++ b/inference_medals_nifti.py @@ -0,0 +1,1885 @@ +""" +Medal-S inference script for generic raw image segmentation. + +This script provides an interface for running Medal-S inference +on raw NIfTI images. It supports both single-stage (Stage 2 only) and +two-stage (Stage 1 + Stage 2) inference modes. + +Usage: + python inference_medals.py --input input.nii.gz --output output.nii.gz \\ + --modality CT --texts "Aorta observed in abdominal CT scans" --labels 1 + + # Or use JSON configuration file: + python inference_medals.py --input input.nii.gz --output output.nii.gz \\ + --config config.json --mode stage1+stage2 + +Author: Pengcheng Shi +Institute: Medical Image Insights, Inc., Shanghai, China +Email: shipc1220@gmail.com +License: Apache License 2.0 +""" + +import os +import argparse +import json +import time +import math +import random +import itertools +import gc +import numpy as np +import SimpleITK as sitk +import torch +import torch.nn.functional as F +from typing import List +from scipy.ndimage import label, gaussian_filter +from einops import rearrange +from tqdm import tqdm +from torch.cuda.amp import autocast + +from data.default_resampling import resample_data_or_seg, compute_new_shape, resample_data_or_seg_to_spacing +from data.resample_torch import resample_torch_fornnunet, resample_torch_simple +from model.maskformer import Maskformer +from model.knowledge_encoder import Knowledge_Encoder + +def adjust_spacing(img_array, img_spacing): + """ + Adjust spacing based on image dimensions. + + This function swaps spacing values if the dimension with minimum size + doesn't match the dimension with maximum spacing. + + Args: + img_array: Image array (used for shape reference) + img_spacing: Spacing array + + Returns: + Adjusted spacing array + """ + img_spacing = np.asarray(img_spacing) + min_dim_index = np.argmin(img_array.shape) + max_spacing_index = np.argmax(img_spacing) + + if (min_dim_index != max_spacing_index) and (img_spacing[max_spacing_index] > 0.5): + new_order = list(range(len(img_spacing))) + new_order[min_dim_index], new_order[max_spacing_index] = new_order[max_spacing_index], new_order[min_dim_index] + img_spacing = img_spacing[new_order] + + return img_spacing + + +def remove_small_objects_binary(binary_data, min_size=10): + """ + Remove small objects from binary data. + + Args: + binary_data: Binary array + min_size: Minimum size threshold for objects to keep + + Returns: + Binary array with small objects removed + """ + labeled_array, num_features = label(binary_data) + sizes = np.bincount(labeled_array.ravel()) + remove = sizes < min_size + remove[0] = False # Ensure the background (label 0) is not removed + labeled_array[remove[labeled_array]] = 0 + return labeled_array > 0 + + +def respace_image(image: np.ndarray, current_spacing: np.ndarray, target_spacing: np.ndarray, device: torch.device) -> np.ndarray: + """ + Resample image to target spacing. + + Args: + image: Input image array with shape (C, H, W, D) + current_spacing: Current spacing array + target_spacing: Target spacing array + device: PyTorch device for resampling + + Returns: + Resampled image array + """ + new_shape = compute_new_shape(image.shape[1:], current_spacing, target_spacing) + resampled_image = resample_torch_fornnunet( + image, new_shape, current_spacing, target_spacing, + is_seg=False, num_threads=8, device=device, + memefficient_seg_resampling=False, + force_separate_z=None, + separate_z_anisotropy_threshold=3.0 + ) + return resampled_image + + +def respace_mask(mask: np.ndarray, current_spacing: np.ndarray, target_spacing: np.ndarray, device: torch.device) -> np.ndarray: + """ + Resample mask to target spacing. + + Args: + mask: Input mask array with shape (C, H, W, D) + current_spacing: Current spacing array + target_spacing: Target spacing array + device: PyTorch device for resampling + + Returns: + Resampled mask array + """ + new_shape = compute_new_shape(mask.shape[1:], current_spacing, target_spacing) + resampled_mask = resample_torch_fornnunet( + mask, new_shape, current_spacing, target_spacing, + is_seg=True, num_threads=8, device=device, + memefficient_seg_resampling=False, + force_separate_z=None, + separate_z_anisotropy_threshold=3.0 + ) + return resampled_mask + + +def split_3d(image_tensor, crop_size=[288, 288, 96]): + """ + Split 3D image into overlapping patches. + + Patches are extracted with 50% overlap (stride = crop_size / 2) to ensure + complete coverage of the image volume. + + Args: + image_tensor: Input image tensor with shape (C, H, W, D) + crop_size: Size of each patch [h, w, d] + + Returns: + split_patch: List of patch tensors + split_idx: List of patch indices [h_s, h_e, w_s, w_e, d_s, d_e] + """ + interval_h, interval_w, interval_d = crop_size[0] // 2, crop_size[1] // 2, crop_size[2] // 2 + split_idx = [] + split_patch = [] + + c, h, w, d = image_tensor.shape + h_crop = max(math.ceil(h / interval_h) - 1, 1) + w_crop = max(math.ceil(w / interval_w) - 1, 1) + d_crop = max(math.ceil(d / interval_d) - 1, 1) + + for i in range(h_crop): + h_s = i * interval_h + h_e = h_s + crop_size[0] + if h_e > h: + h_s = h - crop_size[0] + h_e = h + if h_s < 0: + h_s = 0 + for j in range(w_crop): + w_s = j * interval_w + w_e = w_s + crop_size[1] + if w_e > w: + w_s = w - crop_size[1] + w_e = w + if w_s < 0: + w_s = 0 + for k in range(d_crop): + d_s = k * interval_d + d_e = d_s + crop_size[2] + if d_e > d: + d_s = d - crop_size[2] + d_e = d + if d_s < 0: + d_s = 0 + split_idx.append([h_s, h_e, w_s, w_e, d_s, d_e]) + split_patch.append(image_tensor[:, h_s:h_e, w_s:w_e, d_s:d_e]) + + return split_patch, split_idx + + +def pad_if_necessary(image, crop_size=[288, 288, 96]): + """ + Pad image if necessary to meet crop size requirements. + + Args: + image: Input image tensor with shape (C, H, W, D) + crop_size: Minimum size requirements [h, w, d] + + Returns: + padded_image: Padded image tensor + padding_info: Tuple of padding amounts (pad_h, pad_w, pad_d) + """ + c, h, w, d = image.shape + croph, cropw, cropd = crop_size + pad_in_h = 0 if h >= croph else croph - h + pad_in_w = 0 if w >= cropw else cropw - w + pad_in_d = 0 if d >= cropd else cropd - d + + padding_info = (pad_in_h, pad_in_w, pad_in_d) + + if pad_in_h + pad_in_w + pad_in_d > 0: + pad = (0, pad_in_d, 0, pad_in_w, 0, pad_in_h) + image = F.pad(image, pad, 'constant', 0) + + return image, padding_info + + +def remove_padding(padded_image, padding_info): + """ + Remove padding from image. + + Args: + padded_image: Padded image (can be torch.Tensor or numpy array) + padding_info: Tuple of padding amounts (pad_h, pad_w, pad_d) + + Returns: + Image with padding removed + """ + pad_in_h, pad_in_w, pad_in_d = padding_info + + if len(padded_image.shape) == 4: + if isinstance(padded_image, torch.Tensor): + return padded_image[:, :padded_image.shape[1]-pad_in_h, :padded_image.shape[2]-pad_in_w, :padded_image.shape[3]-pad_in_d] + else: + return padded_image[:, :padded_image.shape[1]-pad_in_h, :padded_image.shape[2]-pad_in_w, :padded_image.shape[3]-pad_in_d] + else: + if isinstance(padded_image, torch.Tensor): + return padded_image[:padded_image.shape[0]-pad_in_h, :padded_image.shape[1]-pad_in_w, :padded_image.shape[2]-pad_in_d] + else: + return padded_image[:padded_image.shape[0]-pad_in_h, :padded_image.shape[1]-pad_in_w, :padded_image.shape[2]-pad_in_d] + + +def internal_maybe_mirror_and_predict(model=None, queries=None, image_input=None, simulated_lowres_sc_pred=None, + simulated_lowres_mc_pred=None, mirror_axes=(0, 1, 2)): + """ + Apply test-time augmentation with mirroring. + + This function performs inference with multiple mirroring combinations + and averages the results for improved robustness. + + Args: + model: Model to use for prediction + queries: Query tensor + image_input: Input image tensor + simulated_lowres_sc_pred: Simulated low-res single-channel prediction + simulated_lowres_mc_pred: Simulated low-res multi-channel prediction + mirror_axes: Axes to mirror (0, 1, 2 for spatial dimensions) + + Returns: + Averaged prediction tensor + """ + prediction = model(queries=queries, + image_input=image_input, + simulated_lowres_sc_pred=simulated_lowres_sc_pred, + simulated_lowres_mc_pred=simulated_lowres_mc_pred, + train_mode=False) + + if mirror_axes is not None: + assert max(mirror_axes) <= image_input.ndim - 3, 'mirror_axes does not match the dimension of the input!' + mirror_axes = [m + 2 for m in mirror_axes] + axes_combinations = [ + c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1) + ] + for axes in axes_combinations: + image_input_fliped = torch.flip(image_input, axes) + simulated_lowres_sc_pred_fliped = torch.flip(simulated_lowres_sc_pred.unsqueeze(0), axes).squeeze(0) if simulated_lowres_sc_pred is not None else None + simulated_lowres_mc_pred_fliped = torch.flip(simulated_lowres_mc_pred.unsqueeze(0), axes).squeeze(0) if simulated_lowres_mc_pred is not None else None + prediction_fliped = model(queries=queries, + image_input=image_input_fliped, + simulated_lowres_sc_pred=simulated_lowres_sc_pred_fliped, + simulated_lowres_mc_pred=simulated_lowres_mc_pred_fliped, + train_mode=False) + prediction += torch.flip(prediction_fliped, axes) + prediction /= (len(axes_combinations) + 1) + return prediction + + +def compute_patch_prediction( + queries: torch.Tensor, + patches: torch.Tensor, + lowres_single_channel_pred: torch.Tensor, + lowres_multi_channel_pred: torch.Tensor, + model: torch.nn.Module, + possible_block_sizes: List[int], + n_repeats: int = 1, + disable_tta: bool = True +) -> torch.Tensor: + """ + Compute patch predictions using complementary masking. + + This function splits the volume into blocks, processes complementary halves + using random masks, and combines results. The process is repeated n_repeats + times with different random masks, and results are averaged. + + Args: + queries: Input query tensor, shape (batch, query_dim) + patches: Image patch tensor, shape (batch, channels, h, w, d) + lowres_single_channel_pred: Low-res single-channel prediction, shape (1, 1, h, w, d) + lowres_multi_channel_pred: Low-res multi-channel prediction, shape (1, c, h, w, d) + model: Trained neural network model + possible_block_sizes: List of possible block sizes (e.g., [8, 16, 32]) + n_repeats: Number of times to repeat prediction with different masks + disable_tta: Whether to disable test-time augmentation + + Returns: + Averaged patch prediction, shape (1, c, h, w, d) + """ + # Validate inputs + if not possible_block_sizes: + raise ValueError("possible_block_sizes cannot be empty") + if n_repeats < 1: + raise ValueError("n_repeats must be at least 1") + + _, _, h, w, d = lowres_single_channel_pred.shape + device = lowres_single_channel_pred.device + prediction_sum = torch.zeros_like(lowres_multi_channel_pred, device=device) + + def upsample_block_mask(block_mask: torch.Tensor, block_size: int) -> torch.Tensor: + """Upsample a block mask to full resolution.""" + upsampled = ( + block_mask.unsqueeze(0).unsqueeze(0) + .repeat_interleave(block_size, dim=2) + .repeat_interleave(block_size, dim=3) + .repeat_interleave(block_size, dim=4) + [:, :, :h, :w, :d] + ).float() + return upsampled + + for _ in range(n_repeats): + block_size = random.choice(possible_block_sizes) + n_blocks_h = (h + block_size - 1) // block_size + n_blocks_w = (w + block_size - 1) // block_size + n_blocks_d = (d + block_size - 1) // block_size + total_blocks = n_blocks_h * n_blocks_w * n_blocks_d + + num_selected = max(1, total_blocks // 2) + block_mask = torch.zeros(n_blocks_h, n_blocks_w, n_blocks_d, dtype=torch.bool, device=device) + indices = torch.randperm(total_blocks, device=device)[:num_selected] + block_mask.view(-1)[indices] = True + + mask = upsample_block_mask(block_mask, block_size) + complementary_mask = 1.0 - mask + + masked_sc_pred = lowres_single_channel_pred * mask + masked_mc_pred = lowres_multi_channel_pred * mask + + if disable_tta: + first_half_pred = model( + queries=queries, + image_input=patches, + simulated_lowres_sc_pred=masked_sc_pred, + simulated_lowres_mc_pred=masked_mc_pred, + train_mode=False + ) + else: + first_half_pred = internal_maybe_mirror_and_predict( + model=model, + queries=queries, + image_input=patches, + simulated_lowres_sc_pred=masked_sc_pred, + simulated_lowres_mc_pred=masked_mc_pred, + mirror_axes=(0, 1, 2) + ) + + masked_sc_pred_comp = lowres_single_channel_pred * complementary_mask + masked_mc_pred_comp = lowres_multi_channel_pred * complementary_mask + + if disable_tta: + second_half_pred = model( + queries=queries, + image_input=patches, + simulated_lowres_sc_pred=masked_sc_pred_comp, + simulated_lowres_mc_pred=masked_mc_pred_comp, + train_mode=False + ) + else: + second_half_pred = internal_maybe_mirror_and_predict( + model=model, + queries=queries, + image_input=patches, + simulated_lowres_sc_pred=masked_sc_pred_comp, + simulated_lowres_mc_pred=masked_mc_pred_comp, + mirror_axes=(0, 1, 2) + ) + + final_prediction = first_half_pred * complementary_mask + second_half_pred * mask + prediction_sum += final_prediction + + return prediction_sum / n_repeats + + +def read_npz_data(raw_image, raw_spacing, crop_size=[288, 288, 96], + target_spacing=[1.5, 1.5, 3.0], scaled_roi_lowres_pred_array=None, + class_name_list=[], stage_1_flag=False, device=torch.device("cuda", 0), verbose=True): + """ + Read and preprocess image data for inference. + + This function handles spacing adjustments, image resampling, padding, + and patch splitting for the inference pipeline. + + Args: + raw_image: Input image array with shape (d, h, w) + raw_spacing: Spacing array with shape (3,) + crop_size: Target crop size [h, w, d] + target_spacing: Target spacing [h, w, d] + scaled_roi_lowres_pred_array: Optional low-res prediction for ROI-based inference + class_name_list: List of class names (kept for compatibility, not used) + stage_1_flag: Whether this is Stage 1 inference (kept for compatibility, not used) + device: PyTorch device for resampling + verbose: Whether to print detailed information (default: True) + + Returns: + data_dict: Dictionary containing preprocessed patches and metadata + """ + raw_d, raw_h, raw_w = raw_image.shape + image = rearrange(raw_image, 'd h w -> h w d') + spacing = raw_spacing.astype(np.float32) + + # Simplified spacing adjustment following the provided steps + # Step 1: Handle very small spacing values + for i in range(3): + if spacing[i] <= 0.1: + spacing[i] = 1.0 + + # Step 2: Adjust spacing based on image dimensions + spacing = adjust_spacing(image, spacing) + + # Step 3: Initialize parameters for spacing adjustment + max_dims = [1000, 1000, 700] + min_dims = crop_size + thresholds = [] + current = 1.25 + while current <= 50: + thresholds.append(current) + current *= 1.25 + raw_target_spacing = target_spacing.copy() + + # Step 4: Adjust spacing based on constraints + for i in range(3): + # If spacing is less than 1.0 and image dimension is within max_dims, set to 1.0 + if spacing[i] < 1.0 and image.shape[i] <= max_dims[i]: + spacing[i] = 1.0 # second stage model resolution + + # If physical dimension exceeds max_dims and spacing is greater than target, use target spacing + if spacing[i] * image.shape[i] > max_dims[i] * target_spacing[i] and spacing[i] > target_spacing[i]: + spacing[i] = target_spacing[i] + # If physical dimension is less than min_dims threshold, adjust target_spacing + elif spacing[i] * image.shape[i] < min_dims[i] * target_spacing[i]: + alpha_spacing = 1 + for threshold in reversed(thresholds): + if image.shape[i] <= (min_dims[i] / threshold): + alpha_spacing = threshold + break + + raw_target_spacing[i] = target_spacing[i] + target_spacing[i] = max(spacing[i] * image.shape[i] / min_dims[i], spacing[i] / alpha_spacing) + if verbose: + print("alpha_spacing: ", alpha_spacing) + print("spacing[i] * image.shape[i] / min_dims[i], spacing[i] / alpha_spacing: ", spacing[i] * image.shape[i] / min_dims[i], spacing[i] / alpha_spacing) + print("raw_target_spacing[i], target_spacing[i]: ", raw_target_spacing[i], target_spacing[i]) + target_spacing[i] = min(raw_target_spacing[i], target_spacing[i]) + if verbose: + print("image.shape[i], min_dims[i], target_spacing[i], spacing[i]: ", image.shape[i], min_dims[i], target_spacing[i], spacing[i]) + + # Set default num_iterations (no special class handling) + num_iterations = 1 + + image = image[np.newaxis, ...].astype(np.float32) + if verbose: + print("image.shape: ", image.shape) + print("spacing: ", spacing) + print("target_spacing: ", target_spacing) + image = respace_image(image, spacing, target_spacing, torch.device('cpu')) + if verbose: + print("respace image.shape: ", image.shape) + image = torch.tensor(image) + image, padding_info = pad_if_necessary(image, crop_size=crop_size) + _, h, w, d = image.shape + + patches, y1y2_x1x2_z1z2_ls = split_3d(image, crop_size=crop_size) + + data_dict = { + 'spacing': spacing, + 'original_shape': (raw_h, raw_w, raw_d), + 'current_shape': (h, w, d), + 'patches': patches, + 'y1y2_x1x2_z1z2_ls': y1y2_x1x2_z1z2_ls, + 'padding_info': padding_info, + 'raw_image': raw_image, + 'num_iterations': num_iterations + } + + if scaled_roi_lowres_pred_array is not None: + lowres_pred = rearrange(scaled_roi_lowres_pred_array, 'd h w -> h w d') + lowres_pred = lowres_pred[np.newaxis, ...].astype(np.float32) + lowres_pred = respace_mask(lowres_pred, spacing, target_spacing, torch.device('cpu')) + lowres_pred = torch.tensor(lowres_pred) + lowres_pred, padding_info = pad_if_necessary(lowres_pred, crop_size=crop_size) + lowres_pred_patches, _ = split_3d(lowres_pred, crop_size=crop_size) + data_dict['lowres_pred_patches'] = lowres_pred_patches + data_dict['padding_info'] = padding_info + + return data_dict + + +def compute_gaussian(tile_size, sigma_scale: float = 1. / 8, value_scaling_factor: float = 10, dtype=np.float16): + """ + Compute Gaussian importance map for patch weighting. + + This creates a Gaussian weight map centered at the patch center, used for + weighted averaging of overlapping patch predictions. + + Args: + tile_size: Size of the tile (crop_size) + sigma_scale: Scale factor for Gaussian sigma (relative to tile size) + value_scaling_factor: Scaling factor for the Gaussian values + dtype: Data type for the output array + + Returns: + Gaussian importance map array + """ + tmp = np.zeros(tile_size) + center_coords = [i // 2 for i in tile_size] + sigmas = [i * sigma_scale for i in tile_size] + tmp[tuple(center_coords)] = 1 + gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) + gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * value_scaling_factor + gaussian_importance_map = gaussian_importance_map.astype(dtype) + gaussian_importance_map[gaussian_importance_map == 0] = np.min( + gaussian_importance_map[gaussian_importance_map != 0]) + return gaussian_importance_map + + +def sc_mask_to_mc_mask(sc_mask, label_values_ls): + """ + Convert single-channel mask to multi-channel mask. + + Args: + sc_mask: Single-channel mask with shape (1, 1, h, w, d) or (h, w, d) + label_values_ls: List of label values to create channels for + + Returns: + Multi-channel mask with shape (1, n_classes, h, w, d) + """ + sc_mask = sc_mask.squeeze(0).squeeze(0) + assert sc_mask.ndim == 3 + h, w, d = sc_mask.shape + n = len(label_values_ls) + mc_mask = torch.zeros((n, h, w, d), dtype=bool).to(sc_mask.device) + for i, label_value in enumerate(label_values_ls): + mc_mask[i] = torch.where(sc_mask == label_value, 1, 0) + mc_mask = mc_mask.to(torch.float32) + mc_mask = mc_mask.unsqueeze(0) + return mc_mask + + +class MedicalSegmentationPipeline: + """ + Pipeline for medical image segmentation. + + This class handles model loading, data preprocessing, and inference execution + for the Medal-S segmentation pipeline. + """ + + def __init__(self, config): + """ + Initialize the segmentation pipeline. + + Args: + config: Dictionary containing pipeline configuration parameters + """ + self.config = config + self.device = torch.device(config['device']) + + def _load_model(self): + """ + Load vision model and text encoder from checkpoints. + + Returns: + model: Loaded vision model (Maskformer) + text_encoder: Loaded text encoder (Knowledge_Encoder) + """ + crop_str = '_'.join(map(str, self.config['crop_size'])) + spacing_str = '_'.join(map(str, self.config['target_spacing_model'])) + + vision_backbone_checkpoint = os.path.join( + self.config['checkpoints_path'], + f"nano_UNet_CVPR2025_crop_size_{crop_str}_spacing_{spacing_str}_step_{self.config['model_step']}.pth") + + model = Maskformer( + self.config['vision_backbone'], + self.config['input_channels'], + self.config['crop_size'], + self.config['patch_size'], + False + ) + model = model.to(self.device) + checkpoint = torch.load(vision_backbone_checkpoint, map_location=self.device) + new_state_dict = { + k[7:] if k.startswith('module.') else k: v + for k, v in checkpoint['model_state_dict'].items() + if 'mid_mask_embed_proj' not in k + } + model.load_state_dict(new_state_dict) + model.eval() + + text_encoder = Knowledge_Encoder( + biolord_checkpoint=os.path.join( + self.config['checkpoints_path'], + 'BioLORD-2023-C' + ) + ) + text_encoder = text_encoder.to(self.device) + checkpoint = torch.load( + os.path.join(self.config['checkpoints_path'], 'text_encoder.pth'), + map_location=self.device + ) + new_state_dict = { + k[7:] if k.startswith('module.') else k: v + for k, v in checkpoint['model_state_dict'].items() + } + text_encoder.load_state_dict(new_state_dict, strict=False) + text_encoder.eval() + + return model, text_encoder + + def run_inference(self, raw_image, raw_spacing, verbose=True): + """ + Run inference on the input image. + + This method performs the complete inference pipeline: + 1. Load models (vision backbone and text encoder) + 2. Preprocess image data (resampling, padding, patch splitting) + 3. Encode text prompts + 4. Process patches and aggregate predictions + 5. Post-process results (remove padding, resample to original shape) + + Args: + raw_image: Input image array with shape (d, h, w) + raw_spacing: Spacing array with shape (3,) + verbose: Whether to print detailed information (default: True) + + Returns: + pred_array: Segmentation array with shape (d, h, w), dtype int16 + max_prob_array: Maximum probability array (if return_max_prob=True), or None + """ + model, text_encoder = self._load_model() + pred_array = None + crop_size = self.config['crop_size'] + disable_tta = self.config['disable_tta'] + instance_label = self.config['instance_label'] + modality = self.config['modality'] + text_prompts = self.config['texts'] + label_values = self.config['label_values'] + return_max_prob = self.config['return_max_prob'] + class_name_list = self.config['class_name_list'] + stage_1_flag = self.config['stage_1_flag'] + with torch.no_grad(): + # Gaussian is kept on CPU, as accumulation will now happen on CPU + gaussian = torch.tensor(compute_gaussian(tuple(crop_size)), dtype=torch.float32).cpu() + + data_dict = read_npz_data( + raw_image=raw_image, + raw_spacing=raw_spacing, + crop_size=crop_size, + target_spacing=self.config['target_spacing'], + scaled_roi_lowres_pred_array=self.config['scaled_roi_lowres_pred_array'], + class_name_list=class_name_list, + stage_1_flag=stage_1_flag, + device=self.device, + verbose=verbose + ) + + spacing = data_dict['spacing'] + original_shape = data_dict['original_shape'] + current_shape = data_dict['current_shape'] + batched_patches = data_dict['patches'] + batched_y1y2_x1x2_z1z2 = data_dict['y1y2_x1x2_z1z2_ls'] + padding_info = data_dict['padding_info'] + raw_image = data_dict['raw_image'] + num_iterations = data_dict['num_iterations'] + batched_lowres_pred_patches = data_dict.get('lowres_pred_patches') + + modality_code = torch.tensor([{ + 'ct': 0, 'mri': 1, 'us': 2, 'pet': 3, 'microscopy': 4 + }[modality]]).to(self.device) # Keep modality_code on GPU if text_encoder needs it on GPU + + h, w, d = current_shape + n_total_classes = len(text_prompts) + + # Get category batch size from config, default to 24 + category_batch_size = self.config.get('category_batch_size', 24) + background_threshold = self.config.get('background_threshold', 0.5) + + # Initialize max_prob and max_class_label_value on CPU to save GPU memory + max_prob = torch.zeros((h, w, d), dtype=torch.float32, device='cpu') + max_class_label_value = torch.zeros((h, w, d), dtype=torch.int16, device='cpu') + + # Process categories in batches to avoid OOM + category_range = range(0, n_total_classes, category_batch_size) + pbar = tqdm(category_range, desc="Processing Categories") + for i in pbar: + current_category_texts = text_prompts[i:i + category_batch_size] + current_label_values = label_values[i:i + category_batch_size] + current_n = len(current_category_texts) + end_idx = min(i + current_n - 1, n_total_classes - 1) + + # Update progress bar description with current category range + pbar.set_description(f"Processing Categories {i}-{end_idx}") + + # Keep these large tensors on CPU for accumulation + temp_prediction_batch_cpu = torch.zeros((current_n, h, w, d), dtype=torch.float32, device='cpu') + temp_accumulation_batch_cpu = torch.zeros((current_n, h, w, d), dtype=torch.float32, device='cpu') + + # Encode text prompts for current batch + with autocast(enabled=False): + queries = text_encoder(current_category_texts, modality_code, self.device) # queries remain on GPU for model input + + # Process patches for current category batch + for patches, lowres_pred_patches, y1y2_x1x2_z1z2_ls in tqdm( + zip(batched_patches, batched_lowres_pred_patches if batched_lowres_pred_patches is not None else [None]*len(batched_patches), batched_y1y2_x1x2_z1z2), + total=len(batched_patches), + desc="Processing", + ncols=100, + bar_format="{l_bar}{bar:20}{r_bar}", + colour="green", + leave=False + ): + patches = patches.unsqueeze(0).to(device=self.device, dtype=torch.float32) # patches on GPU for model input + y1, y2, x1, x2, z1, z2 = y1y2_x1x2_z1z2_ls + + simulated_lowres_sc_pred = None + simulated_lowres_mc_pred = None + + if not self.config['w_lowres_pred_prompts']: + simulated_lowres_sc_pred = torch.zeros((1, 1, *crop_size), device=self.device, dtype=torch.float32) + simulated_lowres_mc_pred = torch.zeros((1, current_n, *crop_size), device=self.device, dtype=torch.float32) + prediction_patch = model( + queries=queries, + image_input=patches, + simulated_lowres_sc_pred=simulated_lowres_sc_pred, + simulated_lowres_mc_pred=simulated_lowres_mc_pred, + train_mode=False + ) if self.config['disable_tta'] else internal_maybe_mirror_and_predict( + model=model, + queries=queries, + image_input=patches, + simulated_lowres_sc_pred=simulated_lowres_sc_pred, + simulated_lowres_mc_pred=simulated_lowres_mc_pred, + mirror_axes=(0, 1, 2) + ) + else: + lowres_pred_patches = lowres_pred_patches.unsqueeze(0).to(device=self.device, dtype=torch.float32) + simulated_lowres_sc_pred = torch.where(lowres_pred_patches > 0, torch.ones_like(lowres_pred_patches), torch.zeros_like(lowres_pred_patches)) + simulated_lowres_mc_pred = sc_mask_to_mc_mask(lowres_pred_patches, [int(val) for val in current_label_values]) + + possible_block_sizes = [8] + if instance_label == 1: + n_repeats = 1 + else: + n_repeats = 1 + prediction_patch = compute_patch_prediction(queries, patches, simulated_lowres_sc_pred, simulated_lowres_mc_pred, model, possible_block_sizes, n_repeats, disable_tta) + + if instance_label == 1: # Instance segmentation mode + for _ in range(num_iterations): + prediction_patch_prob = torch.sigmoid(prediction_patch).detach() + simulated_lowres_mc_pred = torch.where(prediction_patch_prob > 0.5, 1.0, 0.0) + simulated_lowres_sc_pred = (simulated_lowres_mc_pred.sum(dim=1, keepdim=True) > 0).float() + possible_block_sizes = [4] + n_repeats = 1 + prediction_patch = compute_patch_prediction(queries, patches, simulated_lowres_sc_pred, simulated_lowres_mc_pred, model, possible_block_sizes, n_repeats, disable_tta) + + prediction_patch_prob_gpu = torch.sigmoid(prediction_patch).detach() + current_gaussian_slice = gaussian[:y2-y1, :x2-x1, :z2-z1] # Already on CPU + + # Perform accumulation on CPU. Move prediction_patch_prob_gpu to CPU here. + temp_prediction_batch_cpu[:, y1:y2, x1:x2, z1:z2] += (prediction_patch_prob_gpu[0, :, :y2-y1, :x2-x1, :z2-z1].cpu() * current_gaussian_slice) + temp_accumulation_batch_cpu[:, y1:y2, x1:x2, z1:z2] += current_gaussian_slice + + # Explicitly delete GPU tensors to free up memory immediately + del prediction_patch, prediction_patch_prob_gpu, patches + if simulated_lowres_sc_pred is not None: + del simulated_lowres_sc_pred + if simulated_lowres_mc_pred is not None: + del simulated_lowres_mc_pred + torch.cuda.empty_cache() # Clear any cached GPU memory after each patch processing + gc.collect() # Python garbage collection + + # Normalize predictions by accumulation + batch_accumulation_cpu = temp_accumulation_batch_cpu + batch_accumulation_cpu[batch_accumulation_cpu == 0] = 1e-8 + batch_prediction_prob_cpu = temp_prediction_batch_cpu / batch_accumulation_cpu + + # Update max_prob and max_class_label_value on CPU + for j in range(current_n): + class_prob_cpu = batch_prediction_prob_cpu[j, ...] # Already on CPU + class_label_value_cpu_scalar = torch.tensor(int(current_label_values[j]), dtype=torch.int16, device='cpu') # Already on CPU + + update_mask_cpu = class_prob_cpu > max_prob + max_prob[update_mask_cpu] = class_prob_cpu[update_mask_cpu] + max_class_label_value[update_mask_cpu] = class_label_value_cpu_scalar + + # Clean up batch tensors + del temp_prediction_batch_cpu, temp_accumulation_batch_cpu, batch_accumulation_cpu, batch_prediction_prob_cpu, queries + # Previous patch-level deletions handle GPU memory + + # Final operations on CPU + background_indices = max_prob < background_threshold + max_class_label_value[background_indices] = 0 + results = max_class_label_value.numpy() # Already on CPU, just convert to numpy + + results = remove_padding(results, padding_info) + current_h, current_w, current_d = results.shape + if results.shape != original_shape: + results = resample_torch_simple( + results[np.newaxis, ...], + new_shape=original_shape, + is_seg=True, + num_threads=4, + device=torch.device('cpu'), + memefficient_seg_resampling=False).squeeze(0) + + if verbose: + print(f"Resized segmentation from {current_h, current_w, current_d} to {original_shape}") + + pred_array = rearrange(results, 'h w d -> d h w').astype(np.int16) + + if return_max_prob and instance_label == 0: + # max_prob is already on CPU, just convert to numpy for post-processing + max_prob_numpy = max_prob.numpy() + max_prob_numpy = remove_padding(max_prob_numpy, padding_info) + current_h, current_w, current_d = max_prob_numpy.shape + if max_prob_numpy.shape != original_shape: + max_prob_numpy = resample_torch_simple( + max_prob_numpy[np.newaxis, ...], + new_shape=original_shape, + is_seg=False, + num_threads=4, + device=torch.device('cpu'), + memefficient_seg_resampling=False).squeeze(0) + + if verbose: + print(f"Resized max probability from {current_h, current_w, current_d} to {original_shape}") + max_prob = rearrange(max_prob_numpy, 'h w d -> d h w').astype(np.float32) + + if return_max_prob and instance_label == 0: + return pred_array, max_prob + else: + return pred_array, None + + +def run_segmentation( + raw_image, + raw_spacing, + crop_size=[192, 192, 96], + target_spacing=[1.5, 1.5, 3.0], + target_spacing_model=[1.5, 1.5, 3.0], + w_lowres_pred_prompts=False, + scaled_roi_lowres_pred_array=None, + disable_tta=True, + model_step=100000, + vision_backbone="UNET", + input_channels=2, + patch_size=[32, 32, 32], + modality='CT', + instance_label=0, + texts=[], + label_values=[], + return_max_prob=False, + class_name_list=[], + stage_1_flag=False, + device="cuda:0", + checkpoints_path="./checkpoints", + category_batch_size=24, + background_threshold=0.5, + verbose=True, +): + """ + Main segmentation function. + + This function orchestrates the entire segmentation pipeline including + model loading, data preprocessing, patch-based inference, and result aggregation. + + Args: + raw_image: Input image array with shape (d, h, w), dtype uint8, values in [0, 255] + raw_spacing: Spacing array with shape (3,) + crop_size: Crop size for patch processing [h, w, d] + target_spacing: Target spacing for resampling [h, w, d] + target_spacing_model: Target spacing for model (should match target_spacing) + w_lowres_pred_prompts: Whether to use low-res predictions as spatial prompts + scaled_roi_lowres_pred_array: Low-res prediction array for spatial prompts + disable_tta: Disable test-time augmentation + model_step: Model checkpoint step number + vision_backbone: Vision backbone architecture name + input_channels: Number of input channels + patch_size: Patch size for the model + modality: Imaging modality ('CT', 'MRI', 'US', 'PET', 'microscopy') + instance_label: 0 for semantic segmentation, 1 for instance segmentation + texts: List of text prompts (one per class) + label_values: List of label values (one per class) + return_max_prob: Whether to return maximum probability map + class_name_list: List of class names for class-specific adjustments + stage_1_flag: Whether this is Stage 1 inference + device: Device string (e.g., 'cuda:0' or 'cpu') + checkpoints_path: Path to model checkpoints directory + category_batch_size: Number of categories to process in each batch (default: 24) + Adjust based on GPU memory. Larger 3D images require smaller batch sizes. + Accumulation operations are performed on CPU for more stable memory usage. + background_threshold: Probability threshold for background (default: 0.5) + Voxels with max probability below this threshold will be labeled as background. + verbose: Whether to print detailed information (default: True) + + Returns: + pred_array: Segmentation array with shape (d, h, w), dtype int16 + max_prob_array: Maximum probability array (if return_max_prob=True), or None + """ + w_lowres_pred_prompts = scaled_roi_lowres_pred_array is not None + config = { + 'device': device, + 'modality': modality, + 'instance_label': instance_label, + 'texts': texts, + 'label_values': label_values, + 'vision_backbone': vision_backbone, + 'crop_size': crop_size, + 'patch_size': patch_size, + 'target_spacing': target_spacing, + 'target_spacing_model': target_spacing_model, + 'model_step': model_step, + 'input_channels': input_channels, + 'w_lowres_pred_prompts': w_lowres_pred_prompts, + 'scaled_roi_lowres_pred_array': scaled_roi_lowres_pred_array, + 'disable_tta': disable_tta, + 'checkpoints_path': checkpoints_path, + 'return_max_prob': return_max_prob, + 'class_name_list': class_name_list, + 'stage_1_flag': stage_1_flag, + 'category_batch_size': category_batch_size, + 'background_threshold': background_threshold, + } + + pipeline = MedicalSegmentationPipeline(config) + return pipeline.run_inference(raw_image, raw_spacing, verbose=verbose) + + +# ============================================================================ +# Main Inference Functions +# ============================================================================ +# These functions provide the high-level interface for running inference +# on raw NIfTI images with proper preprocessing and post-processing. +# ============================================================================ + + +def normalize_image_ct(image_data, window_level=40, window_width=400, window_type='soft_tissue'): + """ + Normalize CT image using window/level technique. + + Args: + image_data: Input CT image array + window_level: Window level (center of the window). If None, will use default based on window_type + window_width: Window width (range of the window). If None, will use default based on window_type + window_type: Type of window ('soft_tissue', 'bone', 'lung'). Used if window_level/window_width are None + + Returns: + Normalized image array with dtype uint8, values in [0, 255] + """ + # Default window settings for different window types + default_windows = { + 'soft_tissue': {'window_level': 40, 'window_width': 400}, + 'bone': {'window_level': 500, 'window_width': 1500}, + 'lung': {'window_level': -600, 'window_width': 1500} + } + + # Use defaults if not provided + if window_level is None or window_width is None: + if window_type in default_windows: + window_level = default_windows[window_type]['window_level'] + window_width = default_windows[window_type]['window_width'] + else: + # Fallback to soft_tissue defaults + window_level = default_windows['soft_tissue']['window_level'] + window_width = default_windows['soft_tissue']['window_width'] + + lower_bound = window_level - window_width / 2 + upper_bound = window_level + window_width / 2 + image_data_pre = np.clip(image_data, lower_bound, upper_bound) + image_data_pre = ( + (image_data_pre - np.min(image_data_pre)) + / (np.max(image_data_pre) - np.min(image_data_pre) + 1e-8) + * 255.0 + ) + return image_data_pre.astype(np.uint8) + + +def normalize_image_other(image_data, percentile_lower=None, percentile_upper=None, preserve_zero=None, normalization_settings=None): + """ + Normalize non-CT images using percentile-based normalization. + + This method clips values to specified percentiles, then + normalizes to [0, 255] range while optionally preserving zero values. + + Args: + image_data: Input image array + percentile_lower: Lower percentile for clipping. If None, will use default or value from normalization_settings + percentile_upper: Upper percentile for clipping. If None, will use default or value from normalization_settings + preserve_zero: Whether to preserve zero values. If None, will use default or value from normalization_settings + normalization_settings: Dictionary containing normalization settings from config. + Format: {'percentile_lower': 0.5, 'percentile_upper': 99.5, 'preserve_zero': True} + + Returns: + Normalized image array with dtype uint8, values in [0, 255] + """ + # Default normalization settings + default_percentile_lower = 0.5 + default_percentile_upper = 99.5 + default_preserve_zero = True + + # Use settings from config if provided + if normalization_settings is not None: + if percentile_lower is None: + percentile_lower = normalization_settings.get('percentile_lower', default_percentile_lower) + if percentile_upper is None: + percentile_upper = normalization_settings.get('percentile_upper', default_percentile_upper) + if preserve_zero is None: + preserve_zero = normalization_settings.get('preserve_zero', default_preserve_zero) + else: + # Use defaults if not provided + if percentile_lower is None: + percentile_lower = default_percentile_lower + if percentile_upper is None: + percentile_upper = default_percentile_upper + if preserve_zero is None: + preserve_zero = default_preserve_zero + + # Calculate percentiles from non-zero values + non_zero_data = image_data[image_data > 0] + if len(non_zero_data) > 0: + lower_bound, upper_bound = np.percentile( + non_zero_data, [percentile_lower, percentile_upper] + ) + else: + # If all values are zero, use min/max + lower_bound = np.min(image_data) + upper_bound = np.max(image_data) + + image_data_pre = np.clip(image_data, lower_bound, upper_bound) + image_data_pre = ( + (image_data_pre - np.min(image_data_pre)) + / (np.max(image_data_pre) - np.min(image_data_pre) + 1e-8) + * 255.0 + ) + + if preserve_zero: + image_data_pre[image_data == 0] = 0 + + return image_data_pre.astype(np.uint8) + + +def load_nifti_image(image_path): + """ + Load NIfTI image and extract data, spacing, and metadata. + + Args: + image_path: Path to NIfTI image file + + Returns: + image_data: Image array with shape (d, h, w) + spacing_xyz: Spacing tuple (x, y, z) from SimpleITK + metadata: Dictionary containing origin, direction, and spacing_xyz + """ + img_sitk = sitk.ReadImage(image_path) + image_data = sitk.GetArrayFromImage(img_sitk) # Shape: (d, h, w) + spacing_xyz = img_sitk.GetSpacing() # (x, y, z) + + # Save metadata for output + metadata = { + 'origin': img_sitk.GetOrigin(), + 'direction': img_sitk.GetDirection(), + 'spacing_xyz': spacing_xyz + } + + return image_data, spacing_xyz, metadata + + +def convert_spacing(spacing_xyz, image_shape): + """ + Convert spacing from SimpleITK format (x, y, z) to format expected by run_segmentation. + + Following the conversion logic from inference_raw_nifti_2.py: + 1. SimpleITK returns (x, y, z) + 2. Image from SimpleITK is (d, h, w) where d=z, h=y, w=x + 3. Convert to (d, h, w) spacing: (z, x, y) = (d, h, w) + 4. Then convert to format expected by run_segmentation: (h, w, d) + + Args: + spacing_xyz: Spacing tuple from SimpleITK (x, y, z) + image_shape: Image shape (d, h, w) + + Returns: + img_spacing: Spacing array in format expected by run_segmentation + """ + img_spacing = np.array(spacing_xyz, dtype=np.float32) + + # Step 1: Convert from (x, y, z) to (d, h, w) spacing + # SimpleITK: (x, y, z) -> Image: (d, h, w) where d=z, h=y, w=x + # So spacing (x, y, z) -> (z, x, y) = (d, h, w) + img_spacing_transposed = img_spacing[[2, 0, 1]] # (z, x, y) = (d, h, w) + + # Step 2: Handle very small spacing values + for i in range(3): + if img_spacing_transposed[i] < 0.1: + img_spacing_transposed[i] = 1.0 + + # Step 3: Optional: Adjust spacing based on image dimensions + # Note: adjust_spacing expects image in (h, w, d) format, so we need to rearrange + # For now, we'll skip this adjustment or use a dummy array + try: + img_spacing_transposed = adjust_spacing( + np.zeros(image_shape), # Dummy array for shape reference + img_spacing_transposed + ).astype(np.float32) + except Exception: + # If adjust_spacing fails, use spacing as-is + pass + + # Step 4: Convert to format expected by run_segmentation + # This converts (d, h, w) to (h, w, d) + img_spacing = img_spacing_transposed[[1, 2, 0]] + + return img_spacing + + +def run_inference_single_window( + image_data, + spacing_xyz, + metadata, + modality='CT', + texts=None, + label_values=None, + inference_mode='stage2_only', + device="cuda:0", + checkpoints_path="./checkpoints", + window_settings=None, + window_type='soft_tissue', + normalization_settings=None, + verbose=True +): + """ + Run inference for a single window type. + + This is an internal function used by run_inference to handle single window type inference. + + Args: + image_data: Raw image data array (d, h, w) + spacing_xyz: Spacing tuple (x, y, z) + metadata: Image metadata dictionary + modality: Imaging modality ('CT', 'MRI', 'US', 'PET', 'microscopy') + texts: List of text prompts (one per class) + label_values: List of label values (one per class) + inference_mode: Inference mode ('stage2_only' or 'stage1+stage2') + device: Device to use ('cuda:0' or 'cpu') + checkpoints_path: Path to model checkpoints + window_settings: Dictionary containing window settings for different window types (CT only) + window_type: Type of window to use ('soft_tissue', 'bone', 'lung') + normalization_settings: Dictionary containing normalization settings for non-CT modalities + verbose: Whether to print detailed information (default: True) + + Returns: + pred_array: Segmentation array (d, h, w) + """ + if texts is None: + texts = [] + if label_values is None: + label_values = [] + + if len(texts) != len(label_values): + raise ValueError("Number of text prompts must match number of label values") + + # Normalize image + if verbose: + print(f"Normalizing image for {window_type} window (modality: {modality})") + if modality.upper() == 'CT': + # Get window settings from config if available + window_level = None + window_width = None + if window_settings is not None and window_type in window_settings: + window_level = window_settings[window_type].get('window_level') + window_width = window_settings[window_type].get('window_width') + if verbose: + print(f"Using {window_type} window: level={window_level}, width={window_width}") + + img_array = normalize_image_ct(image_data, window_level=window_level, + window_width=window_width, window_type=window_type) + else: + # Get normalization settings from config if available + if normalization_settings is not None: + if verbose: + print(f"Using normalization settings from config: {normalization_settings}") + img_array = normalize_image_other(image_data, normalization_settings=normalization_settings) + else: + # Use default normalization + if verbose: + print("Using default normalization settings") + img_array = normalize_image_other(image_data) + + if verbose: + print(f"Normalized image range: [{img_array.min()}, {img_array.max()}]") + + # Convert spacing + img_spacing = convert_spacing(spacing_xyz, img_array.shape) + if verbose: + print(f"Converted spacing: {img_spacing}") + + # Run inference + if inference_mode == 'stage1+stage2': + if verbose: + print(f"Running two-stage inference with {window_type} window...") + # Stage 1: Low-resolution + if verbose: + print("Stage 1: Low-resolution segmentation...") + stage_1_pred, _ = run_segmentation( + raw_image=img_array, + raw_spacing=img_spacing, + crop_size=[224, 224, 128], + target_spacing=[1.5, 1.5, 3.0], + target_spacing_model=[1.5, 1.5, 3.0], + w_lowres_pred_prompts=False, + scaled_roi_lowres_pred_array=None, + disable_tta=True, + model_step=358600, + modality=modality.lower(), + instance_label=0, + texts=texts, + label_values=label_values, + return_max_prob=False, + class_name_list=[], + stage_1_flag=True, + device=device, + checkpoints_path=checkpoints_path, + verbose=verbose + ) + + # Check if Stage 1 found anything + if stage_1_pred.sum() == 0: + if verbose: + print("Warning: Stage 1 found no predictions. Using Stage 1 result as final output.") + final_pred = stage_1_pred + else: + if verbose: + print("Stage 1 completed. Extracting ROI for Stage 2...") + + # Remove small objects from Stage 1 prediction + min_size = 10 + lowres_pred_binary = (stage_1_pred > 0).astype(np.int16) + lowres_pred_binary = remove_small_objects_binary(lowres_pred_binary, min_size=min_size).astype(np.int16) + stage_1_pred_cleaned = stage_1_pred * lowres_pred_binary + + # Extract ROI from Stage 1 prediction + # Find bounding box of non-zero regions + non_zero_indices = np.argwhere(stage_1_pred_cleaned > 0) + if len(non_zero_indices) == 0: + if verbose: + print("Warning: No non-zero regions after cleaning. Using Stage 1 result.") + final_pred = stage_1_pred_cleaned + else: + z_min, y_min, x_min = non_zero_indices.min(axis=0) + z_max, y_max, x_max = non_zero_indices.max(axis=0) + + # Calculate ROI center and range with scaling factor + m = 1.1 # Scaling factor for ROI expansion + z_center = (z_min + z_max) / 2 + y_center = (y_min + y_max) / 2 + x_center = (x_min + x_max) / 2 + + z_range = (z_max - z_min + 1) * m / 2 + y_range = (y_max - y_min + 1) * m / 2 + x_range = (x_max - x_min + 1) * m / 2 + + # Calculate minimum ranges based on Stage 2 crop size and spacing + stage_2_crop_size = [192, 192, 192] + stage_2_target_spacing = [1.0, 1.0, 1.0] + + img_spacing_for_roi = img_spacing.copy() + + min_z_range = (stage_2_crop_size[2] / 2) * stage_2_target_spacing[2] / img_spacing_for_roi[2] if img_spacing_for_roi[2] > 0 else z_range + min_y_range = (stage_2_crop_size[0] / 2) * stage_2_target_spacing[0] / img_spacing_for_roi[0] if img_spacing_for_roi[0] > 0 else y_range + min_x_range = (stage_2_crop_size[1] / 2) * stage_2_target_spacing[1] / img_spacing_for_roi[1] if img_spacing_for_roi[1] > 0 else x_range + + z_range = max(min_z_range - 1, z_range) + y_range = max(min_y_range - 1, y_range) + x_range = max(min_x_range - 1, x_range) + + z_min_new = max(0, int(z_center - z_range)) + z_max_new = min(stage_1_pred_cleaned.shape[0] - 1, int(z_center + z_range)) + y_min_new = max(0, int(y_center - y_range)) + y_max_new = min(stage_1_pred_cleaned.shape[1] - 1, int(y_center + y_range)) + x_min_new = max(0, int(x_center - x_range)) + x_max_new = min(stage_1_pred_cleaned.shape[2] - 1, int(x_center + x_range)) + + if verbose: + print(f"ROI bounds: z=[{z_min_new}:{z_max_new}], y=[{y_min_new}:{y_max_new}], x=[{x_min_new}:{x_max_new}]") + + roi_array = img_array[z_min_new:z_max_new+1, y_min_new:y_max_new+1, x_min_new:x_max_new+1] + roi_lowres_pred = stage_1_pred_cleaned[z_min_new:z_max_new+1, y_min_new:y_max_new+1, x_min_new:x_max_new+1] + + if verbose: + print(f"ROI image shape: {roi_array.shape}") + print(f"ROI prediction shape: {roi_lowres_pred.shape}") + + # Stage 2: High-resolution segmentation on ROI + if verbose: + print("Stage 2: High-resolution segmentation on ROI...") + roi_pred, _ = run_segmentation( + raw_image=roi_array, + raw_spacing=img_spacing, + crop_size=[192, 192, 192], + target_spacing=[1.0, 1.0, 1.0], + target_spacing_model=[1.0, 1.0, 1.0], + w_lowres_pred_prompts=True, + scaled_roi_lowres_pred_array=roi_lowres_pred, + disable_tta=True, + model_step=341300, + modality=modality.lower(), + instance_label=0, + texts=texts, + label_values=label_values, + return_max_prob=False, + class_name_list=[], + stage_1_flag=False, + device=device, + checkpoints_path=checkpoints_path, + verbose=verbose + ) + + # Integrate ROI prediction back into full volume + if verbose: + print("Integrating Stage 2 results back into full volume...") + final_pred = np.zeros_like(stage_1_pred_cleaned, dtype=np.int16) + final_pred[z_min_new:z_max_new+1, y_min_new:y_max_new+1, x_min_new:x_max_new+1] = roi_pred + if verbose: + print("Stage1+Stage2 inference completed.") + elif inference_mode == 'stage2_only': + if verbose: + print(f"Running Stage 2 inference with {window_type} window...") + final_pred, _ = run_segmentation( + raw_image=img_array, + raw_spacing=img_spacing, + crop_size=[192, 192, 192], + target_spacing=[1.0, 1.0, 1.0], + target_spacing_model=[1.0, 1.0, 1.0], + w_lowres_pred_prompts=False, + scaled_roi_lowres_pred_array=None, + disable_tta=True, + model_step=341300, + modality=modality.lower(), + instance_label=0, + texts=texts, + label_values=label_values, + return_max_prob=False, + class_name_list=[], + stage_1_flag=False, + device=device, + checkpoints_path=checkpoints_path, + verbose=verbose + ) + else: + raise ValueError(f"Unknown inference mode: {inference_mode}. Must be 'stage2_only' or 'stage1+stage2'") + + return final_pred + + +def run_inference( + image_path, + output_path, + modality='CT', + texts=None, + label_values=None, + inference_mode='stage2_only', + device="cuda:0", + checkpoints_path="./checkpoints", + window_settings=None, + window_type='soft_tissue', + normalization_settings=None, + window_type_mapping=None, + verbose=True +): + """ + Run Medal-S inference on a raw NIfTI image. + + Supports multi-window inference for CT images: if multiple window types are specified + (e.g., soft_tissue, bone, lung), each window type will be processed separately with + its corresponding window settings, and results will be merged. + + Args: + image_path: Path to input NIfTI image + output_path: Path to save output segmentation (will be modified with mode suffix) + modality: Imaging modality ('CT', 'MRI', 'US', 'PET', 'microscopy') + texts: List of text prompts (one per class) + label_values: List of label values (one per class) + inference_mode: Inference mode ('stage2_only' or 'stage1+stage2') + device: Device to use ('cuda:0' or 'cpu') + checkpoints_path: Path to model checkpoints + window_settings: Dictionary containing window settings for different window types (CT only). + Format: {'soft_tissue': {'window_level': 40, 'window_width': 400}, ...} + window_type: Type of window to use ('soft_tissue', 'bone', 'lung'). Default: 'soft_tissue' (CT only) + Ignored if window_type_mapping indicates multiple window types + normalization_settings: Dictionary containing normalization settings for non-CT modalities. + Format: {'percentile_lower': 0.5, 'percentile_upper': 99.5, 'preserve_zero': True} + window_type_mapping: Dictionary mapping each text to its window type. + Format: {'text1': 'soft_tissue', 'text2': 'bone', ...} + If provided and contains multiple window types, will perform separate inference for each + verbose: Whether to print detailed information (default: True) + + Returns: + pred_array: Segmentation array (d, h, w) + inference_time: Total inference time in seconds + """ + if texts is None: + texts = [] + if label_values is None: + label_values = [] + + if len(texts) != len(label_values): + raise ValueError("Number of text prompts must match number of label values") + + # Add mode suffix to output filename + if inference_mode == 'stage1+stage2': + suffix = '_stage1+stage2' + elif inference_mode == 'stage2_only': + suffix = '_stage2_only' + else: + suffix = f'_{inference_mode}' + + # Modify output path to include suffix + base_path, ext = os.path.splitext(output_path) + if ext == '.gz': # Handle .nii.gz + base_path, nii_ext = os.path.splitext(base_path) + output_path = f"{base_path}{suffix}{nii_ext}{ext}" + else: + output_path = f"{base_path}{suffix}{ext}" + + if verbose: + print(f"Output will be saved to: {output_path}") + + # Start timing + start_time = time.time() + + # Load image + if verbose: + print(f"Loading image: {image_path}") + image_data, spacing_xyz, metadata = load_nifti_image(image_path) + if verbose: + print(f"Image shape: {image_data.shape}") + print(f"Original spacing (x, y, z): {spacing_xyz}") + + # Determine inference strategy based on modality and window types + if modality.upper() == 'CT': + # CT modality: check for multiple window types + if window_type_mapping is not None: + window_types = list(set(window_type_mapping.values())) + if len(window_types) > 1: + # Multiple window types: perform separate inference for each window type + if verbose: + print(f"\n{'='*60}") + print(f"CT with {len(window_types)} window types detected: {window_types}") + print("Performing separate inference for each window type...") + print(f"{'='*60}\n") + + all_predictions = [] + + for wt in window_types: + if verbose: + print(f"\n{'='*60}") + print(f"Processing {wt} window type...") + print(f"{'='*60}\n") + + # Filter texts and label_values for this window type + wt_texts = [text for text in texts if window_type_mapping.get(text) == wt] + wt_indices = [i for i, text in enumerate(texts) if window_type_mapping.get(text) == wt] + wt_label_values = [label_values[i] for i in wt_indices] + + if len(wt_texts) == 0: + if verbose: + print(f"No classes for {wt} window type, skipping...") + continue + + if verbose: + print(f"Classes for {wt} window: {len(wt_texts)}") + print(f" Texts: {wt_texts}") + print(f" Labels: {wt_label_values}") + + # Run inference for this window type with its specific window settings + wt_pred = run_inference_single_window( + image_data=image_data, + spacing_xyz=spacing_xyz, + metadata=metadata, + modality=modality, + texts=wt_texts, + label_values=wt_label_values, + inference_mode=inference_mode, + device=device, + checkpoints_path=checkpoints_path, + window_settings=window_settings, + window_type=wt, # Use the specific window type + normalization_settings=normalization_settings, + verbose=verbose + ) + + all_predictions.append((wt_pred, wt_label_values)) + + # Merge predictions: use maximum label value when overlapping + if verbose: + print(f"\n{'='*60}") + print("Merging predictions from all window types...") + print(f"{'='*60}\n") + + final_pred = np.zeros_like(all_predictions[0][0], dtype=np.int16) + for wt_pred, wt_labels in all_predictions: + # For each label in this window type's prediction + for label_val in wt_labels: + label_int = int(label_val) + mask = (wt_pred == label_int) + # Only update if current prediction is background (0) or smaller label + final_pred[mask] = np.maximum(final_pred[mask], label_int) + + if verbose: + print("Merging completed.") + else: + # Single window type: use the specific window type + if len(window_types) == 1: + window_type = window_types[0] + if verbose: + print(f"CT with single window type: {window_type}") + + final_pred = run_inference_single_window( + image_data=image_data, + spacing_xyz=spacing_xyz, + metadata=metadata, + modality=modality, + texts=texts, + label_values=label_values, + inference_mode=inference_mode, + device=device, + checkpoints_path=checkpoints_path, + window_settings=window_settings, + window_type=window_type, # Use the determined window type + normalization_settings=normalization_settings, + verbose=verbose + ) + else: + # No window_type_mapping: use default window_type + if verbose: + print(f"CT without window_type_mapping, using window type: {window_type}") + final_pred = run_inference_single_window( + image_data=image_data, + spacing_xyz=spacing_xyz, + metadata=metadata, + modality=modality, + texts=texts, + label_values=label_values, + inference_mode=inference_mode, + device=device, + checkpoints_path=checkpoints_path, + window_settings=window_settings, + window_type=window_type, + normalization_settings=normalization_settings, + verbose=verbose + ) + else: + # Non-CT modality: use normalization_settings (other normalization) + if verbose: + print(f"Non-CT modality ({modality}): using normalization_settings") + final_pred = run_inference_single_window( + image_data=image_data, + spacing_xyz=spacing_xyz, + metadata=metadata, + modality=modality, + texts=texts, + label_values=label_values, + inference_mode=inference_mode, + device=device, + checkpoints_path=checkpoints_path, + window_settings=window_settings, # Not used for non-CT + window_type=window_type, # Not used for non-CT + normalization_settings=normalization_settings, # Used for non-CT + verbose=verbose + ) + + # End timing + end_time = time.time() + inference_time = end_time - start_time + + if verbose: + print(f"\n{'='*60}") + print(f"Inference Mode: {inference_mode}") + print(f"Total Inference Time: {inference_time:.2f} seconds ({inference_time/60:.2f} minutes)") + print(f"{'='*60}\n") + + # Save result + if verbose: + print(f"Saving segmentation to: {output_path}") + seg_sitk = sitk.GetImageFromArray(final_pred.astype(np.int16)) + seg_sitk.SetSpacing(metadata['spacing_xyz']) + seg_sitk.SetOrigin(metadata['origin']) + seg_sitk.SetDirection(metadata['direction']) + sitk.WriteImage(seg_sitk, output_path) + if verbose: + print(f"Successfully saved segmentation to: {output_path}") + + return final_pred, inference_time + + +def load_config_from_json(config_path): + """ + Load configuration from JSON file. + + Supports two formats: + 1. Legacy format: single 'texts' array + 2. New format: separate arrays for 'texts_soft_tissue', 'texts_bone', 'texts_lung' + + If 'labels' field is missing or empty, automatically generates consecutive + integer labels starting from 1 (i.e., [1, 2, 3, ..., n] where n is the + number of texts). + + Args: + config_path: Path to JSON configuration file + + Returns: + config: Dictionary containing configuration parameters with processed labels + + Example: + # Legacy format: + {"texts": ["Aorta", "Liver"], "labels": [1, 2]} + + # New format with window types: + { + "texts_soft_tissue": ["Aorta", "Liver"], + "texts_bone": ["Vertebrae C1"], + "texts_lung": ["Left lung"], + "window_settings": { + "soft_tissue": {"window_level": 40, "window_width": 400}, + "bone": {"window_level": 400, "window_width": 1500}, + "lung": {"window_level": -600, "window_width": 1500} + } + } + """ + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + # Check if using new format (separate window types) + has_window_types = any(key in config for key in ['texts_soft_tissue', 'texts_bone', 'texts_lung']) + + if has_window_types: + # New format: combine all texts from different window types + texts_soft_tissue = config.get('texts_soft_tissue', []) + texts_bone = config.get('texts_bone', []) + texts_lung = config.get('texts_lung', []) + + # Combine all texts in order: soft_tissue, bone, lung + texts = texts_soft_tissue + texts_bone + texts_lung + + # Store window type mapping for each text + window_type_mapping = {} + for text in texts_soft_tissue: + window_type_mapping[text] = 'soft_tissue' + for text in texts_bone: + window_type_mapping[text] = 'bone' + for text in texts_lung: + window_type_mapping[text] = 'lung' + + config['texts'] = texts + config['window_type_mapping'] = window_type_mapping + else: + # Legacy format: single texts array + texts = config.get('texts', []) + # Default all texts to soft_tissue window type for backward compatibility + window_type_mapping = {text: 'soft_tissue' for text in texts} + config['window_type_mapping'] = window_type_mapping + + # Process labels: auto-generate if missing or empty + texts = config.get('texts', []) + labels = config.get('labels', None) + + if labels is None or len(labels) == 0: + # Auto-generate consecutive labels starting from 1 + labels = list(range(1, len(texts) + 1)) + print(f" Auto-generated consecutive labels: {labels}") + else: + # Convert labels to integers (handle both string and integer inputs) + labels = [int(label) for label in labels] + + # Validate that number of labels matches number of texts + if len(labels) != len(texts): + raise ValueError( + f"Number of labels ({len(labels)}) must match number of texts ({len(texts)}). " + f"Texts: {len(texts)}, Labels: {len(labels)}" + ) + + config['labels'] = labels + return config + + +def main(): + """ + Main entry point for the inference script. + + Parses command-line arguments and runs inference with the specified + configuration. + """ + parser = argparse.ArgumentParser( + description="Medal-S inference for raw NIfTI images", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Using JSON configuration file: + python inference_medals.py --input image.nii.gz --output result.nii.gz \\ + --config config.json --mode stage2_only + + # Using command-line arguments: + python inference_medals.py --input image.nii.gz --output result.nii.gz \\ + --modality CT --texts "Aorta in CT" --labels 1 --mode stage1+stage2 + """ + ) + parser.add_argument( + "--input", "-i", + type=str, + required=True, + help="Path to input NIfTI image" + ) + parser.add_argument( + "--output", "-o", + type=str, + required=True, + help="Path to save output segmentation (suffix will be added automatically based on inference mode)" + ) + parser.add_argument( + "--config", "-c", + type=str, + default=None, + help="Path to JSON configuration file (if provided, will override --texts, --labels, --modality)" + ) + parser.add_argument( + "--modality", "-m", + type=str, + default="CT", + choices=['CT', 'MRI', 'US', 'PET', 'microscopy'], + help="Imaging modality (default: CT, ignored if --config is provided)" + ) + parser.add_argument( + "--texts", + type=str, + nargs='+', + default=None, + help="Text prompts (one per class, ignored if --config is provided)" + ) + parser.add_argument( + "--labels", + type=str, + nargs='+', + default=None, + help="Label values (one per class, must match texts, ignored if --config is provided)" + ) + parser.add_argument( + "--mode", + type=str, + default="stage2_only", + choices=['stage2_only', 'stage1+stage2'], + help="Inference mode: 'stage2_only' (default) or 'stage1+stage2'" + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to use (default: cuda:0)" + ) + parser.add_argument( + "--checkpoints", + type=str, + default="./checkpoints", + help="Path to model checkpoints (default: ./checkpoints)" + ) + parser.add_argument( + "--verbose", "-v", + action='store_true', + default=False, + help="Print detailed information during inference (default: False)" + ) + + args = parser.parse_args() + verbose = args.verbose + + # Load configuration from JSON file if provided + window_settings = None + window_type = 'soft_tissue' + normalization_settings = None + window_type_mapping = None + + if args.config: + if not os.path.exists(args.config): + raise FileNotFoundError(f"Configuration file not found: {args.config}") + config = load_config_from_json(args.config) + texts = config.get('texts', []) + labels = config.get('labels', []) + modality = config.get('modality', 'CT') + window_settings = config.get('window_settings') + normalization_settings = config.get('normalization_settings') + window_type_mapping = config.get('window_type_mapping') + + # Determine default window type based on texts (for CT only, used as fallback) + if modality.upper() == 'CT': + if window_type_mapping: + window_types = list(set(window_type_mapping.values())) + if len(window_types) == 1: + window_type = window_types[0] + else: + # Default to soft_tissue if mixed types (will be handled by multi-window inference) + window_type = 'soft_tissue' + + # Convert labels to strings for compatibility with run_segmentation + # (run_segmentation expects string labels) + label_values = [str(label) for label in labels] + + if verbose: + print(f"Loaded configuration from: {args.config}") + print(f" Modality: {modality}") + print(f" Number of classes: {len(texts)}") + print(f" Labels: {labels}") + if modality.upper() == 'CT' and window_settings: + print(f" Window settings available for: {list(window_settings.keys())}") + if window_type_mapping: + window_types = list(set(window_type_mapping.values())) + if len(window_types) > 1: + print(f" Multiple window types detected: {window_types}") + print(f" Will perform separate inference for each window type") + else: + print(f" Using window type: {window_type}") + else: + print(f" Using window type: {window_type}") + elif normalization_settings: + print(f" Normalization settings: {normalization_settings}") + else: + # Use command line arguments + if args.texts is None or args.labels is None: + raise ValueError("Either --config or both --texts and --labels must be provided") + texts = args.texts + label_values = args.labels + modality = args.modality + + # Create output directory if needed + output_dir = os.path.dirname(args.output) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + # Run inference + run_inference( + image_path=args.input, + output_path=args.output, + modality=modality, + texts=texts, + label_values=label_values, + inference_mode=args.mode, + device=args.device, + checkpoints_path=args.checkpoints, + window_settings=window_settings, + window_type=window_type, + normalization_settings=normalization_settings, + window_type_mapping=window_type_mapping, + verbose=verbose + ) + + +if __name__ == '__main__': + main() + diff --git a/model/SwinUNETR.py b/model/SwinUNETR.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdfbde68116e42fc375ca2ffd58ecaad7f13974 --- /dev/null +++ b/model/SwinUNETR.py @@ -0,0 +1,1116 @@ +from typing import Sequence, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from torch.nn import LayerNorm + +from monai.networks.blocks import MLPBlock as Mlp +from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock +from monai.networks.layers import DropPath, trunc_normal_ +from monai.utils import ensure_tuple_rep, optional_import + +rearrange, _ = optional_import("einops", name="rearrange") + + +class SwinUNETR_Enc(nn.Module): + """ + Swin UNETR based on: "Hatamizadeh et al., + Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images + " + """ + + def __init__( + self, + img_size: Union[Sequence[int], int], + in_channels: int, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), + feature_size: int = 24, + norm_name: Union[Tuple, str] = "instance", + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + normalize: bool = True, + use_checkpoint: bool = False, + spatial_dims: int = 3, + return_skips: bool = True, + ) -> None: + """ + Args: + img_size: dimension of input image. + in_channels: dimension of input channels. + out_channels: dimension of output channels. + feature_size: dimension of network feature size. + depths: number of layers in each stage. + num_heads: number of attention heads. + norm_name: feature normalization type and arguments. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + dropout_path_rate: drop path rate. + normalize: normalize output intermediate features in each stage. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: number of spatial dims. + """ + + super().__init__() + + self.return_skips = return_skips + + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_size = ensure_tuple_rep(2, spatial_dims) + window_size = ensure_tuple_rep(7, spatial_dims) + + if not (spatial_dims == 2 or spatial_dims == 3): + raise ValueError("spatial dimension should be 2 or 3.") + + for m, p in zip(img_size, patch_size): + for i in range(5): + if m % np.power(p, i + 1) != 0: + raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.") + + if not (0 <= drop_rate <= 1): + raise ValueError("dropout rate should be between 0 and 1.") + + if not (0 <= attn_drop_rate <= 1): + raise ValueError("attention dropout rate should be between 0 and 1.") + + if not (0 <= dropout_path_rate <= 1): + raise ValueError("drop path rate should be between 0 and 1.") + + if feature_size % 12 != 0: + raise ValueError("feature_size should be divisible by 12.") + + self.normalize = normalize + + self.swinViT = SwinTransformer( + in_chans=in_channels, + embed_dim=feature_size, + window_size=window_size, + patch_size=patch_size, + depths=depths, + num_heads=num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=dropout_path_rate, + norm_layer=nn.LayerNorm, + use_checkpoint=use_checkpoint, + spatial_dims=spatial_dims, + ) + + self.encoder1 = UnetrBasicBlock( # 2 conv layers + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder5 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=8 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder6 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + def load_from(self, weights): + + with torch.no_grad(): + self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"]) + self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"]) + for bname, block in self.swinViT.layers1[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers1") + self.swinViT.layers1[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers1.0.downsample.reduction.weight"] + ) + self.swinViT.layers1[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers1.0.downsample.norm.weight"] + ) + self.swinViT.layers1[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers1.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers2[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers2") + self.swinViT.layers2[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers2.0.downsample.reduction.weight"] + ) + self.swinViT.layers2[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers2.0.downsample.norm.weight"] + ) + self.swinViT.layers2[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers2.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers3[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers3") + self.swinViT.layers3[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers3.0.downsample.reduction.weight"] + ) + self.swinViT.layers3[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers3.0.downsample.norm.weight"] + ) + self.swinViT.layers3[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers3.0.downsample.norm.bias"] + ) + for bname, block in self.swinViT.layers4[0].blocks.named_children(): + block.load_from(weights, n_block=bname, layer="layers4") + self.swinViT.layers4[0].downsample.reduction.weight.copy_( + weights["state_dict"]["module.layers4.0.downsample.reduction.weight"] + ) + self.swinViT.layers4[0].downsample.norm.weight.copy_( + weights["state_dict"]["module.layers4.0.downsample.norm.weight"] + ) + self.swinViT.layers4[0].downsample.norm.bias.copy_( + weights["state_dict"]["module.layers4.0.downsample.norm.bias"] + ) + + def forward(self, x_in): + # print(x_in.shape, task_id.shape) + hidden_states_out = self.swinViT(x_in, self.normalize) + + enc0 = self.encoder1(x_in) + enc1 = self.encoder2(hidden_states_out[0]) + enc2 = self.encoder3(hidden_states_out[1]) + enc3 = self.encoder4(hidden_states_out[2]) + enc4 = self.encoder5(hidden_states_out[3]) + dec4 = self.encoder6(hidden_states_out[4]) + # print(x_in.shape, enc0.shape, enc1.shape, enc2.shape, enc3.shape, dec4.shape) + # torch.Size([6, 1, 64, 64, 64]) torch.Size([6, 48, 64, 64, 64]) torch.Size([6, 48, 32, 32, 32]) + # torch.Size([6, 96, 16, 16, 16]) torch.Size([6, 192, 8,8, 8]) torch.Size([6, 768, 2, 2, 2]) + + if self.return_skips: + return [enc0, enc1, enc2, enc3, enc4, dec4] + else: + return [dec4] + +class SwinUNETR(nn.Module): + """ + Swin UNETR based on: "Hatamizadeh et al., + Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images + " + """ + + def __init__( + self, + img_size: Union[Sequence[int], int], + in_channels: int, + depths: Sequence[int] = (2, 2, 2, 2), + num_heads: Sequence[int] = (3, 6, 12, 24), + feature_size: int = 24, + norm_name: Union[Tuple, str] = "instance", + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + dropout_path_rate: float = 0.0, + normalize: bool = True, + use_checkpoint: bool = False, + spatial_dims: int = 3, + encoding: Union[Tuple, str] = 'rand_embedding', ## rand_embedding or word_embedding + deep_supervision: bool = True, + return_skips: bool = True, + ) -> None: + """ + Args: + img_size: dimension of input image. + in_channels: dimension of input channels. + out_channels: dimension of output channels. + feature_size: dimension of network feature size. + depths: number of layers in each stage. + num_heads: number of attention heads. + norm_name: feature normalization type and arguments. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + dropout_path_rate: drop path rate. + normalize: normalize output intermediate features in each stage. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: number of spatial dims. + Examples:: + # for 3D single channel input with size (96,96,96), 4-channel output and feature size of 48. + >>> net = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=4, feature_size=48) + # for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. + >>> net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2)) + # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing. + >>> net = SwinUNETR(img_size=(96,96), in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2) + """ + + super().__init__() + + self.deep_supervision = deep_supervision + self.return_skips = return_skips + + self.encoding = encoding + + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_size = ensure_tuple_rep(2, spatial_dims) + window_size = ensure_tuple_rep(7, spatial_dims) + + if not (spatial_dims == 2 or spatial_dims == 3): + raise ValueError("spatial dimension should be 2 or 3.") + + for m, p in zip(img_size, patch_size): + for i in range(5): + if m % np.power(p, i + 1) != 0: + raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.") + + if not (0 <= drop_rate <= 1): + raise ValueError("dropout rate should be between 0 and 1.") + + if not (0 <= attn_drop_rate <= 1): + raise ValueError("attention dropout rate should be between 0 and 1.") + + if not (0 <= dropout_path_rate <= 1): + raise ValueError("drop path rate should be between 0 and 1.") + + if feature_size % 12 != 0: + raise ValueError("feature_size should be divisible by 12.") + + self.normalize = normalize + + self.encoder = SwinUNETR_Enc( + img_size, + in_channels, + depths, + num_heads, + feature_size, + norm_name, + drop_rate, + attn_drop_rate, + dropout_path_rate, + normalize, + use_checkpoint, + spatial_dims, + return_skips=True + ) + + self.decoder5 = UnetrUpBlock( # a transpose conv layer and 2 conv layers + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + def forward(self, x_in): + enc0, enc1, enc2, enc3, enc4, dec4 = self.encoder(x_in) + + dec3 = self.decoder5(dec4, enc4) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + # print(dec3.shape, dec2.shape, dec1.shape, dec0.shape, out.shape) + # torch.Size([6, 384, 4, 4, 4]) torch.Size([6, 192, 8, 8, 8]) torch.Size([6, 96, 16, 16, 16]) + # torch.Size([6, 48, 32, 32, 32]) torch.Size([6, 48, 64, 64, 64]) + + if self.deep_supervision: + out_ls = [out, dec0, dec1, dec2, dec3] + else: + out_ls = [out] + + if self.return_skips: + skips = [enc0, enc1, enc2, enc3, enc4, dec4] + else: + skips = [dec4] + + return skips, out_ls + + +def window_partition(x, window_size): + """window partition operation based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + Args: + x: input tensor. + window_size: local window size. + """ + x_shape = x.size() + if len(x_shape) == 5: + b, d, h, w, c = x_shape + x = x.view( + b, + d // window_size[0], + window_size[0], + h // window_size[1], + window_size[1], + w // window_size[2], + window_size[2], + c, + ) + windows = ( + x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c) + ) + elif len(x_shape) == 4: + b, h, w, c = x.shape + x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c) + return windows + + +def window_reverse(windows, window_size, dims): + """window reverse operation based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + Args: + windows: windows tensor. + window_size: local window size. + dims: dimension values. + """ + if len(dims) == 4: + b, d, h, w = dims + x = windows.view( + b, + d // window_size[0], + h // window_size[1], + w // window_size[2], + window_size[0], + window_size[1], + window_size[2], + -1, + ) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(b, d, h, w, -1) + + elif len(dims) == 3: + b, h, w = dims + x = windows.view(b, h // window_size[0], w // window_size[0], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) + return x + + +def get_window_size(x_size, window_size, shift_size=None): + """Computing window size based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + Args: + x_size: input size. + window_size: local window size. + shift_size: window shifting size. + """ + + use_window_size = list(window_size) + if shift_size is not None: + use_shift_size = list(shift_size) + for i in range(len(x_size)): + if x_size[i] <= window_size[i]: + use_window_size[i] = x_size[i] + if shift_size is not None: + use_shift_size[i] = 0 + + if shift_size is None: + return tuple(use_window_size) + else: + return tuple(use_window_size), tuple(use_shift_size) + + +class WindowAttention(nn.Module): + """ + Window based multi-head self attention module with relative position bias based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: Sequence[int], + qkv_bias: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + """ + Args: + dim: number of feature channels. + num_heads: number of attention heads. + window_size: local window size. + qkv_bias: add a learnable bias to query, key, value. + attn_drop: attention dropout rate. + proj_drop: dropout rate of output. + """ + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + mesh_args = torch.meshgrid.__kwdefaults__ + + if len(self.window_size) == 3: + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), + num_heads, + ) + ) + coords_d = torch.arange(self.window_size[0]) + coords_h = torch.arange(self.window_size[1]) + coords_w = torch.arange(self.window_size[2]) + if mesh_args is not None: + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w, indexing="ij")) + else: + coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + elif len(self.window_size) == 2: + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + if mesh_args is not None: + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) + else: + coords = torch.stack(torch.meshgrid(coords_h, coords_w)) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask): + b, n, c = x.shape + qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * self.scale + attn = q @ k.transpose(-2, -1) + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.clone()[:n, :n].reshape(-1) + ].reshape(n, n, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + attn = attn + relative_position_bias.unsqueeze(0) + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(b, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ + Swin Transformer block based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: Sequence[int], + shift_size: Sequence[int], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + drop_path: float = 0.0, + act_layer: str = "GELU", + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + use_checkpoint: bool = False, + ) -> None: + """ + Args: + dim: number of feature channels. + num_heads: number of attention heads. + window_size: local window size. + shift_size: window shift size. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + drop: dropout rate. + attn_drop: attention dropout rate. + drop_path: stochastic depth rate. + act_layer: activation layer. + norm_layer: normalization layer. + use_checkpoint: use gradient checkpointing for reduced memory usage. + """ + + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_checkpoint = use_checkpoint + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(hidden_size=dim, mlp_dim=mlp_hidden_dim, act=act_layer, dropout_rate=drop, dropout_mode="swin") + + def forward_part1(self, x, mask_matrix): + x_shape = x.size() + x = self.norm1(x) + if len(x_shape) == 5: + b, d, h, w, c = x.shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) + pad_l = pad_t = pad_d0 = 0 + pad_d1 = (window_size[0] - d % window_size[0]) % window_size[0] + pad_b = (window_size[1] - h % window_size[1]) % window_size[1] + pad_r = (window_size[2] - w % window_size[2]) % window_size[2] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1)) + _, dp, hp, wp, _ = x.shape + dims = [b, dp, hp, wp] + + elif len(x_shape) == 4: + b, h, w, c = x.shape + window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) + pad_l = pad_t = 0 + pad_r = (window_size[0] - h % window_size[0]) % window_size[0] + pad_b = (window_size[1] - w % window_size[1]) % window_size[1] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, hp, wp, _ = x.shape + dims = [b, hp, wp] + + if any(i > 0 for i in shift_size): + if len(x_shape) == 5: + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3)) + elif len(x_shape) == 4: + shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + x_windows = window_partition(shifted_x, window_size) + attn_windows = self.attn(x_windows, mask=attn_mask) + attn_windows = attn_windows.view(-1, *(window_size + (c,))) + shifted_x = window_reverse(attn_windows, window_size, dims) + if any(i > 0 for i in shift_size): + if len(x_shape) == 5: + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3)) + elif len(x_shape) == 4: + x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + else: + x = shifted_x + + if len(x_shape) == 5: + if pad_d1 > 0 or pad_r > 0 or pad_b > 0: + x = x[:, :d, :h, :w, :].contiguous() + elif len(x_shape) == 4: + if pad_r > 0 or pad_b > 0: + x = x[:, :h, :w, :].contiguous() + + return x + + def forward_part2(self, x): + return self.drop_path(self.mlp(self.norm2(x))) + + def load_from(self, weights, n_block, layer): + root = f"module.{layer}.0.blocks.{n_block}." + block_names = [ + "norm1.weight", + "norm1.bias", + "attn.relative_position_bias_table", + "attn.relative_position_index", + "attn.qkv.weight", + "attn.qkv.bias", + "attn.proj.weight", + "attn.proj.bias", + "norm2.weight", + "norm2.bias", + "mlp.fc1.weight", + "mlp.fc1.bias", + "mlp.fc2.weight", + "mlp.fc2.bias", + ] + with torch.no_grad(): + self.norm1.weight.copy_(weights["state_dict"][root + block_names[0]]) + self.norm1.bias.copy_(weights["state_dict"][root + block_names[1]]) + self.attn.relative_position_bias_table.copy_(weights["state_dict"][root + block_names[2]]) + self.attn.relative_position_index.copy_(weights["state_dict"][root + block_names[3]]) + self.attn.qkv.weight.copy_(weights["state_dict"][root + block_names[4]]) + self.attn.qkv.bias.copy_(weights["state_dict"][root + block_names[5]]) + self.attn.proj.weight.copy_(weights["state_dict"][root + block_names[6]]) + self.attn.proj.bias.copy_(weights["state_dict"][root + block_names[7]]) + self.norm2.weight.copy_(weights["state_dict"][root + block_names[8]]) + self.norm2.bias.copy_(weights["state_dict"][root + block_names[9]]) + self.mlp.linear1.weight.copy_(weights["state_dict"][root + block_names[10]]) + self.mlp.linear1.bias.copy_(weights["state_dict"][root + block_names[11]]) + self.mlp.linear2.weight.copy_(weights["state_dict"][root + block_names[12]]) + self.mlp.linear2.bias.copy_(weights["state_dict"][root + block_names[13]]) + + def forward(self, x, mask_matrix): + shortcut = x + if self.use_checkpoint: + x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix) + else: + x = self.forward_part1(x, mask_matrix) + x = shortcut + self.drop_path(x) + if self.use_checkpoint: + x = x + checkpoint.checkpoint(self.forward_part2, x) + else: + x = x + self.forward_part2(x) + return x + + +class PatchMerging(nn.Module): + """ + Patch merging layer based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, dim: int, norm_layer: Type[LayerNorm] = nn.LayerNorm, spatial_dims: int = 3 + ) -> None: # type: ignore + """ + Args: + dim: number of feature channels. + norm_layer: normalization layer. + spatial_dims: number of spatial dims. + """ + + super().__init__() + self.dim = dim + if spatial_dims == 3: + self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) + self.norm = norm_layer(8 * dim) + elif spatial_dims == 2: + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + + x_shape = x.size() + if len(x_shape) == 5: + b, d, h, w, c = x_shape + pad_input = (h % 2 == 1) or (w % 2 == 1) or (d % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, d % 2, 0, w % 2, 0, h % 2)) + x0 = x[:, 0::2, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, 0::2, :] + x2 = x[:, 0::2, 1::2, 0::2, :] + x3 = x[:, 0::2, 0::2, 1::2, :] + x4 = x[:, 1::2, 0::2, 1::2, :] + x5 = x[:, 0::2, 1::2, 0::2, :] + x6 = x[:, 0::2, 0::2, 1::2, :] + x7 = x[:, 1::2, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) + + elif len(x_shape) == 4: + b, h, w, c = x_shape + pad_input = (h % 2 == 1) or (w % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, w % 2, 0, h % 2)) + x0 = x[:, 0::2, 0::2, :] + x1 = x[:, 1::2, 0::2, :] + x2 = x[:, 0::2, 1::2, :] + x3 = x[:, 1::2, 1::2, :] + x = torch.cat([x0, x1, x2, x3], -1) + + x = self.norm(x) + x = self.reduction(x) + return x + + +def compute_mask(dims, window_size, shift_size, device): + """Computing region masks based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + Args: + dims: dimension values. + window_size: local window size. + shift_size: shift size. + device: device. + """ + + cnt = 0 + + if len(dims) == 3: + d, h, w = dims + img_mask = torch.zeros((1, d, h, w, 1), device=device) + for d in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for h in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + for w in slice(-window_size[2]), slice(-window_size[2], -shift_size[2]), slice(-shift_size[2], None): + img_mask[:, d, h, w, :] = cnt + cnt += 1 + + elif len(dims) == 2: + h, w = dims + img_mask = torch.zeros((1, h, w, 1), device=device) + for h in slice(-window_size[0]), slice(-window_size[0], -shift_size[0]), slice(-shift_size[0], None): + for w in slice(-window_size[1]), slice(-window_size[1], -shift_size[1]), slice(-shift_size[1], None): + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, window_size) + mask_windows = mask_windows.squeeze(-1) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +class BasicLayer(nn.Module): + """ + Basic Swin Transformer layer in one stage based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + dim: int, + depth: int, + num_heads: int, + window_size: Sequence[int], + drop_path: list, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + drop: float = 0.0, + attn_drop: float = 0.0, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + downsample: isinstance = None, # type: ignore + use_checkpoint: bool = False, + ) -> None: + """ + Args: + dim: number of feature channels. + depths: number of layers in each stage. + num_heads: number of attention heads. + window_size: local window size. + drop_path: stochastic depth rate. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + drop: dropout rate. + attn_drop: attention dropout rate. + norm_layer: normalization layer. + downsample: downsample layer at the end of the layer. + use_checkpoint: use gradient checkpointing for reduced memory usage. + """ + + super().__init__() + self.window_size = window_size + self.shift_size = tuple(i // 2 for i in window_size) + self.no_shift = tuple(0 for i in window_size) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=self.window_size, + shift_size=self.no_shift if (i % 2 == 0) else self.shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_checkpoint=use_checkpoint, + ) + for i in range(depth) + ] + ) + self.downsample = downsample + if self.downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, spatial_dims=len(self.window_size)) + + def forward(self, x): + x_shape = x.size() + if len(x_shape) == 5: + b, c, d, h, w = x_shape + window_size, shift_size = get_window_size((d, h, w), self.window_size, self.shift_size) + x = rearrange(x, "b c d h w -> b d h w c") + dp = int(np.ceil(d / window_size[0])) * window_size[0] + hp = int(np.ceil(h / window_size[1])) * window_size[1] + wp = int(np.ceil(w / window_size[2])) * window_size[2] + attn_mask = compute_mask([dp, hp, wp], window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(b, d, h, w, -1) + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, "b d h w c -> b c d h w") + + elif len(x_shape) == 4: + b, c, h, w = x_shape + window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size) + x = rearrange(x, "b c h w -> b h w c") + hp = int(np.ceil(h / window_size[0])) * window_size[0] + wp = int(np.ceil(w / window_size[1])) * window_size[1] + attn_mask = compute_mask([hp, wp], window_size, shift_size, x.device) + for blk in self.blocks: + x = blk(x, attn_mask) + x = x.view(b, h, w, -1) + if self.downsample is not None: + x = self.downsample(x) + x = rearrange(x, "b h w c -> b c h w") + return x + + +class SwinTransformer(nn.Module): + """ + Swin Transformer based on: "Liu et al., + Swin Transformer: Hierarchical Vision Transformer using Shifted Windows + " + https://github.com/microsoft/Swin-Transformer + """ + + def __init__( + self, + in_chans: int, + embed_dim: int, + window_size: Sequence[int], + patch_size: Sequence[int], + depths: Sequence[int], + num_heads: Sequence[int], + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + norm_layer: Type[LayerNorm] = nn.LayerNorm, # type: ignore + patch_norm: bool = False, + use_checkpoint: bool = False, + spatial_dims: int = 3, + ) -> None: + """ + Args: + in_chans: dimension of input channels. + embed_dim: number of linear projection output channels. + window_size: local window size. + patch_size: patch size. + depths: number of layers in each stage. + num_heads: number of attention heads. + mlp_ratio: ratio of mlp hidden dim to embedding dim. + qkv_bias: add a learnable bias to query, key, value. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + drop_path_rate: stochastic depth rate. + norm_layer: normalization layer. + patch_norm: add normalization after patch embedding. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: spatial dimension. + """ + + super().__init__() + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.window_size = window_size + self.patch_size = patch_size + self.patch_embed = PatchEmbed( + patch_size=self.patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, # type: ignore + spatial_dims=spatial_dims, + ) + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.layers1 = nn.ModuleList() + self.layers2 = nn.ModuleList() + self.layers3 = nn.ModuleList() + self.layers4 = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=self.window_size, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + downsample=PatchMerging, + use_checkpoint=use_checkpoint, + ) + if i_layer == 0: + self.layers1.append(layer) + elif i_layer == 1: + self.layers2.append(layer) + elif i_layer == 2: + self.layers3.append(layer) + elif i_layer == 3: + self.layers4.append(layer) + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + + def proj_out(self, x, normalize=False): + if normalize: + x_shape = x.size() + if len(x_shape) == 5: + n, ch, d, h, w = x_shape + x = rearrange(x, "n c d h w -> n d h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n d h w c -> n c d h w") + elif len(x_shape) == 4: + n, ch, h, w = x_shape + x = rearrange(x, "n c h w -> n h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n h w c -> n c h w") + return x + + def forward(self, x, normalize=True): + x0 = self.patch_embed(x) + x0 = self.pos_drop(x0) + x0_out = self.proj_out(x0, normalize) + x1 = self.layers1[0](x0.contiguous()) + x1_out = self.proj_out(x1, normalize) + x2 = self.layers2[0](x1.contiguous()) + x2_out = self.proj_out(x2, normalize) + x3 = self.layers3[0](x2.contiguous()) + x3_out = self.proj_out(x3, normalize) + x4 = self.layers4[0](x3.contiguous()) + x4_out = self.proj_out(x4, normalize) + return [x0_out, x1_out, x2_out, x3_out, x4_out] + +if __name__ == '__main__': + import os + def get_parameter_number(model): + total_num = sum(p.numel() for p in model.parameters()) + trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return {'Total': total_num, 'Trainable': trainable_num} + + model = SwinUNETR( + img_size=[288, 288, 96], # the real input should satisfy : d,h,w > 32 + in_channels=3, + feature_size=48, + drop_rate=0.0, + attn_drop_rate=0.0, + dropout_path_rate=0.0, + use_checkpoint=False, + deep_supervision=True, + return_skips=True, + ).cuda() + + if is_master(): + print(f"** UNET ** {get_parameter_number(model)['Total']/1e6}M parameters") + + image = torch.rand((1, 3, 288, 288, 96)).cuda() + skips, outs = model(image) + + for s in skips: + print(s.shape) + for out in outs: + print(out.shape) \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/base_bert.py b/model/base_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..96edabc9a703aa827287abaefcbde3fe1acfed1b --- /dev/null +++ b/model/base_bert.py @@ -0,0 +1,26 @@ +import torch.nn as nn +import torch + +from transformers import BertModel, AutoTokenizer + +class BaseBERT(nn.Module): + def __init__(self, basebert_checkpoint='bert-base-uncased'): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(basebert_checkpoint) + self.model = BertModel.from_pretrained(basebert_checkpoint) + self.modality_embed = nn.Embedding(4, 768) + + def forward(self, text, modality): + encoded = self.tokenizer( + text, + truncation=True, + padding=True, + return_tensors='pt', + max_length=64, + ).to(device=torch.cuda.current_device()) + + text_feature = self.model(**encoded).last_hidden_state[:, 0, :] + modality_feature = self.modality_embed(modality) + text_feature += modality_feature + + return text_feature \ No newline at end of file diff --git a/model/build_model.py b/model/build_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd1732c7cf5f2320d78a6b7a34e3407b12381f4 --- /dev/null +++ b/model/build_model.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +import time +import os +from torch.nn.parallel import DistributedDataParallel as DDP + +import numpy as np + +from .maskformer import Maskformer + +from train.dist import is_master + + +def get_parameter_number(model): + total_num = sum(p.numel() for p in model.parameters()) + trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return {'Total': total_num, 'Trainable': trainable_num} + + +def build_maskformer(args, device, gpu_id): + model = Maskformer(args.vision_backbone, args.input_channels, args.crop_size, args.patch_size, args.deep_supervision) + + model = model.to(device) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_id], find_unused_parameters=True) + + def get_parameter_number(model): + total_num = sum(p.numel() for p in model.parameters()) + trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return {'Total': total_num, 'Trainable': trainable_num} + + if is_master(): + print(f"** MODEL ** {get_parameter_number(model)['Total']/1e6}M parameters") + + return model + + +def load_checkpoint(checkpoint_file, + resume, + partial_load, + model, + device, + optimizer=None, + ): + + if is_master(): + print('** CHECKPOINT ** : Load checkpoint from %s' % (checkpoint_file)) + + checkpoint = torch.load(checkpoint_file, map_location=device) + + # load part of the checkpoint + if partial_load: + model_dict = model.state_dict() + # check difference + unexpected_state_dict = [k for k in checkpoint['model_state_dict'].keys() if k not in model_dict.keys()] + missing_state_dict = [k for k in model_dict.keys() if k not in checkpoint['model_state_dict'].keys()] + unmatchd_state_dict = [k for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape != model_dict[k].shape] + # load partial parameters + state_dict = {k:v for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape == model_dict[k].shape} + model_dict.update(state_dict) + model.load_state_dict(model_dict) + if is_master(): + print('The following parameters are unexpected in SAT checkpoint:\n', unexpected_state_dict) + print('The following parameters are missing in SAT checkpoint:\n', missing_state_dict) + print('The following parameters have different shapes in SAT checkpoint:\n', unmatchd_state_dict) + print('The following parameters are loaded in SAT:\n', state_dict.keys()) + else: + model.load_state_dict(checkpoint['model_state_dict']) + + # if resume, load optimizer and step + if resume: + try: + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + except: + print('Optimizer state dict not matched, skip loading optimizer state dict') + pass + start_step = int(checkpoint['step']) + 1 + print('Resume from step %d' % (start_step)) + else: + start_step = 1 + + return model, optimizer, start_step + + +def inherit_knowledge_encoder(knowledge_encoder_checkpoint, + model, + device + ): + # inherit unet encoder and multiscale feature projection layer from knowledge encoder + checkpoint = torch.load(knowledge_encoder_checkpoint, map_location=device) + + model_dict = model.state_dict() + visual_encoder_state_dict = {k.replace('atlas_tower', 'backbone'):v for k,v in checkpoint['model_state_dict'].items() if 'atlas_tower.encoder' in k} # encoder部分 + model_dict.update(visual_encoder_state_dict) + proj_state_dict = {k.replace('atlas_tower.', ''):v for k,v in checkpoint['model_state_dict'].items() if 'atlas_tower.projection_layer' in k} # projection layer部分 + model_dict.update(proj_state_dict) + model.load_state_dict(model_dict) + + if is_master(): + print('** CHECKPOINT ** : Inherit pretrained unet encoder from %s' % (knowledge_encoder_checkpoint)) + print('The following parameters are loaded in SAT:\n', list(visual_encoder_state_dict.keys())+list(proj_state_dict.keys())) + + return model \ No newline at end of file diff --git a/model/dynamic-network-architectures-main/.gitignore b/model/dynamic-network-architectures-main/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b31d01b49fa27d77549ff92d7256347e8c021eff --- /dev/null +++ b/model/dynamic-network-architectures-main/.gitignore @@ -0,0 +1,113 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# IPython Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# dotenv +.env + +# virtualenv +venv/ +ENV/ + +# Spyder project settings +.spyderproject + +# Rope project settings +.ropeproject + +*.memmap +*.zip +*.npz +*.npy +*.jpg +*.jpeg +.idea +*.txt +.idea/* +*.nii.gz +*.nii +*.tif +*.bmp +*.pkl +*.xml +*.pkl +*.pdf +*.jpg +*.jpeg + +*.model + +cifar_lightning/mlruns* diff --git a/model/dynamic-network-architectures-main/LICENCE b/model/dynamic-network-architectures-main/LICENCE new file mode 100644 index 0000000000000000000000000000000000000000..f58bbd7a84d846ae29d70d66b9aa744e23b858b7 --- /dev/null +++ b/model/dynamic-network-architectures-main/LICENCE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2022] [Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/model/dynamic-network-architectures-main/README.md b/model/dynamic-network-architectures-main/README.md new file mode 100644 index 0000000000000000000000000000000000000000..698865ab1bde3206c2af6e37fb8c65f65c81e133 --- /dev/null +++ b/model/dynamic-network-architectures-main/README.md @@ -0,0 +1,25 @@ +# Dynamic Network Architectures + +This repository contains several ResNet, U-Net and VGG architectures in pytorch that can be dynamically adapted to a varying number of image dimensions (1D, 2D or 3D) and the number of input channels. + +## Available models +### ResNet +We implement the standard [ResNetD](https://arxiv.org/pdf/1812.01187.pdf) 18, 34, 50 and 152. For ResNets 50 and 152 also bottleneck implementations are available. Moreover, adapted versions that are better suited for smaller image sizes such as CIFAR can be used. + +All models additionally include regularization techniques like [Stochastic Depth](https://arxiv.org/pdf/1603.09382.pdf), [Squeeze & Excitation](https://arxiv.org/pdf/1709.01507.pdf) and [Final Layer Dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf). + +### VGG +In contrast to the original [VGG](https://arxiv.org/pdf/1409.1556.pdf) implementation we exclude the final fully-connected layers in the end and replace it by additional convolutional layers and only one fully-connected layer in the end. Adapted versions that are better suited for smaller image sizes such as CIFAR can be used. + +### U-Net +For the [U-Net](https://arxiv.org/pdf/1505.04597.pdf) a plain convolutional encoder as well as a residual encoder are available. + +# Acknowledgements + +

+      + +

+ +This Repository is developed and maintained by the Applied Computer Vision Lab (ACVL) +of [Helmholtz Imaging](https://www.helmholtz-imaging.de/). \ No newline at end of file diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/PKG-INFO b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..f0867b032ef1fe229b174c326d18f390c075cb28 --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/PKG-INFO @@ -0,0 +1,16 @@ +Metadata-Version: 2.4 +Name: dynamic_network_architectures +Version: 0.2 +Summary: none +Author: Fabian Isensee +Author-email: f.isensee@dkfz.de +License: private +License-File: LICENCE +Requires-Dist: torch>=1.6.0a +Requires-Dist: numpy +Dynamic: author +Dynamic: author-email +Dynamic: license +Dynamic: license-file +Dynamic: requires-dist +Dynamic: summary diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/SOURCES.txt b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..a39a6aaa7328ee5b4b3211e1dcc371e78269544d --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/SOURCES.txt @@ -0,0 +1,24 @@ +LICENCE +README.md +setup.py +dynamic_network_architectures/__init__.py +dynamic_network_architectures.egg-info/PKG-INFO +dynamic_network_architectures.egg-info/SOURCES.txt +dynamic_network_architectures.egg-info/dependency_links.txt +dynamic_network_architectures.egg-info/not-zip-safe +dynamic_network_architectures.egg-info/requires.txt +dynamic_network_architectures.egg-info/top_level.txt +dynamic_network_architectures/architectures/__init__.py +dynamic_network_architectures/architectures/resnet.py +dynamic_network_architectures/architectures/unet.py +dynamic_network_architectures/architectures/vgg.py +dynamic_network_architectures/building_blocks/__init__.py +dynamic_network_architectures/building_blocks/helper.py +dynamic_network_architectures/building_blocks/plain_conv_encoder.py +dynamic_network_architectures/building_blocks/regularization.py +dynamic_network_architectures/building_blocks/residual.py +dynamic_network_architectures/building_blocks/residual_encoders.py +dynamic_network_architectures/building_blocks/simple_conv_blocks.py +dynamic_network_architectures/building_blocks/unet_decoder.py +dynamic_network_architectures/initialization/__init__.py +dynamic_network_architectures/initialization/weight_init.py \ No newline at end of file diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/dependency_links.txt b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/not-zip-safe b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/not-zip-safe new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/not-zip-safe @@ -0,0 +1 @@ + diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/requires.txt b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..2238a0ebb2f7515170bee41f153cb55be57ec1cd --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/requires.txt @@ -0,0 +1,2 @@ +torch>=1.6.0a +numpy diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/top_level.txt b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..e8c83963d5525514ca88d6025eb66180b00ac65b --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures.egg-info/top_level.txt @@ -0,0 +1 @@ +dynamic_network_architectures diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/__init__.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/__pycache__/__init__.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78e4a8d0a91f442759af979d2b62944a60aa2280 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__init__.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/__init__.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a26abb065274cd6f9b42b8990d46c66c8442a423 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/unet.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9ccdfda440a42bef3444eaa07a8ffddf7ec0293 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/__pycache__/unet.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/resnet.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5af9f6f0646960a7e7c275f69fbe1268af72e90b --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/resnet.py @@ -0,0 +1,236 @@ +import torch +from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder, BottleneckD, BasicBlockD +from dynamic_network_architectures.building_blocks.helper import get_matching_pool_op, get_default_network_config +from dynamic_network_architectures.building_blocks.simple_conv_blocks import ConvDropoutNormReLU +from torch import nn + +_ResNet_CONFIGS = { + '18': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (2, 2, 2, 2), 'strides': (1, 2, 2, 2), + 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None}, + '34': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (3, 4, 6, 3), 'strides': (1, 2, 2, 2), + 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None}, + '50': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 6, 10, 5), 'strides': (1, 2, 2, 2), + 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None}, + '152': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 13, 55, 4), 'strides': (1, 2, 2, 2), + 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': True, 'stem_channels': None}, + '50_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 4, 6, 3), 'strides': (1, 2, 2, 2), + 'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': True, + 'stem_channels': 64}, + '152_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 8, 36, 3), + 'strides': (1, 2, 2, 2), + 'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': True, + 'stem_channels': 64}, + '18_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (2, 2, 2, 2), 'strides': (1, 2, 2, 2), + 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False, + 'stem_channels': None}, + '34_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (3, 4, 6, 3), 'strides': (1, 2, 2, 2), + 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False, + 'stem_channels': None}, + '50_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 6, 10, 5), + 'strides': (1, 2, 2, 2), + 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False, + 'stem_channels': None}, + '152_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_blocks_per_stage': (4, 13, 55, 4), + 'strides': (1, 2, 2, 2), + 'block': BasicBlockD, 'bottleneck_channels': None, 'disable_default_stem': False, + 'stem_channels': None}, + '50_cifar_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 4, 6, 3), + 'strides': (1, 2, 2, 2), + 'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': False, + 'stem_channels': 64}, + '152_cifar_bn': {'features_per_stage': (256, 512, 1024, 2048), 'n_blocks_per_stage': (3, 8, 36, 3), + 'strides': (1, 2, 2, 2), + 'block': BottleneckD, 'bottleneck_channels': (64, 128, 256, 512), 'disable_default_stem': False, + 'stem_channels': 64}, +} + + +class ResNetD(nn.Module): + def __init__(self, n_classes: int, n_input_channel: int = 3, config='18', input_dimension=2, + final_layer_dropout=0.0, stochastic_depth_p=0.0, squeeze_excitation=False, + squeeze_excitation_rd_ratio=1./16): + """ + Implements ResNetD (https://arxiv.org/pdf/1812.01187.pdf). + Args: + n_classes: Number of classes + n_input_channel: Number of input channels (e.g. 3 for RGB) + config: Configuration of the ResNet + input_dimension: Number of dimensions of the data (1, 2 or 3) + final_layer_dropout: Probability of dropout before the final classifier + stochastic_depth_p: Stochastic Depth probability + squeeze_excitation: Whether Squeeze and Excitation should be applied + squeeze_excitation_rd_ratio: Squeeze and Excitation Reduction Ratio + Returns: + ResNet Model + """ + super().__init__() + self.input_channels = n_input_channel + self.cfg = _ResNet_CONFIGS[config] + self.ops = get_default_network_config(dimension=input_dimension) + self.final_layer_dropout_p = final_layer_dropout + + if self.cfg['disable_default_stem']: + stem_features = self.cfg['stem_channels'] if self.cfg['stem_channels'] is not None else \ + self.cfg['features_per_stage'][0] + self.stem = self._build_imagenet_stem_D(stem_features) + encoder_input_features = stem_features + else: + encoder_input_features = n_input_channel + self.stem = None + + self.encoder = ResidualEncoder(encoder_input_features, n_stages=len(self.cfg['features_per_stage']), + features_per_stage=self.cfg['features_per_stage'], conv_op=self.ops['conv_op'], + kernel_sizes=3, strides=self.cfg['strides'], + n_blocks_per_stage=self.cfg['n_blocks_per_stage'], conv_bias=False, + norm_op=self.ops['norm_op'], norm_op_kwargs=None, dropout_op=None, + dropout_op_kwargs=None, nonlin=nn.ReLU, + nonlin_kwargs={'inplace': True}, block=self.cfg['block'], + bottleneck_channels=self.cfg['bottleneck_channels'], return_skips=False, + disable_default_stem=self.cfg['disable_default_stem'], + stem_channels=self.cfg['stem_channels'], + stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, + squeeze_excitation_reduction_ratio=squeeze_excitation_rd_ratio) + + self.gap = get_matching_pool_op(conv_op=self.ops['conv_op'], adaptive=True, pool_type='avg')(1) + self.classifier = nn.Linear(self.cfg['features_per_stage'][-1], n_classes, True) + self.final_layer_dropout = self.ops['dropout_op'](p=self.final_layer_dropout_p) + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + x = self.encoder(x) + x = self.gap(x) + x = self.final_layer_dropout(x).squeeze() + + return self.classifier(x) + + def _build_imagenet_stem_D(self, stem_features): + """ + https://arxiv.org/pdf/1812.01187.pdf + + use 3 3x3(x3) convs instead of one 7x7. Stride is located in first conv. + + Fig2 b) describes this + :return: + """ + c1 = ConvDropoutNormReLU(self.ops['conv_op'], self.input_channels, stem_features, 3, 2, False, + self.ops['norm_op'], None, None, None, nn.ReLU, {'inplace': True}) + c2 = ConvDropoutNormReLU(self.ops['conv_op'], stem_features, stem_features, 3, 1, False, + self.ops['norm_op'], None, None, None, nn.ReLU, {'inplace': True}) + c3 = ConvDropoutNormReLU(self.ops['conv_op'], stem_features, stem_features, 3, 1, False, + self.ops['norm_op'], None, None, None, nn.ReLU, {'inplace': True}) + pl = get_matching_pool_op(conv_op=self.ops['conv_op'], adaptive=False, pool_type='max')(2) + stem = nn.Sequential(c1, c2, c3, pl) + return stem + + +class ResNet18_CIFAR(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='18_cifar', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet34_CIFAR(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='34_cifar', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet50_CIFAR(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='50_cifar', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet152_CIFAR(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='152_cifar', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet50bn_CIFAR(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='50_cifar_bn', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet152bn_CIFAR(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='152_cifar_bn', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet18(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='18', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet34(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='34', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet50(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='50', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet152(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='152', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet50bn(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='50_bn', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + +class ResNet152bn(ResNetD): + def __init__(self, n_classes: int, n_input_channels: int = 3, input_dimension: int = 2, + final_layer_dropout: float = 0.0, stochastic_depth_p: float = 0.0, squeeze_excitation: bool = False, + squeeze_excitation_rd_ratio: float = 1./16): + super().__init__(n_classes, n_input_channels, config='152_bn', input_dimension=input_dimension, + final_layer_dropout=final_layer_dropout, stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, squeeze_excitation_rd_ratio=squeeze_excitation_rd_ratio) + + +if __name__ == '__main__': + data = torch.rand((1, 3, 224, 224)) + + model = ResNet50bn(10, 3) + import hiddenlayer as hl + + g = hl.build_graph(model, data, + transforms=None) + g.save("network_architecture.pdf") + del g + + #print(model.compute_conv_feature_map_size((32, 32))) \ No newline at end of file diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/unet.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d4caca9289906919e729e456fe6b845de199ee --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/unet.py @@ -0,0 +1,220 @@ +from typing import Union, Type, List, Tuple + +import torch +from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder +from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD +from torch import nn +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd + +from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder +from dynamic_network_architectures.building_blocks.unet_decoder import UNetDecoder, UNetDecoder_Seg +from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim + + +class PlainConvUNet(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...]], + n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, # activation + nonlin_kwargs: dict = None, + deep_supervision: bool = False, + nonlin_first: bool = False + ): + """ + nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin + """ + super().__init__() + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_conv_per_stage: {n_conv_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + self.encoder = PlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides, + n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op, + dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True, + nonlin_first=nonlin_first) + + self.decoder = UNetDecoder(self.encoder, n_conv_per_stage_decoder, deep_supervision, + nonlin_first=nonlin_first) + + def forward(self, x): + skips = self.encoder(x) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3] + outs = self.decoder(skips) # [2, 32, 256, 256, 96] ... [2, 512, 16, 16, 6] + return skips, outs # latent_embeddings(a list of multiscale features), perpixel_embeddings(a list of decoder outputs) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + + +class PlainConvUNet_Seg(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...]], + n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + num_classes: int, + n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, # activation + nonlin_kwargs: dict = None, + deep_supervision: bool = False, + nonlin_first: bool = False + ): + """ + nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin + """ + super().__init__() + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_conv_per_stage: {n_conv_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + self.encoder = PlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides, + n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op, + dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True, + nonlin_first=nonlin_first) + self.decoder = UNetDecoder_Seg(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision, + nonlin_first=nonlin_first) + + def forward(self, x): + skips = self.encoder(x) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3] + out = self.decoder(skips) # [2, num_class, 256, 256, 96] + return out + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + + +class ResidualEncoderUNet(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + deep_supervision: bool = False, + block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, + bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, + stem_channels: int = None + ): + super().__init__() + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + self.encoder = ResidualEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides, + n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op, + dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels, + return_skips=True, disable_default_stem=False, stem_channels=stem_channels) + + self.decoder = UNetDecoder(self.encoder, n_conv_per_stage_decoder, deep_supervision) + + def forward(self, x): + skips = self.encoder(x) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3] + outs = self.decoder(skips) # [2, 32, 256, 256, 96] ... [2, 512, 16, 16, 6] + return skips, outs # latent_embeddings(a list of multiscale features), perpixel_embeddings(a list of decoder outputs) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + + +if __name__ == '__main__': + import sys + sys.path.append('/remote-home/zihengzhao/Knowledge-Enhanced-Medical-Segmentation/medical-universal-segmentation/model/dynamic-network-architectures-main') + + data = torch.rand((2, 3, 256, 256, 96)).cuda() + + model = PlainConvUNet(3, 6, (32, 64, 128, 256, 512, 768), nn.Conv3d, 3, (1, 2, 2, 2, 2, 2), (2, 2, 2, 2, 2, 2), 4, + (2, 2, 2, 2, 2), False, nn.BatchNorm3d, None, None, None, nn.ReLU, deep_supervision=True).cuda() + + dec_outs, enc_outs = model(data) + print('DEC') + for i in dec_outs: + print(i.shape) # (2, 4, 256, 256, 96) + print('ENC') + for i in enc_outs: + print(i.shape) # () + exit() + + + if False: + import hiddenlayer as hl + + g = hl.build_graph(model, data, + transforms=None) + g.save("network_architecture.pdf") + del g + + print(model.compute_conv_feature_map_size(data.shape[2:])) + + data = torch.rand((1, 4, 512, 512)) + + model = PlainConvUNet(4, 8, (32, 64, 125, 256, 512, 512, 512, 512), nn.Conv2d, 3, (1, 2, 2, 2, 2, 2, 2, 2), (2, 2, 2, 2, 2, 2, 2, 2), 4, + (2, 2, 2, 2, 2, 2, 2), False, nn.BatchNorm2d, None, None, None, nn.ReLU, deep_supervision=True) + + if False: + import hiddenlayer as hl + + g = hl.build_graph(model, data, + transforms=None) + g.save("network_architecture.pdf") + del g + + print(model.compute_conv_feature_map_size(data.shape[2:])) diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/vgg.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..b31e2fafb281fce169e71b7adcc8790680ec9048 --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/architectures/vgg.py @@ -0,0 +1,85 @@ +import torch +from torch import nn + +from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder +from dynamic_network_architectures.building_blocks.helper import get_matching_pool_op, get_default_network_config + +_VGG_CONFIGS = { + '16': {'features_per_stage': (64, 128, 256, 512, 512, 512), 'n_conv_per_stage': (2, 2, 2, 3, 3, 3), + 'strides': (1, 2, 2, 2, 2, 2)}, + '19': {'features_per_stage': (64, 128, 256, 512, 512, 512), 'n_conv_per_stage': (2, 2, 3, 3, 4, 4), + 'strides': (1, 2, 2, 2, 2, 2)}, + '16_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_conv_per_stage': (2, 3, 5, 5), 'strides': (1, 2, 2, 2)}, + '19_cifar': {'features_per_stage': (64, 128, 256, 512), 'n_conv_per_stage': (3, 4, 5, 6), 'strides': (1, 2, 2, 2)}, +} + +_VGG_OPS = { + 1: {'conv_op': nn.Conv1d, 'norm_op': nn.BatchNorm1d}, + 2: {'conv_op': nn.Conv2d, 'norm_op': nn.BatchNorm2d}, + 3: {'conv_op': nn.Conv3d, 'norm_op': nn.BatchNorm3d}, +} + + +class VGG(nn.Module): + def __init__(self, n_classes: int, n_input_channel: int = 3, config='16', input_dimension=2): + """ + This is not 1:1 VGG because it does not have the bloated fully connected layers at the end. Since these were + counted towards the XX layers as well, we increase the number of convolutional layers so that we have the + desired number of conv layers in total + + We also use batchnorm + """ + super().__init__() + cfg = _VGG_CONFIGS[config] + ops = get_default_network_config(dimension=input_dimension) + self.encoder = PlainConvEncoder( + n_input_channel, n_stages=len(cfg['features_per_stage']), features_per_stage=cfg['features_per_stage'], + conv_op=ops['conv_op'], + kernel_sizes=3, strides=cfg['strides'], n_conv_per_stage=cfg['n_conv_per_stage'], conv_bias=False, + norm_op=ops['norm_op'], norm_op_kwargs=None, dropout_op=None, dropout_op_kwargs=None, nonlin=nn.ReLU, + nonlin_kwargs={'inplace': True}, return_skips=False + ) + self.gap = get_matching_pool_op(conv_op=ops['conv_op'], adaptive=True, pool_type='avg')(1) + self.classifier = nn.Linear(cfg['features_per_stage'][-1], n_classes, True) + + def forward(self, x): + x = self.encoder(x) + x = self.gap(x).squeeze() + return self.classifier(x) + + def compute_conv_feature_map_size(self, input_size): + return self.encoder.compute_conv_feature_map_size(input_size) + + +class VGG16(VGG): + def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2): + super().__init__(n_classes, n_input_channel, config='16', input_dimension=input_dimension) + + +class VGG19(VGG): + def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2): + super().__init__(n_classes, n_input_channel, config='19', input_dimension=input_dimension) + + +class VGG16_cifar(VGG): + def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2): + super().__init__(n_classes, n_input_channel, config='16_cifar', input_dimension=input_dimension) + + +class VGG19_cifar(VGG): + def __init__(self, n_classes: int, n_input_channel: int = 3, input_dimension: int = 2): + super().__init__(n_classes, n_input_channel, config='19_cifar', input_dimension=input_dimension) + + +if __name__ == '__main__': + data = torch.rand((1, 3, 32, 32)) + + model = VGG19_cifar(10, 3) + import hiddenlayer as hl + + g = hl.build_graph(model, data, + transforms=None) + g.save("network_architecture.pdf") + del g + + print(model.compute_conv_feature_map_size((32, 32))) \ No newline at end of file diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__init__.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/__init__.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a09c0b8a84be2743059e39180aeb0be48e7c80eb Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/helper.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b724e3f33e070d3263d7c8855c030918e2d39a8 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/helper.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/plain_conv_encoder.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/plain_conv_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f4961d821b33b89d92be03944962d55b747bba8 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/plain_conv_encoder.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/regularization.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/regularization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f9cd99f25713c639dd86e9bde18d8a7b49ad8d0 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/regularization.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..340f7da8033b509147cf3c7a55007908537031a7 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual_encoders.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual_encoders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a07a4bc20b833bc829dbbac2194f50a89534e1cb Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/residual_encoders.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/simple_conv_blocks.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/simple_conv_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5d5e40b5286a886e26594f721662902f5c0abd4 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/simple_conv_blocks.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/unet_decoder.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/unet_decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48dfc784517f7417a18cad4cbcf8aed266bb3aa3 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/__pycache__/unet_decoder.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/helper.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..4eeb919680509a2c6e897a9aa7843bde2a68778d --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/helper.py @@ -0,0 +1,242 @@ +from typing import Type +import numpy as np +import torch.nn +from torch import nn +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd +from torch.nn.modules.dropout import _DropoutNd +from torch.nn.modules.instancenorm import _InstanceNorm + + +def convert_dim_to_conv_op(dimension: int) -> Type[_ConvNd]: + """ + :param dimension: 1, 2 or 3 + :return: conv Class of corresponding dimension + """ + if dimension == 1: + return nn.Conv1d + elif dimension == 2: + return nn.Conv2d + elif dimension == 3: + return nn.Conv3d + else: + raise ValueError("Unknown dimension. Only 1, 2 and 3 are supported") + + +def convert_conv_op_to_dim(conv_op: Type[_ConvNd]) -> int: + """ + :param conv_op: conv class + :return: dimension: 1, 2 or 3 + """ + if conv_op == nn.Conv1d: + return 1 + elif conv_op == nn.Conv2d: + return 2 + elif conv_op == nn.Conv3d: + return 3 + else: + raise ValueError("Unknown dimension. Only 1d 2d and 3d conv are supported. got %s" % str(conv_op)) + + +def get_matching_pool_op(conv_op: Type[_ConvNd] = None, + dimension: int = None, + adaptive=False, + pool_type: str = 'avg') -> Type[torch.nn.Module]: + """ + You MUST set EITHER conv_op OR dimension. Do not set both! + :param conv_op: + :param dimension: + :param adaptive: + :param pool_type: either 'avg' or 'max' + :return: + """ + assert not ((conv_op is not None) and (dimension is not None)), \ + "You MUST set EITHER conv_op OR dimension. Do not set both!" + assert pool_type in ['avg', 'max'], 'pool_type must be either avg or max' + if conv_op is not None: + dimension = convert_conv_op_to_dim(conv_op) + assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3' + + if conv_op is not None: + dimension = convert_conv_op_to_dim(conv_op) + + if dimension == 1: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool1d + else: + return nn.AvgPool1d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool1d + else: + return nn.MaxPool1d + elif dimension == 2: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool2d + else: + return nn.AvgPool2d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool2d + else: + return nn.MaxPool2d + elif dimension == 3: + if pool_type == 'avg': + if adaptive: + return nn.AdaptiveAvgPool3d + else: + return nn.AvgPool3d + elif pool_type == 'max': + if adaptive: + return nn.AdaptiveMaxPool3d + else: + return nn.MaxPool3d + + +def get_matching_instancenorm(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_InstanceNorm]: + """ + You MUST set EITHER conv_op OR dimension. Do not set both! + + :param conv_op: + :param dimension: + :return: + """ + assert not ((conv_op is not None) and (dimension is not None)), \ + "You MUST set EITHER conv_op OR dimension. Do not set both!" + if conv_op is not None: + dimension = convert_conv_op_to_dim(conv_op) + if dimension is not None: + assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3' + if dimension == 1: + return nn.InstanceNorm1d + elif dimension == 2: + return nn.InstanceNorm2d + elif dimension == 3: + return nn.InstanceNorm3d + + +def get_matching_convtransp(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_ConvTransposeNd]: + """ + You MUST set EITHER conv_op OR dimension. Do not set both! + + :param conv_op: + :param dimension: + :return: + """ + assert not ((conv_op is not None) and (dimension is not None)), \ + "You MUST set EITHER conv_op OR dimension. Do not set both!" + if conv_op is not None: + dimension = convert_conv_op_to_dim(conv_op) + assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3' + if dimension == 1: + return nn.ConvTranspose1d + elif dimension == 2: + return nn.ConvTranspose2d + elif dimension == 3: + return nn.ConvTranspose3d + + +def get_matching_batchnorm(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_BatchNorm]: + """ + You MUST set EITHER conv_op OR dimension. Do not set both! + + :param conv_op: + :param dimension: + :return: + """ + assert not ((conv_op is not None) and (dimension is not None)), \ + "You MUST set EITHER conv_op OR dimension. Do not set both!" + if conv_op is not None: + dimension = convert_conv_op_to_dim(conv_op) + assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3' + if dimension == 1: + return nn.BatchNorm1d + elif dimension == 2: + return nn.BatchNorm2d + elif dimension == 3: + return nn.BatchNorm3d + + +def get_matching_dropout(conv_op: Type[_ConvNd] = None, dimension: int = None) -> Type[_DropoutNd]: + """ + You MUST set EITHER conv_op OR dimension. Do not set both! + + :param conv_op: + :param dimension: + :return: + """ + assert not ((conv_op is not None) and (dimension is not None)), \ + "You MUST set EITHER conv_op OR dimension. Do not set both!" + assert dimension in [1, 2, 3], 'Dimension must be 1, 2 or 3' + if dimension == 1: + return nn.Dropout + elif dimension == 2: + return nn.Dropout2d + elif dimension == 3: + return nn.Dropout3d + + +def maybe_convert_scalar_to_list(conv_op, scalar): + """ + useful for converting, for example, kernel_size=3 to [3, 3, 3] in case of nn.Conv3d + :param conv_op: + :param scalar: + :return: + """ + if not isinstance(scalar, (tuple, list, np.ndarray)): + if conv_op == nn.Conv2d: + return [scalar] * 2 + elif conv_op == nn.Conv3d: + return [scalar] * 3 + elif conv_op == nn.Conv1d: + return [scalar] * 1 + else: + raise RuntimeError("Invalid conv op: %s" % str(conv_op)) + else: + return scalar + + +def get_default_network_config(dimension: int = 2, + nonlin: str = "ReLU", + norm_type: str = "bn") -> dict: + """ + Use this to get a standard configuration. A network configuration looks like this: + + config = {'conv_op': torch.nn.modules.conv.Conv2d, + 'dropout_op': torch.nn.modules.dropout.Dropout2d, + 'norm_op': torch.nn.modules.batchnorm.BatchNorm2d, + 'norm_op_kwargs': {'eps': 1e-05, 'affine': True}, + 'nonlin': torch.nn.modules.activation.ReLU, + 'nonlin_kwargs': {'inplace': True}} + + There is no need to use get_default_network_config. You can create your own. Network configs are a convenient way of + setting dimensionality, normalization and nonlinearity. + + :param dimension: integer denoting the dimension of the data. 1, 2 and 3 are accepted + :param nonlin: string (ReLU or LeakyReLU) + :param norm_type: string (bn=batch norm, in=instance norm) + torch.nn.Module + :return: dict + """ + config = {} + config['conv_op'] = convert_dim_to_conv_op(dimension) + config['dropout_op'] = get_matching_dropout(dimension=dimension) + if norm_type == "bn": + config['norm_op'] = get_matching_batchnorm(dimension=dimension) + elif norm_type == "in": + config['norm_op'] = get_matching_instancenorm(dimension=dimension) + + config['norm_op_kwargs'] = None # this will use defaults + + if nonlin == "LeakyReLU": + config['nonlin'] = nn.LeakyReLU + config['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True} + elif nonlin == "ReLU": + config['nonlin'] = nn.ReLU + config['nonlin_kwargs'] = {'inplace': True} + else: + raise NotImplementedError('Unknown nonlin %s. Only "LeakyReLU" and "ReLU" are supported for now' % nonlin) + + return config \ No newline at end of file diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/plain_conv_encoder.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/plain_conv_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ec5329e429361a116a1d93e09c1bb34f2575749b --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/plain_conv_encoder.py @@ -0,0 +1,105 @@ +import torch +from torch import nn +import numpy as np +from typing import Union, Type, List, Tuple + +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd +from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks +from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op + + +class PlainConvEncoder(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...]], + n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + return_skips: bool = False, + nonlin_first: bool = False, + pool: str = 'conv' + ): + + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + assert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + stages = [] + for s in range(n_stages): + stage_modules = [] + if pool == 'max' or pool == 'avg': + if (isinstance(strides[s], int) and strides[s] != 1) or \ + isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]): + stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s])) + conv_stride = 1 + elif pool == 'conv': + conv_stride = strides[s] + else: + raise RuntimeError() + stage_modules.append(StackedConvBlocks( + n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride, + conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first + )) + stages.append(nn.Sequential(*stage_modules)) + input_channels = features_per_stage[s] + + self.stages = nn.Sequential(*stages) + self.output_channels = features_per_stage + self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + self.return_skips = return_skips + + # we store some things that a potential decoder needs + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + def forward(self, x): + ret = [] + for s in self.stages: + x = s(x) + ret.append(x) + if self.return_skips: + return ret + else: + return ret[-1] + + def compute_conv_feature_map_size(self, input_size): + output = np.int64(0) + for s in range(len(self.stages)): + if isinstance(self.stages[s], nn.Sequential): + for sq in self.stages[s]: + if hasattr(sq, 'compute_conv_feature_map_size'): + output += self.stages[s][-1].compute_conv_feature_map_size(input_size) + else: + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + return output + + diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/regularization.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/regularization.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b10e6f05a8c75b6b3d2735c68990377d91dd8f --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/regularization.py @@ -0,0 +1,86 @@ +from torch import nn + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): + """ + This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py). + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """ + This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py). + + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + +class SqueezeExcite(nn.Module): + """ + This class is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/squeeze_excite.py) + and slightly modified so that the convolution type can be adapted. + + SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, conv_op, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer=nn.Sigmoid): + super(SqueezeExcite, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = conv_op(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = act_layer(inplace=True) + self.fc2 = conv_op(rd_channels, channels, kernel_size=1, bias=True) + self.gate = gate_layer() + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +def make_divisible(v, divisor=8, min_value=None, round_limit=.9): + """ + This function is taken from the timm package (https://github.com/rwightman/pytorch-image-models/blob/b7cb8d0337b3e7b50516849805ddb9be5fc11644/timm/models/layers/helpers.py#L25) + """ + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < round_limit * v: + new_v += divisor + return new_v diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual.py new file mode 100644 index 0000000000000000000000000000000000000000..180fe6441ff1c95e2305b2b378439f0529b8c028 --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual.py @@ -0,0 +1,371 @@ +from typing import Tuple, List, Union, Type +import torch.nn +from torch import nn +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd + +from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op +from dynamic_network_architectures.building_blocks.simple_conv_blocks import ConvDropoutNormReLU +from dynamic_network_architectures.building_blocks.regularization import DropPath, SqueezeExcite +import numpy as np + + +class BasicBlockD(nn.Module): + def __init__(self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + stochastic_depth_p: float = 0.0, + squeeze_excitation: bool = False, + squeeze_excitation_reduction_ratio: float = 1. / 16, + # todo wideresnet? + ): + """ + This implementation follows ResNet-D: + + He, Tong, et al. "Bag of tricks for image classification with convolutional neural networks." + Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. + + The skip has an avgpool (if needed) followed by 1x1 conv instead of just a strided 1x1 conv + + :param conv_op: + :param input_channels: + :param output_channels: + :param kernel_size: refers only to convs in feature extraction path, not to 1x1x1 conv in skip + :param stride: only applies to first conv (and skip). Second conv always has stride 1 + :param conv_bias: + :param norm_op: + :param norm_op_kwargs: + :param dropout_op: only the first conv can have dropout. The second never has + :param dropout_op_kwargs: + :param nonlin: + :param nonlin_kwargs: + :param stochastic_depth_p: + :param squeeze_excitation: + :param squeeze_excitation_reduction_ratio: + """ + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + stride = maybe_convert_scalar_to_list(conv_op, stride) + self.stride = stride + + kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size) + + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + self.conv1 = ConvDropoutNormReLU(conv_op, input_channels, output_channels, kernel_size, stride, conv_bias, + norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs) + self.conv2 = ConvDropoutNormReLU(conv_op, output_channels, output_channels, kernel_size, 1, conv_bias, norm_op, + norm_op_kwargs, None, None, None, None) + + self.nonlin2 = nonlin(**nonlin_kwargs) if nonlin is not None else lambda x: x + + # Stochastic Depth + self.apply_stochastic_depth = False if stochastic_depth_p == 0.0 else True + if self.apply_stochastic_depth: + self.drop_path = DropPath(drop_prob=stochastic_depth_p) + + # Squeeze Excitation + self.apply_se = squeeze_excitation + if self.apply_se: + self.squeeze_excitation = SqueezeExcite(self.output_channels, conv_op, + rd_ratio=squeeze_excitation_reduction_ratio, rd_divisor=8) + + has_stride = (isinstance(stride, int) and stride != 1) or any([i != 1 for i in stride]) + requires_projection = (input_channels != output_channels) + + if has_stride or requires_projection: + ops = [] + if has_stride: + ops.append(get_matching_pool_op(conv_op=conv_op, adaptive=False, pool_type='avg')(stride, stride)) + if requires_projection: + ops.append( + ConvDropoutNormReLU(conv_op, input_channels, output_channels, 1, 1, False, norm_op, + norm_op_kwargs, None, None, None, None + ) + ) + self.skip = nn.Sequential(*ops) + else: + self.skip = lambda x: x + + def forward(self, x): + residual = self.skip(x) + out = self.conv2(self.conv1(x)) + if self.apply_stochastic_depth: + out = self.drop_path(out) + if self.apply_se: + out = self.squeeze_excitation(out) + out += residual + return self.nonlin2(out) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + size_after_stride = [i // j for i, j in zip(input_size, self.stride)] + # conv1 + output_size_conv1 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + # conv2 + output_size_conv2 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + # skip conv (if applicable) + if (self.input_channels != self.output_channels) or any([i != j for i, j in zip(input_size, size_after_stride)]): + assert isinstance(self.skip, nn.Sequential) + output_size_skip = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + else: + assert not isinstance(self.skip, nn.Sequential) + output_size_skip = 0 + return output_size_conv1 + output_size_conv2 + output_size_skip + + +class BottleneckD(nn.Module): + def __init__(self, + conv_op: Type[_ConvNd], + input_channels: int, + bottleneck_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + stochastic_depth_p: float = 0.0, + squeeze_excitation: bool = False, + squeeze_excitation_reduction_ratio: float = 1. / 16 + ): + """ + This implementation follows ResNet-D: + + He, Tong, et al. "Bag of tricks for image classification with convolutional neural networks." + Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. + + The stride sits in the 3x3 conv instead of the 1x1 conv! + The skip has an avgpool (if needed) followed by 1x1 conv instead of just a strided 1x1 conv + + :param conv_op: + :param input_channels: + :param output_channels: + :param kernel_size: only affects the conv in the middle (typically 3x3). The other convs remain 1x1 + :param stride: only applies to the conv in the middle (and skip). Note that this deviates from the canonical + ResNet implementation where the stride is applied to the first 1x1 conv. (This implementation follows ResNet-D) + :param conv_bias: + :param norm_op: + :param norm_op_kwargs: + :param dropout_op: only the second (kernel_size) conv can have dropout. The first and last conv (1x1(x1)) never have it + :param dropout_op_kwargs: + :param nonlin: + :param nonlin_kwargs: + :param stochastic_depth_p: + :param squeeze_excitation: + :param squeeze_excitation_reduction_ratio: + """ + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + self.bottleneck_channels = bottleneck_channels + stride = maybe_convert_scalar_to_list(conv_op, stride) + self.stride = stride + + kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size) + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + self.conv1 = ConvDropoutNormReLU(conv_op, input_channels, bottleneck_channels, 1, 1, conv_bias, + norm_op, norm_op_kwargs, None, None, nonlin, nonlin_kwargs) + self.conv2 = ConvDropoutNormReLU(conv_op, bottleneck_channels, bottleneck_channels, kernel_size, stride, + conv_bias, + norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs) + self.conv3 = ConvDropoutNormReLU(conv_op, bottleneck_channels, output_channels, 1, 1, conv_bias, norm_op, + norm_op_kwargs, None, None, None, None) + + self.nonlin3 = nonlin(**nonlin_kwargs) if nonlin is not None else lambda x: x + + # Stochastic Depth + self.apply_stochastic_depth = False if stochastic_depth_p == 0.0 else True + if self.apply_stochastic_depth: + self.drop_path = DropPath(drop_prob=stochastic_depth_p) + + # Squeeze Excitation + self.apply_se = squeeze_excitation + if self.apply_se: + self.squeeze_excitation = SqueezeExcite(self.output_channels, conv_op, + rd_ratio=squeeze_excitation_reduction_ratio, rd_divisor=8) + + has_stride = (isinstance(stride, int) and stride != 1) or any([i != 1 for i in stride]) + requires_projection = (input_channels != output_channels) + + if has_stride or requires_projection: + ops = [] + if has_stride: + ops.append(get_matching_pool_op(conv_op=conv_op, adaptive=False, pool_type='avg')(stride, stride)) + if requires_projection: + ops.append( + ConvDropoutNormReLU(conv_op, input_channels, output_channels, 1, 1, False, + norm_op, norm_op_kwargs, None, None, None, None + ) + ) + self.skip = nn.Sequential(*ops) + else: + self.skip = lambda x: x + + def forward(self, x): + residual = self.skip(x) + out = self.conv3(self.conv2(self.conv1(x))) + if self.apply_stochastic_depth: + out = self.drop_path(out) + if self.apply_se: + out = self.squeeze_excitation(out) + out += residual + return self.nonlin3(out) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + size_after_stride = [i // j for i, j in zip(input_size, self.stride)] + # conv1 + output_size_conv1 = np.prod([self.bottleneck_channels, *input_size], dtype=np.int64) + # conv2 + output_size_conv2 = np.prod([self.bottleneck_channels, *size_after_stride], dtype=np.int64) + # conv3 + output_size_conv3 = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + # skip conv (if applicable) + if (self.input_channels != self.output_channels) or any([i != j for i, j in zip(input_size, size_after_stride)]): + assert isinstance(self.skip, nn.Sequential) + output_size_skip = np.prod([self.output_channels, *size_after_stride], dtype=np.int64) + else: + assert not isinstance(self.skip, nn.Sequential) + output_size_skip = 0 + return output_size_conv1 + output_size_conv2 + output_size_conv3 + output_size_skip + + +class StackedResidualBlocks(nn.Module): + def __init__(self, + n_blocks: int, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: Union[int, List[int], Tuple[int, ...]], + kernel_size: Union[int, List[int], Tuple[int, ...]], + initial_stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, + bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, + stochastic_depth_p: float = 0.0, + squeeze_excitation: bool = False, + squeeze_excitation_reduction_ratio: float = 1. / 16 + ): + """ + Stack multiple instances of block. + + :param n_blocks: number of residual blocks + :param conv_op: nn.ConvNd class + :param input_channels: only relevant for forst block in the sequence. This is the input number of features. + After the first block, the number of features in the main path to which the residuals are added is output_channels + :param output_channels: number of features in the main path to which the residuals are added (and also the + number of features of the output) + :param kernel_size: kernel size for all nxn (n!=1) convolutions. Default: 3x3 + :param initial_stride: only affects the first block. All subsequent blocks have stride 1 + :param conv_bias: usually False + :param norm_op: nn.BatchNormNd, InstanceNormNd etc + :param norm_op_kwargs: dictionary of kwargs. Leave empty ({}) for defaults + :param dropout_op: nn.DropoutNd, can be None for no dropout + :param dropout_op_kwargs: + :param nonlin: + :param nonlin_kwargs: + :param block: BasicBlockD or BottleneckD + :param bottleneck_channels: if block is BottleneckD then we need to know the number of bottleneck features. + Bottleneck will use first 1x1 conv to reduce input to bottleneck features, then run the nxn (see kernel_size) + conv on that (bottleneck -> bottleneck). Finally the output will be projected back to output_channels + (bottleneck -> output_channels) with the final 1x1 conv + :param stochastic_depth_p: probability of applying stochastic depth in residual blocks + :param squeeze_excitation: whether to apply squeeze and excitation or not + :param squeeze_excitation_reduction_ratio: ratio by how much squeeze and excitation should reduce channels + respective to number of out channels of respective block + """ + super().__init__() + assert n_blocks > 0, 'n_blocks must be > 0' + assert block in [BasicBlockD, BottleneckD], 'block must be BasicBlockD or BottleneckD' + if not isinstance(output_channels, (tuple, list)): + output_channels = [output_channels] * n_blocks + if not isinstance(bottleneck_channels, (tuple, list)): + bottleneck_channels = [bottleneck_channels] * n_blocks + + if block == BasicBlockD: + blocks = nn.Sequential( + block(conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias, + norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, stochastic_depth_p, + squeeze_excitation, squeeze_excitation_reduction_ratio), + *[block(conv_op, output_channels[n - 1], output_channels[n], kernel_size, 1, conv_bias, norm_op, + norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, stochastic_depth_p, + squeeze_excitation, squeeze_excitation_reduction_ratio) for n in range(1, n_blocks)] + ) + else: + blocks = nn.Sequential( + block(conv_op, input_channels, bottleneck_channels[0], output_channels[0], kernel_size, + initial_stride, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, + nonlin, nonlin_kwargs, stochastic_depth_p, squeeze_excitation, squeeze_excitation_reduction_ratio), + *[block(conv_op, output_channels[n - 1], bottleneck_channels[n], output_channels[n], kernel_size, + 1, conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, + nonlin, nonlin_kwargs, stochastic_depth_p, squeeze_excitation, + squeeze_excitation_reduction_ratio) for n in range(1, n_blocks)] + ) + self.blocks = blocks + self.initial_stride = maybe_convert_scalar_to_list(conv_op, initial_stride) + self.output_channels = output_channels[-1] + + def forward(self, x): + return self.blocks(x) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.initial_stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + output = self.blocks[0].compute_conv_feature_map_size(input_size) + size_after_stride = [i // j for i, j in zip(input_size, self.initial_stride)] + for b in self.blocks[1:]: + output += b.compute_conv_feature_map_size(size_after_stride) + return output + + +if __name__ == '__main__': + data = torch.rand((1, 3, 40, 32)) + + stx = StackedResidualBlocks(2, nn.Conv2d, 24, (16, 16), (3, 3), (1, 2), + norm_op=nn.BatchNorm2d, nonlin=nn.ReLU, nonlin_kwargs={'inplace': True}, + block=BottleneckD, bottleneck_channels=3) + model = nn.Sequential(ConvDropoutNormReLU(nn.Conv2d, + 3, 24, 3, 1, True, nn.BatchNorm2d, {}, None, None, nn.LeakyReLU, + {'inplace': True}), + stx) + import hiddenlayer as hl + + g = hl.build_graph(model, data, + transforms=None) + g.save("network_architecture.pdf") + del g + + print(stx.compute_conv_feature_map_size((40, 32))) \ No newline at end of file diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual_encoders.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual_encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..55367626f2986c116ddd24387b9b143ae0548bb9 --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/residual_encoders.py @@ -0,0 +1,172 @@ +import torch +from torch import nn +import numpy as np +from typing import Union, Type, List, Tuple + +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd +from dynamic_network_architectures.building_blocks.residual import StackedResidualBlocks, BottleneckD, BasicBlockD +from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op +from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks + + +class ResidualEncoder(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, + bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, + return_skips: bool = False, + disable_default_stem: bool = False, + stem_channels: int = None, + pool_type: str = 'conv', + stochastic_depth_p: float = 0.0, + squeeze_excitation: bool = False, + squeeze_excitation_reduction_ratio: float = 1. / 16 + ): + """ + + :param input_channels: + :param n_stages: + :param features_per_stage: Note: If the block is BottleneckD, then this number is supposed to be the number of + features AFTER the expansion (which is not coded implicitly in this repository)! See todo! + :param conv_op: + :param kernel_sizes: + :param strides: + :param n_blocks_per_stage: + :param conv_bias: + :param norm_op: + :param norm_op_kwargs: + :param dropout_op: + :param dropout_op_kwargs: + :param nonlin: + :param nonlin_kwargs: + :param block: + :param bottleneck_channels: only needed if block is BottleneckD + :param return_skips: set this to True if used as encoder in a U-Net like network + :param disable_default_stem: If True then no stem will be created. You need to build your own and ensure it is executed first, see todo. + The stem in this implementation does not so stride/pooling so building your own stem is a necessity if you need this. + :param stem_channels: if None, features_per_stage[0] will be used for the default stem. Not recommended for BottleneckD + :param pool_type: if conv, strided conv will be used. avg = average pooling, max = max pooling + """ + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + if bottleneck_channels is None or isinstance(bottleneck_channels, int): + bottleneck_channels = [bottleneck_channels] * n_stages + assert len( + bottleneck_channels) == n_stages, "bottleneck_channels must be None or have as many entries as we have resolution stages (n_stages)" + assert len( + kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len( + n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len( + features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None + + # build a stem, Todo maybe we need more flexibility for this in the future. For now, if you need a custom + # stem you can just disable the stem and build your own. + # THE STEM DOES NOT DO STRIDE/POOLING IN THIS IMPLEMENTATION + if not disable_default_stem: + if stem_channels is None: + stem_channels = features_per_stage[0] + self.stem = StackedConvBlocks(1, conv_op, input_channels, stem_channels, kernel_sizes[0], 1, conv_bias, + norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs) + input_channels = stem_channels + else: + self.stem = None + + # now build the network + stages = [] + for s in range(n_stages): + stride_for_conv = strides[s] if pool_op is None else 1 + + stage = StackedResidualBlocks( + n_blocks_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], stride_for_conv, + conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, + block=block, bottleneck_channels=bottleneck_channels[s], stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, + squeeze_excitation_reduction_ratio=squeeze_excitation_reduction_ratio + ) + + if pool_op is not None: + stage = nn.Sequential(pool_op(strides[s]), stage) + + stages.append(stage) + input_channels = features_per_stage[s] + + self.stages = nn.Sequential(*stages) + self.output_channels = features_per_stage + self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + self.return_skips = return_skips + + # we store some things that a potential decoder needs + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + ret = [] + for s in self.stages: + x = s(x) + ret.append(x) + if self.return_skips: + return ret + else: + return [ret[-1]] + + def compute_conv_feature_map_size(self, input_size): + if self.stem is not None: + output = self.stem.compute_conv_feature_map_size(input_size) + else: + output = np.int64(0) + + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + + return output + + +if __name__ == '__main__': + data = torch.rand((1, 3, 128, 160)) + + model = ResidualEncoder(3, 5, (2, 4, 6, 8, 10), nn.Conv2d, 3, ((1, 1), 2, (2, 2), (2, 2), (2, 2)), 2, False, + nn.BatchNorm2d, None, None, None, nn.ReLU, None, stem_channels=7) + import hiddenlayer as hl + + g = hl.build_graph(model, data, + transforms=None) + g.save("network_architecture.pdf") + del g + + print(model.compute_conv_feature_map_size((128, 160))) diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/simple_conv_blocks.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/simple_conv_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..295fdd6866956d5dabed9c2366100b357d11eb2b --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/simple_conv_blocks.py @@ -0,0 +1,167 @@ +from typing import Tuple, List, Union, Type + +import numpy as np +import torch.nn +from torch import nn +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd + +from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list + + +class ConvDropoutNormReLU(nn.Module): + def __init__(self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + nonlin_first: bool = False + ): + super(ConvDropoutNormReLU, self).__init__() + self.input_channels = input_channels + self.output_channels = output_channels + stride = maybe_convert_scalar_to_list(conv_op, stride) + self.stride = stride + + kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size) + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + ops = [] + + self.conv = conv_op( + input_channels, + output_channels, + kernel_size, + stride, + padding=[(i - 1) // 2 for i in kernel_size], + dilation=1, + bias=conv_bias, + ) + ops.append(self.conv) + + if dropout_op is not None: + self.dropout = dropout_op(**dropout_op_kwargs) + ops.append(self.dropout) + + if norm_op is not None: + self.norm = norm_op(output_channels, **norm_op_kwargs) + ops.append(self.norm) + + if nonlin is not None: + self.nonlin = nonlin(**nonlin_kwargs) + ops.append(self.nonlin) + + if nonlin_first and (norm_op is not None and nonlin is not None): + ops[-1], ops[-2] = ops[-2], ops[-1] + + self.all_modules = nn.Sequential(*ops) + + def forward(self, x): + return self.all_modules(x) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + output_size = [i // j for i, j in zip(input_size, self.stride)] # we always do same padding + return np.prod([self.output_channels, *output_size], dtype=np.int64) + + +class StackedConvBlocks(nn.Module): + def __init__(self, + num_convs: int, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: Union[int, List[int], Tuple[int, ...]], + kernel_size: Union[int, List[int], Tuple[int, ...]], + initial_stride: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + nonlin_first: bool = False + ): + """ + + :param conv_op: + :param num_convs: + :param input_channels: + :param output_channels: can be int or a list/tuple of int. If list/tuple are provided, each entry is for + one conv. The length of the list/tuple must then naturally be num_convs + :param kernel_size: + :param initial_stride: + :param conv_bias: + :param norm_op: + :param norm_op_kwargs: + :param dropout_op: + :param dropout_op_kwargs: + :param nonlin: + :param nonlin_kwargs: + """ + super().__init__() + if not isinstance(output_channels, (tuple, list)): + output_channels = [output_channels] * num_convs + + self.convs = nn.Sequential( + ConvDropoutNormReLU( + conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias, norm_op, + norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first + ), + *[ + ConvDropoutNormReLU( + conv_op, output_channels[i - 1], output_channels[i], kernel_size, 1, conv_bias, norm_op, + norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first + ) + for i in range(1, num_convs) + ] + ) + + self.output_channels = output_channels[-1] + self.initial_stride = maybe_convert_scalar_to_list(conv_op, initial_stride) + + def forward(self, x): + return self.convs(x) + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == len(self.initial_stride), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + output = self.convs[0].compute_conv_feature_map_size(input_size) + size_after_stride = [i // j for i, j in zip(input_size, self.initial_stride)] + for b in self.convs[1:]: + output += b.compute_conv_feature_map_size(size_after_stride) + return output + + +if __name__ == '__main__': + data = torch.rand((1, 3, 40, 32)) + + stx = StackedConvBlocks(2, nn.Conv2d, 24, 16, (3, 3), 2, + norm_op=nn.BatchNorm2d, nonlin=nn.ReLU, nonlin_kwargs={'inplace': True}, + ) + model = nn.Sequential(ConvDropoutNormReLU(nn.Conv2d, + 3, 24, 3, 1, True, nn.BatchNorm2d, {}, None, None, nn.LeakyReLU, + {'inplace': True}), + stx) + import hiddenlayer as hl + + g = hl.build_graph(model, data, + transforms=None) + g.save("network_architecture.pdf") + del g + + stx.compute_conv_feature_map_size((40, 32)) \ No newline at end of file diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/unet_decoder.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/unet_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..feda9fbffffb40f82bcdd88c7451dfc1eb80c199 --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/building_blocks/unet_decoder.py @@ -0,0 +1,261 @@ +import numpy as np +import torch +from torch import nn +from typing import Union, List, Tuple +from einops import rearrange +from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks +from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp +from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder +from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder + + +class UNetDecoder(nn.Module): + def __init__(self, + encoder, + n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], + deep_supervision, nonlin_first: bool = False): + """ + This class needs the skips of the encoder as input in its forward. + + the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder + are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck + features and the lowest skip as inputs + the decoder has two (three) parts in each stage: + 1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage) + 2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge + 3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits? + :param encoder: + :param n_conv_per_stage: + :param deep_supervision: + """ + super().__init__() + self.deep_supervision = deep_supervision + n_stages_encoder = len(encoder.output_channels) + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) + assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ + "resolution stages - 1 (n_stages in encoder - 1), " \ + "here: %d" % n_stages_encoder + + transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) + + # we start with the bottleneck and work out way up + stages = [] + transpconvs = [] + # seg_layers = [] + for s in range(1, n_stages_encoder): + input_features_below = encoder.output_channels[-s] + input_features_skip = encoder.output_channels[-(s + 1)] + stride_for_transpconv = encoder.strides[-s] + transpconvs.append(transpconv_op( + input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, + bias=encoder.conv_bias + )) + # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) + stages.append(StackedConvBlocks( + n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip, + encoder.kernel_sizes[-(s + 1)], 1, encoder.conv_bias, encoder.norm_op, encoder.norm_op_kwargs, + encoder.dropout_op, encoder.dropout_op_kwargs, encoder.nonlin, encoder.nonlin_kwargs, nonlin_first + )) + + # we always build the deep supervision outputs so that we can always load parameters. If we don't do this + # then a model trained with deep_supervision=True could not easily be loaded at inference time where + # deep supervision is not needed. It's just a convenience thing + # seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) + + self.stages = nn.ModuleList(stages) + self.transpconvs = nn.ModuleList(transpconvs) + # self.seg_layers = nn.ModuleList(seg_layers) + + def forward(self, skips): + """ + we expect to get the skips in the order they were computed, so the bottleneck should be the last entry + :param skips: + :return: + """ + lres_input = skips[-1] + seg_outputs = [] + for s in range(len(self.stages)): + x = self.transpconvs[s](lres_input) + x = torch.cat((x, skips[-(s+2)]), 1) + x = self.stages[s](x) + seg_outputs.append(x) + #if self.deep_supervision: + # seg_outputs.append(self.seg_layers[s](x)) + #elif s == (len(self.stages) - 1): + # seg_outputs.append(self.seg_layers[-1](x)) + lres_input = x + + # invert seg outputs so that the largest segmentation prediction is returned first + seg_outputs = seg_outputs[::-1] + + if not self.deep_supervision: + r = [seg_outputs[0]] + else: + r = seg_outputs + return r + +class UNetDecoder_Prompt(nn.Module): + def __init__(self, + encoder, + n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], + deep_supervision, nonlin_first: bool = False): + """ + This class needs the skips of the encoder as input in its forward. + + the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder + are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck + features and the lowest skip as inputs + the decoder has two (three) parts in each stage: + 1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage) + 2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge + 3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits? + :param encoder: + :param n_conv_per_stage: + :param deep_supervision: + """ + super().__init__() + self.deep_supervision = deep_supervision + n_stages_encoder = len(encoder.output_channels) + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) + assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ + "resolution stages - 1 (n_stages in encoder - 1), " \ + "here: %d" % n_stages_encoder + + transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) + + # we start with the bottleneck and work out way up + stages = [] + transpconvs = [] + for s in range(1, n_stages_encoder): + input_features_below = encoder.output_channels[-s] + input_features_skip = encoder.output_channels[-(s + 1)] + stride_for_transpconv = encoder.strides[-s] + transpconvs.append(transpconv_op( + input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, + bias=encoder.conv_bias + )) + # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) + stages.append(StackedConvBlocks( + n_conv_per_stage[s-1], encoder.conv_op, 3 * input_features_skip, input_features_skip, + encoder.kernel_sizes[-(s + 1)], 1, encoder.conv_bias, encoder.norm_op, encoder.norm_op_kwargs, + encoder.dropout_op, encoder.dropout_op_kwargs, encoder.nonlin, encoder.nonlin_kwargs, nonlin_first + )) + + self.stages = nn.ModuleList(stages) + self.transpconvs = nn.ModuleList(transpconvs) + + def forward(self, skips, mask_embedding=None, mask_embed_proj=None, mid_mask_embed_proj=None, simulated_lowres_mc_pred=None, B=2, N=32): # [2, 32, 256, 256, 96] ... [2, 512, 16, 16, 6] + """ + we expect to get the skips in the order they were computed, so the bottleneck should be the last entry + :param skips: + :return: + """ + lres_input = skips[-1] + logits = [] + for s in range(len(self.stages)): + x = self.transpconvs[s](lres_input) + + if s == (len(self.stages) - 1): + mask_embedding_s = mask_embed_proj(mask_embedding) # 768 -> 128/64/48 + else: + mask_embedding_s = mid_mask_embed_proj[s](mask_embedding) + + mask_embedding_s = rearrange(mask_embedding_s, '(b n) dim -> b n dim', b=B, n=N) + + # [B,N,C] -> [B,C,N] + mask_emb_transposed = mask_embedding_s.permute(0, 2, 1) + simulated_lowres_emb = torch.einsum('bnhwd,bcn->bchwd', simulated_lowres_mc_pred[s], mask_emb_transposed) + + x = torch.cat((x, skips[-(s+2)], simulated_lowres_emb), 1) + x = self.stages[s](x) + + logits.append(torch.einsum('bchwd,bnc->bnhwd', x, mask_embedding_s)) + lres_input = x + + # invert logits so that the largest segmentation prediction is returned first + logits = logits[::-1] + + if not self.deep_supervision: + r = [logits[0]] + else: + r = logits + return r + +class UNetDecoder_Seg(nn.Module): + def __init__(self, + encoder, + num_classes: int, + n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], + deep_supervision, nonlin_first: bool = False): + """ + This class needs the skips of the encoder as input in its forward. + + the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder + are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck + features and the lowest skip as inputs + the decoder has two (three) parts in each stage: + 1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage) + 2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge + 3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits? + :param num_classes: + :param n_conv_per_stage: + :param deep_supervision: + """ + super().__init__() + self.deep_supervision = deep_supervision + self.num_classes = num_classes + n_stages_encoder = len(encoder.output_channels) + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) + assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ + "resolution stages - 1 (n_stages in encoder - 1), " \ + "here: %d" % n_stages_encoder + + transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) + + # we start with the bottleneck and work out way up + stages = [] + transpconvs = [] + for s in range(1, n_stages_encoder): + input_features_below = encoder.output_channels[-s] + input_features_skip = encoder.output_channels[-(s + 1)] + stride_for_transpconv = encoder.strides[-s] + transpconvs.append(transpconv_op( + input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, + bias=encoder.conv_bias + )) + # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) + stages.append(StackedConvBlocks( + n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip, + encoder.kernel_sizes[-(s + 1)], 1, encoder.conv_bias, encoder.norm_op, encoder.norm_op_kwargs, + encoder.dropout_op, encoder.dropout_op_kwargs, encoder.nonlin, encoder.nonlin_kwargs, nonlin_first + )) + + self.stages = nn.ModuleList(stages) + self.transpconvs = nn.ModuleList(transpconvs) + self.seg_layer = encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True) + + def forward(self, skips): + """ + we expect to get the skips in the order they were computed, so the bottleneck should be the last entry + :param skips: + :return: + """ + lres_input = skips[-1] + seg_outputs = [] + for s in range(len(self.stages)): + x = self.transpconvs[s](lres_input) + x = torch.cat((x, skips[-(s+2)]), 1) + x = self.stages[s](x) + seg_outputs.append(x) + lres_input = x + + # invert seg outputs so that the largest segmentation prediction is returned first + seg_outputs = seg_outputs[::-1] + + output = self.seg_layer(seg_outputs[0]) # B C H W D + + return output + diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/__init__.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/__pycache__/__init__.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..785fa45db0ddb4abc04fb3ab6d94878a82691c06 Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/__pycache__/__init__.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/__pycache__/weight_init.cpython-310.pyc b/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/__pycache__/weight_init.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8314de14773b7cf71fbb815cc4e18e8616d5a7a Binary files /dev/null and b/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/__pycache__/weight_init.cpython-310.pyc differ diff --git a/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/weight_init.py b/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/weight_init.py new file mode 100644 index 0000000000000000000000000000000000000000..7c164e92ccc5ecd9b113a4daee0bba0112c8ad4b --- /dev/null +++ b/model/dynamic-network-architectures-main/dynamic_network_architectures/initialization/weight_init.py @@ -0,0 +1,31 @@ +from torch import nn + +from dynamic_network_architectures.building_blocks.residual import BasicBlockD + + +class InitWeights_He(object): + def __init__(self, neg_slope: float = 1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +class InitWeights_XavierUniform(object): + def __init__(self, gain: int = 1): + self.gain = gain + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.xavier_uniform_(module.weight, self.gain) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +def init_last_bn_before_add_to_0(module): + if isinstance(module, BasicBlockD): + module.conv2.norm.weight = nn.init.constant_(module.conv2.norm.weight, 0) + module.conv2.norm.bias = nn.init.constant_(module.conv2.norm.bias, 0) diff --git a/model/dynamic-network-architectures-main/imgs/Logos/DKFZ_Logo.png b/model/dynamic-network-architectures-main/imgs/Logos/DKFZ_Logo.png new file mode 100644 index 0000000000000000000000000000000000000000..e7f277f3cada65c5e0ed721ca0cd1f81df77465c Binary files /dev/null and b/model/dynamic-network-architectures-main/imgs/Logos/DKFZ_Logo.png differ diff --git a/model/dynamic-network-architectures-main/imgs/Logos/HI_Logo.png b/model/dynamic-network-architectures-main/imgs/Logos/HI_Logo.png new file mode 100644 index 0000000000000000000000000000000000000000..2d29d69625be75bb79534a023a722a33a2e3683e --- /dev/null +++ b/model/dynamic-network-architectures-main/imgs/Logos/HI_Logo.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50a2300e6ddbd62bc1524fe3b3382a2b7a654cdecbb87269f0671b43fd6dbc88 +size 450178 diff --git a/model/dynamic-network-architectures-main/setup.py b/model/dynamic-network-architectures-main/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..7cd25b4903e6aa98f26f242fdf48904425ac5781 --- /dev/null +++ b/model/dynamic-network-architectures-main/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup, find_namespace_packages + +setup(name='dynamic_network_architectures', + packages=find_namespace_packages(include=["dynamic_network_architectures", "dynamic_network_architectures.*"]), + version='0.2', + description='none', + author='Fabian Isensee', + author_email='f.isensee@dkfz.de', + license='private', + install_requires=[ + "torch>=1.6.0a", + "numpy" + ], + zip_safe=False) diff --git a/model/knowledge_encoder.py b/model/knowledge_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1eaea8a7c2cbb9cb5d75b2a5bd50f0365c5b9068 --- /dev/null +++ b/model/knowledge_encoder.py @@ -0,0 +1,27 @@ +import torch.nn as nn + +from .text_tower import Text_Tower + +class Knowledge_Encoder(nn.Module): + def __init__(self, biolord_checkpoint='FremyCompany/BioLORD-2023-C'): + super().__init__() + # LP + self.text_tower = Text_Tower(biolord_checkpoint) + self.projection_layer = nn.Sequential( + nn.Linear(768, 768), + nn.GELU(), + nn.Linear(768, 768) + ) + self.modality_embed = nn.Embedding(5, 768) + + def forward(self, text, modality, device): + text_feature = self.text_tower(text, device) + proj_text_feature = self.projection_layer(text_feature) + + modality_feature = self.modality_embed(modality) + + text_feature = text_feature + modality_feature + proj_text_feature = proj_text_feature + modality_feature + + # return text_feature, proj_text_feature + return proj_text_feature \ No newline at end of file diff --git a/model/maskformer.py b/model/maskformer.py new file mode 100644 index 0000000000000000000000000000000000000000..87f6a1ae7dd90f23089e20d8692756c466b1685c --- /dev/null +++ b/model/maskformer.py @@ -0,0 +1,400 @@ +import random +from typing import Tuple, Union, List + +import torch.nn as nn +import torch.nn.functional as F +import torch +from einops import rearrange, repeat, reduce +from positional_encodings.torch_encodings import PositionalEncoding3D +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks +from dynamic_network_architectures.initialization.weight_init import InitWeights_He + +from .transformer_decoder import TransformerDecoder,TransformerDecoderLayer +from .SwinUNETR import SwinUNETR + +class Maskformer(nn.Module): + def __init__(self, vision_backbone='UNET', input_channels=1, image_size=[288, 288, 96], patch_size=[32, 32, 32], deep_supervision=False): + """ + Args: + vision_backbone (str, optional): visual backbone. Defaults to UNET. + image_size (list, optional): image size. Defaults to [288, 288, 96]. + patch_size (list, optional): maxium downsample ratio of the bottleneck feature map. Defaults to [32, 32, 32]. + deep_supervision (bool, optional): seg results from mid layers of decoder. Defaults to False. + """ + super().__init__() + image_height, image_width, frames = image_size + self.hw_patch_size = patch_size[0] + self.frame_patch_size = patch_size[-1] + + self.deep_supervision = deep_supervision + + # backbone can be any multi-scale enc-dec vision backbone + # the enc outputs multi-scale latent features + # the dec outputs multi-scale per-pixel features + self.backbone = { + 'SwinUNETR' : SwinUNETR( + img_size=[288, 288, 96], # 48, 48, 96, 192, 384, 768 + in_channels=3, + feature_size=48, + drop_rate=0.0, + attn_drop_rate=0.0, + dropout_path_rate=0.0, + use_checkpoint=False, + ), + 'UNET' : PlainConvUNet(input_channels=input_channels, #3, + n_stages=6, + features_per_stage=(64, 64, 128, 256, 512, 768), + conv_op=nn.Conv3d, + kernel_sizes=3, + strides=(1, 2, 2, 2, 2, 2), + n_conv_per_stage=(2, 2, 2, 2, 2, 2), + n_conv_per_stage_decoder=(2, 2, 2, 2, 2), + conv_bias=True, + norm_op=nn.InstanceNorm3d, + norm_op_kwargs={'eps': 1e-5, 'affine': True}, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, + nonlin_kwargs=None, + deep_supervision=deep_supervision, + nonlin_first=False + ), + 'UNET-L' : PlainConvUNet(input_channels=3, + n_stages=6, + features_per_stage=(128, 128, 256, 512, 1024, 1536), + conv_op=nn.Conv3d, + kernel_sizes=3, + strides=(1, 2, 2, 2, 2, 2), + n_conv_per_stage=(3, 3, 3, 3, 3, 3), + n_conv_per_stage_decoder=(3, 3, 3, 3, 3), + conv_bias=True, + norm_op=nn.InstanceNorm3d, + norm_op_kwargs={'eps': 1e-5, 'affine': True}, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, + nonlin_kwargs=None, + deep_supervision=deep_supervision, + nonlin_first=False + ), + 'UNET-H' : PlainConvUNet(input_channels=3, + n_stages=6, + features_per_stage=(256, 256, 512, 1024, 1536, 2048), + conv_op=nn.Conv3d, + kernel_sizes=3, + strides=(1, 2, 2, 2, 2, 2), + n_conv_per_stage=(3, 3, 3, 3, 3, 3), + n_conv_per_stage_decoder=(3, 3, 3, 3, 3), + conv_bias=True, + norm_op=nn.InstanceNorm3d, + norm_op_kwargs={'eps': 1e-5, 'affine': True}, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, + nonlin_kwargs=None, + deep_supervision=deep_supervision, + nonlin_first=False + ), + 'UNET-Res' : ResidualEncoderUNet( + input_channels=input_channels, + n_stages=6, + features_per_stage=[32, 64, 128, 256, 320, 320], + conv_op=nn.Conv3d, + kernel_sizes=[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], + strides=[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], + n_blocks_per_stage=[1, 3, 4, 6, 6, 6], + n_conv_per_stage_decoder=[1, 1, 1, 1, 1], + conv_bias=True, + norm_op=nn.InstanceNorm3d, + norm_op_kwargs={"eps": 1e-5, "affine": True}, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, + deep_supervision=deep_supervision, + ) + }[vision_backbone] + + self.backbone.apply(InitWeights_He(1e-2)) + + # fixed to text encoder out dim + if vision_backbone == 'UNET-H': + query_dim = 1536 + elif vision_backbone == 'UNET-Res': + query_dim = 320 + else: + query_dim = 768 + + # all backbones are 6-depth, thus the first 5 scale latent feature outputs need to be down-sampled + self.avg_pool_ls = [ + nn.AvgPool3d(32, 32), + nn.AvgPool3d(16, 16), + nn.AvgPool3d(8, 8), + nn.AvgPool3d(4, 4), + nn.AvgPool3d(2, 2), + ] + + # multi-scale latent feature are projected to query_dim before query decoder + self.projection_layer = { + 'SwinUNETR' : nn.Sequential( + nn.Linear(1536, 768), + nn.GELU(), + nn.Linear(768, query_dim), + nn.GELU() + ), + 'UNET' : nn.Sequential( + nn.Linear(1792, 768), + nn.GELU(), + nn.Linear(768, query_dim), + nn.GELU() + ), + 'UNET-L' : nn.Sequential( + nn.Linear(3584, 1536), # 128, 128, 256, 512, 1024, 1536 --> 3584 --> 768 + nn.GELU(), + nn.Linear(1536, query_dim), + nn.GELU() + ), + 'UNET-H' : nn.Sequential( + nn.Linear(5632, 3072), + nn.GELU(), + nn.Linear(3072, query_dim), + nn.GELU() + ), + 'UNET-Res' : nn.Sequential( + nn.Linear(1120, 320), + nn.GELU(), + nn.Linear(320, query_dim), + nn.GELU() + ) + }[vision_backbone] + + # positional encoding + pos_embedding = PositionalEncoding3D(query_dim)(torch.zeros(1, (image_height//self.hw_patch_size), (image_width//self.hw_patch_size), (frames//self.frame_patch_size), query_dim)) # b h/p w/p d/p dim + self.pos_embedding = rearrange(pos_embedding, 'b h w d c -> (h w d) b c') # n b dim + + # (fused latent embeddings + pe) x query prompts + decoder_layer = TransformerDecoderLayer(d_model=query_dim, nhead=8, normalize_before=True) + decoder_norm = nn.LayerNorm(query_dim) + self.transformer_decoder = TransformerDecoder(decoder_layer=decoder_layer, num_layers=6, norm=decoder_norm) + + if query_dim != 768: + self.query_proj = nn.Sequential( + nn.Linear(768, query_dim), + nn.GELU(), + nn.Linear(query_dim, query_dim), + nn.GELU() + ) + else: + self.query_proj = nn.Identity() + + # mask embedding are projected to perpixel_dim + # mid stage output (only consider the last 3 mid layers of decoder, i.e. feature maps with resolution /2 /4 /8) + if self.deep_supervision: + feature_per_stage = { + 'SwinUNETR':[48, 96, 192], + 'UNET':[64, 128, 256], + 'UNET-L':[128, 256, 512], + 'UNET-H':[256, 512, 1024], + 'UNET-Res':[64, 128, 256] + }[vision_backbone] + mid_dim = { + 'SwinUNETR':[256, 384, 512], + 'UNET':[256, 384, 512], + 'UNET-L':[384, 512, 512], + 'UNET-H':[768, 1024, 1024], + 'UNET-Res':[256, 320, 320] + }[vision_backbone] + self.mid_mask_embed_proj = [] + for hidden_dim, per_pixel_dim in zip(mid_dim, feature_per_stage): + self.mid_mask_embed_proj.append( + nn.Sequential( + nn.Linear(query_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, per_pixel_dim), + nn.GELU(), + ), + ) + self.mid_mask_embed_proj = nn.ModuleList(self.mid_mask_embed_proj) + + # largest output + mid_dim, per_pixel_dim = { + 'SwinUNETR' : [256, 48], + 'UNET' : [256, 64], + 'UNET-L' : [384, 128], + 'UNET-H' : [768, 256], + 'UNET-Res' : [128, 32] + }[vision_backbone] + self.mask_embed_proj = nn.Sequential( + nn.Linear(query_dim, mid_dim), + nn.GELU(), + nn.Linear(mid_dim, per_pixel_dim), + nn.GELU(), + ) + + self.fusion_conv = StackedConvBlocks( + 1, nn.Conv3d, 2 * per_pixel_dim, per_pixel_dim, + 3, 1, True, nn.InstanceNorm3d, {'eps': 1e-5, 'affine': True}, + None, None, nn.LeakyReLU, None, False) + + def enhance_with_coarse_pred(self, pixel_emb, mask_emb, coarse_pred): + """ + Enhance pixel embeddings with coarse prediction information + + Args: + pixel_emb (torch.tensor): B,C,H,W,D per-pixel embeddings + mask_emb (torch.tensor): B,N,C mask embeddings + coarse_pred (torch.tensor): B,N,H,W,D coarse prediction probabilities + + Returns: + torch.tensor: enhanced pixel embeddings B,C,H,W,D + """ + + # [B,N,C] -> [B,C,N] + mask_emb_transposed = mask_emb.permute(0, 2, 1) + + enhanced_emb = torch.einsum('bnhwd,bcn->bchwd', coarse_pred, mask_emb_transposed) + + combined = torch.cat([pixel_emb, enhanced_emb], dim=1) + + enhanced_pixel_emb = self.fusion_conv(combined) # B,C,H,W,D + + return enhanced_pixel_emb + + def vision_backbone_forward(self, image_input): + """ + Visual backbone forward + + Args: + image_input (torch.tensor): C,H,W,D (C=1) + + Returns: + image_embedding (torch.tensor): multiscale image features from encoder layers. N,B,d + pos (torch.tensor): position encoding. N,B,d + per_pixel_embedding_ls (List of torch.tensor): perpixel embeddings from decoder layers. B,d,H,W,D + """ + + # Image Encoder and Pixel Decoder + latent_embedding_ls, per_pixel_embedding_ls = self.backbone(image_input) # B Dim H/P W/P D/P + + # avg pooling each multiscale feature to H/P W/P D/P + image_embedding = [] + for latent_embedding, avg_pool in zip(latent_embedding_ls, self.avg_pool_ls): + tmp = avg_pool(latent_embedding) + image_embedding.append(tmp) # B ? H/P W/P D/P + image_embedding.append(latent_embedding_ls[-1]) + + # aggregate multiscale features into image embedding (and proj to align with query dim) + image_embedding = torch.cat(image_embedding, dim=1) + image_embedding = rearrange(image_embedding, 'b d h w depth -> b h w depth d') + image_embedding = self.projection_layer(image_embedding) # B H/P W/P D/P Dim + image_embedding = rearrange(image_embedding, 'b h w d dim -> (h w d) b dim') # (H/P W/P D/P) B Dim + + # add pe to image embedding + pos = self.pos_embedding.to(latent_embedding_ls[-1].device) # (H/P W/P D/P) B Dim + + return image_embedding, pos, per_pixel_embedding_ls + + def infer_forward(self, q, image_embedding, pos, per_pixel_embedding_ls, simulated_lowres_mc_pred=None): + """ + infer batches of queries (a list) on a batch of patches + + Args: + q (List of torch.tensor): N,d + simulated_lowres_mc_pred (torch.tensor, optional): B,N,H,W,D low-res multi-channel prediction + + Returns: + logits (torch.tensor): concat seg output of all queries. B,N_all,H,W,D + """ + _, B, _ = image_embedding.shape + + # query decoder + N,_ = q.shape # N is the num of query + q = repeat(q, 'n dim -> n b dim', b=B) # N B Dim NOTE:By default, attention in torch is not batch_first + q = self.query_proj(q) + mask_embedding,_ = self.transformer_decoder(q, image_embedding, pos = pos) # N B Dim + mask_embedding = rearrange(mask_embedding, 'n b dim -> (b n) dim') # (B N) Dim + + # Dot product + mask_embedding = self.mask_embed_proj(mask_embedding) # 768 -> 128/64/48 + mask_embedding = rearrange(mask_embedding, '(b n) dim -> b n dim', b=B, n=N) + per_pixel_embedding = per_pixel_embedding_ls[0] # decoder最后一层的输出 + + # Enhance features with low-res multi-channel prediction if available + if simulated_lowres_mc_pred is not None: + per_pixel_embedding = self.enhance_with_coarse_pred( + per_pixel_embedding, + mask_embedding, + simulated_lowres_mc_pred) + + logits = torch.einsum('bchwd,bnc->bnhwd', per_pixel_embedding, mask_embedding) # bnhwd + + return logits + + def train_forward(self, queries, image_embedding, pos, per_pixel_embedding_ls, simulated_lowres_mc_pred=None): + """ + Args: + queries (torch.tensor): B,N,d + simulated_lowres_mc_pred (torch.tensor, optional): B,N,H,W,D low-res multi-channel prediction + + Returns: + logits (List of torch.tensor): list of seg results. B,N,H,W,D + """ + _, B, _ = image_embedding.shape + + # query decoder + _, N, _ = queries.shape # N is the num of query + queries = rearrange(queries, 'b n dim -> n b dim') # N B Dim NOTE:By default, attention in torch is not batch_first + queries = self.query_proj(queries) + mask_embedding,_ = self.transformer_decoder(queries, image_embedding, pos = pos) # N B Dim + mask_embedding = rearrange(mask_embedding, 'n b dim -> (b n) dim') # (B N) Dim + + # Dot product + last_mask_embedding = self.mask_embed_proj(mask_embedding) # 768 -> 128/64/48 + last_mask_embedding = rearrange(last_mask_embedding, '(b n) dim -> b n dim', b=B, n=N) + per_pixel_embedding = per_pixel_embedding_ls[0] # decoder最后一层的输出 + + # Enhance features with low-res multi-channel prediction if available + if simulated_lowres_mc_pred is not None: + per_pixel_embedding = self.enhance_with_coarse_pred( + per_pixel_embedding, + last_mask_embedding, + simulated_lowres_mc_pred) + + logits = [torch.einsum('bchwd,bnc->bnhwd', per_pixel_embedding, last_mask_embedding)] + + # deep supervision + if self.deep_supervision: + for mask_embed_proj, per_pixel_embedding in zip(self.mid_mask_embed_proj, per_pixel_embedding_ls[1:]): # H/2 --> H/16 + mid_mask_embedding = mask_embed_proj(mask_embedding) + mid_mask_embedding = rearrange(mid_mask_embedding, '(b n) dim -> b n dim', b=B, n=N) + + logits.append(torch.einsum('bchwd,bnc->bnhwd', per_pixel_embedding, mid_mask_embedding)) + + return logits + + def forward(self, queries, image_input, simulated_lowres_sc_pred=None, simulated_lowres_mc_pred=None, train_mode=True): + # Handle single-channel low-res prediction if provided + if simulated_lowres_sc_pred is not None: + # concatenate image and simulated low-res single channel prediction + image_input = torch.cat([image_input, simulated_lowres_sc_pred], dim=1) # b2whd + + # get vision features + image_embedding, pos, per_pixel_embedding_ls = self.vision_backbone_forward(image_input) + + # Train Forward ----------------------------------------------------------------------- + if train_mode: + logits = self.train_forward(queries, image_embedding, pos, per_pixel_embedding_ls, simulated_lowres_mc_pred) + + # Infer / Evaluate Forward ------------------------------------------------------------ + else: + del image_input + torch.cuda.empty_cache() + logits = self.infer_forward(queries, image_embedding, pos, per_pixel_embedding_ls, simulated_lowres_mc_pred) + + return logits + +if __name__ == '__main__': + model = Maskformer().cuda() + image = torch.rand((1, 3, 288, 288, 96)).cuda() + query = torch.rand((2, 10, 768)).cuda() + segmentations = model(query, image) + print(segmentations.shape) \ No newline at end of file diff --git a/model/maskformer_multi_scale.py b/model/maskformer_multi_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..b02cd29fc39e8b1d2a1d828516e41b6c6bbfc530 --- /dev/null +++ b/model/maskformer_multi_scale.py @@ -0,0 +1,290 @@ +import random +from typing import Tuple, Union, List + +import torch.nn as nn +import torch.nn.functional as F +import torch +from einops import rearrange, repeat, reduce +from positional_encodings.torch_encodings import PositionalEncoding3D +from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet +from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks +from dynamic_network_architectures.initialization.weight_init import InitWeights_He + +from .transformer_decoder import TransformerDecoder,TransformerDecoderLayer +from .SwinUNETR import SwinUNETR + +class Maskformer(nn.Module): + def __init__(self, vision_backbone='UNET', input_channels=1, image_size=[288, 288, 96], patch_size=[32, 32, 32], deep_supervision=False): + """ + Args: + vision_backbone (str, optional): visual backbone. Defaults to UNET. + image_size (list, optional): image size. Defaults to [288, 288, 96]. + patch_size (list, optional): maxium downsample ratio of the bottleneck feature map. Defaults to [32, 32, 32]. + deep_supervision (bool, optional): seg results from mid layers of decoder. Defaults to False. + """ + super().__init__() + image_height, image_width, frames = image_size + self.hw_patch_size = patch_size[0] + self.frame_patch_size = patch_size[-1] + + self.deep_supervision = deep_supervision + + # backbone can be any multi-scale enc-dec vision backbone + # the enc outputs multi-scale latent features + # the dec outputs multi-scale per-pixel features + self.backbone = { + 'SwinUNETR' : SwinUNETR( + img_size=[288, 288, 96], # 48, 48, 96, 192, 384, 768 + in_channels=3, + feature_size=48, + drop_rate=0.0, + attn_drop_rate=0.0, + dropout_path_rate=0.0, + use_checkpoint=False, + ), + 'UNET' : PlainConvUNet(input_channels=input_channels, #3, + n_stages=6, + features_per_stage=(64, 64, 128, 256, 512, 768), + conv_op=nn.Conv3d, + kernel_sizes=3, + strides=(1, 2, 2, 2, 2, 2), + n_conv_per_stage=(2, 2, 2, 2, 2, 2), + n_conv_per_stage_decoder=(2, 2, 2, 2, 2), + conv_bias=True, + norm_op=nn.InstanceNorm3d, + norm_op_kwargs={'eps': 1e-5, 'affine': True}, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, + nonlin_kwargs=None, + deep_supervision=deep_supervision, + nonlin_first=False + ), + 'UNET-L' : PlainConvUNet(input_channels=3, + n_stages=6, + features_per_stage=(128, 128, 256, 512, 1024, 1536), + conv_op=nn.Conv3d, + kernel_sizes=3, + strides=(1, 2, 2, 2, 2, 2), + n_conv_per_stage=(3, 3, 3, 3, 3, 3), + n_conv_per_stage_decoder=(3, 3, 3, 3, 3), + conv_bias=True, + norm_op=nn.InstanceNorm3d, + norm_op_kwargs={'eps': 1e-5, 'affine': True}, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, + nonlin_kwargs=None, + deep_supervision=deep_supervision, + nonlin_first=False + ), + 'UNET-H' : PlainConvUNet(input_channels=3, + n_stages=6, + features_per_stage=(256, 256, 512, 1024, 1536, 2048), + conv_op=nn.Conv3d, + kernel_sizes=3, + strides=(1, 2, 2, 2, 2, 2), + n_conv_per_stage=(3, 3, 3, 3, 3, 3), + n_conv_per_stage_decoder=(3, 3, 3, 3, 3), + conv_bias=True, + norm_op=nn.InstanceNorm3d, + norm_op_kwargs={'eps': 1e-5, 'affine': True}, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, + nonlin_kwargs=None, + deep_supervision=deep_supervision, + nonlin_first=False + ), + 'UNET-Res' : ResidualEncoderUNet( + input_channels=input_channels, + n_stages=6, + features_per_stage=[32, 64, 128, 256, 320, 320], + conv_op=nn.Conv3d, + kernel_sizes=[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], + strides=[[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], + n_blocks_per_stage=[1, 3, 4, 6, 6, 6], + n_conv_per_stage_decoder=[1, 1, 1, 1, 1], + conv_bias=True, + norm_op=nn.InstanceNorm3d, + norm_op_kwargs={"eps": 1e-5, "affine": True}, + nonlin=nn.LeakyReLU, + nonlin_kwargs={"inplace": True}, + deep_supervision=deep_supervision, + ) + }[vision_backbone] + + self.backbone.apply(InitWeights_He(1e-2)) + + # fixed to text encoder out dim + if vision_backbone == 'UNET-H': + query_dim = 1536 + elif vision_backbone == 'UNET-Res': + query_dim = 320 + else: + query_dim = 768 + + # all backbones are 6-depth, thus the first 5 scale latent feature outputs need to be down-sampled + self.avg_pool_ls = [ + nn.AvgPool3d(32, 32), + nn.AvgPool3d(16, 16), + nn.AvgPool3d(8, 8), + nn.AvgPool3d(4, 4), + nn.AvgPool3d(2, 2), + ] + + # multi-scale latent feature are projected to query_dim before query decoder + self.projection_layer = { + 'SwinUNETR' : nn.Sequential( + nn.Linear(1536, 768), + nn.GELU(), + nn.Linear(768, query_dim), + nn.GELU() + ), + 'UNET' : nn.Sequential( + nn.Linear(1792, 768), + nn.GELU(), + nn.Linear(768, query_dim), + nn.GELU() + ), + 'UNET-L' : nn.Sequential( + nn.Linear(3584, 1536), # 128, 128, 256, 512, 1024, 1536 --> 3584 --> 768 + nn.GELU(), + nn.Linear(1536, query_dim), + nn.GELU() + ), + 'UNET-H' : nn.Sequential( + nn.Linear(5632, 3072), + nn.GELU(), + nn.Linear(3072, query_dim), + nn.GELU() + ), + 'UNET-Res' : nn.Sequential( + nn.Linear(1120, 320), + nn.GELU(), + nn.Linear(320, query_dim), + nn.GELU() + ) + }[vision_backbone] + + # positional encoding + pos_embedding = PositionalEncoding3D(query_dim)(torch.zeros(1, (image_height//self.hw_patch_size), (image_width//self.hw_patch_size), (frames//self.frame_patch_size), query_dim)) # b h/p w/p d/p dim + self.pos_embedding = rearrange(pos_embedding, 'b h w d c -> (h w d) b c') # n b dim + + # (fused latent embeddings + pe) x query prompts + decoder_layer = TransformerDecoderLayer(d_model=query_dim, nhead=8, normalize_before=True) + decoder_norm = nn.LayerNorm(query_dim) + self.transformer_decoder = TransformerDecoder(decoder_layer=decoder_layer, num_layers=6, norm=decoder_norm) + + if query_dim != 768: + self.query_proj = nn.Sequential( + nn.Linear(768, query_dim), + nn.GELU(), + nn.Linear(query_dim, query_dim), + nn.GELU() + ) + else: + self.query_proj = nn.Identity() + + # mask embedding are projected to perpixel_dim + # mid stage output (only consider the last 3 mid layers of decoder, i.e. feature maps with resolution /2 /4 /8) + if self.deep_supervision: + feature_per_stage = { + 'SwinUNETR':[48, 96, 192], + 'UNET':[64, 128, 256], + 'UNET-L':[128, 256, 512], + 'UNET-H':[256, 512, 1024], + 'UNET-Res':[64, 128, 256] + }[vision_backbone] + mid_dim = { + 'SwinUNETR':[256, 384, 512], + 'UNET':[256, 384, 512], + 'UNET-L':[384, 512, 512], + 'UNET-H':[768, 1024, 1024], + 'UNET-Res':[256, 320, 320] + }[vision_backbone] + self.mid_mask_embed_proj = [] + for hidden_dim, per_pixel_dim in zip(mid_dim, feature_per_stage): + self.mid_mask_embed_proj.append( + nn.Sequential( + nn.Linear(query_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, per_pixel_dim), + nn.GELU(), + ), + ) + self.mid_mask_embed_proj = nn.ModuleList(self.mid_mask_embed_proj) + + # largest output + mid_dim, per_pixel_dim = { + 'SwinUNETR' : [256, 48], + 'UNET' : [256, 64], + 'UNET-L' : [384, 128], + 'UNET-H' : [768, 256], + 'UNET-Res' : [128, 32] + }[vision_backbone] + self.mask_embed_proj = nn.Sequential( + nn.Linear(query_dim, mid_dim), + nn.GELU(), + nn.Linear(mid_dim, per_pixel_dim), + nn.GELU(), + ) + + def forward(self, queries, image_input, simulated_lowres_mc_pred): + + # Image Encoder + latent_embedding_ls = self.backbone.encoder(image_input) # [2, 32, 256, 256, 96] ... [2, 768, 8, 8, 3] + + # avg pooling each multiscale feature to H/P W/P D/P + image_embedding = [] + for latent_embedding, avg_pool in zip(latent_embedding_ls, self.avg_pool_ls): + tmp = avg_pool(latent_embedding) + image_embedding.append(tmp) # B ? H/P W/P D/P + image_embedding.append(latent_embedding_ls[-1]) + + # aggregate multiscale features into image embedding (and proj to align with query dim) + image_embedding = torch.cat(image_embedding, dim=1) + image_embedding = rearrange(image_embedding, 'b d h w depth -> b h w depth d') + image_embedding = self.projection_layer(image_embedding) # B H/P W/P D/P Dim + image_embedding = rearrange(image_embedding, 'b h w d dim -> (h w d) b dim') # (H/P W/P D/P) B Dim + + # add pe to image embedding + pos = self.pos_embedding.to(latent_embedding_ls[-1].device) # (H/P W/P D/P) B Dim + + _, B, _ = image_embedding.shape + + # query decoder + _, N, _ = queries.shape # N is the num of query + queries = rearrange(queries, 'b n dim -> n b dim') # N B Dim NOTE:By default, attention in torch is not batch_first + queries = self.query_proj(queries) + mask_embedding,_ = self.transformer_decoder(queries, image_embedding, pos = pos) # N B Dim + mask_embedding = rearrange(mask_embedding, 'n b dim -> (b n) dim') # (B N) Dim + + logits = self.backbone.decoder(latent_embedding_ls, mask_embedding=mask_embedding, mask_embed_proj=self.mask_embed_proj, mid_mask_embed_proj=self.mid_mask_embed_proj, simulated_lowres_mc_pred=simulated_lowres_mc_pred, B=B, N=N) # [2, 32, 256, 256, 96] ... [2, 512, 16, 16, 6] + + return logits + + def forward(self, queries, image_input, simulated_lowres_sc_pred=None, simulated_lowres_mc_pred=None, train_mode=True): + # Handle single-channel low-res prediction if provided + if simulated_lowres_sc_pred is not None: + # concatenate image and simulated low-res single channel prediction + image_input = torch.cat([image_input, simulated_lowres_sc_pred], dim=1) # b2whd + + # Train Forward ----------------------------------------------------------------------- + if train_mode: + logits = self.seg_forward(queries, image_input, simulated_lowres_mc_pred): + + # Infer / Evaluate Forward ------------------------------------------------------------ + else: + del image_input + torch.cuda.empty_cache() + logits = self.seg_forward(queries, image_input, simulated_lowres_mc_pred): + + return logits + +if __name__ == '__main__': + model = Maskformer().cuda() + image = torch.rand((1, 3, 288, 288, 96)).cuda() + query = torch.rand((2, 10, 768)).cuda() + segmentations = model(query, image) + print(segmentations.shape) \ No newline at end of file diff --git a/model/med_cpt.py b/model/med_cpt.py new file mode 100644 index 0000000000000000000000000000000000000000..1d3d4b7ef10bfc49de22277b3c6c1f1e69920739 --- /dev/null +++ b/model/med_cpt.py @@ -0,0 +1,26 @@ +import torch.nn as nn +import torch + +from transformers import AutoModel, AutoTokenizer + +class MedCPT(nn.Module): + def __init__(self, cpt_checkpoint='ncbi/MedCPT-Query-Encoder'): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(cpt_checkpoint) + self.model = AutoModel.from_pretrained(cpt_checkpoint) + self.modality_embed = nn.Embedding(4, 768) + + def forward(self, text, modality): + encoded = self.tokenizer( + text, + truncation=True, + padding=True, + return_tensors='pt', + max_length=64, + ).to(device=torch.cuda.current_device()) + + text_feature = self.model(**encoded).last_hidden_state[:, 0, :] + modality_feature = self.modality_embed(modality) + text_feature += modality_feature + + return text_feature diff --git a/model/text_encoder.py b/model/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6badef9dab018f196497d3a480674c078b5bc9d8 --- /dev/null +++ b/model/text_encoder.py @@ -0,0 +1,138 @@ +import torch +import numpy as np +import os +import torch.nn as nn + +from einops import rearrange, reduce, repeat + +from .knowledge_encoder import Knowledge_Encoder +from .med_cpt import MedCPT +from .base_bert import BaseBERT +from train.dist import is_master + +def compute_average_gradient(module): + # 初始化梯度总和和参数计数 + total_gradient = 0.0 + total_params = 0 + + # 遍历module的所有参数 + for param in module.parameters(): + if param.grad is not None: + # 累加此参数的梯度绝对值 + total_gradient += param.grad.abs().mean().item() + total_params += 1 + + # 计算平均梯度 + if total_params > 0: + average_gradient = total_gradient / total_params + else: + average_gradient = None + + return average_gradient + +class Text_Encoder(nn.Module): + def __init__(self, + text_encoder, + checkpoint=None, + # other params + open_bert_layer=12, + open_modality_embed=False, + partial_load=False, + gpu_id=None, + device=None): + super().__init__() + + self.device = device + + # choose text encoder + class_name = { + 'ours': Knowledge_Encoder, + 'medcpt': MedCPT, + 'basebert': BaseBERT, + }[text_encoder] + + model = class_name() + model = model.to(device) + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu_id], find_unused_parameters=True) + + # load checkpoint + if checkpoint: + if is_master(): + print(f"** QUERY ** Load encoder from {checkpoint}.") + + checkpoint = torch.load(checkpoint, map_location=device) + checkpoint['model_state_dict'] = {k:v for k,v in checkpoint['model_state_dict'].items() if 'atlas_tower' not in k and 'temperature' not in k} + if partial_load: + model_dict = model.state_dict() + # check difference + unexpected_state_dict = [k for k in checkpoint['model_state_dict'].keys() if k not in model_dict.keys()] + missing_state_dict = [k for k in model_dict.keys() if k not in checkpoint['model_state_dict'].keys()] + unmatchd_state_dict = [k for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape != model_dict[k].shape] + # load partial parameters + state_dict = {k:v for k,v in checkpoint['model_state_dict'].items() if k in model_dict.keys() and v.shape == model_dict[k].shape} + model_dict.update(state_dict) + model.load_state_dict(model_dict) + if is_master(): + print('The following parameters are unexpected in query generator checkpoint:\n', unexpected_state_dict) + print('The following parameters are missing in query generator checkpoint:\n', missing_state_dict) + print('The following parameters have different shapes in query generator checkpoint:\n', unmatchd_state_dict) + print('The following parameters are loaded in query generator :\n', state_dict.keys()) + else: + model.load_state_dict(checkpoint['model_state_dict']) + + # open bert + for name, param in model.named_parameters(): + if 'encoder.layer.' in name and int(name.split('encoder.layer.')[-1].split('.')[0])>open_bert_layer: # encoder.layer.11.xxx --> 11 + param.requires_grad = True + elif open_bert_layer < 11 and ('pooler' in name or 'mlp_embed' in name): + param.requires_grad = True + elif open_modality_embed and 'modality_embed' in name: + param.requires_grad = True + else: + param.requires_grad = False + + self.model = model + + def forward(self, label_name, modality_name): + """ + Args: + label_name (List of List of Str / List of Str): B x N / N + modality_name (List / Str): B / 1 + NOTE: a list of labels paired with one modality + + Return: + queries (Tensor): B x N / N + """ + if isinstance(label_name[0], list): + batch_size = len(label_name) + num_query = len(label_name[0]) + input_text = [t for t_ls in label_name for t in t_ls] # BN + modality = [mod for mod in modality_name for n in range(num_query)] # repeat each mod for N times -> BN + else: + num_query = len(label_name) + input_text = label_name # N + modality = [modality_name for n in range(num_query)] # N + + # name to code + modality_code_dict = { + 'ct':0, + 'mri':1, + 'us':2, + 'pet':3, + 'microscopy':4 + } + modality_code = torch.tensor([modality_code_dict[mod] for mod in modality]) # bn + + # get embed + queries = self.model(input_text, modality_code, self.device) + + if isinstance(label_name[0], list): + queries = rearrange(queries, '(b n) d -> b n d', b=batch_size, n=num_query) + + return queries + + + + + \ No newline at end of file diff --git a/model/text_tower.py b/model/text_tower.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae358022cc23b15a995959fb3b5d4fc44e14f1e --- /dev/null +++ b/model/text_tower.py @@ -0,0 +1,31 @@ +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoModel + +from .tokenizer import MyTokenizer + + +class Text_Tower(nn.Module): + def __init__(self, biolord_checkpoint: str = None,): + super().__init__() + + self.biolord = AutoModel.from_pretrained(biolord_checkpoint) + self.tokenizer = MyTokenizer(biolord_checkpoint, 256) + + def mean_pooling(self, model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def forward(self, text, device): + text = self.tokenizer.tokenize(text) # (n, max_l) + text['input_ids'] = text['input_ids'].to(device=device) + text['attention_mask'] = text['attention_mask'].to(device=device) + + output = self.biolord(**text) + pooler_output = self.mean_pooling(output, text['attention_mask']) + + return pooler_output + + diff --git a/model/tokenizer.py b/model/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ec7110d167b6e61eec8f7558591b158a7b7100 --- /dev/null +++ b/model/tokenizer.py @@ -0,0 +1,47 @@ +import torch +from typing import Union, List +from transformers import AutoTokenizer +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +class MyTokenizer(): + def __init__(self, tokenizer, max_length=256): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + self.max_length = max_length + + def tokenize(self, texts:[str, List[str]]) -> torch.LongTensor: + """ + tokenize a lits of strings or a single string, pad/trunctate to max length input of the text tower + + Args: + texts (str, List[str]]): a string + + Returns: + torch.LongTensor: the tokenized tensor and the attention mask(mask out paddings) + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = '[CLS]' + eot_token = '[SEP]' + all_token_ids = [] + max_len_in_this_batch = 0 + for text in texts: # a string + tokens = [sot_token] + self.tokenizer.tokenize(text) + [eot_token] # list of str + + if len(tokens) > max_len_in_this_batch: + max_len_in_this_batch = len(tokens) + all_token_ids.append(self.tokenizer.convert_tokens_to_ids(tokens)) + if max_len_in_this_batch > self.max_length: + max_len_in_this_batch = self.max_length + result = torch.zeros(len(all_token_ids), max_len_in_this_batch, dtype=torch.long) + + for i, token_ids in enumerate(all_token_ids): # list of int + if len(token_ids) > max_len_in_this_batch: + token_ids = token_ids[:max_len_in_this_batch] # Truncate + token_ids[-1] = self.tokenizer.convert_tokens_to_ids('[SEP]') + result[i, :len(token_ids)] = torch.tensor(token_ids) + + attn_mask = torch.where(result>0, 1, 0) + + return {'input_ids':result, 'attention_mask':attn_mask} \ No newline at end of file diff --git a/model/transformer_decoder.py b/model/transformer_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..76a84b9cbbf657d3b7b5089c61d9c4ceecc5a6f6 --- /dev/null +++ b/model/transformer_decoder.py @@ -0,0 +1,153 @@ +""" +Code modified from DETR tranformer: +https://github.com/facebookresearch/detr +Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" + +import copy +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + T,B,C = memory.shape + intermediate = [] + atten_layers = [] + for n,layer in enumerate(self.layers): + + residual=True + output,ws = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos,residual=residual) + atten_layers.append(ws) + if self.return_intermediate: + intermediate.append(self.norm(output)) + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + return output,atten_layers + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + # self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + residual=True): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = self.norm1(tgt) + tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + + # attn_weights [B,NUM_Q,T] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt,ws + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + tgt2 = self.norm2(tgt) + tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt,attn_weights + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + residual=True): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") \ No newline at end of file diff --git a/model/umamba_mid.py b/model/umamba_mid.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd0cd93837a743e44867f76a319bc403d5cff64 --- /dev/null +++ b/model/umamba_mid.py @@ -0,0 +1,405 @@ +import numpy as np +import torch +from torch import nn +from typing import Union, Type, List, Tuple + +from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp +from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder + +from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks +from dynamic_network_architectures.building_blocks.residual import StackedResidualBlocks + +from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op +from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD +from torch.nn.modules.conv import _ConvNd +from torch.nn.modules.dropout import _DropoutNd +from torch.cuda.amp import autocast +from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim + +from mamba_ssm import Mamba + + +class MambaLayer(nn.Module): + def __init__(self, dim, d_state = 16, d_conv = 4, expand = 2): + super().__init__() + self.dim = dim + self.norm = nn.LayerNorm(dim) + self.mamba = Mamba( + d_model=dim, # Model dimension d_model + d_state=d_state, # SSM state expansion factor + d_conv=d_conv, # Local convolution width + expand=expand, # Block expansion factor + ) + + @autocast(enabled=False) + def forward(self, x): + if x.dtype == torch.float16: + x = x.type(torch.float32) + B, C = x.shape[:2] + assert C == self.dim + n_tokens = x.shape[2:].numel() + img_dims = x.shape[2:] + x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) + x_norm = self.norm(x_flat) + x_mamba = self.mamba(x_norm) + out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims) + + return out + + +class ResidualMambaMidEncoder(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...], Tuple[Tuple[int, ...], ...]], + n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, + bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, + return_skips: bool = False, + disable_default_stem: bool = False, + stem_channels: int = None, + pool_type: str = 'conv', + stochastic_depth_p: float = 0.0, + squeeze_excitation: bool = False, + squeeze_excitation_reduction_ratio: float = 1. / 16 + ): + super().__init__() + if isinstance(kernel_sizes, int): + kernel_sizes = [kernel_sizes] * n_stages + if isinstance(features_per_stage, int): + features_per_stage = [features_per_stage] * n_stages + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(strides, int): + strides = [strides] * n_stages + if bottleneck_channels is None or isinstance(bottleneck_channels, int): + bottleneck_channels = [bottleneck_channels] * n_stages + assert len( + bottleneck_channels) == n_stages, "bottleneck_channels must be None or have as many entries as we have resolution stages (n_stages)" + assert len( + kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)" + assert len( + n_blocks_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len( + features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)" + assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \ + "Important: first entry is recommended to be 1, else we run strided conv drectly on the input" + + pool_op = get_matching_pool_op(conv_op, pool_type=pool_type) if pool_type != 'conv' else None + + # build a stem, Todo maybe we need more flexibility for this in the future. For now, if you need a custom + # stem you can just disable the stem and build your own. + # THE STEM DOES NOT DO STRIDE/POOLING IN THIS IMPLEMENTATION + if not disable_default_stem: + if stem_channels is None: + stem_channels = features_per_stage[0] + self.stem = StackedConvBlocks(1, conv_op, input_channels, stem_channels, kernel_sizes[0], 1, conv_bias, + norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs) + input_channels = stem_channels + else: + self.stem = None + + # now build the network + stages = [] + mamba_layers = [] + for s in range(n_stages): + stride_for_conv = strides[s] if pool_op is None else 1 + + stage = StackedResidualBlocks( + n_blocks_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], stride_for_conv, + conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, + block=block, bottleneck_channels=bottleneck_channels[s], stochastic_depth_p=stochastic_depth_p, + squeeze_excitation=squeeze_excitation, + squeeze_excitation_reduction_ratio=squeeze_excitation_reduction_ratio + ) + + if pool_op is not None: + stage = nn.Sequential(pool_op(strides[s]), stage) + + stages.append(stage) + input_channels = features_per_stage[s] + + if s >= 3: + mamba_layers.append(MambaLayer(input_channels)) + + #self.stages = nn.Sequential(*stages) + self.stages = nn.ModuleList(stages) + self.output_channels = features_per_stage + self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides] + self.return_skips = return_skips + + # we store some things that a potential decoder needs + self.conv_op = conv_op + self.norm_op = norm_op + self.norm_op_kwargs = norm_op_kwargs + self.nonlin = nonlin + self.nonlin_kwargs = nonlin_kwargs + self.dropout_op = dropout_op + self.dropout_op_kwargs = dropout_op_kwargs + self.conv_bias = conv_bias + self.kernel_sizes = kernel_sizes + + self.mamba_layers = nn.ModuleList(mamba_layers) + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + ret = [] + #for s in self.stages: + for s in range(len(self.stages)): + #x = s(x) + x = self.stages[s](x) + if s >= 3: + x = self.mamba_layers[s-3](x) + ret.append(x) + if self.return_skips: + return ret + else: + return [ret[-1]] + + def compute_conv_feature_map_size(self, input_size): + if self.stem is not None: + output = self.stem.compute_conv_feature_map_size(input_size) + else: + output = np.int64(0) + + for s in range(len(self.stages)): + output += self.stages[s].compute_conv_feature_map_size(input_size) + input_size = [i // j for i, j in zip(input_size, self.strides[s])] + + return output + + +class UNetResDecoder(nn.Module): + def __init__(self, + encoder: Union[PlainConvEncoder, ResidualMambaMidEncoder], + n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], + deep_supervision, nonlin_first: bool = False): + """ + This class needs the skips of the encoder as input in its forward. + + the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder + are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck + features and the lowest skip as inputs + the decoder has two (three) parts in each stage: + 1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage) + 2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge + 3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits? + :param encoder: + :param n_conv_per_stage: + :param deep_supervision: + """ + super().__init__() + self.deep_supervision = deep_supervision + self.encoder = encoder + n_stages_encoder = len(encoder.output_channels) + if isinstance(n_conv_per_stage, int): + n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) + assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ + "resolution stages - 1 (n_stages in encoder - 1), " \ + "here: %d" % n_stages_encoder + + transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) + + # we start with the bottleneck and work out way up + stages = [] + transpconvs = [] + # seg_layers = [] + for s in range(1, n_stages_encoder): + input_features_below = encoder.output_channels[-s] + input_features_skip = encoder.output_channels[-(s + 1)] + stride_for_transpconv = encoder.strides[-s] + transpconvs.append(transpconv_op( + input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, + bias=encoder.conv_bias + )) + # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) + stages.append(StackedResidualBlocks( + n_blocks = n_conv_per_stage[s-1], + conv_op = encoder.conv_op, + input_channels = 2 * input_features_skip, + output_channels = input_features_skip, + kernel_size = encoder.kernel_sizes[-(s + 1)], + initial_stride = 1, + conv_bias = encoder.conv_bias, + norm_op = encoder.norm_op, + norm_op_kwargs = encoder.norm_op_kwargs, + dropout_op = encoder.dropout_op, + dropout_op_kwargs = encoder.dropout_op_kwargs, + nonlin = encoder.nonlin, + nonlin_kwargs = encoder.nonlin_kwargs, + )) + # we always build the deep supervision outputs so that we can always load parameters. If we don't do this + # then a model trained with deep_supervision=True could not easily be loaded at inference time where + # deep supervision is not needed. It's just a convenience thing + # seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) + + self.stages = nn.ModuleList(stages) + self.transpconvs = nn.ModuleList(transpconvs) + # self.seg_layers = nn.ModuleList(seg_layers) + + def forward(self, skips): + """ + we expect to get the skips in the order they were computed, so the bottleneck should be the last entry + :param skips: + :return: + """ + lres_input = skips[-1] + seg_outputs = [] + for s in range(len(self.stages)): + x = self.transpconvs[s](lres_input) + x = torch.cat((x, skips[-(s+2)]), 1) + x = self.stages[s](x) + seg_outputs.append(x) + #if self.deep_supervision: + # seg_outputs.append(self.seg_layers[s](x)) + #elif s == (len(self.stages) - 1): + # seg_outputs.append(self.seg_layers[-1](x)) + lres_input = x + + # invert seg outputs so that the largest segmentation prediction is returned first + seg_outputs = seg_outputs[::-1] + + if not self.deep_supervision: + r = [seg_outputs[0]] + else: + r = seg_outputs + return r + + def compute_conv_feature_map_size(self, input_size): + """ + IMPORTANT: input_size is the input_size of the encoder! + :param input_size: + :return: + """ + # first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at + # least have the size of the skip above that (therefore -1) + skip_sizes = [] + for s in range(len(self.encoder.strides) - 1): + skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])]) + input_size = skip_sizes[-1] + # print(skip_sizes) + + assert len(skip_sizes) == len(self.stages) + + # our ops are the other way around, so let's match things up + output = np.int64(0) + for s in range(len(self.stages)): + # print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)]) + # conv blocks + output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)]) + # trans conv + output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64) + # segmentation + if self.deep_supervision or (s == (len(self.stages) - 1)): + output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64) + return output + + +class UMambaMid(nn.Module): + def __init__(self, + input_channels: int, + n_stages: int, + features_per_stage: Union[int, List[int], Tuple[int, ...]], + conv_op: Type[_ConvNd], + kernel_sizes: Union[int, List[int], Tuple[int, ...]], + strides: Union[int, List[int], Tuple[int, ...]], + n_conv_per_stage: Union[int, List[int], Tuple[int, ...]], + n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: dict = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: dict = None, + nonlin: Union[None, Type[torch.nn.Module]] = None, + nonlin_kwargs: dict = None, + deep_supervision: bool = False, + block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, + bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, + stem_channels: int = None + ): + super().__init__() + n_blocks_per_stage = n_conv_per_stage + if isinstance(n_blocks_per_stage, int): + n_blocks_per_stage = [n_blocks_per_stage] * n_stages + if isinstance(n_conv_per_stage_decoder, int): + n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) + assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ + f"resolution stages. here: {n_stages}. " \ + f"n_blocks_per_stage: {n_blocks_per_stage}" + assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ + f"as we have resolution stages. here: {n_stages} " \ + f"stages, so it should have {n_stages - 1} entries. " \ + f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" + self.encoder = ResidualMambaMidEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides, + n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op, + dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels, + return_skips=True, disable_default_stem=False, stem_channels=stem_channels) + self.decoder = UNetResDecoder(self.encoder, n_conv_per_stage_decoder, deep_supervision) + + def forward(self, x): + skips = self.encoder(x) + outs = self.decoder(skips) + return skips, outs + + def compute_conv_feature_map_size(self, input_size): + assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \ + "batch channel. Do not give input_size=(b, c, x, y(, z)). " \ + "Give input_size=(x, y(, z))!" + return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) + + +if __name__ == '__main__': + from dynamic_network_architectures.architectures.unet import PlainConvUNet + import os + + model = UMambaEnc( + input_channels=3, + n_stages=6, + features_per_stage=(64, 64, 128, 256, 512, 768), + conv_op=nn.Conv3d, + kernel_sizes=3, + strides=(1, 2, 2, 2, 2, 2), + n_conv_per_stage=(2, 2, 2, 2, 2, 2), + n_conv_per_stage_decoder=(2, 2, 2, 2, 2), + conv_bias=True, + norm_op=nn.InstanceNorm3d, + norm_op_kwargs={'eps': 1e-5, 'affine': True}, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nn.LeakyReLU, + nonlin_kwargs=None, + ).cuda() + + image_input = torch.rand((1,3,288,288,96)).cuda() + latent_embedding_ls, per_pixel_embedding_ls = model(image_input) # B Dim H/P W/P D/P + + for tmp in latent_embedding_ls: + print(tmp.shape) + + print('----------') + + for tmp in per_pixel_embedding_ls: + print(tmp.shape) + + import time + + time.sleep(1) + + def get_parameter_number(model): + total_num = sum(p.numel() for p in model.parameters()) + trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return {'Total': total_num, 'Trainable': trainable_num} + + if is_master(): + print(f"** MODEL ** {get_parameter_number(model)['Total']/1e6}M parameters") \ No newline at end of file diff --git a/run_inference_medals_nifti.sh b/run_inference_medals_nifti.sh new file mode 100644 index 0000000000000000000000000000000000000000..b4ff0a8cd02f9b4fd4ed60fdd2887dc2ab920971 --- /dev/null +++ b/run_inference_medals_nifti.sh @@ -0,0 +1,123 @@ +#!/bin/bash +# +# Medal-S Inference Script +# +# This script runs Medal-S inference in Stage 1 + Stage 2 mode: +# - Stage 1 + Stage 2: Accurate two-stage inference with ROI refinement +# +# Usage: +# bash run_inference_medals_nifti.sh +# bash run_inference_medals_nifti.sh [input_path] [output_dir] [device] [checkpoints_path] +# +# Configuration Files: +# - CT images: Use config_CT.json (supports multi-window types: soft_tissue, bone, lung) +# * Multiple window types: Each window type will be processed separately and merged +# * Single window type: Uses the corresponding window settings +# - Non-CT images: Use config_nonCT.json (MRI, US, PET, microscopy) +# * Uses normalization_settings for percentile-based normalization +# +# To control verbose output, edit VERBOSE variable in Configuration section: +# VERBOSE="" # Default: verbose disabled +# VERBOSE="--verbose" # Explicitly enable verbose output +# +# Output files will be automatically named with mode suffix: +# - *_stage1+stage2.nii.gz + +# ============================================================================ +# Configuration +# ============================================================================ +IMAGE_PATH="./inputs/CT_chest_and_abdomen_large_depth_row_8_0000.nii.gz" #Totalsegmentator_s0059_0000.nii.gz" +OUTPUT_DIR="./outputs" +DEVICE="cuda:1" +CHECKPOINTS_PATH="./checkpoints" +CONFIG_FILE="./config_CT.json" # Use config_CT.json for CT, config_nonCT.json for non-CT +VERBOSE="--verbose" # Set to "--verbose" to explicitly enable verbose output, empty for default (disabled) +# ============================================================================ +# Setup +# ============================================================================ + +# Create output directory if it doesn't exist +mkdir -p "$OUTPUT_DIR" + +# Check if config file exists +if [ ! -f "$CONFIG_FILE" ]; then + echo "Error: Config file not found: $CONFIG_FILE" + echo "Please check the CONFIG_FILE path in the script configuration section." + exit 1 +fi + +# Get output filename (without extension) +OUTPUT_FILENAME=$(basename "$IMAGE_PATH") +OUTPUT_BASE_PATH="$OUTPUT_DIR/$OUTPUT_FILENAME" + +# Print configuration +echo "==========================================" +echo "Medal-S Inference - Stage 1 + Stage 2" +echo "==========================================" +echo "Input: $IMAGE_PATH" +echo "Output directory: $OUTPUT_DIR" +echo "Config file: $CONFIG_FILE" +echo "Device: $DEVICE" +echo "Checkpoints: $CHECKPOINTS_PATH" +if [ -n "$VERBOSE" ]; then + echo "Verbose: $VERBOSE" +else + echo "Verbose: default (disabled)" +fi +echo "==========================================" +echo "" +echo "Note:" +echo " - CT images with multiple window types will be processed separately" +echo " - Each window type uses its corresponding window settings" +echo " - Results from all window types will be merged automatically" +echo "" + +# ============================================================================ +# Stage 1 + Stage 2 Inference +# ============================================================================ +echo "==========================================" +echo "Stage 1 + Stage 2 Inference" +echo "==========================================" +echo "Running Stage 1 + Stage 2 inference..." +echo "" + +python inference_medals_nifti.py \ + --input "$IMAGE_PATH" \ + --output "$OUTPUT_BASE_PATH" \ + --config "$CONFIG_FILE" \ + --mode stage1+stage2 \ + --device "$DEVICE" \ + --checkpoints "$CHECKPOINTS_PATH" \ + $VERBOSE + +EXIT_CODE=$? +if [ $EXIT_CODE -eq 0 ]; then + echo "" + echo "✓ Stage 1 + Stage 2 inference completed successfully!" + echo "" +else + echo "" + echo "✗ Error: Stage 1 + Stage 2 inference failed!" + exit 1 +fi + +# ============================================================================ +# Summary +# ============================================================================ +echo "==========================================" +echo "Inference completed successfully!" +echo "==========================================" +echo "" +echo "Output file:" +# Handle .nii.gz extension properly +if [[ "$OUTPUT_FILENAME" == *.nii.gz ]]; then + BASE_NAME="${OUTPUT_FILENAME%.nii.gz}" + echo " - $OUTPUT_DIR/${BASE_NAME}_stage1+stage2.nii.gz" +elif [[ "$OUTPUT_FILENAME" == *.nii ]]; then + BASE_NAME="${OUTPUT_FILENAME%.nii}" + echo " - $OUTPUT_DIR/${BASE_NAME}_stage1+stage2.nii" +else + echo " - $OUTPUT_DIR/${OUTPUT_FILENAME}_stage1+stage2" +fi +echo "" +