AlekMan commited on
Commit
7818d5e
·
verified ·
1 Parent(s): 87003dd

Update llm_trainer.py

Browse files
Files changed (1) hide show
  1. llm_trainer.py +55 -80
llm_trainer.py CHANGED
@@ -1,80 +1,55 @@
1
- import torch
2
- from torch.nn import functional as F
3
- from transformers import PreTrainedTokenizer, AutoTokenizer
4
-
5
- from llm_trainer.dataset.DataLoader import DataLoader
6
-
7
- class LLMTrainer:
8
- def __init__(self,
9
- model: torch.nn.Module = None,
10
- tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
11
- model_returns_logits: bool = False):
12
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- print(f"Training on: {self.device}")
14
-
15
- if tokenizer is None:
16
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
17
-
18
- self.tokenizer = tokenizer
19
-
20
- if model is None:
21
- raise ValueError("Specify a model.")
22
- self.model = model
23
-
24
- self.train_loader = None
25
- self.current_step: int = 0
26
-
27
- self.model_returns_logits = model_returns_logits
28
-
29
- def generate_text(self, prompt: str = "Once upon a time", n_return_sequences: int = 4, length: int = 32) -> None:
30
- # Make sure the model is on the same device
31
- self.model.to(self.device)
32
- self.model.eval()
33
-
34
- tokens = self.tokenizer.encode(prompt, return_tensors="pt").type(torch.long)
35
- tokens = tokens.repeat(n_return_sequences, 1)
36
-
37
- generated_tokens = tokens.to(self.device)
38
- with torch.no_grad():
39
- while generated_tokens.size(1) < length:
40
-
41
- with torch.autocast(device_type=self.device, dtype=torch.bfloat16):
42
- if self.model_returns_logits:
43
- logits = self.model(generated_tokens)
44
- else:
45
- logits = self.model(generated_tokens).logits
46
-
47
- # logits.shape = (batch_size, context_window, vocab_size)
48
-
49
- logits = logits[:, -1, :] # Get last token logits (B, vocab_size)
50
- probs = F.softmax(logits, dim=-1) # Convert to probabilities
51
-
52
- # Top-k sampling
53
- topk_probs, topk_indices = torch.topk(probs, k=10, dim=-1)
54
- sampled_indices = torch.multinomial(topk_probs, 1) # Shape: (B, 1)
55
- next_tokens = torch.gather(topk_indices, -1, sampled_indices) # (B, 1)
56
-
57
- # Append generated token to sequence
58
- generated_tokens = torch.cat((generated_tokens, next_tokens), dim=1)
59
-
60
- # print the generated text
61
- continuations = []
62
- for i in range(n_return_sequences):
63
- tokens = generated_tokens[i, :length].tolist()
64
- decoded = self.tokenizer.decode(tokens)
65
- print(f"=== sample {i} ===\n{decoded}")
66
- continuations.append(decoded)
67
- return continuations
68
-
69
- def load_checkpoint(self, checkpoint_path: str) -> None:
70
- checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
71
-
72
- # If the model was saved after running `torch.compile` then the names of its layers were changed.
73
- # Need to change it back.
74
- new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in checkpoint['model_state_dict'].items()}
75
- self.model.to(self.device)
76
- self.model.load_state_dict(new_state_dict)
77
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
78
- self.train_loader: DataLoader = checkpoint["train_loader"]
79
-
80
- self.current_step = checkpoint['step'] # Resume from the last step
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+ from transformers import PreTrainedTokenizer, AutoTokenizer
4
+
5
+ class LLMTrainer:
6
+ def __init__(self,
7
+ model: torch.nn.Module = None,
8
+ tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
9
+ model_returns_logits: bool = False):
10
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+ if tokenizer is None:
13
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
14
+
15
+ self.tokenizer = tokenizer
16
+ self.model = model
17
+
18
+ self.train_loader = None
19
+ self.current_step: int = 0
20
+
21
+ self.model_returns_logits = model_returns_logits
22
+
23
+ def generate_text(self, prompt: str = "Once upon a time", n_return_sequences: int = 4, length: int = 32) -> None:
24
+ self.model.to(self.device)
25
+ self.model.eval()
26
+
27
+ tokens = self.tokenizer.encode(prompt, return_tensors="pt").type(torch.long)
28
+ tokens = tokens.repeat(n_return_sequences, 1)
29
+
30
+ generated_tokens = tokens.to(self.device)
31
+ with torch.no_grad():
32
+ while generated_tokens.size(1) < length:
33
+
34
+ with torch.autocast(device_type=self.device, dtype=torch.bfloat16):
35
+ if self.model_returns_logits:
36
+ logits = self.model(generated_tokens)
37
+ else:
38
+ logits = self.model(generated_tokens).logits
39
+
40
+ logits = logits[:, -1, :] # Get last token logits (B, vocab_size)
41
+ probs = F.softmax(logits, dim=-1) # Convert to probabilities
42
+
43
+ topk_probs, topk_indices = torch.topk(probs, k=10, dim=-1)
44
+ sampled_indices = torch.multinomial(topk_probs, 1) # Shape: (B, 1)
45
+ next_tokens = torch.gather(topk_indices, -1, sampled_indices) # (B, 1)
46
+
47
+ generated_tokens = torch.cat((generated_tokens, next_tokens), dim=1)
48
+
49
+ continuations = []
50
+ for i in range(n_return_sequences):
51
+ tokens = generated_tokens[i, :length].tolist()
52
+ decoded = self.tokenizer.decode(tokens)
53
+ print(f"=== sample {i} ===\n{decoded}")
54
+ continuations.append(decoded)
55
+ return continuations