row56 commited on
Commit
528fe74
·
verified ·
1 Parent(s): f4b2cc5

Upload proto_model/utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. proto_model/utils.py +279 -0
proto_model/utils.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Optional, Union, Iterable, List
5
+
6
+ import matplotlib
7
+ import numpy as np
8
+ import torch
9
+ from pytorch_lightning.callbacks import ModelCheckpoint
10
+ import shutil
11
+
12
+
13
+ def freeze_model_weights(model: torch.nn.Module) -> None:
14
+ for param in model.parameters():
15
+ param.requires_grad = False
16
+
17
+
18
+ # def init_attention_from_tf_idf(batch, tf_idf, vectorizer, token_vectors,):
19
+ # features = vectorizer.get_feature_names()
20
+ #
21
+ # all_relevant_tokens = []
22
+ # for j, sample in enumerate(batch["tokens"]):
23
+ #
24
+ # global_sample_ind = train_dataloader.dataset.data.id.tolist().index(batch["sample_ids"][j])
25
+ # tf_idf_sample = tf_idf[global_sample_ind]
26
+ # relevant_tokens_sample = []
27
+ # for k in range(batch["input_ids"].shape[1]):
28
+ # if k < len(sample):
29
+ # token = sample[k]
30
+ # if token in features:
31
+ # token_ind = features.index(token)
32
+ # if token_ind in tf_idf_sample.indices:
33
+ # tf_idf_ind = np.where(tf_idf_sample.indices == token_ind)[0][0]
34
+ # token_value = tf_idf_sample.data[tf_idf_ind]
35
+ # if token_value > 0.05:
36
+ # relevant_tokens_sample.append(1)
37
+ # continue
38
+ # relevant_tokens_sample.append(0)
39
+ # all_relevant_tokens.append(relevant_tokens_sample)
40
+ #
41
+ # all_relevant_tokens = torch.tensor(all_relevant_tokens)
42
+ # if self.use_cuda:
43
+ # all_relevant_tokens = all_relevant_tokens.cuda()
44
+ #
45
+ # relevant_tokens = torch.einsum('ik,ikl->ikl', all_relevant_tokens, token_vectors)
46
+ #
47
+ # mean_over_relevant_tokens = relevant_tokens.mean(dim=1)
48
+ #
49
+ # # get tensor of shape batch_size x num_classes x dim
50
+ # masked_att_vectors_per_sample = torch.einsum('ik,il->ilk', mean_over_relevant_tokens,
51
+ # target_tensors)
52
+ #
53
+ # # sum into one vector per prototype. shape: num_classes x dim
54
+ # sum_att_per_prototype = torch.add(sum_att_per_prototype, masked_att_vectors_per_sample.sum(dim=0)
55
+ # .detach())
56
+ #
57
+ # n_att_per_prototype += target_tensors.sum(dim=0).detach()
58
+
59
+ def attention_mask_from_tokens(masks, token_list):
60
+ mask_patterns = [["chief", "complaint", ":"],
61
+ ["present", "illness", ":"],
62
+ ["medical", "history", ":"],
63
+ ["medication", "on", "admission", ":"],
64
+ ["allergies", ":"],
65
+ ["physical", "exam", ":"],
66
+ ["family", "history", ":"],
67
+ ["social", "history", ":"],
68
+ ["[CLS]"],
69
+ ["[SEP]"],
70
+ ]
71
+
72
+ for i, tokens in enumerate(token_list):
73
+ for j, token in enumerate(tokens):
74
+ for pattern in mask_patterns:
75
+ if pattern == tokens[j:j + len(pattern)]:
76
+ masks[i, j:j + len(pattern)] = 0
77
+
78
+ return masks
79
+
80
+
81
+ def get_bert_vectors_per_sample(batch, bert, use_cuda, linear=None):
82
+ input_ids = batch["input_ids"]
83
+ attention_mask = batch["attention_masks"]
84
+ token_type_ids = batch["token_type_ids"]
85
+
86
+ if use_cuda:
87
+ input_ids = input_ids.cuda()
88
+ attention_mask = attention_mask.cuda()
89
+ token_type_ids = token_type_ids.cuda()
90
+
91
+ output = bert(input_ids=input_ids,
92
+ attention_mask=attention_mask,
93
+ token_type_ids=token_type_ids)
94
+
95
+ if linear is not None:
96
+ if use_cuda:
97
+ linear = linear.cuda()
98
+ token_vectors = linear(output.last_hidden_state)
99
+ else:
100
+ token_vectors = output.last_hidden_state
101
+
102
+ mean_over_tokens = token_vectors.mean(dim=1)
103
+
104
+ return mean_over_tokens, token_vectors
105
+
106
+
107
+ def get_attended_vector_per_sample(batch, bert, use_cuda, linear=None):
108
+ input_ids = batch["input_ids"]
109
+ attention_mask = batch["attention_masks"]
110
+ token_type_ids = batch["token_type_ids"]
111
+
112
+ if use_cuda:
113
+ input_ids = input_ids.cuda()
114
+ attention_mask = attention_mask.cuda()
115
+ token_type_ids = token_type_ids.cuda()
116
+
117
+ output = bert(input_ids=input_ids,
118
+ attention_mask=attention_mask,
119
+ token_type_ids=token_type_ids)
120
+
121
+ if linear is not None:
122
+ if use_cuda:
123
+ linear = linear.cuda()
124
+ token_vectors = linear(output.last_hidden_state)
125
+ else:
126
+ token_vectors = output.last_hidden_state
127
+
128
+ mean_over_tokens = token_vectors.mean(dim=1)
129
+
130
+ return mean_over_tokens, token_vectors
131
+
132
+
133
+ def pad_batch_samples(batch_samples: Iterable, num_tokens: int) -> List:
134
+ padded_samples = []
135
+ for sample in batch_samples:
136
+ missing_tokens = num_tokens - len(sample)
137
+ tokens_to_append = ["[PAD]"] * missing_tokens
138
+ padded_samples += sample + tokens_to_append
139
+ return padded_samples
140
+
141
+
142
+ class ProjectorCallback(ModelCheckpoint):
143
+ def __init__(
144
+ self,
145
+ train_dataloader,
146
+ project_n_batches=-1, # -1 means project all batches
147
+ dirpath: Optional[Union[str, Path]] = None,
148
+ filename: Optional[str] = None,
149
+ monitor: Optional[str] = None,
150
+ verbose: bool = False,
151
+ save_last: Optional[bool] = None,
152
+ save_top_k: Optional[int] = None,
153
+ save_weights_only: bool = False,
154
+ mode: str = "auto",
155
+ period: int = 1,
156
+ prefix: str = ""
157
+ ):
158
+ super().__init__(dirpath=dirpath, filename=filename, monitor=monitor, verbose=verbose, save_last=save_last,
159
+ save_top_k=save_top_k, save_weights_only=save_weights_only, mode=mode, period=period,
160
+ prefix=prefix)
161
+ self.train_dataloader = train_dataloader
162
+ self.project_n_batches = project_n_batches
163
+
164
+ def on_validation_end(self, trainer, pl_module):
165
+ """
166
+ After each validation step, save the learned token and prototype embeddings for analysis in the Projector.
167
+ """
168
+ super().on_validation_end(trainer, pl_module)
169
+
170
+ with torch.no_grad():
171
+
172
+ all_vectors = []
173
+ metadata = []
174
+ for i, batch in enumerate(self.train_dataloader):
175
+ _, _, batch_features = pl_module(batch, return_metadata=True)
176
+
177
+ targets = batch["targets"]
178
+
179
+ features = batch_features[0]
180
+ tokens = batch_features[1]
181
+ prototype_vectors = batch_features[2]
182
+
183
+ batch_size = features.shape[0]
184
+
185
+ window_len = features.shape[1]
186
+
187
+ for sample_i in range(batch_size):
188
+ for window_i in range(window_len):
189
+ window_vector = features[sample_i][window_i]
190
+ window_tokens = tokens[sample_i * window_len + window_i]
191
+
192
+ if window_tokens == "[PAD]" or window_tokens == "[SEP]":
193
+ continue
194
+
195
+ all_vectors.append(window_vector)
196
+ metadata.append([window_tokens, targets[sample_i]])
197
+
198
+ if ["PROTO_0", 0] not in metadata:
199
+ for j, vector in enumerate(prototype_vectors):
200
+ prototype_class = int(j // pl_module.prototypes_per_class)
201
+ all_vectors.append(vector.squeeze())
202
+ metadata.append([f"PROTO_{prototype_class}", prototype_class])
203
+
204
+ if self.project_n_batches != -1 and i >= self.project_n_batches - 1:
205
+ break
206
+
207
+ trainer.logger.experiment.add_embedding(torch.stack(all_vectors), metadata, global_step=trainer.global_step,
208
+ metadata_header=["tokens", "target"])
209
+
210
+ delete_intermediate_embeddings(trainer.logger.experiment.log_dir, trainer.global_step)
211
+
212
+
213
+ def delete_intermediate_embeddings(log_dir, current_step):
214
+ dir_content = os.listdir(log_dir)
215
+ for file_or_dir in dir_content:
216
+ try:
217
+ file_as_integer = int(file_or_dir)
218
+ abs_path = os.path.join(log_dir, file_or_dir)
219
+
220
+ if os.path.isdir(abs_path) and file_as_integer != current_step and file_as_integer != 0:
221
+ remove_dir(abs_path)
222
+
223
+ except:
224
+ continue
225
+
226
+ embedding_config = """embeddings {{
227
+ tensor_name: "default:{embedding_id}"
228
+ metadata_path: "{embedding_id}/default/metadata.tsv"
229
+ tensor_path: "{embedding_id}/default/tensors.tsv"\n}}"""
230
+
231
+ config_text = embedding_config.format(embedding_id="00000") + "\n" + \
232
+ embedding_config.format(embedding_id=f"{current_step:05}")
233
+
234
+ with open(os.path.join(log_dir, "projector_config.pbtxt"), "w") as config_file_write:
235
+ config_file_write.write(config_text)
236
+
237
+
238
+ def remove_dir(path):
239
+ try:
240
+ shutil.rmtree(path)
241
+ print(f"delete dir {path}")
242
+ except OSError as e:
243
+ print("Error: %s : %s" % (path, e.strerror))
244
+
245
+
246
+ def load_eval_buckets(eval_bucket_path):
247
+ buckets = None
248
+ if eval_bucket_path is not None:
249
+ with open(eval_bucket_path) as bucket_file:
250
+ buckets = json.load(bucket_file)
251
+ return buckets
252
+
253
+
254
+ def build_heatmaps(case_tokens, token_scores, tint="red", amplifier=8):
255
+ heatmap_per_prototype = []
256
+ for prototype_scores in token_scores:
257
+
258
+ template = '<span style="color: black; background-color: {}">{}</span>'
259
+ heatmap_string = ''
260
+ for word, color in zip(case_tokens, prototype_scores):
261
+ color = min(1, color * amplifier)
262
+ if tint == "red":
263
+ hex_color = matplotlib.colors.rgb2hex([1, 1 - color, 1 - color])
264
+ elif tint == "blue":
265
+ hex_color = matplotlib.colors.rgb2hex([1 - color, 1 - color, 1])
266
+ else:
267
+ hex_color = matplotlib.colors.rgb2hex([1 - color, 1, 1 - color])
268
+
269
+ if "##" not in word:
270
+ heatmap_string += '&nbsp'
271
+ word_string = word
272
+ else:
273
+ word_string = word.replace("##", "")
274
+
275
+ heatmap_string += template.format(hex_color, word_string)
276
+
277
+ heatmap_per_prototype.append(heatmap_string)
278
+
279
+ return heatmap_per_prototype