Unggi commited on
Commit
7a46642
ยท
1 Parent(s): 10cffb0
Files changed (1) hide show
  1. bart_demo_gradio.py +49 -0
bart_demo_gradio.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import transformers
5
+
6
+ # saved_model
7
+ def load_model(model_path, config):
8
+ saved_data = torch.load(
9
+ model_path,
10
+ map_location="cpu" if config.gpu_id < 0 else "cuda:%d" % config.gpu_id
11
+ )
12
+
13
+ bart_best = saved_data["model"]
14
+ train_config = saved_data["config"]
15
+ tokenizer = transformers.PreTrainedTokenizerFast.from_pretrained(config.pretrained_model_name)
16
+
17
+ ## Load weights.
18
+ model = transformers.BartForConditionalGeneration.from_pretrained(config.pretrained_model_name)
19
+ model.load_state_dict(bart_best)
20
+
21
+ return model, tokenizer
22
+
23
+
24
+ # main
25
+ def inference(prompt):
26
+
27
+ config = define_argparser()
28
+ model_path = config.model_fpath
29
+
30
+ model, tokenizer = load_model(
31
+ model_path=model_path,
32
+ config=config
33
+ )
34
+
35
+ input_ids = tokenizer.encode(prompt)
36
+ input_ids = torch.tensor(input_ids)
37
+ input_ids = input_ids.unsqueeze(0)
38
+ output = model.generate(input_ids)
39
+ output = tokenizer.decode(output[0], skip_special_tokens=True)
40
+
41
+ return output
42
+
43
+ demo = gr.Interface(
44
+ fn=inference,
45
+ inputs="text",
46
+ outputs="text" #return ๊ฐ’
47
+ ).launch(share=True) # launch(share=True)๋ฅผ ์„ค์ •ํ•˜๋ฉด ์™ธ๋ถ€์—์„œ ์ ‘์† ๊ฐ€๋Šฅํ•œ ๋งํฌ๊ฐ€ ์ƒ์„ฑ๋จ
48
+
49
+ demo.launch()