CHOPT-NEW / dataset.py
sxtforreal's picture
Upload 5 files
975624b verified
import torch
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import lightning.pytorch as pl
import config
import pandas as pd
import copy
from ast import literal_eval
from sklearn.model_selection import train_test_split
import random
def get_code_by_entity(entity, dictionary):
"""
Query the dictionary by entity and return its code.
Return the key with the longest value list if multiple keys found.
"""
keys = []
length = []
for key, values in dictionary.items():
if entity in values:
keys.append(key)
length.append(len(values))
d = dict(zip(keys, length))
if len(d) > 0:
return max(d, key=d.get)
else:
return None
def num_ancestors(df, code):
result = len(df.loc[df["concept"] == code, "ancestors"].values[0])
return result
def get_score(df, code1, code2):
result = df[
((df["Code1"] == code1) & (df["Code2"] == code2))
| ((df["Code1"] == code2) & (df["Code2"] == code1))
]
if result.empty:
return None
return result.iloc[0]["score"]
def mask(tokenizer, dictionary, unique_d, text, entities, anchor=True):
"""
Randomly select one entity from the entities, mask the first existence in the text and create duplicates with synonyms. The rest are treated as context.
Returns a dictionary {input_ids, attention_mask, mlm_labels, masked_indices, tags}.
"""
if anchor is True:
entity = random.choice(entities)
code = get_code_by_entity(entity, dictionary)
try:
synonyms = dictionary[code]
except:
return None
text_token = tokenizer.tokenize(text)
ent_token = tokenizer.tokenize(entity.lower())
num_ent_token = len(ent_token)
input_ids = [copy.deepcopy(text_token) for _ in range(len(synonyms))]
mlm_labels = [copy.deepcopy(text_token) for _ in range(len(synonyms))]
masked_indices = []
for i, t in enumerate(mlm_labels):
start_indices = [
index for index, value in enumerate(t) if value == ent_token[0]
]
masked_index = []
for start in start_indices:
if (
tokenizer.convert_tokens_to_string(t[start : start + num_ent_token])
== entity.lower()
) and len(masked_index) == 0:
syn = tokenizer.tokenize(synonyms[i])
mlm_labels[i][start : start + num_ent_token] = syn
input_ids[i][start : start + num_ent_token] = ["[MASK]"] * len(syn)
masked_index.extend(list(range(start, start + len(syn))))
masked_indices.append(masked_index)
if any(not sublist for sublist in masked_indices):
empty_mask_idx = [
k for k, sublist in enumerate(masked_indices) if not sublist
]
input_ids = [x for i, x in enumerate(input_ids) if i not in empty_mask_idx]
mlm_labels = [
x for i, x in enumerate(mlm_labels) if i not in empty_mask_idx
]
masked_indices = [
sublist for k, sublist in enumerate(masked_indices) if sublist
]
if len(input_ids) <= 1:
return None
input_ids_lst = []
attention_mask_lst = []
mlm_labels_lst = []
for j, token in enumerate(input_ids):
input_id = torch.tensor(tokenizer.convert_tokens_to_ids(token))
input_ids_lst.append(input_id)
attention_mask_lst.append(torch.ones_like(input_id))
mlm_label = torch.tensor(tokenizer.convert_tokens_to_ids(mlm_labels[j]))
for l in range(len(mlm_label)):
if l not in masked_indices[j]:
mlm_label[l] = -100
mlm_labels_lst.append(mlm_label)
tags = [1] * len(input_ids_lst)
tags[0] = 0
codes = [code] * len(input_ids_lst)
if code not in unique_d:
return None
out = {
"input_ids": input_ids_lst,
"attention_mask": attention_mask_lst,
"mlm_labels": mlm_labels_lst,
"masked_indices": masked_indices,
"tags": tags,
"codes": codes,
}
if anchor is False:
entity = random.choice(entities)
code = get_code_by_entity(entity, dictionary)
input_ids = tokenizer.tokenize(text)
mlm_labels = copy.deepcopy(input_ids)
ent_token = tokenizer.tokenize(entity.lower())
num_ent_token = len(ent_token)
masked_indices = []
start_indices = []
for i, t in enumerate(mlm_labels):
if t == ent_token[0]:
start_indices.append(i)
for start in start_indices:
if (
tokenizer.convert_tokens_to_string(
input_ids[start : start + num_ent_token]
)
== entity.lower()
) and len(masked_indices) == 0:
input_ids[start : start + num_ent_token] = ["[MASK]"] * num_ent_token
masked_indices.extend(list(range(start, start + num_ent_token)))
if len(masked_indices) == 0:
return None
input_ids_lst = []
attention_mask_lst = []
mlm_labels_lst = []
input_id = torch.tensor(tokenizer.convert_tokens_to_ids(input_ids))
input_ids_lst.append(input_id)
attention_mask_lst.append(torch.ones_like(input_id))
mlm_labels = tokenizer.convert_tokens_to_ids(mlm_labels)
for l in range(len(mlm_labels)):
if l not in masked_indices:
mlm_labels[l] = -100
mlm_labels_lst.append(torch.tensor(mlm_labels))
tags = [2] * len(input_ids_lst)
code = get_code_by_entity(entity, dictionary)
if code not in unique_d:
return None
codes = [code] * len(input_ids_lst)
out = {
"input_ids": input_ids_lst,
"attention_mask": attention_mask_lst,
"mlm_labels": mlm_labels_lst,
"masked_indices": masked_indices,
"tags": tags,
"codes": codes,
}
return out
class CLDataset(Dataset):
def __init__(
self,
data: pd.DataFrame,
):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data_row = self.data.iloc[index]
sentence = data_row.sentences
concepts = data_row.concepts
return [sentence, concepts]
def collate_func(batch, tokenizer, dictionary, all_d, pairs):
input_ids_lst = []
attention_mask_lst = []
mlm_labels_lst = []
masked_indices_lst = []
tags_lst = []
codes_lst = []
scores_lst = []
unique_d = pairs["Code1"].unique()
anchor = batch[0]
anchor_masked = mask(tokenizer, dictionary, unique_d, anchor[0], anchor[1])
while anchor_masked is None:
batch = batch[1:]
anchor = batch[0]
anchor_masked = mask(tokenizer, dictionary, unique_d, anchor[0], anchor[1])
for i in range(len(anchor_masked["input_ids"])):
input_ids_lst.append(anchor_masked["input_ids"][i])
attention_mask_lst.append(anchor_masked["attention_mask"][i])
mlm_labels_lst.append(anchor_masked["mlm_labels"][i])
masked_indices_lst.extend(anchor_masked["masked_indices"])
tags_lst.extend(anchor_masked["tags"])
codes_lst.extend(anchor_masked["codes"])
ap_code = anchor_masked["codes"][0]
ap_score = num_ancestors(all_d, ap_code)
scores_lst.extend([ap_score] * len(tags_lst))
negatives = batch[1:]
for neg in negatives:
neg_masked = mask(tokenizer, dictionary, unique_d, neg[0], neg[1], False)
if neg_masked is None:
continue
for j in range(len(neg_masked["input_ids"])):
input_ids_lst.append(neg_masked["input_ids"][j])
attention_mask_lst.append(neg_masked["attention_mask"][j])
mlm_labels_lst.extend(neg_masked["mlm_labels"])
masked_indices_lst.append(neg_masked["masked_indices"])
tags_lst.extend(neg_masked["tags"])
codes_lst.extend(neg_masked["codes"])
n_code = neg_masked["codes"][0]
if n_code == ap_code:
an_score = num_ancestors(all_d, n_code)
else:
an_score = get_score(pairs, ap_code, n_code)
scores_lst.append(an_score)
padded_input_ids = pad_sequence(input_ids_lst, padding_value=0)
padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
padded_attention_mask = pad_sequence(attention_mask_lst, padding_value=0)
padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
padded_mlm_labels = pad_sequence(mlm_labels_lst, padding_value=-100)
padded_mlm_labels = torch.transpose(padded_mlm_labels, 0, 1)
return {
"input_ids": padded_input_ids,
"attention_mask": padded_attention_mask,
"mlm_labels": padded_mlm_labels,
"masked_indices": masked_indices_lst,
"tags": tags_lst,
"codes": codes_lst,
"scores": scores_lst,
}
def create_dataloader(dataset, tokenizer, dictionary, all_d, pairs, shuffle):
return DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=shuffle,
num_workers=1,
collate_fn=lambda batch: collate_func(
batch, tokenizer, dictionary, all_d, pairs
),
)
class CLDataModule(pl.LightningDataModule):
def __init__(self, train_df, val_df, tokenizer, dictionary, all_d, pairs):
super().__init__()
self.train_df = train_df
self.val_df = val_df
self.tokenizer = tokenizer
self.dictionary = dictionary
self.all_d = all_d
self.pairs = pairs
def setup(self, stage=None):
self.train_dataset = CLDataset(self.train_df)
self.val_dataset = CLDataset(self.val_df)
def train_dataloader(self):
return create_dataloader(
self.train_dataset,
self.tokenizer,
self.dictionary,
self.all_d,
self.pairs,
shuffle=True,
)
def val_dataloader(self):
return create_dataloader(
self.val_dataset,
self.tokenizer,
self.dictionary,
self.all_d,
self.pairs,
shuffle=False,
)
if __name__ == "__main__":
query_df = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv"
)
query_df["concepts"] = query_df["concepts"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(literal_eval)
query_df["codes"] = query_df["codes"].apply(
lambda x: [val for val in x if val is not None]
)
train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
all_d = pd.read_csv(
"/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv"
)
all_d.drop(columns=["finding_sites", "morphology"], inplace=True)
all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv")
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
d = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs)
d.setup()
train = d.train_dataloader()
for batch in train:
b = batch
break