File size: 15,486 Bytes
75854b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import os
import numpy as np
import math
from torch.nn import Tanh, BatchNorm1d
from typing import Optional
import torch.nn as nn
import torch
from transformers import BertModel, BertForSequenceClassification
from transformers import BertTokenizer
from transformers import AutoTokenizer, AutoModel

from torch.utils.data import Dataset as Dataset_n
from torch.utils.data import DataLoader as DataLoader_n
from torch.utils.data import WeightedRandomSampler

def _freeze_bert(
        bert_model: BertModel, freeze_bert=True, freeze_layer_count=-1
):
    """Freeze parameters in BertModel (in place)
    Args:
        bert_model: HuggingFace bert model
        freeze_bert: Bool whether to freeze the bert model
        freeze_layer_count: If freeze_bert, up to what layer to freeze.
    Returns:
        bert_model
    """
    if freeze_bert:
        # freeze the entire bert model
        for param in bert_model.parameters():
            param.requires_grad = False
    else:
        # freeze the embeddings
        for param in bert_model.embeddings.parameters():
            param.requires_grad = False
        if freeze_layer_count != -1:
            if freeze_layer_count > 0 :
                # freeze layers in bert_model.encoder
                for layer in bert_model.encoder.layer[:freeze_layer_count]:
                    for param in layer.parameters():
                        param.requires_grad = False

            if freeze_layer_count < 0 :
                # freeze layers in bert_model.encoder
                for layer in bert_model.encoder.layer[freeze_layer_count:]:
                    for param in layer.parameters():
                        param.requires_grad = False
    return None

def get_frozen_embeder(key_word="bert-large-uncased"):
    tokenizer = AutoTokenizer.from_pretrained(key_word, do_lower_case=False)
    model = AutoModel.from_pretrained(key_word)

    _freeze_bert(model, freeze_bert=True, freeze_layer_count=None)
    return model, tokenizer


def str2emb(string, max_words_num=100, embeder=None, tokenizer=None, reduce_method='mean'):
    string = string.lower()
    str_token = tokenizer(string, return_tensors='pt', max_length=max_words_num,
                          padding='max_length', truncation=True)    
    embeder_output = embeder(**str_token)
    if reduce_method == 'mean':
        embeder_output = torch.mean(embeder_output.last_hidden_state, dim=1)
    elif reduce_method == 'max':
        embeder_output = torch.max(embeder_output.last_hidden_state, dim=1)[0]
    else:
        embeder_output = embeder_output.last_hidden_state
    return embeder_output

