BART_QA / app.py
Maddy90's picture
Update app.py
e667377 verified
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration
# Load the trained model and tokenizer from the specified directory
model_path = "bart_QA"
tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path)
def answer_question(question, context):
if question.strip() == "" or context.strip() == "":
return "Please provide both a question and context."
# Generate input sequence
input_text = f"question: {question} context: {context}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate answer
outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# BART QA System")
with gr.Row():
question_input = gr.Textbox(lines=1, placeholder="Enter your question here", label="Question")
context_input = gr.Textbox(lines=5, placeholder="Enter the context here", label="Context")
answer_output = gr.Textbox(label="Answer", interactive=False)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
submit_btn.click(fn=answer_question, inputs=[question_input, context_input], outputs=answer_output)
clear_btn.click(fn=lambda: ("", ""), inputs=[], outputs=[question_input, context_input])
demo.launch()