flpelerin commited on
Commit
bab1439
·
1 Parent(s): 301a9d0

Update file model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -33,7 +33,7 @@ class Model:
33
  def compute_loss(self, input_ids, labels=None, criterion=None):
34
  lm_logits = self.model(input_ids).logits
35
 
36
- labels = input_ids.to(self.model.device)
37
  shift_logits = lm_logits[:, :-1, :].contiguous()
38
  labels = labels[:, 1:].contiguous()
39
 
@@ -48,7 +48,7 @@ class Model:
48
 
49
  with torch.no_grad():
50
  encoded_ids = tokenizer.encode(seed_text)
51
- input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(self.model.device)
52
  output = self.model.generate(input_ids, max_length=max_len)
53
 
54
  logits = output[0].tolist()
 
33
  def compute_loss(self, input_ids, labels=None, criterion=None):
34
  lm_logits = self.model(input_ids).logits
35
 
36
+ labels = input_ids.to(GetDevice())
37
  shift_logits = lm_logits[:, :-1, :].contiguous()
38
  labels = labels[:, 1:].contiguous()
39
 
 
48
 
49
  with torch.no_grad():
50
  encoded_ids = tokenizer.encode(seed_text)
51
+ input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice())
52
  output = self.model.generate(input_ids, max_length=max_len)
53
 
54
  logits = output[0].tolist()