AlekMan commited on
Commit
021e532
·
verified ·
1 Parent(s): 4a94fa6

Upload llm_trainer.py

Browse files
Files changed (1) hide show
  1. llm_trainer.py +80 -0
llm_trainer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+ from transformers import PreTrainedTokenizer, AutoTokenizer
4
+
5
+ from llm_trainer.dataset.DataLoader import DataLoader
6
+
7
+ class LLMTrainer:
8
+ def __init__(self,
9
+ model: torch.nn.Module = None,
10
+ tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
11
+ model_returns_logits: bool = False):
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Training on: {self.device}")
14
+
15
+ if tokenizer is None:
16
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
17
+
18
+ self.tokenizer = tokenizer
19
+
20
+ if model is None:
21
+ raise ValueError("Specify a model.")
22
+ self.model = model
23
+
24
+ self.train_loader = None
25
+ self.current_step: int = 0
26
+
27
+ self.model_returns_logits = model_returns_logits
28
+
29
+ def generate_text(self, prompt: str = "Once upon a time", n_return_sequences: int = 4, length: int = 32) -> None:
30
+ # Make sure the model is on the same device
31
+ self.model.to(self.device)
32
+ self.model.eval()
33
+
34
+ tokens = self.tokenizer.encode(prompt, return_tensors="pt").type(torch.long)
35
+ tokens = tokens.repeat(n_return_sequences, 1)
36
+
37
+ generated_tokens = tokens.to(self.device)
38
+ with torch.no_grad():
39
+ while generated_tokens.size(1) < length:
40
+
41
+ with torch.autocast(device_type=self.device, dtype=torch.bfloat16):
42
+ if self.model_returns_logits:
43
+ logits = self.model(generated_tokens)
44
+ else:
45
+ logits = self.model(generated_tokens).logits
46
+
47
+ # logits.shape = (batch_size, context_window, vocab_size)
48
+
49
+ logits = logits[:, -1, :] # Get last token logits (B, vocab_size)
50
+ probs = F.softmax(logits, dim=-1) # Convert to probabilities
51
+
52
+ # Top-k sampling
53
+ topk_probs, topk_indices = torch.topk(probs, k=10, dim=-1)
54
+ sampled_indices = torch.multinomial(topk_probs, 1) # Shape: (B, 1)
55
+ next_tokens = torch.gather(topk_indices, -1, sampled_indices) # (B, 1)
56
+
57
+ # Append generated token to sequence
58
+ generated_tokens = torch.cat((generated_tokens, next_tokens), dim=1)
59
+
60
+ # print the generated text
61
+ continuations = []
62
+ for i in range(n_return_sequences):
63
+ tokens = generated_tokens[i, :length].tolist()
64
+ decoded = self.tokenizer.decode(tokens)
65
+ print(f"=== sample {i} ===\n{decoded}")
66
+ continuations.append(decoded)
67
+ return continuations
68
+
69
+ def load_checkpoint(self, checkpoint_path: str) -> None:
70
+ checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
71
+
72
+ # If the model was saved after running `torch.compile` then the names of its layers were changed.
73
+ # Need to change it back.
74
+ new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in checkpoint['model_state_dict'].items()}
75
+ self.model.to(self.device)
76
+ self.model.load_state_dict(new_state_dict)
77
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
78
+ self.train_loader: DataLoader = checkpoint["train_loader"]
79
+
80
+ self.current_step = checkpoint['step'] # Resume from the last step