Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| np.random.seed(17) | |
| torch.manual_seed(17) | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| def load_tokenizer_and_model(model_name_or_path, device): | |
| return GPT2Tokenizer.from_pretrained(model_name_or_path), GPT2LMHeadModel.from_pretrained(model_name_or_path).to(device) | |
| def generate( | |
| model, tok, text, device, | |
| do_sample=True, max_length=200, repetition_penalty=5.0, | |
| top_k=5, top_p=0.95, temperature=1, | |
| num_beams=None, | |
| no_repeat_ngram_size=3 | |
| ): | |
| input_ids = tok.encode(text, return_tensors="pt").to(device) | |
| out = model.generate( | |
| input_ids.to(device), | |
| max_length=max_length, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=do_sample, | |
| top_k=top_k, top_p=top_p, temperature=temperature, | |
| num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size | |
| ) | |
| return list(map(tok.decode, out)) | |