Spaces:
Runtime error
Runtime error
Eitan Lifshits
commited on
Commit
·
d82cb0c
1
Parent(s):
fd71869
update visual output
Browse files- app.py +7 -6
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
|
|
|
| 3 |
|
| 4 |
model = pipeline("question-answering", model="Eitanli/distilbert-qa-checkpoint-v3")
|
| 5 |
|
|
@@ -10,11 +11,11 @@ questions = ['which ingredients are mentioned?',
|
|
| 10 |
|
| 11 |
|
| 12 |
def predict(context, topk, answer_threshold):
|
| 13 |
-
output =
|
| 14 |
for question in questions:
|
| 15 |
-
|
| 16 |
-
answers =
|
| 17 |
-
output
|
| 18 |
return output
|
| 19 |
|
| 20 |
|
|
@@ -22,8 +23,8 @@ iface = gr.Interface(
|
|
| 22 |
fn=predict,
|
| 23 |
inputs=[
|
| 24 |
gr.Textbox(label="Recipe line"),
|
| 25 |
-
gr.Slider(
|
| 26 |
-
gr.Slider(
|
| 27 |
outputs=gr.Textbox(label='Questions and answers')
|
| 28 |
)
|
| 29 |
iface.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from transformers import pipeline
|
| 3 |
+
import numpy as np
|
| 4 |
|
| 5 |
model = pipeline("question-answering", model="Eitanli/distilbert-qa-checkpoint-v3")
|
| 6 |
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def predict(context, topk, answer_threshold):
|
| 14 |
+
output = 'Recipe analysis:'
|
| 15 |
for question in questions:
|
| 16 |
+
pred = model(question=question, context=context, topk=topk)
|
| 17 |
+
answers = '\n'.join([f"{ans['answer']} ({np.round(ans['score'], 2)})" for ans in pred if ans['score'] > answer_threshold])
|
| 18 |
+
output += f'\n\n{question}:\n{answers}'
|
| 19 |
return output
|
| 20 |
|
| 21 |
|
|
|
|
| 23 |
fn=predict,
|
| 24 |
inputs=[
|
| 25 |
gr.Textbox(label="Recipe line"),
|
| 26 |
+
gr.Slider(1, 5, step=1.0, value=2, label="top k", info="Choose between 1 and 5"),
|
| 27 |
+
gr.Slider(0, 0.99, step=0.01, value=0.8, label="answer_threshold", info="Select a threshold in [0, 0.99]")],
|
| 28 |
outputs=gr.Textbox(label='Questions and answers')
|
| 29 |
)
|
| 30 |
iface.launch()
|
requirements.txt
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
transformers==4.30.2
|
| 2 |
-
torch>=2.0
|
|
|
|
|
|
| 1 |
transformers==4.30.2
|
| 2 |
+
torch>=2.0
|
| 3 |
+
numpy
|