qa-bert / app.py
leenthaher's picture
Create app.py
bcb30a6 verified
import gradio as gr
from transformers import BertTokenizerFast, BertForQuestionAnswering
import torch
import torch.nn.functional as F
MODEL_DIR = "deepset/bert-base-uncased-squad2"
MAX_LENGTH = 384
DEVICE = torch.device("cpu")
print("[INFO] Loading model...")
tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR)
model = BertForQuestionAnswering.from_pretrained(MODEL_DIR)
model.to(DEVICE)
model.eval()
print("[INFO] Model ready!")
def answer_question(context, question):
if not context.strip():
return "⚠️ Please enter a context paragraph."
if not question.strip():
return "⚠️ Please enter a question."
inputs = tokenizer(
question, context,
max_length=MAX_LENGTH,
truncation="only_second",
return_offsets_mapping=True,
return_tensors="pt",
padding="max_length",
)
offset_mapping = inputs.pop("offset_mapping")[0].tolist()
seq_ids = inputs.sequence_ids(0)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
start_logits = outputs.start_logits[0].cpu()
end_logits = outputs.end_logits[0].cpu()
for idx, sid in enumerate(seq_ids):
if sid != 1:
start_logits[idx] = -1e9
end_logits[idx] = -1e9
start_probs = F.softmax(start_logits, dim=-1).numpy()
end_probs = F.softmax(end_logits, dim=-1).numpy()
best_score = -1e9
best_start = 0
best_end = 0
for s in range(len(start_probs)):
for e in range(s, min(s + 50, len(end_probs))):
score = start_probs[s] + end_probs[e]
if score > best_score:
best_score = score
best_start = s
best_end = e
char_start = offset_mapping[best_start][0] if offset_mapping[best_start] else 0
char_end = offset_mapping[best_end][1] if offset_mapping[best_end] else 0
answer = context[char_start:char_end]
if best_score > 1.5:
confidence = "🟢 High confidence"
elif best_score > 1.0:
confidence = "🟡 Medium confidence"
else:
confidence = "🔴 Low confidence"
return f"**Answer:** {answer}\n\n**Score:** {best_score:.4f}{confidence}"
demo = gr.Interface(
fn=answer_question,
inputs=[
gr.Textbox(label="📖 Context", placeholder="Paste your paragraph here...", lines=6),
gr.Textbox(label="❓ Question", placeholder="Ask a question about the context...", lines=2),
],
outputs=gr.Markdown(label="💡 Answer"),
title="🤖 QA with BERT — SQuAD",
description="Question Answering using BERT fine-tuned on SQuAD v2. Enter a context and ask a question!",
examples=[
["Alan Turing proposed the Turing test in 1950. Deep Blue, developed by IBM, defeated chess world champion Garry Kasparov in 1997.", "Who developed Deep Blue?"],
["BERT was created and published in 2018 by Jacob Devlin and his colleagues from Google.", "Who created BERT?"],
],
)
demo.launch()