def get_synonyms_dict(dict_type=None):
    '''
    Get the dictionary of synonyms for the specified dictionary type
    '''
    if dict_type == 'ROI':
        dict_synonyms = {
            'whole-body': ['whole-body', 'whole body', 'wholebody', 'whole body', 'whole-body', 'whole body', 'wholebody','polytrauma','head-neck-thorax-abdomen-pelvis-leg','head-neck-thorax-abdomen-pelvis'],
            'neck-thorax-abdomen-pelvis-leg': ['neck-thorax-abdomen-pelvis-leg','neck-thx-abd-pelvis-leg', 'angiography neck-thx-abd-pelvis-leg', 'neck thorax abdomen pelvis leg', 'neck and thorax and abdomen and pelvis and leg', 'neck, thorax, abdomen, pelvis & leg', 'neck/thorax/abdomen/pelvis/leg', 'neck, thorax, abdomen, pelvis and leg', 'neck thorax abdomen pelvis leg'],
            'neck-thorax-abdomen-pelvis': ['neck-thorax-abdomen-pelvis', 'neck-thx-abd-pelvis', 'neck thorax abdomen pelvis', 'neck and thorax and abdomen and pelvis', 'neck, thorax, abdomen & pelvis', 'neck/thorax/abdomen/pelvis', 'neck, thorax, abdomen and pelvis', 'neck thorax abdomen & pelvis'],
            'thorax-abdomen-pelvis-leg': ['thorax-abdomen-pelvis-leg','thx-abd-pelvis-leg', 'angiography thx-abd-pelvis-leg', 'thorax abdomen pelvis leg', 'thorax and abdomen and pelvis and leg', 'thorax, abdomen, pelvis & leg', 'thorax/abdomen/pelvis/leg', 'thorax, abdomen, pelvis and leg', 'thorax abdomen pelvis leg'],
            'neck-thorax-abdomen': ['neck-thorax-abdomen', 'neck-thorax-abdomen', 'neck thorax abdomen', 'neck and thorax and abdomen', 'neck, thorax, abdomen', 'neck/thorax/abdomen', 'neck, thorax, abdomen', 'neck thorax abdomen'],
            'head-neck-thorax-abdomen': ['head-neck-thorax-abdomen', 'head-neck-thorax-abdomen', 'head neck thorax abdomen', 'head and neck and thorax and abdomen', 'head, neck, thorax, abdomen', 'head/thorax/abdomen', 'head, thorax, abdomen', 'head thorax abdomen'],
            'head-neck-thorax': ['head-neck-thorax', 'head neck thorax', 'head and neck and thorax', 'head, neck, thorax', 'head/thorax', 'head, thorax', 'head thorax'],
            'thorax-abdomen-pelvis': ['thorax-abdomen-pelvis', 'thx-abd-pelvis', 'polytrauma', 'thorax abdomen pelvis', 'thorax and abdomen and pelvis', 'thorax, abdomen & pelvis', 'thorax/abdomen/pelvis', 'thorax, abdomen and pelvis', 'thorax abdomen & pelvis'],
            'abdomen-pelvis-leg': ['abdomen-pelvis-leg', 'angiography abdomen-pelvis-leg', 'abd-pelvis-leg', 'abdomen pelvis leg', 'abdomen and pelvis and leg', 'abdomen, pelvis & leg', 'abdomen/pelvis/leg', 'abdomen, pelvis, leg', 'abdomen pelvis leg'],
            'neck-thorax': ['neck-thorax', 'neck thorax', 'neck and thorax', 'neck, thorax', 'thorax-neck', 'thorax neck', 'thorax and neck', 'thorax, neck','thorax/neck'],
            'thorax-abdomen': ['thorax-abdomen', 'thorax abdomen', 'thorax and abdomen', 'thorax, abdomen', 'aortic valve'],
            'abdomen-pelvis': ['abdomen-pelvis', 'abdomen pelvis', 'abdomen and pelvis', 'abdomen & pelvis', 'abdomen/pelvis', 'abdomen-pelvis', 'abdomen pelvis', 'abdomen and pelvis', 'abdomen & pelvis', 'abdomen/pelvis'],
            'pelvis-leg': ['pelvis-leg', 'pelvis leg', 'pelvis and leg', 'pelvis, leg', 'pelvis/leg', 'pelvis-leg', 'pelvis leg', 'pelvis and leg', 'pelvis, leg', 'pelvis/leg'],
            'head-neck': ['head-neck', 'head neck', 'head and neck', 'head, neck', 'head/neck', 'head-neck', 'head neck', 'head and neck', 'head, neck', 'head/neck'],
            'abdomen': ['abdomen', 'abdominal', 'belly', 'stomach', 'tummy', 'gut', 'guts', 'viscera', 'bowels', 'intestines', 'gastrointestinal', 'digestive', 'peritoneum','gastric', 'liver', 'spleen', 'pancreas','kidney','lumbar','renal','hepatic','splenic','pancreatic','intervention'],
            'thorax': ['chest', 'thorax', 'breast', 'lung', 'heart','heart-thorakale aorta', 'heart-thorakale', 'mediastinum', 'pleura', 'bronchus', 'bronchi', 'trachea', 'esophagus', 'diaphragm', 'rib', 'sternum', 'clavicle', 'scapula', 'axilla', 'armpit','breast biopsy','thoracic','mammary','caeiothoracic','mediastinal','pleural','bronchial','bronchial tree','tracheal','esophageal','diaphragmatic','costal','sternal','clavicular','scapular','axillary','axillar','cardiac','pericardial','pericardiac','pericardium'],
            'head': ['head', 'headbasis', 'brain', 'skull', 'face','nose','ear','eye','mouth','jaw','cheek','chin','forehead','temporal','parietal','occipital','frontal','mandible','maxilla','mandibular','maxillary','nasal','orbital','orbita','ocular','auricular','otic','oral','buccal','labial','lingual','palatal'],
            'neck': ['neck', 'throat', 'cervical', 'thyroid', 'trachea', 'larynx', 'pharynx', 'esophagus','pharyngeal','laryngeal','cervical','thyroid','trachea','esophagus','carotid','jugular'],
            'hand': ['hand', 'finger', 'thumb', 'palm', 'wrist', 'knuckle', 'fingernail', 'phalanx', 'metacarpal', 'carpal', 'radius'],
            'arm': ['arm', 'forearm', 'upper arm', 'bicep', 'tricep', 'brachium', 'brachial', 'humerus', 'radius', 'ulna', 'elbow', 'shoulder', 'armpit''clavicle', 'scapula', 'acromion', 'acromioclavicular'],
            'leg': ['leg', 'felsenleg','thigh', 'calf', 'shin', 'knee', 'foot', 'ankle', 'toe', 'heel', 'sole', 'arch', 'instep', 'metatarsal', 'phalanx', 'tibia', 'fibula', 'femur', 'patella', 'kneecap','achilles tendon','achilles'],
            'pelvis': ['pelvis', 'hip', 'groin', 'buttock', 'gluteus', 'gluteal', 'ischium', 'pubis', 'sacrum', 'coccyx', 'acetabulum', 'iliac', 'iliac crest', 'iliac spine', 'iliac wing', 'sacroiliac', 'sacroiliac joint', 'sacroiliac ligament', 'sacroiliac spine', 'ureter', 'bladder', 'urethra', 'prostate', 'testicle', 'ovary', 'uterus',],
            'skeleton': ['skeleton','bone','spine', 'back', 'vertebra', 'sacrum', 'coccyx'],
        }
    elif dict_type == 'Label_tissue':
        dict_synonyms = {
            'liver': ['liver','hepatic'],
            'spleen': ['spleen','splenic'],
            'kidney': ['kidney','renal'],
            'pancreas': ['pancreas','pancreatic'],
            'stomach': ['stomach','gastric'],
            'intestine': ['large intestine', 'small intestine','large bowel','small bowel'],
            'gallbladder': ['gallbladder'],
            'adrenal_gland': ['adrenal_gland','adrenal gland'],
            'bladder': ['bladder'],
            'prostate': ['prostate'],
            'uterus': ['uterus'],
            'ovary': ['ovary'],
            'testicle': ['testicle'],
            'lymph_node': ['lymph_node','lymph node'],
            'bone': ['bone'],
            'lung': ['lung'],
            'heart': ['heart'],
            'esophagus': ['esophagus'],
            'muscle': ['muscle'],
            'fat': ['fat'],
            'skin': ['skin'],
            'vessel': ['vessel'],
            'tumor': ['tumor'],
            'other': ['other']
        }
    elif dict_type == 'Task':
        dict_synonyms = {
            'segmentation': ['segmentation', 'seg', 'mask'],
            'classification': ['classification', 'class', 'diagnosis','identify','identification'],
            'localization': ['localization', 'locate', 'location', 'position'],
            'registration': ['registration', 'register', 'align', 'alignment'],
            'detection': ['detection', 'detect', 'find', 'locate'],
            'quantification': ['quantification', 'quantify', 'measure', 'measurement'],
        }
    elif dict_type == 'Modality':
        dict_synonyms = {
            'CT': ['CT', 'computed tomography'],
            'MRI': ['MRI', 'MR', 'magnetic resonance imaging'],
            'PET': ['PET', 'positron emission tomography'],
            'US': ['US', 'ultrasound'],
            'X-ray': ['X-ray', 'radiography'],
            'SPECT': ['SPECT', 'single-photon emission computed tomlogy'],
        }
    else:
        dict_synonyms = {
            '\'gender\'': ['\'gender\'', '\'sex\'', '\'M/F\'', '\'m/f\''],
            '\'modality\'': ['\'modality\'', '\'modal\''],
            '\'male\'': ['\'male\'', '\'m\''],
            '\'female\'': ['\'female\'', '\'f\'','\'woman\''],
            '\'high-grade glioma\'': ['\'high-grade glioma\'', '\'high grade glioma\'', '\'HGG\''],
            '\'low-grade glioma\'': ['\'low-grade glioma\'', '\'low grade glioma\'', '\'LGG\''],
            '\'atlas scaling factor\'': ['\'atlas scaling factor\'', '\'asf\''],
            '\'age\'': ['\'age\'', '\'years\'', '\'year\'', '\'y/o\'', '\'y.o.\''],
            '\'education\'': ['\'educ\'', '\'educat\'', '\'education\''],
            '\'roi\'': ['\'roi\'', '\'region of interest\'', '\'region\''],
            '\'mini-mental state examination\'': ['\'mini-mental state examination\'', '\'mmse\''],
            '\'clinical dementia rating\'': ['\'clinical dementia rating\'', '\'cdr\''],
            '\'socio-economic status\'': ['\'socio-economic status\'', '\'ses\''],
            '\'unknown\'': ['\'unknown\'', '\'unkn\'', '\'not available\'', '\'nan\'', '\'n/a\'', '\'none\'', '\'n.a.\'', '\'not applicable\'','\'not specified\'', '\'unspecified\'', '\'not given\'', '\'null\''],
            '': [' segmentation', '\'seg\'', '\'registration\''],
        }
    return dict_synonyms

