Update modelling_cxrmate2.py
Browse files- modelling_cxrmate2.py +1 -3
modelling_cxrmate2.py
CHANGED
|
@@ -112,8 +112,6 @@ class CXRMate2ForConditionalGeneration(CXRMate2PreTrainedModel, GenerationMixin)
|
|
| 112 |
hidden_size=config.text_config.hidden_size,
|
| 113 |
)
|
| 114 |
|
| 115 |
-
self.register_buffer('missing_time_delta_token_id', torch.tensor(self.config.missing_time_delta_token_id), persistent=False)
|
| 116 |
-
|
| 117 |
self.post_init()
|
| 118 |
|
| 119 |
def get_input_embeddings(self):
|
|
@@ -220,7 +218,7 @@ class CXRMate2ForConditionalGeneration(CXRMate2PreTrainedModel, GenerationMixin)
|
|
| 220 |
missing_time_delta_mask = time_deltas.isnan()
|
| 221 |
time_deltas = time_deltas.nan_to_num(0) # Replace NaN with dummy value before projection.
|
| 222 |
time_delta_embeddings = self.time_delta_encoder(time_deltas.unsqueeze(-1))
|
| 223 |
-
time_delta_embeddings[missing_time_delta_mask] = self.get_input_embeddings()(self.missing_time_delta_token_id)
|
| 224 |
time_delta_embeddings *= time_deltas_mask.unsqueeze(-1)
|
| 225 |
inputs_embeds += time_delta_embeddings
|
| 226 |
|
|
|
|
| 112 |
hidden_size=config.text_config.hidden_size,
|
| 113 |
)
|
| 114 |
|
|
|
|
|
|
|
| 115 |
self.post_init()
|
| 116 |
|
| 117 |
def get_input_embeddings(self):
|
|
|
|
| 218 |
missing_time_delta_mask = time_deltas.isnan()
|
| 219 |
time_deltas = time_deltas.nan_to_num(0) # Replace NaN with dummy value before projection.
|
| 220 |
time_delta_embeddings = self.time_delta_encoder(time_deltas.unsqueeze(-1))
|
| 221 |
+
time_delta_embeddings[missing_time_delta_mask] = self.get_input_embeddings()(torch.tensor(self.config.missing_time_delta_token_id, device=inputs_embeds.device))
|
| 222 |
time_delta_embeddings *= time_deltas_mask.unsqueeze(-1)
|
| 223 |
inputs_embeds += time_delta_embeddings
|
| 224 |
|