import torch from transformers import AutoTokenizer from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence import lightning.pytorch as pl import config import pandas as pd import copy from ast import literal_eval from sklearn.model_selection import train_test_split import random def get_code_by_entity(entity, dictionary): """ Query the dictionary by entity and return its code. Return the key with the longest value list if multiple keys found. """ keys = [] length = [] for key, values in dictionary.items(): if entity in values: keys.append(key) length.append(len(values)) d = dict(zip(keys, length)) if len(d) > 0: return max(d, key=d.get) else: return None def num_ancestors(df, code): result = len(df.loc[df["concept"] == code, "ancestors"].values[0]) return result def get_score(df, code1, code2): result = df[ ((df["Code1"] == code1) & (df["Code2"] == code2)) | ((df["Code1"] == code2) & (df["Code2"] == code1)) ] if result.empty: return None return result.iloc[0]["score"] def mask(tokenizer, dictionary, unique_d, text, entities, anchor=True): """ Randomly select one entity from the entities, mask the first existence in the text and create duplicates with synonyms. The rest are treated as context. Returns a dictionary {input_ids, attention_mask, mlm_labels, masked_indices, tags}. """ if anchor is True: entity = random.choice(entities) code = get_code_by_entity(entity, dictionary) try: synonyms = dictionary[code] except: return None text_token = tokenizer.tokenize(text) ent_token = tokenizer.tokenize(entity.lower()) num_ent_token = len(ent_token) input_ids = [copy.deepcopy(text_token) for _ in range(len(synonyms))] mlm_labels = [copy.deepcopy(text_token) for _ in range(len(synonyms))] masked_indices = [] for i, t in enumerate(mlm_labels): start_indices = [ index for index, value in enumerate(t) if value == ent_token[0] ] masked_index = [] for start in start_indices: if ( tokenizer.convert_tokens_to_string(t[start : start + num_ent_token]) == entity.lower() ) and len(masked_index) == 0: syn = tokenizer.tokenize(synonyms[i]) mlm_labels[i][start : start + num_ent_token] = syn input_ids[i][start : start + num_ent_token] = ["[MASK]"] * len(syn) masked_index.extend(list(range(start, start + len(syn)))) masked_indices.append(masked_index) if any(not sublist for sublist in masked_indices): empty_mask_idx = [ k for k, sublist in enumerate(masked_indices) if not sublist ] input_ids = [x for i, x in enumerate(input_ids) if i not in empty_mask_idx] mlm_labels = [ x for i, x in enumerate(mlm_labels) if i not in empty_mask_idx ] masked_indices = [ sublist for k, sublist in enumerate(masked_indices) if sublist ] if len(input_ids) <= 1: return None input_ids_lst = [] attention_mask_lst = [] mlm_labels_lst = [] for j, token in enumerate(input_ids): input_id = torch.tensor(tokenizer.convert_tokens_to_ids(token)) input_ids_lst.append(input_id) attention_mask_lst.append(torch.ones_like(input_id)) mlm_label = torch.tensor(tokenizer.convert_tokens_to_ids(mlm_labels[j])) for l in range(len(mlm_label)): if l not in masked_indices[j]: mlm_label[l] = -100 mlm_labels_lst.append(mlm_label) tags = [1] * len(input_ids_lst) tags[0] = 0 codes = [code] * len(input_ids_lst) if code not in unique_d: return None out = { "input_ids": input_ids_lst, "attention_mask": attention_mask_lst, "mlm_labels": mlm_labels_lst, "masked_indices": masked_indices, "tags": tags, "codes": codes, } if anchor is False: entity = random.choice(entities) code = get_code_by_entity(entity, dictionary) input_ids = tokenizer.tokenize(text) mlm_labels = copy.deepcopy(input_ids) ent_token = tokenizer.tokenize(entity.lower()) num_ent_token = len(ent_token) masked_indices = [] start_indices = [] for i, t in enumerate(mlm_labels): if t == ent_token[0]: start_indices.append(i) for start in start_indices: if ( tokenizer.convert_tokens_to_string( input_ids[start : start + num_ent_token] ) == entity.lower() ) and len(masked_indices) == 0: input_ids[start : start + num_ent_token] = ["[MASK]"] * num_ent_token masked_indices.extend(list(range(start, start + num_ent_token))) if len(masked_indices) == 0: return None input_ids_lst = [] attention_mask_lst = [] mlm_labels_lst = [] input_id = torch.tensor(tokenizer.convert_tokens_to_ids(input_ids)) input_ids_lst.append(input_id) attention_mask_lst.append(torch.ones_like(input_id)) mlm_labels = tokenizer.convert_tokens_to_ids(mlm_labels) for l in range(len(mlm_labels)): if l not in masked_indices: mlm_labels[l] = -100 mlm_labels_lst.append(torch.tensor(mlm_labels)) tags = [2] * len(input_ids_lst) code = get_code_by_entity(entity, dictionary) if code not in unique_d: return None codes = [code] * len(input_ids_lst) out = { "input_ids": input_ids_lst, "attention_mask": attention_mask_lst, "mlm_labels": mlm_labels_lst, "masked_indices": masked_indices, "tags": tags, "codes": codes, } return out class CLDataset(Dataset): def __init__( self, data: pd.DataFrame, ): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): data_row = self.data.iloc[index] sentence = data_row.sentences concepts = data_row.concepts return [sentence, concepts] def collate_func(batch, tokenizer, dictionary, all_d, pairs): input_ids_lst = [] attention_mask_lst = [] mlm_labels_lst = [] masked_indices_lst = [] tags_lst = [] codes_lst = [] scores_lst = [] unique_d = pairs["Code1"].unique() anchor = batch[0] anchor_masked = mask(tokenizer, dictionary, unique_d, anchor[0], anchor[1]) while anchor_masked is None: batch = batch[1:] anchor = batch[0] anchor_masked = mask(tokenizer, dictionary, unique_d, anchor[0], anchor[1]) for i in range(len(anchor_masked["input_ids"])): input_ids_lst.append(anchor_masked["input_ids"][i]) attention_mask_lst.append(anchor_masked["attention_mask"][i]) mlm_labels_lst.append(anchor_masked["mlm_labels"][i]) masked_indices_lst.extend(anchor_masked["masked_indices"]) tags_lst.extend(anchor_masked["tags"]) codes_lst.extend(anchor_masked["codes"]) ap_code = anchor_masked["codes"][0] ap_score = num_ancestors(all_d, ap_code) scores_lst.extend([ap_score] * len(tags_lst)) negatives = batch[1:] for neg in negatives: neg_masked = mask(tokenizer, dictionary, unique_d, neg[0], neg[1], False) if neg_masked is None: continue for j in range(len(neg_masked["input_ids"])): input_ids_lst.append(neg_masked["input_ids"][j]) attention_mask_lst.append(neg_masked["attention_mask"][j]) mlm_labels_lst.extend(neg_masked["mlm_labels"]) masked_indices_lst.append(neg_masked["masked_indices"]) tags_lst.extend(neg_masked["tags"]) codes_lst.extend(neg_masked["codes"]) n_code = neg_masked["codes"][0] if n_code == ap_code: an_score = num_ancestors(all_d, n_code) else: an_score = get_score(pairs, ap_code, n_code) scores_lst.append(an_score) padded_input_ids = pad_sequence(input_ids_lst, padding_value=0) padded_input_ids = torch.transpose(padded_input_ids, 0, 1) padded_attention_mask = pad_sequence(attention_mask_lst, padding_value=0) padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1) padded_mlm_labels = pad_sequence(mlm_labels_lst, padding_value=-100) padded_mlm_labels = torch.transpose(padded_mlm_labels, 0, 1) return { "input_ids": padded_input_ids, "attention_mask": padded_attention_mask, "mlm_labels": padded_mlm_labels, "masked_indices": masked_indices_lst, "tags": tags_lst, "codes": codes_lst, "scores": scores_lst, } def create_dataloader(dataset, tokenizer, dictionary, all_d, pairs, shuffle): return DataLoader( dataset, batch_size=config.batch_size, shuffle=shuffle, num_workers=1, collate_fn=lambda batch: collate_func( batch, tokenizer, dictionary, all_d, pairs ), ) class CLDataModule(pl.LightningDataModule): def __init__(self, train_df, val_df, tokenizer, dictionary, all_d, pairs): super().__init__() self.train_df = train_df self.val_df = val_df self.tokenizer = tokenizer self.dictionary = dictionary self.all_d = all_d self.pairs = pairs def setup(self, stage=None): self.train_dataset = CLDataset(self.train_df) self.val_dataset = CLDataset(self.val_df) def train_dataloader(self): return create_dataloader( self.train_dataset, self.tokenizer, self.dictionary, self.all_d, self.pairs, shuffle=True, ) def val_dataloader(self): return create_dataloader( self.val_dataset, self.tokenizer, self.dictionary, self.all_d, self.pairs, shuffle=False, ) if __name__ == "__main__": query_df = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv" ) query_df["concepts"] = query_df["concepts"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply(literal_eval) query_df["codes"] = query_df["codes"].apply( lambda x: [val for val in x if val is not None] ) train_df, val_df = train_test_split(query_df, test_size=config.split_ratio) all_d = pd.read_csv( "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv" ) all_d.drop(columns=["finding_sites", "morphology"], inplace=True) all_d["synonyms"] = all_d["synonyms"].apply(literal_eval) all_d["ancestors"] = all_d["ancestors"].apply(literal_eval) dictionary = dict(zip(all_d["concept"], all_d["synonyms"])) pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv") tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") d = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs) d.setup() train = d.train_dataloader() for batch in train: b = batch break