anicolson commited on
Commit
d7b1b50
·
verified ·
1 Parent(s): c2798c2

Update modelling_cxrmate2.py

Browse files
Files changed (1) hide show
  1. modelling_cxrmate2.py +1 -3
modelling_cxrmate2.py CHANGED
@@ -112,8 +112,6 @@ class CXRMate2ForConditionalGeneration(CXRMate2PreTrainedModel, GenerationMixin)
112
  hidden_size=config.text_config.hidden_size,
113
  )
114
 
115
- self.register_buffer('missing_time_delta_token_id', torch.tensor(self.config.missing_time_delta_token_id), persistent=False)
116
-
117
  self.post_init()
118
 
119
  def get_input_embeddings(self):
@@ -220,7 +218,7 @@ class CXRMate2ForConditionalGeneration(CXRMate2PreTrainedModel, GenerationMixin)
220
  missing_time_delta_mask = time_deltas.isnan()
221
  time_deltas = time_deltas.nan_to_num(0) # Replace NaN with dummy value before projection.
222
  time_delta_embeddings = self.time_delta_encoder(time_deltas.unsqueeze(-1))
223
- time_delta_embeddings[missing_time_delta_mask] = self.get_input_embeddings()(self.missing_time_delta_token_id)
224
  time_delta_embeddings *= time_deltas_mask.unsqueeze(-1)
225
  inputs_embeds += time_delta_embeddings
226
 
 
112
  hidden_size=config.text_config.hidden_size,
113
  )
114
 
 
 
115
  self.post_init()
116
 
117
  def get_input_embeddings(self):
 
218
  missing_time_delta_mask = time_deltas.isnan()
219
  time_deltas = time_deltas.nan_to_num(0) # Replace NaN with dummy value before projection.
220
  time_delta_embeddings = self.time_delta_encoder(time_deltas.unsqueeze(-1))
221
+ time_delta_embeddings[missing_time_delta_mask] = self.get_input_embeddings()(torch.tensor(self.config.missing_time_delta_token_id, device=inputs_embeds.device))
222
  time_delta_embeddings *= time_deltas_mask.unsqueeze(-1)
223
  inputs_embeds += time_delta_embeddings
224