Upload modeling_gluformer.py
Browse files- modeling_gluformer.py +9 -2
modeling_gluformer.py
CHANGED
|
@@ -259,7 +259,7 @@ class Gluformer(nn.Module):
|
|
| 259 |
dec_out = self.dec_embedding(x_id, x_dec, x_mark_dec)
|
| 260 |
dec_out = self.decoder(dec_out, enc_out)
|
| 261 |
dec_out = self.projection(dec_out)
|
| 262 |
-
return dec_out[:, -self.len_pred:, :], var_out
|
| 263 |
|
| 264 |
class GluformerConfig(PretrainedConfig):
|
| 265 |
model_type = "gluformer"
|
|
@@ -356,5 +356,12 @@ class GluformerForTimeSeries(PreTrainedModel):
|
|
| 356 |
timestamps = timestamps.unsqueeze(0)
|
| 357 |
glucose_values = glucose_values.unsqueeze(0)
|
| 358 |
x_id, x_enc, x_dec, x_mark_enc, y_mark_dec = self.preprocessor(subject_id, timestamps, glucose_values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
output, log_var = self.model(x_id, x_enc, x_mark_enc, x_dec, y_mark_dec)
|
| 360 |
-
return self.preprocessor.unnormalize_glucose(output), log_var
|
|
|
|
| 259 |
dec_out = self.dec_embedding(x_id, x_dec, x_mark_dec)
|
| 260 |
dec_out = self.decoder(dec_out, enc_out)
|
| 261 |
dec_out = self.projection(dec_out)
|
| 262 |
+
return dec_out[:, -self.len_pred:, :], var_out
|
| 263 |
|
| 264 |
class GluformerConfig(PretrainedConfig):
|
| 265 |
model_type = "gluformer"
|
|
|
|
| 356 |
timestamps = timestamps.unsqueeze(0)
|
| 357 |
glucose_values = glucose_values.unsqueeze(0)
|
| 358 |
x_id, x_enc, x_dec, x_mark_enc, y_mark_dec = self.preprocessor(subject_id, timestamps, glucose_values)
|
| 359 |
+
if self.device is not None:
|
| 360 |
+
x_id = x_id.to(self.device)
|
| 361 |
+
x_enc = x_enc.to(self.device)
|
| 362 |
+
x_dec = x_dec.to(self.device)
|
| 363 |
+
x_mark_enc = x_mark_enc.to(self.device)
|
| 364 |
+
y_mark_dec = y_mark_dec.to(self.device)
|
| 365 |
+
self.model.to(self.device)
|
| 366 |
output, log_var = self.model(x_id, x_enc, x_mark_enc, x_dec, y_mark_dec)
|
| 367 |
+
return self.preprocessor.unnormalize_glucose(output).cpu(), log_var.cpu()
|