Vigen1 commited on
Commit
f94c2c8
·
verified ·
1 Parent(s): 7a8a7e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -29
app.py CHANGED
@@ -1,32 +1,33 @@
1
- import gradio
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
4
- model_name = 't5-small'
5
 
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
-
8
- finetuned_model = AutoModelForSeq2SeqLM.from_pretrained("finetuned_model_2_epoch")
9
-
10
- question = gradio.Textbox("question")
11
- context = gradio.Textbox("context")
12
-
13
- prompt = f"""Tables:
14
- {context}
15
-
16
- Question:
17
- {question}
18
-
19
- Answer:
20
- """
21
-
22
- inputs = tokenizer(prompt, return_tensors='pt')
23
-
24
- output = tokenizer.decode(
25
- finetuned_model.generate(
26
- inputs["input_ids"],
27
- max_new_tokens=200,
28
- )[0],
29
- skip_special_tokens=True
30
- )
31
-
32
- print(f'MODEL GENERATION - ZERO SHOT:\n{output}')
 
 
 
1
+ import gradio as gr
2
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
 
 
4
 
5
+ def get_output(question, context):
6
+ model_name = 't5-small'
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+
10
+ finetuned_model = AutoModelForSeq2SeqLM.from_pretrained("finetuned_model_2_epoch")
11
+
12
+ prompt = f"""Tables:
13
+ {context}
14
+
15
+ Question:
16
+ {question}
17
+
18
+ Answer:
19
+ """
20
+
21
+ inputs = tokenizer(prompt, return_tensors='pt')
22
+
23
+ output = tokenizer.decode(
24
+ finetuned_model.generate(
25
+ inputs["input_ids"],
26
+ max_new_tokens=200,
27
+ )[0],
28
+ skip_special_tokens=True
29
+ )
30
+
31
+ print(f'MODEL GENERATION - ZERO SHOT:\n{output}')
32
+
33
+ interface = gr.Interface(fn=get_output, inputs = ["text", "text"], outputs=["text"])