jaimin commited on
Commit
665a12b
·
1 Parent(s): 76869a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -3
app.py CHANGED
@@ -1,6 +1,40 @@
1
  import gradio as gr
 
 
2
 
3
- iface = gr.Interface.load("huggingface/csarron/bert-base-uncased-squad-v1",
4
- )
5
 
6
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from gradio.mix import Parallel
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ import torch
 
6
 
7
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained('jaimin/QA')
10
+ model = AutoModelForSeq2SeqLM.from_pretrained('jaimin/QA').to('cuda:0')
11
+
12
+ def get_answer(question,context,device="cpu"):
13
+
14
+ question_doc = "question: {} context: {}".format(question, context)
15
+
16
+ encoding = tokenizer.encode_plus(question_doc, pad_to_max_length=True, return_tensors="pt")
17
+ input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device)
18
+
19
+ model_output = model.generate(
20
+ input_ids=input_ids,
21
+ attention_mask=attention_masks,
22
+ num_beams=10,
23
+ do_sample=True,
24
+ max_length=30,
25
+ top_k=50,
26
+ top_p=0.95,
27
+ early_stopping=True,
28
+ no_repeat_ngram_size=1
29
+ )
30
+ generated_sent = tokenizer.decode(
31
+ model_output, skip_special_tokens=True, clean_up_tokenization_spaces=True)
32
+
33
+ return generated_sent
34
+
35
+ demo = gr.Interface(get_answer, inputs=[gr.inputs.Textbox(label="question", optional=False),
36
+ gr.inputs.Textbox(label="context", optional=False],
37
+ outputs=[gr.outputs.Textbox(label="Answer")])
38
+
39
+ if __name__ == "__main__":
40
+ demo.launch()