File size: 2,972 Bytes
4bd1e93 1ac3b59 4bd1e93 f253a67 4bd1e93 1ac3b59 4bd1e93 1ac3b59 4bd1e93 1ac3b59 4bd1e93 1ac3b59 4bd1e93 1ac3b59 f253a67 4bd1e93 1ac3b59 4bd1e93 1ac3b59 4bd1e93 1ac3b59 4bd1e93 1ac3b59 4bd1e93 1ac3b59 4bd1e93 f253a67 1ac3b59 4bd1e93 f253a67 4bd1e93 1ac3b59 4bd1e93 1ac3b59 4bd1e93 59b0a5a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
MODEL_NAME = "Salommee/bert-squad-qa"
print("Loading model...")
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
model = AutoModelForQuestionAnswering.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise
# --- Function to answer questions ---
def answer_question(context, question):
if not context.strip():
return "β Provide context.", "N/A"
if not question.strip():
return "β Provide question.", "N/A"
inputs = tokenizer(
question,
context,
truncation="only_second",
max_length=384,
return_tensors="pt",
padding=True
)
with torch.no_grad():
outputs = model(**inputs)
start_idx = torch.argmax(outputs.start_logits)
end_idx = torch.argmax(outputs.end_logits)
start_score = torch.softmax(outputs.start_logits, dim=1)[0][start_idx].item()
end_score = torch.softmax(outputs.end_logits, dim=1)[0][end_idx].item()
confidence = start_score * end_score
if start_idx > end_idx or start_idx==0 or end_idx==0:
return "β Answer not found. Try rephrasing your question.", f"{confidence:.2%}"
answer = tokenizer.decode(inputs.input_ids[0][start_idx:end_idx+1], skip_special_tokens=True)
emoji = "π’" if confidence>0.8 else "π‘" if confidence>0.5 else "π΄"
return f"β
{answer}", f"{emoji} {confidence:.2%}"
# --- Example inputs ---
examples = [
["Paris is the capital of France.", "What is the capital of France?"],
["Eiffel Tower built 1887-1889.", "When was the Eiffel Tower built?"],
["Machine learning automates model building.", "What is machine learning?"]
]
# --- Build Gradio interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# π€ BERT Question Answering")
with gr.Row():
with gr.Column(scale=2):
context_input = gr.Textbox(label="π Context", lines=6, placeholder="Enter context here...")
question_input = gr.Textbox(label="β Question", lines=2, placeholder="Ask your question...")
submit_btn = gr.Button("π Get Answer")
with gr.Column(scale=1):
answer_output = gr.Textbox(label="π‘ Answer", lines=2)
confidence_output = gr.Textbox(label="π Confidence", lines=1)
gr.Examples(
examples,
inputs=[context_input, question_input],
outputs=[answer_output, confidence_output],
fn=answer_question
)
submit_btn.click(
fn=answer_question,
inputs=[context_input, question_input],
outputs=[answer_output, confidence_output]
)
if __name__ == "__main__":
demo.launch(share=True) |