File size: 1,560 Bytes
03ff5a5 e667377 03ff5a5 e667377 4f6b579 e667377 03ff5a5 e667377 03ff5a5 e667377 03ff5a5 e667377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration
# Load the trained model and tokenizer from the specified directory
model_path = "bart_QA"
tokenizer = BartTokenizer.from_pretrained(model_path)
model = BartForConditionalGeneration.from_pretrained(model_path)
def answer_question(question, context):
if question.strip() == "" or context.strip() == "":
return "Please provide both a question and context."
# Generate input sequence
input_text = f"question: {question} context: {context}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Generate answer
outputs = model.generate(**inputs, max_length=150, num_beams=4, early_stopping=True)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
return answer
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# BART QA System")
with gr.Row():
question_input = gr.Textbox(lines=1, placeholder="Enter your question here", label="Question")
context_input = gr.Textbox(lines=5, placeholder="Enter the context here", label="Context")
answer_output = gr.Textbox(label="Answer", interactive=False)
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
submit_btn.click(fn=answer_question, inputs=[question_input, context_input], outputs=answer_output)
clear_btn.click(fn=lambda: ("", ""), inputs=[], outputs=[question_input, context_input])
demo.launch()
|