| import torch | |
| from transformers import AutoTokenizer | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.nn.utils.rnn import pad_sequence | |
| import lightning.pytorch as pl | |
| import config | |
| import pandas as pd | |
| import copy | |
| from ast import literal_eval | |
| from sklearn.model_selection import train_test_split | |
| import random | |
| def get_code_by_entity(entity, dictionary): | |
| """ | |
| Query the dictionary by entity and return its code. | |
| Return the key with the longest value list if multiple keys found. | |
| """ | |
| keys = [] | |
| length = [] | |
| for key, values in dictionary.items(): | |
| if entity in values: | |
| keys.append(key) | |
| length.append(len(values)) | |
| d = dict(zip(keys, length)) | |
| if len(d) > 0: | |
| return max(d, key=d.get) | |
| else: | |
| return None | |
| def num_ancestors(df, code): | |
| result = len(df.loc[df["concept"] == code, "ancestors"].values[0]) | |
| return result | |
| def get_score(df, code1, code2): | |
| result = df[ | |
| ((df["Code1"] == code1) & (df["Code2"] == code2)) | |
| | ((df["Code1"] == code2) & (df["Code2"] == code1)) | |
| ] | |
| if result.empty: | |
| return None | |
| return result.iloc[0]["score"] | |
| def mask(tokenizer, dictionary, unique_d, text, entities, anchor=True): | |
| """ | |
| Randomly select one entity from the entities, mask the first existence in the text and create duplicates with synonyms. The rest are treated as context. | |
| Returns a dictionary {input_ids, attention_mask, mlm_labels, masked_indices, tags}. | |
| """ | |
| if anchor is True: | |
| entity = random.choice(entities) | |
| code = get_code_by_entity(entity, dictionary) | |
| try: | |
| synonyms = dictionary[code] | |
| except: | |
| return None | |
| text_token = tokenizer.tokenize(text) | |
| ent_token = tokenizer.tokenize(entity.lower()) | |
| num_ent_token = len(ent_token) | |
| input_ids = [copy.deepcopy(text_token) for _ in range(len(synonyms))] | |
| mlm_labels = [copy.deepcopy(text_token) for _ in range(len(synonyms))] | |
| masked_indices = [] | |
| for i, t in enumerate(mlm_labels): | |
| start_indices = [ | |
| index for index, value in enumerate(t) if value == ent_token[0] | |
| ] | |
| masked_index = [] | |
| for start in start_indices: | |
| if ( | |
| tokenizer.convert_tokens_to_string(t[start : start + num_ent_token]) | |
| == entity.lower() | |
| ) and len(masked_index) == 0: | |
| syn = tokenizer.tokenize(synonyms[i]) | |
| mlm_labels[i][start : start + num_ent_token] = syn | |
| input_ids[i][start : start + num_ent_token] = ["[MASK]"] * len(syn) | |
| masked_index.extend(list(range(start, start + len(syn)))) | |
| masked_indices.append(masked_index) | |
| if any(not sublist for sublist in masked_indices): | |
| empty_mask_idx = [ | |
| k for k, sublist in enumerate(masked_indices) if not sublist | |
| ] | |
| input_ids = [x for i, x in enumerate(input_ids) if i not in empty_mask_idx] | |
| mlm_labels = [ | |
| x for i, x in enumerate(mlm_labels) if i not in empty_mask_idx | |
| ] | |
| masked_indices = [ | |
| sublist for k, sublist in enumerate(masked_indices) if sublist | |
| ] | |
| if len(input_ids) <= 1: | |
| return None | |
| input_ids_lst = [] | |
| attention_mask_lst = [] | |
| mlm_labels_lst = [] | |
| for j, token in enumerate(input_ids): | |
| input_id = torch.tensor(tokenizer.convert_tokens_to_ids(token)) | |
| input_ids_lst.append(input_id) | |
| attention_mask_lst.append(torch.ones_like(input_id)) | |
| mlm_label = torch.tensor(tokenizer.convert_tokens_to_ids(mlm_labels[j])) | |
| for l in range(len(mlm_label)): | |
| if l not in masked_indices[j]: | |
| mlm_label[l] = -100 | |
| mlm_labels_lst.append(mlm_label) | |
| tags = [1] * len(input_ids_lst) | |
| tags[0] = 0 | |
| codes = [code] * len(input_ids_lst) | |
| if code not in unique_d: | |
| return None | |
| out = { | |
| "input_ids": input_ids_lst, | |
| "attention_mask": attention_mask_lst, | |
| "mlm_labels": mlm_labels_lst, | |
| "masked_indices": masked_indices, | |
| "tags": tags, | |
| "codes": codes, | |
| } | |
| if anchor is False: | |
| entity = random.choice(entities) | |
| code = get_code_by_entity(entity, dictionary) | |
| input_ids = tokenizer.tokenize(text) | |
| mlm_labels = copy.deepcopy(input_ids) | |
| ent_token = tokenizer.tokenize(entity.lower()) | |
| num_ent_token = len(ent_token) | |
| masked_indices = [] | |
| start_indices = [] | |
| for i, t in enumerate(mlm_labels): | |
| if t == ent_token[0]: | |
| start_indices.append(i) | |
| for start in start_indices: | |
| if ( | |
| tokenizer.convert_tokens_to_string( | |
| input_ids[start : start + num_ent_token] | |
| ) | |
| == entity.lower() | |
| ) and len(masked_indices) == 0: | |
| input_ids[start : start + num_ent_token] = ["[MASK]"] * num_ent_token | |
| masked_indices.extend(list(range(start, start + num_ent_token))) | |
| if len(masked_indices) == 0: | |
| return None | |
| input_ids_lst = [] | |
| attention_mask_lst = [] | |
| mlm_labels_lst = [] | |
| input_id = torch.tensor(tokenizer.convert_tokens_to_ids(input_ids)) | |
| input_ids_lst.append(input_id) | |
| attention_mask_lst.append(torch.ones_like(input_id)) | |
| mlm_labels = tokenizer.convert_tokens_to_ids(mlm_labels) | |
| for l in range(len(mlm_labels)): | |
| if l not in masked_indices: | |
| mlm_labels[l] = -100 | |
| mlm_labels_lst.append(torch.tensor(mlm_labels)) | |
| tags = [2] * len(input_ids_lst) | |
| code = get_code_by_entity(entity, dictionary) | |
| if code not in unique_d: | |
| return None | |
| codes = [code] * len(input_ids_lst) | |
| out = { | |
| "input_ids": input_ids_lst, | |
| "attention_mask": attention_mask_lst, | |
| "mlm_labels": mlm_labels_lst, | |
| "masked_indices": masked_indices, | |
| "tags": tags, | |
| "codes": codes, | |
| } | |
| return out | |
| class CLDataset(Dataset): | |
| def __init__( | |
| self, | |
| data: pd.DataFrame, | |
| ): | |
| self.data = data | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, index): | |
| data_row = self.data.iloc[index] | |
| sentence = data_row.sentences | |
| concepts = data_row.concepts | |
| return [sentence, concepts] | |
| def collate_func(batch, tokenizer, dictionary, all_d, pairs): | |
| input_ids_lst = [] | |
| attention_mask_lst = [] | |
| mlm_labels_lst = [] | |
| masked_indices_lst = [] | |
| tags_lst = [] | |
| codes_lst = [] | |
| scores_lst = [] | |
| unique_d = pairs["Code1"].unique() | |
| anchor = batch[0] | |
| anchor_masked = mask(tokenizer, dictionary, unique_d, anchor[0], anchor[1]) | |
| while anchor_masked is None: | |
| batch = batch[1:] | |
| anchor = batch[0] | |
| anchor_masked = mask(tokenizer, dictionary, unique_d, anchor[0], anchor[1]) | |
| for i in range(len(anchor_masked["input_ids"])): | |
| input_ids_lst.append(anchor_masked["input_ids"][i]) | |
| attention_mask_lst.append(anchor_masked["attention_mask"][i]) | |
| mlm_labels_lst.append(anchor_masked["mlm_labels"][i]) | |
| masked_indices_lst.extend(anchor_masked["masked_indices"]) | |
| tags_lst.extend(anchor_masked["tags"]) | |
| codes_lst.extend(anchor_masked["codes"]) | |
| ap_code = anchor_masked["codes"][0] | |
| ap_score = num_ancestors(all_d, ap_code) | |
| scores_lst.extend([ap_score] * len(tags_lst)) | |
| negatives = batch[1:] | |
| for neg in negatives: | |
| neg_masked = mask(tokenizer, dictionary, unique_d, neg[0], neg[1], False) | |
| if neg_masked is None: | |
| continue | |
| for j in range(len(neg_masked["input_ids"])): | |
| input_ids_lst.append(neg_masked["input_ids"][j]) | |
| attention_mask_lst.append(neg_masked["attention_mask"][j]) | |
| mlm_labels_lst.extend(neg_masked["mlm_labels"]) | |
| masked_indices_lst.append(neg_masked["masked_indices"]) | |
| tags_lst.extend(neg_masked["tags"]) | |
| codes_lst.extend(neg_masked["codes"]) | |
| n_code = neg_masked["codes"][0] | |
| if n_code == ap_code: | |
| an_score = num_ancestors(all_d, n_code) | |
| else: | |
| an_score = get_score(pairs, ap_code, n_code) | |
| scores_lst.append(an_score) | |
| padded_input_ids = pad_sequence(input_ids_lst, padding_value=0) | |
| padded_input_ids = torch.transpose(padded_input_ids, 0, 1) | |
| padded_attention_mask = pad_sequence(attention_mask_lst, padding_value=0) | |
| padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) | |
| padded_mlm_labels = pad_sequence(mlm_labels_lst, padding_value=-100) | |
| padded_mlm_labels = torch.transpose(padded_mlm_labels, 0, 1) | |
| return { | |
| "input_ids": padded_input_ids, | |
| "attention_mask": padded_attention_mask, | |
| "mlm_labels": padded_mlm_labels, | |
| "masked_indices": masked_indices_lst, | |
| "tags": tags_lst, | |
| "codes": codes_lst, | |
| "scores": scores_lst, | |
| } | |
| def create_dataloader(dataset, tokenizer, dictionary, all_d, pairs, shuffle): | |
| return DataLoader( | |
| dataset, | |
| batch_size=config.batch_size, | |
| shuffle=shuffle, | |
| num_workers=1, | |
| collate_fn=lambda batch: collate_func( | |
| batch, tokenizer, dictionary, all_d, pairs | |
| ), | |
| ) | |
| class CLDataModule(pl.LightningDataModule): | |
| def __init__(self, train_df, val_df, tokenizer, dictionary, all_d, pairs): | |
| super().__init__() | |
| self.train_df = train_df | |
| self.val_df = val_df | |
| self.tokenizer = tokenizer | |
| self.dictionary = dictionary | |
| self.all_d = all_d | |
| self.pairs = pairs | |
| def setup(self, stage=None): | |
| self.train_dataset = CLDataset(self.train_df) | |
| self.val_dataset = CLDataset(self.val_df) | |
| def train_dataloader(self): | |
| return create_dataloader( | |
| self.train_dataset, | |
| self.tokenizer, | |
| self.dictionary, | |
| self.all_d, | |
| self.pairs, | |
| shuffle=True, | |
| ) | |
| def val_dataloader(self): | |
| return create_dataloader( | |
| self.val_dataset, | |
| self.tokenizer, | |
| self.dictionary, | |
| self.all_d, | |
| self.pairs, | |
| shuffle=False, | |
| ) | |
| if __name__ == "__main__": | |
| query_df = pd.read_csv( | |
| "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv" | |
| ) | |
| query_df["concepts"] = query_df["concepts"].apply(literal_eval) | |
| query_df["codes"] = query_df["codes"].apply(literal_eval) | |
| query_df["codes"] = query_df["codes"].apply( | |
| lambda x: [val for val in x if val is not None] | |
| ) | |
| train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) | |
| all_d = pd.read_csv( | |
| "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv" | |
| ) | |
| all_d.drop(columns=["finding_sites", "morphology"], inplace=True) | |
| all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) | |
| all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) | |
| dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) | |
| pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv") | |
| tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
| d = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs) | |
| d.setup() | |
| train = d.train_dataloader() | |
| for batch in train: | |
| b = batch | |
| break | |