File size: 2,289 Bytes
7818d5e c18f087 7818d5e c18f087 7818d5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
from torch.nn import functional as F
from transformers import PreTrainedTokenizer, AutoTokenizer
class LLMTrainer:
def __init__(self,
model: torch.nn.Module = None,
tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
model_returns_logits: bool = False):
self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(self.device_type)
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.tokenizer = tokenizer
self.model = model
self.train_loader = None
self.current_step: int = 0
self.model_returns_logits = model_returns_logits
def generate_text(self, prompt: str = "Once upon a time", n_return_sequences: int = 4, length: int = 32) -> None:
self.model.to(self.device)
self.model.eval()
tokens = self.tokenizer.encode(prompt, return_tensors="pt").type(torch.long)
tokens = tokens.repeat(n_return_sequences, 1)
generated_tokens = tokens.to(self.device)
with torch.no_grad():
while generated_tokens.size(1) < length:
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
if self.model_returns_logits:
logits = self.model(generated_tokens)
else:
logits = self.model(generated_tokens).logits
logits = logits[:, -1, :] # Get last token logits (B, vocab_size)
probs = F.softmax(logits, dim=-1) # Convert to probabilities
topk_probs, topk_indices = torch.topk(probs, k=10, dim=-1)
sampled_indices = torch.multinomial(topk_probs, 1) # Shape: (B, 1)
next_tokens = torch.gather(topk_indices, -1, sampled_indices) # (B, 1)
generated_tokens = torch.cat((generated_tokens, next_tokens), dim=1)
continuations = []
for i in range(n_return_sequences):
tokens = generated_tokens[i, :length].tolist()
decoded = self.tokenizer.decode(tokens)
print(f"=== sample {i} ===\n{decoded}")
continuations.append(decoded)
return continuations
|