|
|
import gradio as gr |
|
|
from transformers import BartTokenizer, BartForConditionalGeneration |
|
|
|
|
|
|
|
|
model_path = "bart_QA" |
|
|
tokenizer = BartTokenizer.from_pretrained(model_path) |
|
|
model = BartForConditionalGeneration.from_pretrained(model_path) |
|
|
|
|
|
def answer_question(question, context): |
|
|
if question.strip() == "" or context.strip() == "": |
|
|
return "Please provide both a question and context." |
|
|
|
|
|
|
|
|
input_text = f"question: {question} context: {context}" |
|
|
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) |
|
|
|
|
|
|
|
|
outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True) |
|
|
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
return answer |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# BART QA System") |
|
|
|
|
|
with gr.Row(): |
|
|
question_input = gr.Textbox(lines=1, placeholder="Enter your question here", label="Question") |
|
|
context_input = gr.Textbox(lines=5, placeholder="Enter the context here", label="Context") |
|
|
|
|
|
answer_output = gr.Textbox(label="Answer", interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("Submit") |
|
|
clear_btn = gr.Button("Clear") |
|
|
|
|
|
submit_btn.click(fn=answer_question, inputs=[question_input, context_input], outputs=answer_output) |
|
|
clear_btn.click(fn=lambda: ("", ""), inputs=[], outputs=[question_input, context_input]) |
|
|
|
|
|
demo.launch() |
|
|
|