import torch from torch.nn import functional as F from transformers import PreTrainedTokenizer, AutoTokenizer class LLMTrainer: def __init__(self, model: torch.nn.Module = None, tokenizer: PreTrainedTokenizer | AutoTokenizer = None, model_returns_logits: bool = False): self.device_type = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_type) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained("gpt2") self.tokenizer = tokenizer self.model = model self.train_loader = None self.current_step: int = 0 self.model_returns_logits = model_returns_logits def generate_text(self, prompt: str = "Once upon a time", n_return_sequences: int = 4, length: int = 32) -> None: self.model.to(self.device) self.model.eval() tokens = self.tokenizer.encode(prompt, return_tensors="pt").type(torch.long) tokens = tokens.repeat(n_return_sequences, 1) generated_tokens = tokens.to(self.device) with torch.no_grad(): while generated_tokens.size(1) < length: with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): if self.model_returns_logits: logits = self.model(generated_tokens) else: logits = self.model(generated_tokens).logits logits = logits[:, -1, :] # Get last token logits (B, vocab_size) probs = F.softmax(logits, dim=-1) # Convert to probabilities topk_probs, topk_indices = torch.topk(probs, k=10, dim=-1) sampled_indices = torch.multinomial(topk_probs, 1) # Shape: (B, 1) next_tokens = torch.gather(topk_indices, -1, sampled_indices) # (B, 1) generated_tokens = torch.cat((generated_tokens, next_tokens), dim=1) continuations = [] for i in range(n_return_sequences): tokens = generated_tokens[i, :length].tolist() decoded = self.tokenizer.decode(tokens) print(f"=== sample {i} ===\n{decoded}") continuations.append(decoded) return continuations