Маликов Дмитрий Романович commited on
Commit
5fe52a8
·
1 Parent(s): 8f9a32d

Add application

Browse files
Files changed (2) hide show
  1. app.py +53 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+
5
+
6
+ large_model_name = "DmitryMalikov/t5-base-question-gen"
7
+ small_model_name = "DmitryMalikov/t5-small-question-gen"
8
+
9
+ tokenizers = {
10
+ "Base T5": AutoTokenizer.from_pretrained(large_model_name),
11
+ "Small T5": AutoTokenizer.from_pretrained(small_model_name),
12
+ }
13
+ models = {
14
+ "Base T5": AutoModelForSeq2SeqLM.from_pretrained(large_model_name),
15
+ "Small T5": AutoModelForSeq2SeqLM.from_pretrained(small_model_name),
16
+ }
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ for model in models.values():
20
+ model.to(device)
21
+
22
+
23
+ def generate_question(context, model_choice):
24
+ tokenizer = tokenizers[model_choice]
25
+ model = models[model_choice]
26
+
27
+ input_text = context.strip()
28
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=256).to(device)
29
+
30
+ outputs = model.generate(
31
+ **inputs,
32
+ max_length=64,
33
+ num_beams=4,
34
+ early_stopping=True
35
+ )
36
+ question = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
+ return question
38
+
39
+
40
+ iface = gr.Interface(
41
+ fn=generate_question,
42
+ inputs=[
43
+ gr.Textbox(label="Context", lines=5, placeholder="Enter text context for question..."),
44
+ gr.Dropdown(choices=["Base T5", "Small T5"], label="Choose model", value="Small T5")
45
+ ],
46
+ outputs=gr.Textbox(label="Generated question"),
47
+ title="Question generation based on context",
48
+ description="Enter text and receive question that can be answered with given context."
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ iface.launch()
53
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers>=4.0.0
2
+ torch
3
+ gradio