def replace_text(text, dict_synonyms):
    '''
    Replace the text in the text with the standard term
    '''
    if isinstance(text, str):
        for key, value in dict_synonyms.items():
            for v in value:
                if v.lower() in text.lower():
                    text = text.replace(v, key)
        return text
    elif isinstance(text, list):
        text = [replace_text(t, dict_synonyms) for t in text]
    elif isinstance(text, dict):
        for key in text.keys():
            # replace values in dict
            text[key] = replace_text(text[key], dict_synonyms)
            # replace keys in dict
            for k in dict_synonyms.keys():
                if k.lower() in key.lower():
                    text[dict_synonyms[k]] = text.pop(key)
    return text


def replace_synonyms(text, dict_synonyms):
    '''
    Replace the synonyms in the text with the standard term
    '''
    if isinstance(text,str):
        for key, value in dict_synonyms.items():
            for v in value:
                if v.lower() in text.lower():
                    return key
        Warning(f"Value {text} is not in the correct format")
    elif isinstance(text,list):
        text = [replace_synonyms(t, dict_synonyms) for t in text]
    elif isinstance(text,dict):
        for key in text.keys():
            # replace values in dict
            text[key] = replace_synonyms(text[key], dict_synonyms)
            # replace keys in dict
            for k in dict_synonyms.keys():
                text[dict_synonyms[k]] = text.pop(key)
    return text    

