sxtforreal commited on
Commit
975624b
·
verified ·
1 Parent(s): d82d662

Upload 5 files

Browse files
Files changed (5) hide show
  1. config.py +29 -0
  2. dataset.py +351 -0
  3. loss.py +58 -0
  4. model.py +305 -0
  5. train.py +81 -0
config.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Set device cuda for GPU if it is available, otherwise run on the CPU
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ # loss
7
+ t_p = 0.25 # optimal: 1/8 ~ 1/32
8
+ zeta = 3 # optimal: 2 ~ 5
9
+ # m = 0.2
10
+
11
+ # Training hyperparameters
12
+ min_epochs = 3
13
+ max_epochs = 30
14
+ learning_rate = 5e-5
15
+ unfreeze_ratio = 1
16
+ mlm_weight = 0.5 # optimal: 0.5~0.75
17
+
18
+ # Dataset
19
+ batch_size = 100
20
+ split_ratio = 0.2
21
+
22
+ # Logger
23
+ log_every_n_steps = 50
24
+ ckcpt_every_n_steps = 5000
25
+
26
+ # Compute related
27
+ accelerator = "gpu"
28
+ devices = 1 # number of gpus
29
+ precision = "16-mixed"
dataset.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ import lightning.pytorch as pl
6
+ import config
7
+ import pandas as pd
8
+ import copy
9
+ from ast import literal_eval
10
+ from sklearn.model_selection import train_test_split
11
+ import random
12
+
13
+
14
+ def get_code_by_entity(entity, dictionary):
15
+ """
16
+ Query the dictionary by entity and return its code.
17
+ Return the key with the longest value list if multiple keys found.
18
+ """
19
+ keys = []
20
+ length = []
21
+ for key, values in dictionary.items():
22
+ if entity in values:
23
+ keys.append(key)
24
+ length.append(len(values))
25
+ d = dict(zip(keys, length))
26
+ if len(d) > 0:
27
+ return max(d, key=d.get)
28
+ else:
29
+ return None
30
+
31
+
32
+ def num_ancestors(df, code):
33
+ result = len(df.loc[df["concept"] == code, "ancestors"].values[0])
34
+ return result
35
+
36
+
37
+ def get_score(df, code1, code2):
38
+ result = df[
39
+ ((df["Code1"] == code1) & (df["Code2"] == code2))
40
+ | ((df["Code1"] == code2) & (df["Code2"] == code1))
41
+ ]
42
+
43
+ if result.empty:
44
+ return None
45
+
46
+ return result.iloc[0]["score"]
47
+
48
+
49
+ def mask(tokenizer, dictionary, unique_d, text, entities, anchor=True):
50
+ """
51
+ Randomly select one entity from the entities, mask the first existence in the text and create duplicates with synonyms. The rest are treated as context.
52
+
53
+ Returns a dictionary {input_ids, attention_mask, mlm_labels, masked_indices, tags}.
54
+ """
55
+ if anchor is True:
56
+ entity = random.choice(entities)
57
+ code = get_code_by_entity(entity, dictionary)
58
+ try:
59
+ synonyms = dictionary[code]
60
+ except:
61
+ return None
62
+ text_token = tokenizer.tokenize(text)
63
+ ent_token = tokenizer.tokenize(entity.lower())
64
+ num_ent_token = len(ent_token)
65
+
66
+ input_ids = [copy.deepcopy(text_token) for _ in range(len(synonyms))]
67
+ mlm_labels = [copy.deepcopy(text_token) for _ in range(len(synonyms))]
68
+ masked_indices = []
69
+
70
+ for i, t in enumerate(mlm_labels):
71
+ start_indices = [
72
+ index for index, value in enumerate(t) if value == ent_token[0]
73
+ ]
74
+ masked_index = []
75
+ for start in start_indices:
76
+ if (
77
+ tokenizer.convert_tokens_to_string(t[start : start + num_ent_token])
78
+ == entity.lower()
79
+ ) and len(masked_index) == 0:
80
+ syn = tokenizer.tokenize(synonyms[i])
81
+ mlm_labels[i][start : start + num_ent_token] = syn
82
+ input_ids[i][start : start + num_ent_token] = ["[MASK]"] * len(syn)
83
+ masked_index.extend(list(range(start, start + len(syn))))
84
+ masked_indices.append(masked_index)
85
+
86
+ if any(not sublist for sublist in masked_indices):
87
+ empty_mask_idx = [
88
+ k for k, sublist in enumerate(masked_indices) if not sublist
89
+ ]
90
+ input_ids = [x for i, x in enumerate(input_ids) if i not in empty_mask_idx]
91
+ mlm_labels = [
92
+ x for i, x in enumerate(mlm_labels) if i not in empty_mask_idx
93
+ ]
94
+ masked_indices = [
95
+ sublist for k, sublist in enumerate(masked_indices) if sublist
96
+ ]
97
+
98
+ if len(input_ids) <= 1:
99
+ return None
100
+
101
+ input_ids_lst = []
102
+ attention_mask_lst = []
103
+ mlm_labels_lst = []
104
+
105
+ for j, token in enumerate(input_ids):
106
+ input_id = torch.tensor(tokenizer.convert_tokens_to_ids(token))
107
+ input_ids_lst.append(input_id)
108
+ attention_mask_lst.append(torch.ones_like(input_id))
109
+ mlm_label = torch.tensor(tokenizer.convert_tokens_to_ids(mlm_labels[j]))
110
+ for l in range(len(mlm_label)):
111
+ if l not in masked_indices[j]:
112
+ mlm_label[l] = -100
113
+ mlm_labels_lst.append(mlm_label)
114
+
115
+ tags = [1] * len(input_ids_lst)
116
+ tags[0] = 0
117
+ codes = [code] * len(input_ids_lst)
118
+ if code not in unique_d:
119
+ return None
120
+
121
+ out = {
122
+ "input_ids": input_ids_lst,
123
+ "attention_mask": attention_mask_lst,
124
+ "mlm_labels": mlm_labels_lst,
125
+ "masked_indices": masked_indices,
126
+ "tags": tags,
127
+ "codes": codes,
128
+ }
129
+
130
+ if anchor is False:
131
+ entity = random.choice(entities)
132
+ code = get_code_by_entity(entity, dictionary)
133
+ input_ids = tokenizer.tokenize(text)
134
+ mlm_labels = copy.deepcopy(input_ids)
135
+ ent_token = tokenizer.tokenize(entity.lower())
136
+ num_ent_token = len(ent_token)
137
+ masked_indices = []
138
+
139
+ start_indices = []
140
+ for i, t in enumerate(mlm_labels):
141
+ if t == ent_token[0]:
142
+ start_indices.append(i)
143
+
144
+ for start in start_indices:
145
+ if (
146
+ tokenizer.convert_tokens_to_string(
147
+ input_ids[start : start + num_ent_token]
148
+ )
149
+ == entity.lower()
150
+ ) and len(masked_indices) == 0:
151
+ input_ids[start : start + num_ent_token] = ["[MASK]"] * num_ent_token
152
+ masked_indices.extend(list(range(start, start + num_ent_token)))
153
+
154
+ if len(masked_indices) == 0:
155
+ return None
156
+
157
+ input_ids_lst = []
158
+ attention_mask_lst = []
159
+ mlm_labels_lst = []
160
+
161
+ input_id = torch.tensor(tokenizer.convert_tokens_to_ids(input_ids))
162
+ input_ids_lst.append(input_id)
163
+ attention_mask_lst.append(torch.ones_like(input_id))
164
+ mlm_labels = tokenizer.convert_tokens_to_ids(mlm_labels)
165
+ for l in range(len(mlm_labels)):
166
+ if l not in masked_indices:
167
+ mlm_labels[l] = -100
168
+ mlm_labels_lst.append(torch.tensor(mlm_labels))
169
+
170
+ tags = [2] * len(input_ids_lst)
171
+ code = get_code_by_entity(entity, dictionary)
172
+
173
+ if code not in unique_d:
174
+ return None
175
+
176
+ codes = [code] * len(input_ids_lst)
177
+
178
+ out = {
179
+ "input_ids": input_ids_lst,
180
+ "attention_mask": attention_mask_lst,
181
+ "mlm_labels": mlm_labels_lst,
182
+ "masked_indices": masked_indices,
183
+ "tags": tags,
184
+ "codes": codes,
185
+ }
186
+
187
+ return out
188
+
189
+
190
+ class CLDataset(Dataset):
191
+ def __init__(
192
+ self,
193
+ data: pd.DataFrame,
194
+ ):
195
+ self.data = data
196
+
197
+ def __len__(self):
198
+ return len(self.data)
199
+
200
+ def __getitem__(self, index):
201
+ data_row = self.data.iloc[index]
202
+ sentence = data_row.sentences
203
+ concepts = data_row.concepts
204
+ return [sentence, concepts]
205
+
206
+
207
+ def collate_func(batch, tokenizer, dictionary, all_d, pairs):
208
+ input_ids_lst = []
209
+ attention_mask_lst = []
210
+ mlm_labels_lst = []
211
+ masked_indices_lst = []
212
+ tags_lst = []
213
+ codes_lst = []
214
+ scores_lst = []
215
+
216
+ unique_d = pairs["Code1"].unique()
217
+
218
+ anchor = batch[0]
219
+ anchor_masked = mask(tokenizer, dictionary, unique_d, anchor[0], anchor[1])
220
+ while anchor_masked is None:
221
+ batch = batch[1:]
222
+ anchor = batch[0]
223
+ anchor_masked = mask(tokenizer, dictionary, unique_d, anchor[0], anchor[1])
224
+
225
+ for i in range(len(anchor_masked["input_ids"])):
226
+ input_ids_lst.append(anchor_masked["input_ids"][i])
227
+ attention_mask_lst.append(anchor_masked["attention_mask"][i])
228
+ mlm_labels_lst.append(anchor_masked["mlm_labels"][i])
229
+ masked_indices_lst.extend(anchor_masked["masked_indices"])
230
+ tags_lst.extend(anchor_masked["tags"])
231
+ codes_lst.extend(anchor_masked["codes"])
232
+ ap_code = anchor_masked["codes"][0]
233
+ ap_score = num_ancestors(all_d, ap_code)
234
+ scores_lst.extend([ap_score] * len(tags_lst))
235
+
236
+ negatives = batch[1:]
237
+ for neg in negatives:
238
+ neg_masked = mask(tokenizer, dictionary, unique_d, neg[0], neg[1], False)
239
+ if neg_masked is None:
240
+ continue
241
+
242
+ for j in range(len(neg_masked["input_ids"])):
243
+ input_ids_lst.append(neg_masked["input_ids"][j])
244
+ attention_mask_lst.append(neg_masked["attention_mask"][j])
245
+ mlm_labels_lst.extend(neg_masked["mlm_labels"])
246
+ masked_indices_lst.append(neg_masked["masked_indices"])
247
+ tags_lst.extend(neg_masked["tags"])
248
+ codes_lst.extend(neg_masked["codes"])
249
+ n_code = neg_masked["codes"][0]
250
+ if n_code == ap_code:
251
+ an_score = num_ancestors(all_d, n_code)
252
+ else:
253
+ an_score = get_score(pairs, ap_code, n_code)
254
+ scores_lst.append(an_score)
255
+
256
+ padded_input_ids = pad_sequence(input_ids_lst, padding_value=0)
257
+ padded_input_ids = torch.transpose(padded_input_ids, 0, 1)
258
+
259
+ padded_attention_mask = pad_sequence(attention_mask_lst, padding_value=0)
260
+ padded_attention_mask = torch.transpose(padded_attention_mask, 0, 1)
261
+
262
+ padded_mlm_labels = pad_sequence(mlm_labels_lst, padding_value=-100)
263
+ padded_mlm_labels = torch.transpose(padded_mlm_labels, 0, 1)
264
+
265
+ return {
266
+ "input_ids": padded_input_ids,
267
+ "attention_mask": padded_attention_mask,
268
+ "mlm_labels": padded_mlm_labels,
269
+ "masked_indices": masked_indices_lst,
270
+ "tags": tags_lst,
271
+ "codes": codes_lst,
272
+ "scores": scores_lst,
273
+ }
274
+
275
+
276
+ def create_dataloader(dataset, tokenizer, dictionary, all_d, pairs, shuffle):
277
+ return DataLoader(
278
+ dataset,
279
+ batch_size=config.batch_size,
280
+ shuffle=shuffle,
281
+ num_workers=1,
282
+ collate_fn=lambda batch: collate_func(
283
+ batch, tokenizer, dictionary, all_d, pairs
284
+ ),
285
+ )
286
+
287
+
288
+ class CLDataModule(pl.LightningDataModule):
289
+ def __init__(self, train_df, val_df, tokenizer, dictionary, all_d, pairs):
290
+ super().__init__()
291
+ self.train_df = train_df
292
+ self.val_df = val_df
293
+ self.tokenizer = tokenizer
294
+ self.dictionary = dictionary
295
+ self.all_d = all_d
296
+ self.pairs = pairs
297
+
298
+ def setup(self, stage=None):
299
+ self.train_dataset = CLDataset(self.train_df)
300
+ self.val_dataset = CLDataset(self.val_df)
301
+
302
+ def train_dataloader(self):
303
+ return create_dataloader(
304
+ self.train_dataset,
305
+ self.tokenizer,
306
+ self.dictionary,
307
+ self.all_d,
308
+ self.pairs,
309
+ shuffle=True,
310
+ )
311
+
312
+ def val_dataloader(self):
313
+ return create_dataloader(
314
+ self.val_dataset,
315
+ self.tokenizer,
316
+ self.dictionary,
317
+ self.all_d,
318
+ self.pairs,
319
+ shuffle=False,
320
+ )
321
+
322
+
323
+ if __name__ == "__main__":
324
+ query_df = pd.read_csv(
325
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv"
326
+ )
327
+ query_df["concepts"] = query_df["concepts"].apply(literal_eval)
328
+ query_df["codes"] = query_df["codes"].apply(literal_eval)
329
+ query_df["codes"] = query_df["codes"].apply(
330
+ lambda x: [val for val in x if val is not None]
331
+ )
332
+ train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
333
+
334
+ all_d = pd.read_csv(
335
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv"
336
+ )
337
+ all_d.drop(columns=["finding_sites", "morphology"], inplace=True)
338
+ all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
339
+ all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
340
+ dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
341
+
342
+ pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv")
343
+
344
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
345
+
346
+ d = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs)
347
+ d.setup()
348
+ train = d.train_dataloader()
349
+ for batch in train:
350
+ b = batch
351
+ break
loss.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import config
5
+
6
+
7
+ class PCCL(nn.Module):
8
+ """
9
+ Pair-wise Cost-sensitive Contrastive Loss.
10
+
11
+ feature_matrix: (B,F)
12
+ label: (B,)
13
+ """
14
+
15
+ def __init__(self):
16
+ super(PCCL, self).__init__()
17
+ self.t_p = config.t_p # positive temperature
18
+ self.zeta = config.zeta # temperature ratio
19
+ self.t_n = config.t_p * self.zeta # negative temperature
20
+ # self.m = config.m # fixed margin
21
+
22
+ def forward(self, feature_matrix, label, score):
23
+ feature_matrix_normalized = F.normalize(feature_matrix, p=2, dim=1)
24
+ anchor = feature_matrix_normalized[0 : label.index(1)]
25
+ positives = feature_matrix_normalized[label.index(1) : label.index(2)]
26
+ pos_cardinal = positives.shape[0]
27
+ negatives = feature_matrix_normalized[label.index(2) :]
28
+ min_score = min(score)
29
+ max_score = max(score)
30
+ normalized_score = [((x - min_score) / (max_score - min_score)) for x in score]
31
+ pos_scores = torch.tensor(normalized_score[label.index(1) : label.index(2)])
32
+ neg_scores = torch.tensor(normalized_score[label.index(2) :])
33
+
34
+ # within-class similarity
35
+ s_i_p = F.cosine_similarity(positives, anchor, dim=1)
36
+ # between-class similarity
37
+ s_i_n = F.cosine_similarity(negatives, anchor, dim=1)
38
+
39
+ pos_scores = pos_scores.to(s_i_p.device)
40
+ neg_scores = neg_scores.to(s_i_n.device)
41
+
42
+ # pair-wise relaxation factors
43
+ alpha_i_p = 1 + torch.max(torch.zeros_like(s_i_p), (pos_scores - s_i_p))
44
+ alpha_i_n = 1 + torch.max(torch.zeros_like(s_i_n), (neg_scores + s_i_n))
45
+
46
+ # normalization factor
47
+ z = torch.sum(torch.exp(torch.div(alpha_i_p, self.t_p) * s_i_p)) + torch.sum(
48
+ torch.exp(torch.div(alpha_i_n, self.t_n) * s_i_n)
49
+ )
50
+
51
+ # loss
52
+ loss = torch.sum(
53
+ torch.log(torch.div(torch.exp(alpha_i_p * torch.div(s_i_p, self.t_p)), z))
54
+ )
55
+
56
+ scale = -1 / pos_cardinal
57
+
58
+ return scale * loss
model.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning.pytorch as pl
2
+ from transformers import (
3
+ AdamW,
4
+ AutoModel,
5
+ AutoConfig,
6
+ get_linear_schedule_with_warmup,
7
+ )
8
+ from transformers.models.bert.modeling_bert import BertLMPredictionHead
9
+ import torch
10
+ from torch import nn
11
+ from loss import PCCL
12
+ import config
13
+
14
+
15
+ class CL_model(pl.LightningModule):
16
+ def __init__(
17
+ self, n_batches=None, n_epochs=None, lr=None, mlm_weight=None, **kwargs
18
+ ):
19
+ super().__init__()
20
+
21
+ ## Params
22
+ self.n_batches = n_batches
23
+ self.n_epochs = n_epochs
24
+ self.lr = lr
25
+ self.mlm_weight = mlm_weight
26
+ self.config = AutoConfig.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
27
+
28
+ ## Encoder
29
+ self.bert = AutoModel.from_pretrained(
30
+ "emilyalsentzer/Bio_ClinicalBERT", return_dict=True
31
+ )
32
+ # Unfreeze layers
33
+ self.bert_layer_num = sum(1 for _ in self.bert.named_parameters())
34
+ self.num_unfreeze_layer = self.bert_layer_num
35
+ self.ratio_unfreeze_layer = 0.0
36
+ if kwargs:
37
+ for key, value in kwargs.items():
38
+ if key == "unfreeze" and isinstance(value, float):
39
+ assert (
40
+ value >= 0.0 and value <= 1.0
41
+ ), "ValueError: value must be a ratio between 0.0 and 1.0"
42
+ self.ratio_unfreeze_layer = value
43
+ if self.ratio_unfreeze_layer > 0.0:
44
+ self.num_unfreeze_layer = int(
45
+ self.bert_layer_num * self.ratio_unfreeze_layer
46
+ )
47
+ for param in list(self.bert.parameters())[: -self.num_unfreeze_layer]:
48
+ param.requires_grad = False
49
+
50
+ self.lm_head = BertLMPredictionHead(self.config)
51
+ # self.projector = nn.Linear(self.bert.config.hidden_size, 128)
52
+ print("Model Initialized!")
53
+
54
+ ## Losses
55
+ self.cl_loss = PCCL()
56
+ self.mlm_loss = nn.CrossEntropyLoss()
57
+
58
+ ## Logs
59
+ self.num_batches = 0
60
+ self.train_loss, self.val_loss = 0, 0
61
+ self.train_loss_cl, self.val_loss_cl = 0, 0
62
+ self.train_loss_mlm, self.val_loss_mlm = 0, 0
63
+ self.training_step_outputs, self.validation_step_outputs = [], []
64
+
65
+ def forward(self, input_ids, attention_mask, masked_indices, eval=False):
66
+ embs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
67
+ cls_tokens = embs.pooler_output
68
+ mask_tokens = []
69
+ for idx, value in enumerate(masked_indices):
70
+ masks = embs.last_hidden_state[idx][value]
71
+ avg_mask = torch.mean(masks, dim=0)
72
+ mask_tokens.append(avg_mask)
73
+ mask_tokens = torch.stack(mask_tokens)
74
+ cls_concat_mask = torch.cat((cls_tokens, mask_tokens), dim=1)
75
+ if eval is True:
76
+ return cls_tokens, mask_tokens, cls_concat_mask
77
+
78
+ mlm_pred = self.lm_head(embs.last_hidden_state)
79
+ mlm_pred = mlm_pred.view(-1, self.config.vocab_size)
80
+ return cls_concat_mask, mlm_pred
81
+
82
+ def training_step(self, batch, batch_idx):
83
+ input_ids = batch["input_ids"]
84
+ attention_mask = batch["attention_mask"]
85
+ mlm_labels = batch["mlm_labels"]
86
+ masked_indices = batch["masked_indices"]
87
+ tags = batch["tags"]
88
+ scores = batch["scores"]
89
+ cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices)
90
+ loss_cl = self.cl_loss(cls_concat_mask, tags, scores)
91
+ loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1))
92
+ loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm
93
+ logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm}
94
+ self.training_step_outputs.append(logs)
95
+ self.log("train_loss", loss, prog_bar=True, logger=True, sync_dist=True)
96
+
97
+ self.num_batches += 1
98
+ self.train_loss_cl += loss_cl
99
+ self.train_loss_mlm += loss_mlm
100
+ self.train_loss += loss
101
+
102
+ if self.num_batches % config.log_every_n_steps == 0:
103
+ avg_loss_cl = self.train_loss_cl / self.num_batches
104
+ avg_loss_mlm = self.train_loss_mlm / self.num_batches
105
+ avg_loss = self.train_loss / self.num_batches
106
+ self.log(
107
+ "train_avg_cl_loss",
108
+ avg_loss_cl,
109
+ prog_bar=True,
110
+ logger=True,
111
+ sync_dist=True,
112
+ )
113
+ self.log(
114
+ "train_avg_mlm_loss",
115
+ avg_loss_mlm,
116
+ prog_bar=True,
117
+ logger=True,
118
+ sync_dist=True,
119
+ )
120
+ self.log(
121
+ "train_avg_loss", avg_loss, prog_bar=True, logger=True, sync_dist=True
122
+ )
123
+ self.train_loss_cl = 0
124
+ self.train_loss_mlm = 0
125
+ self.train_loss = 0
126
+ self.num_batches = 0
127
+
128
+ return loss
129
+
130
+ def on_train_epoch_end(self):
131
+ e_t_avg_loss = (
132
+ torch.stack([x["loss"] for x in self.training_step_outputs])
133
+ .mean()
134
+ .detach()
135
+ .cpu()
136
+ .numpy()
137
+ )
138
+ self.log(
139
+ "avg_loss_train_epoch",
140
+ e_t_avg_loss.item(),
141
+ on_step=False,
142
+ on_epoch=True,
143
+ sync_dist=True,
144
+ )
145
+ e_t_avg_loss_cl = (
146
+ torch.stack([x["loss_cl"] for x in self.training_step_outputs])
147
+ .mean()
148
+ .detach()
149
+ .cpu()
150
+ .numpy()
151
+ )
152
+ self.log(
153
+ "avg_loss_cl_train_epoch",
154
+ e_t_avg_loss_cl.item(),
155
+ on_step=False,
156
+ on_epoch=True,
157
+ sync_dist=True,
158
+ )
159
+ e_t_avg_loss_mlm = (
160
+ torch.stack([x["loss_mlm"] for x in self.training_step_outputs])
161
+ .mean()
162
+ .detach()
163
+ .cpu()
164
+ .numpy()
165
+ )
166
+ self.log(
167
+ "avg_loss_mlm_train_epoch",
168
+ e_t_avg_loss_mlm.item(),
169
+ on_step=False,
170
+ on_epoch=True,
171
+ sync_dist=True,
172
+ )
173
+ print(
174
+ "train_epoch:",
175
+ self.current_epoch,
176
+ "avg_loss:",
177
+ e_t_avg_loss,
178
+ "avg_cl_loss:",
179
+ e_t_avg_loss_cl,
180
+ "avg_mlm_loss:",
181
+ e_t_avg_loss_mlm,
182
+ )
183
+ self.training_step_outputs.clear()
184
+
185
+ def validation_step(self, batch, batch_idx):
186
+ input_ids = batch["input_ids"]
187
+ attention_mask = batch["attention_mask"]
188
+ mlm_labels = batch["mlm_labels"]
189
+ masked_indices = batch["masked_indices"]
190
+ tags = batch["tags"]
191
+ scores = batch["scores"]
192
+ cls_concat_mask, mlm_pred = self(input_ids, attention_mask, masked_indices)
193
+ loss_cl = self.cl_loss(cls_concat_mask, tags, scores)
194
+ loss_mlm = self.mlm_loss(mlm_pred, mlm_labels.reshape(-1))
195
+ loss = (1 - self.mlm_weight) * loss_cl + self.mlm_weight * loss_mlm
196
+ logs = {"loss": loss, "loss_cl": loss_cl, "loss_mlm": loss_mlm}
197
+ self.validation_step_outputs.append(logs)
198
+ self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True)
199
+
200
+ self.num_batches += 1
201
+ self.val_loss_cl += loss_cl
202
+ self.val_loss_mlm += loss_mlm
203
+ self.val_loss += loss
204
+
205
+ if self.num_batches % config.log_every_n_steps == 0:
206
+ avg_loss_cl = self.val_loss_cl / self.num_batches
207
+ avg_loss_mlm = self.val_loss_mlm / self.num_batches
208
+ avg_loss = self.val_loss / self.num_batches
209
+ self.log(
210
+ "val_avg_cl_loss",
211
+ avg_loss_cl,
212
+ prog_bar=True,
213
+ logger=True,
214
+ sync_dist=True,
215
+ )
216
+ self.log(
217
+ "val_avg_mlm_loss",
218
+ avg_loss_mlm,
219
+ prog_bar=True,
220
+ logger=True,
221
+ sync_dist=True,
222
+ )
223
+ self.log(
224
+ "val_avg_loss",
225
+ avg_loss,
226
+ prog_bar=True,
227
+ logger=True,
228
+ sync_dist=True,
229
+ )
230
+ self.val_loss_cl = 0
231
+ self.val_loss_mlm = 0
232
+ self.val_loss = 0
233
+ self.num_batches = 0
234
+
235
+ return loss
236
+
237
+ def on_validation_epoch_end(self):
238
+ e_v_avg_loss = (
239
+ torch.stack([x["loss"] for x in self.validation_step_outputs])
240
+ .mean()
241
+ .detach()
242
+ .cpu()
243
+ .numpy()
244
+ )
245
+ self.log(
246
+ "avg_loss_val_epoch",
247
+ e_v_avg_loss.item(),
248
+ on_step=False,
249
+ on_epoch=True,
250
+ sync_dist=True,
251
+ )
252
+ e_v_avg_loss_cl = (
253
+ torch.stack([x["loss_cl"] for x in self.validation_step_outputs])
254
+ .mean()
255
+ .detach()
256
+ .cpu()
257
+ .numpy()
258
+ )
259
+ self.log(
260
+ "avg_loss_cl_val_epoch",
261
+ e_v_avg_loss_cl.item(),
262
+ on_step=False,
263
+ on_epoch=True,
264
+ sync_dist=True,
265
+ )
266
+ e_v_avg_loss_mlm = (
267
+ torch.stack([x["loss_mlm"] for x in self.validation_step_outputs])
268
+ .mean()
269
+ .detach()
270
+ .cpu()
271
+ .numpy()
272
+ )
273
+ self.log(
274
+ "avg_loss_mlm_val_epoch",
275
+ e_v_avg_loss_mlm.item(),
276
+ on_step=False,
277
+ on_epoch=True,
278
+ sync_dist=True,
279
+ )
280
+ print(
281
+ "val_epoch:",
282
+ self.current_epoch,
283
+ "avg_loss:",
284
+ e_v_avg_loss,
285
+ "avg_cl_loss:",
286
+ e_v_avg_loss_cl,
287
+ "avg_mlm_loss:",
288
+ e_v_avg_loss_mlm,
289
+ )
290
+ self.validation_step_outputs.clear()
291
+
292
+ def configure_optimizers(self):
293
+ # Optimizer
294
+ self.trainable_params = [
295
+ param for param in self.parameters() if param.requires_grad
296
+ ]
297
+ optimizer = AdamW(self.trainable_params, lr=self.lr)
298
+
299
+ # Scheduler
300
+ warmup_steps = self.n_batches // 3
301
+ total_steps = self.n_batches * self.n_epochs - warmup_steps
302
+ scheduler = get_linear_schedule_with_warmup(
303
+ optimizer, warmup_steps, total_steps
304
+ )
305
+ return [optimizer], [scheduler]
train.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning.pytorch import seed_everything
2
+ from lightning.pytorch.callbacks import ModelCheckpoint
3
+ from lightning.pytorch.callbacks.early_stopping import EarlyStopping
4
+ import lightning.pytorch as pl
5
+ import pandas as pd
6
+ from sklearn.model_selection import train_test_split
7
+ from transformers import AutoTokenizer
8
+ from ast import literal_eval
9
+ from pytorch_lightning.loggers import TensorBoardLogger
10
+
11
+ # imports from our own modules
12
+ import config
13
+ from model import CL_model
14
+ from dataset import CLDataModule
15
+
16
+ if __name__ == "__main__":
17
+ seed_everything(0, workers=True)
18
+ logger = TensorBoardLogger(
19
+ "/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/prompt/logs", name="CL"
20
+ )
21
+
22
+ query_df = pd.read_csv(
23
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_df.csv"
24
+ )
25
+ query_df["concepts"] = query_df["concepts"].apply(literal_eval)
26
+ query_df["codes"] = query_df["codes"].apply(literal_eval)
27
+ query_df["codes"] = query_df["codes"].apply(
28
+ lambda x: [val for val in x if val is not None]
29
+ )
30
+ train_df, val_df = train_test_split(query_df, test_size=config.split_ratio)
31
+
32
+ all_d = pd.read_csv(
33
+ "/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/query_all_d.csv"
34
+ )
35
+ all_d.drop(columns=["finding_sites", "morphology"], inplace=True)
36
+ all_d["synonyms"] = all_d["synonyms"].apply(literal_eval)
37
+ all_d["ancestors"] = all_d["ancestors"].apply(literal_eval)
38
+ dictionary = dict(zip(all_d["concept"], all_d["synonyms"]))
39
+
40
+ pairs = pd.read_csv("/data/aiiih/projects/sunx/ccf_fuzzy_diag/data_proc/pairs.csv")
41
+
42
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
43
+
44
+ data_module = CLDataModule(train_df, val_df, tokenizer, dictionary, all_d, pairs)
45
+ data_module.setup()
46
+
47
+ model = CL_model(
48
+ n_batches=len(data_module.train_dataset) / config.batch_size,
49
+ n_epochs=config.max_epochs,
50
+ lr=config.learning_rate,
51
+ mlm_weight=config.mlm_weight,
52
+ unfreeze=config.unfreeze_ratio,
53
+ )
54
+
55
+ checkpoint = ModelCheckpoint(
56
+ dirpath="/home/sunx/data/aiiih/projects/sunx/ccf_fuzzy_diag/train/ckpt/v2",
57
+ filename="{epoch}-{step}",
58
+ save_weights_only=True,
59
+ save_last=True,
60
+ every_n_train_steps=config.ckcpt_every_n_steps,
61
+ monitor=None,
62
+ save_top_k=-1,
63
+ )
64
+
65
+ trainer = pl.Trainer(
66
+ accelerator=config.accelerator,
67
+ devices=config.devices,
68
+ strategy="ddp",
69
+ logger=logger,
70
+ max_epochs=config.max_epochs,
71
+ min_epochs=config.min_epochs,
72
+ precision=config.precision,
73
+ callbacks=[
74
+ EarlyStopping(monitor="val_loss", min_delta=1e-3, patience=2, mode="min"),
75
+ checkpoint,
76
+ ],
77
+ profiler="simple",
78
+ log_every_n_steps=config.log_every_n_steps,
79
+ )
80
+
81
+ trainer.fit(model, data_module)