| import os |
| import numpy as np |
| import math |
| from torch.nn import Tanh, BatchNorm1d |
| from typing import Optional |
| import torch.nn as nn |
| import torch |
| from transformers import BertModel, BertForSequenceClassification |
| from transformers import BertTokenizer |
| from transformers import AutoTokenizer, AutoModel |
|
|
| from torch.utils.data import Dataset as Dataset_n |
| from torch.utils.data import DataLoader as DataLoader_n |
| from torch.utils.data import WeightedRandomSampler |
|
|
| def _freeze_bert( |
| bert_model: BertModel, freeze_bert=True, freeze_layer_count=-1 |
| ): |
| """Freeze parameters in BertModel (in place) |
| Args: |
| bert_model: HuggingFace bert model |
| freeze_bert: Bool whether to freeze the bert model |
| freeze_layer_count: If freeze_bert, up to what layer to freeze. |
| Returns: |
| bert_model |
| """ |
| if freeze_bert: |
| |
| for param in bert_model.parameters(): |
| param.requires_grad = False |
| else: |
| |
| for param in bert_model.embeddings.parameters(): |
| param.requires_grad = False |
| if freeze_layer_count != -1: |
| if freeze_layer_count > 0 : |
| |
| for layer in bert_model.encoder.layer[:freeze_layer_count]: |
| for param in layer.parameters(): |
| param.requires_grad = False |
|
|
| if freeze_layer_count < 0 : |
| |
| for layer in bert_model.encoder.layer[freeze_layer_count:]: |
| for param in layer.parameters(): |
| param.requires_grad = False |
| return None |
|
|
| def get_frozen_embeder(key_word="bert-large-uncased"): |
| tokenizer = AutoTokenizer.from_pretrained(key_word, do_lower_case=False) |
| model = AutoModel.from_pretrained(key_word) |
|
|
| _freeze_bert(model, freeze_bert=True, freeze_layer_count=None) |
| return model, tokenizer |
|
|
|
|
| def str2emb(string, max_words_num=100, embeder=None, tokenizer=None, reduce_method='mean'): |
| string = string.lower() |
| str_token = tokenizer(string, return_tensors='pt', max_length=max_words_num, |
| padding='max_length', truncation=True) |
| embeder_output = embeder(**str_token) |
| if reduce_method == 'mean': |
| embeder_output = torch.mean(embeder_output.last_hidden_state, dim=1) |
| elif reduce_method == 'max': |
| embeder_output = torch.max(embeder_output.last_hidden_state, dim=1)[0] |
| else: |
| embeder_output = embeder_output.last_hidden_state |
| return embeder_output |
|
|
| def get_synonyms_dict(dict_type=None): |
| ''' |
| Get the dictionary of synonyms for the specified dictionary type |
| ''' |
| if dict_type == 'ROI': |
| dict_synonyms = { |
| 'whole-body': ['whole-body', 'whole body', 'wholebody', 'whole body', 'whole-body', 'whole body', 'wholebody','polytrauma','head-neck-thorax-abdomen-pelvis-leg','head-neck-thorax-abdomen-pelvis'], |
| 'neck-thorax-abdomen-pelvis-leg': ['neck-thorax-abdomen-pelvis-leg','neck-thx-abd-pelvis-leg', 'angiography neck-thx-abd-pelvis-leg', 'neck thorax abdomen pelvis leg', 'neck and thorax and abdomen and pelvis and leg', 'neck, thorax, abdomen, pelvis & leg', 'neck/thorax/abdomen/pelvis/leg', 'neck, thorax, abdomen, pelvis and leg', 'neck thorax abdomen pelvis leg'], |
| 'neck-thorax-abdomen-pelvis': ['neck-thorax-abdomen-pelvis', 'neck-thx-abd-pelvis', 'neck thorax abdomen pelvis', 'neck and thorax and abdomen and pelvis', 'neck, thorax, abdomen & pelvis', 'neck/thorax/abdomen/pelvis', 'neck, thorax, abdomen and pelvis', 'neck thorax abdomen & pelvis'], |
| 'thorax-abdomen-pelvis-leg': ['thorax-abdomen-pelvis-leg','thx-abd-pelvis-leg', 'angiography thx-abd-pelvis-leg', 'thorax abdomen pelvis leg', 'thorax and abdomen and pelvis and leg', 'thorax, abdomen, pelvis & leg', 'thorax/abdomen/pelvis/leg', 'thorax, abdomen, pelvis and leg', 'thorax abdomen pelvis leg'], |
| 'neck-thorax-abdomen': ['neck-thorax-abdomen', 'neck-thorax-abdomen', 'neck thorax abdomen', 'neck and thorax and abdomen', 'neck, thorax, abdomen', 'neck/thorax/abdomen', 'neck, thorax, abdomen', 'neck thorax abdomen'], |
| 'head-neck-thorax-abdomen': ['head-neck-thorax-abdomen', 'head-neck-thorax-abdomen', 'head neck thorax abdomen', 'head and neck and thorax and abdomen', 'head, neck, thorax, abdomen', 'head/thorax/abdomen', 'head, thorax, abdomen', 'head thorax abdomen'], |
| 'head-neck-thorax': ['head-neck-thorax', 'head neck thorax', 'head and neck and thorax', 'head, neck, thorax', 'head/thorax', 'head, thorax', 'head thorax'], |
| 'thorax-abdomen-pelvis': ['thorax-abdomen-pelvis', 'thx-abd-pelvis', 'polytrauma', 'thorax abdomen pelvis', 'thorax and abdomen and pelvis', 'thorax, abdomen & pelvis', 'thorax/abdomen/pelvis', 'thorax, abdomen and pelvis', 'thorax abdomen & pelvis'], |
| 'abdomen-pelvis-leg': ['abdomen-pelvis-leg', 'angiography abdomen-pelvis-leg', 'abd-pelvis-leg', 'abdomen pelvis leg', 'abdomen and pelvis and leg', 'abdomen, pelvis & leg', 'abdomen/pelvis/leg', 'abdomen, pelvis, leg', 'abdomen pelvis leg'], |
| 'neck-thorax': ['neck-thorax', 'neck thorax', 'neck and thorax', 'neck, thorax', 'thorax-neck', 'thorax neck', 'thorax and neck', 'thorax, neck','thorax/neck'], |
| 'thorax-abdomen': ['thorax-abdomen', 'thorax abdomen', 'thorax and abdomen', 'thorax, abdomen', 'aortic valve'], |
| 'abdomen-pelvis': ['abdomen-pelvis', 'abdomen pelvis', 'abdomen and pelvis', 'abdomen & pelvis', 'abdomen/pelvis', 'abdomen-pelvis', 'abdomen pelvis', 'abdomen and pelvis', 'abdomen & pelvis', 'abdomen/pelvis'], |
| 'pelvis-leg': ['pelvis-leg', 'pelvis leg', 'pelvis and leg', 'pelvis, leg', 'pelvis/leg', 'pelvis-leg', 'pelvis leg', 'pelvis and leg', 'pelvis, leg', 'pelvis/leg'], |
| 'head-neck': ['head-neck', 'head neck', 'head and neck', 'head, neck', 'head/neck', 'head-neck', 'head neck', 'head and neck', 'head, neck', 'head/neck'], |
| 'abdomen': ['abdomen', 'abdominal', 'belly', 'stomach', 'tummy', 'gut', 'guts', 'viscera', 'bowels', 'intestines', 'gastrointestinal', 'digestive', 'peritoneum','gastric', 'liver', 'spleen', 'pancreas','kidney','lumbar','renal','hepatic','splenic','pancreatic','intervention'], |
| 'thorax': ['chest', 'thorax', 'breast', 'lung', 'heart','heart-thorakale aorta', 'heart-thorakale', 'mediastinum', 'pleura', 'bronchus', 'bronchi', 'trachea', 'esophagus', 'diaphragm', 'rib', 'sternum', 'clavicle', 'scapula', 'axilla', 'armpit','breast biopsy','thoracic','mammary','caeiothoracic','mediastinal','pleural','bronchial','bronchial tree','tracheal','esophageal','diaphragmatic','costal','sternal','clavicular','scapular','axillary','axillar','cardiac','pericardial','pericardiac','pericardium'], |
| 'head': ['head', 'headbasis', 'brain', 'skull', 'face','nose','ear','eye','mouth','jaw','cheek','chin','forehead','temporal','parietal','occipital','frontal','mandible','maxilla','mandibular','maxillary','nasal','orbital','orbita','ocular','auricular','otic','oral','buccal','labial','lingual','palatal'], |
| 'neck': ['neck', 'throat', 'cervical', 'thyroid', 'trachea', 'larynx', 'pharynx', 'esophagus','pharyngeal','laryngeal','cervical','thyroid','trachea','esophagus','carotid','jugular'], |
| 'hand': ['hand', 'finger', 'thumb', 'palm', 'wrist', 'knuckle', 'fingernail', 'phalanx', 'metacarpal', 'carpal', 'radius'], |
| 'arm': ['arm', 'forearm', 'upper arm', 'bicep', 'tricep', 'brachium', 'brachial', 'humerus', 'radius', 'ulna', 'elbow', 'shoulder', 'armpit''clavicle', 'scapula', 'acromion', 'acromioclavicular'], |
| 'leg': ['leg', 'felsenleg','thigh', 'calf', 'shin', 'knee', 'foot', 'ankle', 'toe', 'heel', 'sole', 'arch', 'instep', 'metatarsal', 'phalanx', 'tibia', 'fibula', 'femur', 'patella', 'kneecap','achilles tendon','achilles'], |
| 'pelvis': ['pelvis', 'hip', 'groin', 'buttock', 'gluteus', 'gluteal', 'ischium', 'pubis', 'sacrum', 'coccyx', 'acetabulum', 'iliac', 'iliac crest', 'iliac spine', 'iliac wing', 'sacroiliac', 'sacroiliac joint', 'sacroiliac ligament', 'sacroiliac spine', 'ureter', 'bladder', 'urethra', 'prostate', 'testicle', 'ovary', 'uterus',], |
| 'skeleton': ['skeleton','bone','spine', 'back', 'vertebra', 'sacrum', 'coccyx'], |
| } |
| elif dict_type == 'Label_tissue': |
| dict_synonyms = { |
| 'liver': ['liver','hepatic'], |
| 'spleen': ['spleen','splenic'], |
| 'kidney': ['kidney','renal'], |
| 'pancreas': ['pancreas','pancreatic'], |
| 'stomach': ['stomach','gastric'], |
| 'intestine': ['large intestine', 'small intestine','large bowel','small bowel'], |
| 'gallbladder': ['gallbladder'], |
| 'adrenal_gland': ['adrenal_gland','adrenal gland'], |
| 'bladder': ['bladder'], |
| 'prostate': ['prostate'], |
| 'uterus': ['uterus'], |
| 'ovary': ['ovary'], |
| 'testicle': ['testicle'], |
| 'lymph_node': ['lymph_node','lymph node'], |
| 'bone': ['bone'], |
| 'lung': ['lung'], |
| 'heart': ['heart'], |
| 'esophagus': ['esophagus'], |
| 'muscle': ['muscle'], |
| 'fat': ['fat'], |
| 'skin': ['skin'], |
| 'vessel': ['vessel'], |
| 'tumor': ['tumor'], |
| 'other': ['other'] |
| } |
| elif dict_type == 'Task': |
| dict_synonyms = { |
| 'segmentation': ['segmentation', 'seg', 'mask'], |
| 'classification': ['classification', 'class', 'diagnosis','identify','identification'], |
| 'localization': ['localization', 'locate', 'location', 'position'], |
| 'registration': ['registration', 'register', 'align', 'alignment'], |
| 'detection': ['detection', 'detect', 'find', 'locate'], |
| 'quantification': ['quantification', 'quantify', 'measure', 'measurement'], |
| } |
| elif dict_type == 'Modality': |
| dict_synonyms = { |
| 'CT': ['CT', 'computed tomography'], |
| 'MRI': ['MRI', 'MR', 'magnetic resonance imaging'], |
| 'PET': ['PET', 'positron emission tomography'], |
| 'US': ['US', 'ultrasound'], |
| 'X-ray': ['X-ray', 'radiography'], |
| 'SPECT': ['SPECT', 'single-photon emission computed tomlogy'], |
| } |
| else: |
| dict_synonyms = { |
| '\'gender\'': ['\'gender\'', '\'sex\'', '\'M/F\'', '\'m/f\''], |
| '\'modality\'': ['\'modality\'', '\'modal\''], |
| '\'male\'': ['\'male\'', '\'m\''], |
| '\'female\'': ['\'female\'', '\'f\'','\'woman\''], |
| '\'high-grade glioma\'': ['\'high-grade glioma\'', '\'high grade glioma\'', '\'HGG\''], |
| '\'low-grade glioma\'': ['\'low-grade glioma\'', '\'low grade glioma\'', '\'LGG\''], |
| '\'atlas scaling factor\'': ['\'atlas scaling factor\'', '\'asf\''], |
| '\'age\'': ['\'age\'', '\'years\'', '\'year\'', '\'y/o\'', '\'y.o.\''], |
| '\'education\'': ['\'educ\'', '\'educat\'', '\'education\''], |
| '\'roi\'': ['\'roi\'', '\'region of interest\'', '\'region\''], |
| '\'mini-mental state examination\'': ['\'mini-mental state examination\'', '\'mmse\''], |
| '\'clinical dementia rating\'': ['\'clinical dementia rating\'', '\'cdr\''], |
| '\'socio-economic status\'': ['\'socio-economic status\'', '\'ses\''], |
| '\'unknown\'': ['\'unknown\'', '\'unkn\'', '\'not available\'', '\'nan\'', '\'n/a\'', '\'none\'', '\'n.a.\'', '\'not applicable\'','\'not specified\'', '\'unspecified\'', '\'not given\'', '\'null\''], |
| '': [' segmentation', '\'seg\'', '\'registration\''], |
| } |
| return dict_synonyms |
|
|
| def replace_text(text, dict_synonyms): |
| ''' |
| Replace the text in the text with the standard term |
| ''' |
| if isinstance(text, str): |
| for key, value in dict_synonyms.items(): |
| for v in value: |
| if v.lower() in text.lower(): |
| text = text.replace(v, key) |
| return text |
| elif isinstance(text, list): |
| text = [replace_text(t, dict_synonyms) for t in text] |
| elif isinstance(text, dict): |
| for key in text.keys(): |
| |
| text[key] = replace_text(text[key], dict_synonyms) |
| |
| for k in dict_synonyms.keys(): |
| if k.lower() in key.lower(): |
| text[dict_synonyms[k]] = text.pop(key) |
| return text |
|
|
|
|
| def replace_synonyms(text, dict_synonyms): |
| ''' |
| Replace the synonyms in the text with the standard term |
| ''' |
| if isinstance(text,str): |
| for key, value in dict_synonyms.items(): |
| for v in value: |
| if v.lower() in text.lower(): |
| return key |
| Warning(f"Value {text} is not in the correct format") |
| elif isinstance(text,list): |
| text = [replace_synonyms(t, dict_synonyms) for t in text] |
| elif isinstance(text,dict): |
| for key in text.keys(): |
| |
| text[key] = replace_synonyms(text[key], dict_synonyms) |
| |
| for k in dict_synonyms.keys(): |
| text[dict_synonyms[k]] = text.pop(key) |
| return text |
|
|
| if __name__ == "__main__": |
| |
| |
| model_name = "/home/jachin/data/Github/OmniMorph/External/Models/bert_large_uncased" |
| |
| |
| |
|
|
| reduce_method = 'mean' |
| max_words_num = 32 |
| |
| embeder, tokenizer = get_frozen_embeder(model_name) |
|
|
| |
| string1 = "modality: ct, gender: female, age: 51, roi: abdomen" |
| |
| embeder_output1 = str2emb(string1, max_words_num, embeder, tokenizer, reduce_method=reduce_method) |
|
|
| |
| |
| |
| string2 = "modality: ct, gender: female, age: 50, roi: head" |
| |
| embeder_output2 = str2emb(string2, max_words_num, embeder, tokenizer, reduce_method=reduce_method) |
|
|
| input_size = embeder.config.vocab_size |
| in_size = embeder.config.hidden_size |
| |
| print(embeder, input_size, in_size) |
| print(tokenizer) |
| |
| |
| print(embeder_output1) |
| print(embeder_output1.shape) |
| |
| |
| print(embeder_output2) |
| print(embeder_output2.shape) |
| |
| |
| |
| |
| |
| error = torch.abs(embeder_output1 - embeder_output2) |
| print(error) |
| print("Embedding distance between the two sentences: ") |
| print(f"String1: {string1}") |
| print(f"String2: {string2}") |
| print(torch.mean(error)) |
| exit() |