Omini3D / Dataloader /embding_gen.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
import torch
from torch.utils.data import Dataset, DataLoader
import json
import SimpleITK as sitk
import numpy as np
from skimage.transform import rescale, resize, downscale_local_mean
# from torchvision.transforms import v2
import sys
from bert_helper import *
sys.path.append('./')
from Dataloader.dataloader_utils import *
import random
mapping_files = {
# 'MSD': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MSD_processed/nifti_mappings_updated.json',
# 'TotalSegmentor': '/home/data/Github/data/data_gen_def/DATASETS_processed/TotalSegmentorCT_MRI/nifti_mappings.json',
# 'Kaggle_osic': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/Kaggle_osic_new/nifti_mappings.json',
# 'CancerImageArchive': '/home/data/Github/data/data_gen_def/DATASETS_processed/CancerImageArchive_test/nifti_mappings.json',
# 'MnMs': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MnMs/nifti_mappings.json',
# 'Brats2019': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2019/nifti_mappings.json',
# 'Brats2020': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2020/nifti_mappings.json',
# 'Brats2021': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2021/nifti_mappings.json',
# 'OASIS_1': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_1/CS_SECTIONAL/nifti_mappings.json',
# 'OASIS_2': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_2/RAW_V2/nifti_mappings.json',
'OAI_ZIB_KL': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Data/Omini3D/DATASETS_processed/OAI_ZIB/nifti_mappings.json',
'OAI_ZIB_WOMAC': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Data/Omini3D/DATASETS_processed/OAI_ZIB/nifti_mappings.json',
# 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/PSMA-FDG-PET-CT-LESION/V2/nifti_mappings.json',
# 'PSMA-CT':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/Longitudinal-CT/nifti_mappings.json',
# 'AbdomenAtlas':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenAtlas_v2/nifti_mappings.json',
# 'AbdomenCT1k':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenCT1k/nifti_mappings.json',
}
save_paths = {
'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json',
'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json',
'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json',
'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json',
'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json',
'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json',
'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json',
'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json',
'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
'OAI_ZIB_KL': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Code/OmniMorph/Dataloader/nifty_mappings/OAI_ZIB_KL_mappings.json',
'OAI_ZIB_WOMAC': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Code/OmniMorph/Dataloader/nifty_mappings/OAI_ZIB_WOMAC_mappings.json',
}
query = {
'MSD': ['description'],
'TotalSegmentor': ['age','gender'],
'Kaggle_osic': ['Age','Sex','Smoke_Status','Weeks','FVC','Percent'],
'CancerImageArchive':['Series_Description', 'Study_Description', 'Manufacturer'],
'MnMs': ['Age','Sex','Height','Weight'],
'Brats2019': ['Age', 'Grade', 'Survival','ResectionStatus'],
'Brats2020': ['Age', 'Grade', 'Survival','ResectionStatus'],
'Brats2021': ['Age', 'Grade', 'Survival','ResectionStatus'],
'OASIS_1': ['Age', 'M/F','ASF','Educ','SES','MMSE','eTIV','CDR','nWBV'],
'OASIS_2': ['Age', 'Group','M/F','ASF','Educ','SES','MMSE','eTIV','CDR','nWBV'],
'PSMA-FDG-PET-CT-LESION':['Study Description', 'diagnosis','age','sex',"pet_radionuclide",'ct_contrast_agent'],
'PSMA-CT':[],
'AbdomenAtlas':[],
'AbdomenCT1k':[],
'OAI_ZIB_KL': ['Age', 'Gender', 'KL_Grade', 'BMI'],
'OAI_ZIB_WOMAC': ['Age', 'Gender', 'WOMAC_Pain', 'WOMAC_ADL', 'WOMAC_Stiffness', 'BMI'],
}
add_text = {
'MSD': {},
'TotalSegmentor': {},
'Kaggle_osic': {'description': 'pulmonary fibrosis progression'},
'CancerImageArchive': {},
'MnMs': {},
'Brats2019': {'description': 'could include brain tumor, glioma, glioblastoma, low grade glioma, high grade glioma'},
'Brats2020': {'description': 'could include brain tumor, glioma, glioblastoma, low grade glioma, high grade glioma'},
'Brats2021': {'description': 'could include brain tumor, glioma, glioblastoma, low grade glioma, high grade glioma'},
'OASIS_1': {},
'OASIS_2': {},
'PSMA-CT':{'description': 'melanoma patients'},
'PSMA-FDG-PET-CT-LESION':{'description': 'malignant melanoma, lymphoma, lung cancer, or healthy'},
'AbdomenAtlas':{},
'AbdomenCT1k':{},
'OAI_ZIB_KL': {'description': 'right knee osteoarthritis'},
'OAI_ZIB_WOMAC': {'description': 'right knee osteoarthritis'},
}
# bert intialization
model_name = '/rds/project/rds-TWhPgQVLKbA/Code/OmniMorph/External/Models/bert_large_uncased'
reduce_method = 'mean'
max_words_num = 32 # max number of words in the caption > 2
# max_words_num = 64 # max number of words in the caption > 2
embeder, tokenizer = get_frozen_embeder(model_name)
def embed_str_filter(str_input, filter_words=['segmentation', 'registration']):
'''
Filter out specific words from the input string.
'''
for word in filter_words:
str_input = str_input.replace(word, '')
return str_input
for dataset in mapping_files.keys():
jsn_path = mapping_files[dataset]
with open(jsn_path, 'r') as f:
embd_json = json.load(f)
for key in embd_json.keys():
embd_json_temp = {}
embd_json_temp['Modality'] = embd_json[key]['Modality']
embd_json_temp['ROI'] = embd_json[key]['ROI']
query_key = query[dataset]
meta_data = embd_json[key]['Metadata']
for q in query_key:
if q in meta_data:
embd_json_temp[q] = meta_data[q]
else:
embd_json_temp[q] = 'N/A'
for q in add_text[dataset].keys():
if q in embd_json_temp:
embd_json_temp[q] += ', ' + add_text[dataset][q]
else:
embd_json_temp[q] = add_text[dataset][q]
emdb_str = str(embd_json_temp)[1:-1].lower()
embd_str = replace_text(emdb_str, get_synonyms_dict(None))
embd_str = embed_str_filter(embd_str)
print(f'embd_json_temp: {str(embd_json_temp)}')
print(f'embd_str: {embd_str}')
print(f'words_num: {len(embd_str.split())}')
assert(len(embd_str.split()) <= max_words_num), f'Too many words in the caption: {embd_str}'
embd = str2emb(embd_str, max_words_num, embeder, tokenizer, reduce_method=reduce_method)
print(embd)
embd_json[key]['embd'] = embd.tolist()[0]
embd_json[key]['embd_key'] = embd_str
# exit()
new_jsn_path = save_paths[dataset]
with open(new_jsn_path, 'w') as f:
json.dump(embd_json, f, indent=4)