|
|
import torch |
|
|
from torch.nn import functional as F |
|
|
from transformers import PreTrainedTokenizer, AutoTokenizer |
|
|
|
|
|
class LLMTrainer: |
|
|
def __init__(self, |
|
|
model: torch.nn.Module = None, |
|
|
tokenizer: PreTrainedTokenizer | AutoTokenizer = None, |
|
|
model_returns_logits: bool = False): |
|
|
self.device_type = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = torch.device(self.device_type) |
|
|
|
|
|
if tokenizer is None: |
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.model = model |
|
|
|
|
|
self.train_loader = None |
|
|
self.current_step: int = 0 |
|
|
|
|
|
self.model_returns_logits = model_returns_logits |
|
|
|
|
|
def generate_text(self, prompt: str = "Once upon a time", n_return_sequences: int = 4, length: int = 32) -> None: |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
tokens = self.tokenizer.encode(prompt, return_tensors="pt").type(torch.long) |
|
|
tokens = tokens.repeat(n_return_sequences, 1) |
|
|
|
|
|
generated_tokens = tokens.to(self.device) |
|
|
with torch.no_grad(): |
|
|
while generated_tokens.size(1) < length: |
|
|
|
|
|
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16): |
|
|
if self.model_returns_logits: |
|
|
logits = self.model(generated_tokens) |
|
|
else: |
|
|
logits = self.model(generated_tokens).logits |
|
|
|
|
|
logits = logits[:, -1, :] |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
|
|
topk_probs, topk_indices = torch.topk(probs, k=10, dim=-1) |
|
|
sampled_indices = torch.multinomial(topk_probs, 1) |
|
|
next_tokens = torch.gather(topk_indices, -1, sampled_indices) |
|
|
|
|
|
generated_tokens = torch.cat((generated_tokens, next_tokens), dim=1) |
|
|
|
|
|
continuations = [] |
|
|
for i in range(n_return_sequences): |
|
|
tokens = generated_tokens[i, :length].tolist() |
|
|
decoded = self.tokenizer.decode(tokens) |
|
|
print(f"=== sample {i} ===\n{decoded}") |
|
|
continuations.append(decoded) |
|
|
return continuations |
|
|
|