IkIzma commited on
Commit
8f917da
·
1 Parent(s): 861649f

Upload Generate_text_with_RuGPTs_HF.py

Browse files
Files changed (1) hide show
  1. Generate_text_with_RuGPTs_HF.py +28 -0
Generate_text_with_RuGPTs_HF.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ np.random.seed(17)
5
+ torch.manual_seed(17)
6
+
7
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
8
+
9
+ def load_tokenizer_and_model(model_name_or_path, device):
10
+ return GPT2Tokenizer.from_pretrained(model_name_or_path), GPT2LMHeadModel.from_pretrained(model_name_or_path).cuda()
11
+
12
+ def generate(
13
+ model, tok, text, device,
14
+ do_sample=True, max_length=50, repetition_penalty=5.0,
15
+ top_k=5, top_p=0.95, temperature=1,
16
+ num_beams=None,
17
+ no_repeat_ngram_size=3
18
+ ):
19
+ input_ids = tok.encode(text, return_tensors="pt").to(device)
20
+ out = model.generate(
21
+ input_ids.cuda(),
22
+ max_length=max_length,
23
+ repetition_penalty=repetition_penalty,
24
+ do_sample=do_sample,
25
+ top_k=top_k, top_p=top_p, temperature=temperature,
26
+ num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
27
+ )
28
+ return list(map(tok.decode, out))