if __name__ == "__main__":
    # model_name = "bert-base-uncased"
    # model_name = "bert-large-uncased"
    model_name = "/home/jachin/data/Github/OmniMorph/External/Models/bert_large_uncased"
    # model_name = "Rostlab/prot_bert"
    # model_name = "fspanda/Medical-Bio-BERT2"
    # model_name = "GerMedBERT/medbert-512"

    reduce_method = 'mean'
    max_words_num = 32  # max number of words in the caption > 2
    
    embeder, tokenizer = get_frozen_embeder(model_name)

    # string1 = ["mri", "female"]
    string1 = "modality: ct, gender: female, age: 51, roi: abdomen"
    # string1 = "modality: Magnetic Resonance, gender: female"
    embeder_output1 = str2emb(string1, max_words_num, embeder, tokenizer, reduce_method=reduce_method)

    # string2 = "Hello world!"
    # string2 = ["ct", "male"]
    # string2 = "modality: mri, gender: female, roi: head"
    string2 = "modality: ct, gender: female, age: 50, roi: head"
    # string2 = "modality: ct, gender: male, roi: head"
    embeder_output2 = str2emb(string2, max_words_num, embeder, tokenizer, reduce_method=reduce_method)

    input_size = embeder.config.vocab_size
    in_size = embeder.config.hidden_size
    
    print(embeder, input_size, in_size)
    print(tokenizer)
    
    # embeder_output1 shape: [batch_size, max_words_num, hidden_size]
    print(embeder_output1)
    print(embeder_output1.shape)  # torch.Size([1, 8, 768])
    
    # embeder_output2 shape: [batch_size, max_words_num, hidden_size]
    print(embeder_output2)
    print(embeder_output2.shape)  # torch.Size([1, 8, 768])
    
    # check the difference between the two sentences in embedding space
    # embeder_output1[0, :, :] shape: [max_words_num, hidden_size]
    # embeder_output2[0, :, :] shape: [max_words_num, hidden_size]
    # error = torch.max(torch.abs(embeder_output1[0, :, :] - embeder_output2[0, :, :]), dim=-1)
    error = torch.abs(embeder_output1 - embeder_output2)
    print(error)
    print("Embedding distance between the two sentences: ")
    print(f"String1: {string1}")
    print(f"String2: {string2}")
    print(torch.mean(error))
    exit()