njeffrie commited on
Commit
337c284
·
verified ·
1 Parent(s): f6c67a4

Upload modeling_gluformer.py

Browse files
Files changed (1) hide show
  1. modeling_gluformer.py +9 -2
modeling_gluformer.py CHANGED
@@ -259,7 +259,7 @@ class Gluformer(nn.Module):
259
  dec_out = self.dec_embedding(x_id, x_dec, x_mark_dec)
260
  dec_out = self.decoder(dec_out, enc_out)
261
  dec_out = self.projection(dec_out)
262
- return dec_out[:, -self.len_pred:, :], var_out
263
 
264
  class GluformerConfig(PretrainedConfig):
265
  model_type = "gluformer"
@@ -356,5 +356,12 @@ class GluformerForTimeSeries(PreTrainedModel):
356
  timestamps = timestamps.unsqueeze(0)
357
  glucose_values = glucose_values.unsqueeze(0)
358
  x_id, x_enc, x_dec, x_mark_enc, y_mark_dec = self.preprocessor(subject_id, timestamps, glucose_values)
 
 
 
 
 
 
 
359
  output, log_var = self.model(x_id, x_enc, x_mark_enc, x_dec, y_mark_dec)
360
- return self.preprocessor.unnormalize_glucose(output), log_var
 
259
  dec_out = self.dec_embedding(x_id, x_dec, x_mark_dec)
260
  dec_out = self.decoder(dec_out, enc_out)
261
  dec_out = self.projection(dec_out)
262
+ return dec_out[:, -self.len_pred:, :], var_out
263
 
264
  class GluformerConfig(PretrainedConfig):
265
  model_type = "gluformer"
 
356
  timestamps = timestamps.unsqueeze(0)
357
  glucose_values = glucose_values.unsqueeze(0)
358
  x_id, x_enc, x_dec, x_mark_enc, y_mark_dec = self.preprocessor(subject_id, timestamps, glucose_values)
359
+ if self.device is not None:
360
+ x_id = x_id.to(self.device)
361
+ x_enc = x_enc.to(self.device)
362
+ x_dec = x_dec.to(self.device)
363
+ x_mark_enc = x_mark_enc.to(self.device)
364
+ y_mark_dec = y_mark_dec.to(self.device)
365
+ self.model.to(self.device)
366
  output, log_var = self.model(x_id, x_enc, x_mark_enc, x_dec, y_mark_dec)
367
+ return self.preprocessor.unnormalize_glucose(output).cpu(), log_var.cpu()