File size: 921 Bytes
8f917da
 
 
 
 
 
 
 
 
361d939
8f917da
 
 
39f9d41
8f917da
 
 
 
 
 
63efcfa
8f917da
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import torch

np.random.seed(17)
torch.manual_seed(17)

from transformers import GPT2LMHeadModel, GPT2Tokenizer

def load_tokenizer_and_model(model_name_or_path, device):
    return GPT2Tokenizer.from_pretrained(model_name_or_path), GPT2LMHeadModel.from_pretrained(model_name_or_path).to(device)

def generate(
    model, tok, text, device, 
    do_sample=True, max_length=200, repetition_penalty=5.0,
    top_k=5, top_p=0.95, temperature=1,
    num_beams=None,
    no_repeat_ngram_size=3
    ):
    input_ids = tok.encode(text, return_tensors="pt").to(device)
    out = model.generate(
        input_ids.to(device),
        max_length=max_length,
        repetition_penalty=repetition_penalty,
        do_sample=do_sample,
        top_k=top_k, top_p=top_p, temperature=temperature,
        num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
        )
    return list(map(tok.decode, out))