Update llm_trainer.py
Browse files- llm_trainer.py +3 -2
llm_trainer.py
CHANGED
|
@@ -7,7 +7,8 @@ class LLMTrainer:
|
|
| 7 |
model: torch.nn.Module = None,
|
| 8 |
tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
|
| 9 |
model_returns_logits: bool = False):
|
| 10 |
-
self.
|
|
|
|
| 11 |
|
| 12 |
if tokenizer is None:
|
| 13 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
@@ -31,7 +32,7 @@ class LLMTrainer:
|
|
| 31 |
with torch.no_grad():
|
| 32 |
while generated_tokens.size(1) < length:
|
| 33 |
|
| 34 |
-
with torch.autocast(device_type=self.
|
| 35 |
if self.model_returns_logits:
|
| 36 |
logits = self.model(generated_tokens)
|
| 37 |
else:
|
|
|
|
| 7 |
model: torch.nn.Module = None,
|
| 8 |
tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
|
| 9 |
model_returns_logits: bool = False):
|
| 10 |
+
self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
self.device = torch.device(self.device_type)
|
| 12 |
|
| 13 |
if tokenizer is None:
|
| 14 |
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
| 32 |
with torch.no_grad():
|
| 33 |
while generated_tokens.size(1) < length:
|
| 34 |
|
| 35 |
+
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
|
| 36 |
if self.model_returns_logits:
|
| 37 |
logits = self.model(generated_tokens)
|
| 38 |
else:
|