Upload modeling_gluformer.py

#2
by njeffrie - opened
Files changed (1) hide show
  1. modeling_gluformer.py +8 -1
modeling_gluformer.py CHANGED
@@ -356,5 +356,12 @@ class GluformerForTimeSeries(PreTrainedModel):
356
  timestamps = timestamps.unsqueeze(0)
357
  glucose_values = glucose_values.unsqueeze(0)
358
  x_id, x_enc, x_dec, x_mark_enc, y_mark_dec = self.preprocessor(subject_id, timestamps, glucose_values)
359
- output, log_var = self.model(x_id, x_enc.to(self.device), x_mark_enc.to(self.device), x_dec.to(self.device), y_mark_dec.to(self.device))
 
 
 
 
 
 
 
360
  return self.preprocessor.unnormalize_glucose(output).cpu(), log_var.cpu()
 
356
  timestamps = timestamps.unsqueeze(0)
357
  glucose_values = glucose_values.unsqueeze(0)
358
  x_id, x_enc, x_dec, x_mark_enc, y_mark_dec = self.preprocessor(subject_id, timestamps, glucose_values)
359
+ if self.device is not None:
360
+ x_id = x_id.to(self.device)
361
+ x_enc = x_enc.to(self.device)
362
+ x_dec = x_dec.to(self.device)
363
+ x_mark_enc = x_mark_enc.to(self.device)
364
+ y_mark_dec = y_mark_dec.to(self.device)
365
+ self.model.to(self.device)
366
+ output, log_var = self.model(x_id, x_enc, x_mark_enc, x_dec, y_mark_dec)
367
  return self.preprocessor.unnormalize_glucose(output).cpu(), log_var.cpu()