Update file model.py
Browse files
model.py
CHANGED
|
@@ -33,7 +33,7 @@ class Model:
|
|
| 33 |
def compute_loss(self, input_ids, labels=None, criterion=None):
|
| 34 |
lm_logits = self.model(input_ids).logits
|
| 35 |
|
| 36 |
-
labels = input_ids.to(
|
| 37 |
shift_logits = lm_logits[:, :-1, :].contiguous()
|
| 38 |
labels = labels[:, 1:].contiguous()
|
| 39 |
|
|
@@ -48,7 +48,7 @@ class Model:
|
|
| 48 |
|
| 49 |
with torch.no_grad():
|
| 50 |
encoded_ids = tokenizer.encode(seed_text)
|
| 51 |
-
input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(
|
| 52 |
output = self.model.generate(input_ids, max_length=max_len)
|
| 53 |
|
| 54 |
logits = output[0].tolist()
|
|
|
|
| 33 |
def compute_loss(self, input_ids, labels=None, criterion=None):
|
| 34 |
lm_logits = self.model(input_ids).logits
|
| 35 |
|
| 36 |
+
labels = input_ids.to(GetDevice())
|
| 37 |
shift_logits = lm_logits[:, :-1, :].contiguous()
|
| 38 |
labels = labels[:, 1:].contiguous()
|
| 39 |
|
|
|
|
| 48 |
|
| 49 |
with torch.no_grad():
|
| 50 |
encoded_ids = tokenizer.encode(seed_text)
|
| 51 |
+
input_ids = torch.tensor(encoded_ids).unsqueeze(0).to(GetDevice())
|
| 52 |
output = self.model.generate(input_ids, max_length=max_len)
|
| 53 |
|
| 54 |
logits = output[0].tolist()
|