JFLa commited on
Commit
2c73d36
·
verified ·
1 Parent(s): 8aaa098

Upload 2 files

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. GF_CAB.py +237 -0
  3. Graphic_Abstract.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Graphic_Abstract.png filter=lfs diff=lfs merge=lfs -text
GF_CAB.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from datasets import load_from_disk
3
+ import torch
4
+ from transformers import BertForMaskedLM
5
+ import os
6
+ import sys
7
+ from tqdm.notebook import tqdm
8
+ import seaborn as sns
9
+ import matplotlib.pyplot as plt
10
+ # sys.path.append('/Users/chenj0i/Desktop/Lab Work/Geneformer')
11
+ from geneformer.pretrainer import token_dictionary
12
+ import datetime
13
+ import time
14
+ import pickle
15
+ import random
16
+ import subprocess
17
+ import numpy as np
18
+ import pytz
19
+ import torch
20
+ from datasets import load_from_disk, Dataset
21
+ from transformers import BertConfig, BertForMaskedLM, TrainingArguments, TrainerCallback, Trainer, BertModel, BertPreTrainedModel
22
+ from geneformer import GeneformerPretrainer
23
+ from typing import Tuple
24
+ from torch import Tensor
25
+ from transformers.modeling_outputs import MaskedLMOutput
26
+ from transformers.models.bert.modeling_bert import BertLMPredictionHead, BertOnlyMLMHead, BertPredictionHeadTransform
27
+ from transformers.activations import ACT2FN
28
+ from typing import List, Optional, Tuple, Union
29
+ import torch.nn.functional as F
30
+
31
+ class CustomBertForMaskedLM(BertPreTrainedModel):
32
+ _keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
33
+ _tied_weights_keys = ["decoder.weight", "bert.embeddings.word_embeddings.weight"]
34
+
35
+ def __init__(self, config):
36
+ super().__init__(config)
37
+ self.bert = BertModel(config, add_pooling_layer=False)
38
+ self.transform = BertPredictionHeadTransform(config)
39
+
40
+ self.decoder = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
41
+
42
+ self.bias = torch.nn.Parameter(torch.zeros(config.vocab_size))
43
+
44
+ # Initialize weights
45
+ self.init_weights()
46
+
47
+ # Tie weights automatically
48
+ self.tie_weights()
49
+
50
+ # self.post_init()
51
+
52
+ def tie_weights(self):
53
+ """
54
+ Ties the weights between the input embeddings and output decoder weights.
55
+ """
56
+ self.decoder.weight = self.bert.embeddings.word_embeddings.weight
57
+
58
+ def probability_convert(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
59
+ device = probs.device
60
+ batch_size, seq_length, vocab_size = probs.size()
61
+ _, input_seq_length = input_ids.size()
62
+
63
+ # truncated_labels = labels[:, :input_seq_length]
64
+ # non_mask = truncated_labels == -100
65
+ non_mask = labels == -100
66
+ non_mask_indices = non_mask.nonzero(as_tuple=True)
67
+ known_gene_indices = input_ids[non_mask]
68
+
69
+ # Generate (1-p) matrix whiel assigning all known genes in the beginning
70
+ zeros = torch.zeros((batch_size, 1, vocab_size), device=device)
71
+ zeros[non_mask_indices[0], 0, known_gene_indices] = 1.0
72
+ probs_shifted = torch.cat((zeros, probs[:, :-1, :]), dim=1)
73
+ inv_probs_shifted = 1 - probs_shifted
74
+
75
+ # Cumulative product to get (1-p_1)*(1-p_2)*...*(p_i)
76
+ cumprod_inv_probs = torch.cumprod(inv_probs_shifted, dim=1)
77
+ modified_probs = probs * cumprod_inv_probs
78
+
79
+ # # Since we are assigning probabilities for already known genes,
80
+ # # (1-p_1)*(1-p_2)*...*(p_i) for these genes can result in 0, due to hard assignment of probs to be 1
81
+ # # Add 1e-18 to avoid dividing modified probs by 0
82
+ # # During dubugging stage, some issues occurred in the normalization step.
83
+ # # Since probabilities in each position do not necessarily need to sum up to one, leave out normalization.
84
+ normalized_probs = modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18)
85
+ modified_probs = modified_probs / normalized_probs # Normalization after cumulative production
86
+
87
+ return modified_probs
88
+
89
+ def assign_known_gene_probs(self, probs: Tensor, input_ids: Tensor, labels: Tensor) -> Tensor:
90
+
91
+ device = probs.device
92
+ batch_size, seq_length, vocab_size = probs.size()
93
+ _, input_seq_length = input_ids.size()
94
+
95
+ # Truncate `labels` to match the length of `input_ids` along the sequence dimension
96
+ truncated_labels = labels[:, :input_seq_length]
97
+
98
+ non_mask = truncated_labels == -100
99
+ non_mask_indices = non_mask.nonzero(as_tuple=True)
100
+
101
+ ones = torch.ones((batch_size, seq_length, vocab_size), device=device)
102
+ zeros = torch.zeros((batch_size, seq_length, vocab_size), device=device)
103
+
104
+ known_gene_indices = input_ids[non_mask]
105
+
106
+ ones[non_mask_indices[0], non_mask_indices[1], :] = 0.0
107
+ zeros[non_mask_indices[0], non_mask_indices[1], known_gene_indices] = 1.0
108
+
109
+ # Modify already known genes' probabilities using the one-hot tensor
110
+ modified_probs = probs * ones
111
+ modified_probs = modified_probs + zeros
112
+
113
+ # Do the normalization
114
+ modified_probs = modified_probs / modified_probs.sum(dim=-1, keepdim=True).clamp(min=1e-18) # Normalize
115
+
116
+ return modified_probs
117
+
118
+ def compute_similarity_on_probs(self, probs: Tensor) -> Tensor:
119
+ """
120
+ Optimized computation of average cosine similarity across all positions in each sequence and batch.
121
+
122
+ Args:
123
+ probs (torch.Tensor): Probability tensor of shape (batch_size, seq_length, vocab_size).
124
+
125
+ Returns:
126
+ torch.Tensor: Average similarity term for loss computation.
127
+ """
128
+ batch_size, seq_length, vocab_size = probs.size()
129
+
130
+ # Normalize along the vocab_size dimension
131
+ probs_norm = F.normalize(probs, dim=-1) # Shape: (batch_size, seq_length, vocab_size)
132
+
133
+ # Compute pairwise cosine similarity using einsum
134
+ similarities = torch.einsum("biv,bjv->bij", probs_norm, probs_norm) # Shape: (batch_size, seq_length, seq_length), listing pair-wise similarity values across all positions
135
+
136
+ # Mask out lower triangle (to consider only i < j pairs)
137
+ mask_sim = torch.triu(torch.ones(seq_length, seq_length, device=probs.device), diagonal=1)
138
+ valid_similarities = similarities * mask_sim # Shape: (batch_size, seq_length, seq_length)
139
+
140
+ # Compute average similarity
141
+ total_similarity = valid_similarities.sum()
142
+ total_comparisons = mask_sim.sum().item() * batch_size
143
+
144
+ return total_similarity / total_comparisons
145
+
146
+
147
+ def forward(
148
+ self,
149
+ input_ids: Tensor | None = None,
150
+ attention_mask: Tensor | None = None,
151
+ token_type_ids: Tensor | None = None,
152
+ position_ids: Tensor | None = None,
153
+ head_mask: Tensor | None = None,
154
+ inputs_embeds: Tensor | None = None,
155
+ encoder_hidden_states: Tensor | None = None,
156
+ encoder_attention_mask: Tensor | None = None,
157
+ labels: Tensor | None = None,
158
+ output_attentions: bool | None = None,
159
+ output_hidden_states: bool | None = None,
160
+ return_dict: bool | None = None) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
161
+
162
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
163
+
164
+ outputs = self.bert(
165
+ input_ids,
166
+ attention_mask=attention_mask,
167
+ token_type_ids=token_type_ids,
168
+ position_ids=position_ids,
169
+ head_mask=head_mask,
170
+ inputs_embeds=inputs_embeds,
171
+ output_attentions=output_attentions,
172
+ output_hidden_states=output_hidden_states,
173
+ return_dict=return_dict,
174
+ )
175
+
176
+ hidden_states = outputs[0]
177
+ hidden_transform = self.transform(hidden_states)
178
+ logits = self.decoder(hidden_transform) + self.bias
179
+
180
+ # temperature = 0.75
181
+ # logits = logits / temperature
182
+
183
+ probs = F.softmax(logits, dim=-1)
184
+
185
+ # Probability manipulations to avoid repeats from already known genes
186
+ ### Modified part below
187
+ # print(probs.shape)
188
+ probs = self.assign_known_gene_probs(probs, input_ids, labels)
189
+ convert_probs = self.probability_convert(probs, input_ids, labels)
190
+ assigned_probs = self.assign_known_gene_probs(convert_probs, input_ids, labels)
191
+
192
+ masked_lm_loss = None
193
+ if labels is not None:
194
+ # probs_flat = assigned_probs.view(-1, self.config.vocab_size) ### Modified
195
+ probs_flat = probs.view(-1, self.config.vocab_size)
196
+ labels_flat = labels.view(-1)
197
+ mask = (labels != -100).float().view(-1)
198
+
199
+ # Compute masked cross-entropy loss
200
+ masked_lm_loss = -torch.log(torch.clamp(probs_flat[torch.arange(len(labels_flat)), labels_flat], min=1e-18)) * mask
201
+ masked_lm_loss = masked_lm_loss.sum() / mask.sum()
202
+
203
+ similarity_loss = self.compute_similarity_on_probs(assigned_probs)
204
+ lambda_similarity = 200.0 # Adjust this value through experimentation
205
+ masked_lm_loss = masked_lm_loss + lambda_similarity * similarity_loss
206
+
207
+
208
+ else:
209
+ loss = None
210
+
211
+ if not return_dict:
212
+ output = (assigned_probs,) + outputs[2:]
213
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
214
+
215
+ return MaskedLMOutput(
216
+ loss=masked_lm_loss,
217
+ # logits=assigned_probs,
218
+ logits=probs,
219
+ hidden_states=outputs.hidden_states,
220
+ attentions=outputs.attentions,
221
+ )
222
+
223
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
224
+ input_shape = input_ids.shape
225
+ effective_batch_size = input_shape[0]
226
+
227
+ # add a dummy token
228
+ if self.config.pad_token_id is None:
229
+ raise ValueError("The PAD token should be defined for generation")
230
+
231
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
232
+ dummy_token = torch.full(
233
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
234
+ )
235
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
236
+
237
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
Graphic_Abstract.png ADDED

Git LFS Details

  • SHA256: b0250c1358e2325acb458c79a0d19718c05a618db17cae5a87346c470d7902e5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB