HSE_AI / llm_trainer.py
AlekMan's picture
Update llm_trainer.py
c18f087 verified
import torch
from torch.nn import functional as F
from transformers import PreTrainedTokenizer, AutoTokenizer
class LLMTrainer:
def __init__(self,
model: torch.nn.Module = None,
tokenizer: PreTrainedTokenizer | AutoTokenizer = None,
model_returns_logits: bool = False):
self.device_type = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(self.device_type)
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.tokenizer = tokenizer
self.model = model
self.train_loader = None
self.current_step: int = 0
self.model_returns_logits = model_returns_logits
def generate_text(self, prompt: str = "Once upon a time", n_return_sequences: int = 4, length: int = 32) -> None:
self.model.to(self.device)
self.model.eval()
tokens = self.tokenizer.encode(prompt, return_tensors="pt").type(torch.long)
tokens = tokens.repeat(n_return_sequences, 1)
generated_tokens = tokens.to(self.device)
with torch.no_grad():
while generated_tokens.size(1) < length:
with torch.autocast(device_type=self.device_type, dtype=torch.bfloat16):
if self.model_returns_logits:
logits = self.model(generated_tokens)
else:
logits = self.model(generated_tokens).logits
logits = logits[:, -1, :] # Get last token logits (B, vocab_size)
probs = F.softmax(logits, dim=-1) # Convert to probabilities
topk_probs, topk_indices = torch.topk(probs, k=10, dim=-1)
sampled_indices = torch.multinomial(topk_probs, 1) # Shape: (B, 1)
next_tokens = torch.gather(topk_indices, -1, sampled_indices) # (B, 1)
generated_tokens = torch.cat((generated_tokens, next_tokens), dim=1)
continuations = []
for i in range(n_return_sequences):
tokens = generated_tokens[i, :length].tolist()
decoded = self.tokenizer.decode(tokens)
print(f"=== sample {i} ===\n{decoded}")
continuations.append(decoded)
return continuations