AlekMan commited on
Commit
c18f087
·
verified ·
1 Parent(s): a19143d

Update llm_trainer.py

Browse files
Files changed (1) hide show
  1. llm_trainer.py +3 -2
llm_trainer.py CHANGED
@@ -7,7 +7,8 @@ class LLMTrainer:
7
  model: torch.nn.Module = None,
8
  tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
9
  model_returns_logits: bool = False):
10
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
11
 
12
  if tokenizer is None:
13
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
@@ -31,7 +32,7 @@ class LLMTrainer:
31
  with torch.no_grad():
32
  while generated_tokens.size(1) < length:
33
 
34
- with torch.autocast(device_type=self.device, dtype=torch.bfloat16):
35
  if self.model_returns_logits:
36
  logits = self.model(generated_tokens)
37
  else:
 
7
  model: torch.nn.Module = None,
8
  tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
9
  model_returns_logits: bool = False):
10
+ self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
11
+ self.device = torch.device(self.device_type)
12
 
13
  if tokenizer is None:
14
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
32
  with torch.no_grad():
33
  while generated_tokens.size(1) < length:
34
 
35
+ with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
36
  if self.model_returns_logits:
37
  logits = self.model(generated_tokens)
38
  else: