Transformer_Homework_2 / Generate_text.py
IkIzma's picture
Duplicate from IkIzma/Transformer_Homework
6098a23
raw
history blame contribute delete
921 Bytes
import numpy as np
import torch
np.random.seed(17)
torch.manual_seed(17)
from transformers import GPT2LMHeadModel, GPT2Tokenizer
def load_tokenizer_and_model(model_name_or_path, device):
return GPT2Tokenizer.from_pretrained(model_name_or_path), GPT2LMHeadModel.from_pretrained(model_name_or_path).to(device)
def generate(
model, tok, text, device,
do_sample=True, max_length=200, repetition_penalty=5.0,
top_k=5, top_p=0.95, temperature=1,
num_beams=None,
no_repeat_ngram_size=3
):
input_ids = tok.encode(text, return_tensors="pt").to(device)
out = model.generate(
input_ids.to(device),
max_length=max_length,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
top_k=top_k, top_p=top_p, temperature=temperature,
num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
)
return list(map(tok.decode, out))