SQuAD / models /bert_model.py
tnp554's picture
feat: deploy SQuAD backend with all AI models
09daf0b
"""
bert_model.py — HuggingFace BERT Question Answering Model.
Model: deepset/bert-base-cased-squad2
Uses direct PyTorch inference (compatible with transformers 5.x).
"""
import logging
logger = logging.getLogger(__name__)
_tokenizer = None
_model = None
MODEL_NAME = "deepset/bert-base-cased-squad2"
def init_bert_model():
"""Load the BERT QA model. Called once at app startup."""
global _tokenizer, _model
try:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
logger.info(f"[BERT] Loading model '{MODEL_NAME}' ...")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
_model.eval()
logger.info("[BERT] Model loaded and ready.")
except Exception as exc:
logger.error(f"[BERT] Failed to load model: {exc}")
_tokenizer = None
_model = None
def _run_qa_inference(context: str, question: str) -> dict:
"""Direct PyTorch inference — works with any transformers version."""
import torch
import torch.nn.functional as F
inputs = _tokenizer(
question, context,
return_tensors="pt",
truncation=True,
max_length=512,
)
with torch.no_grad():
outputs = _model(**inputs)
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
start_idx = int(torch.argmax(start_logits))
end_idx = int(torch.argmax(end_logits)) + 1
if end_idx <= start_idx:
end_idx = start_idx + 1
input_ids = inputs["input_ids"][0]
answer_tokens = input_ids[start_idx:end_idx]
answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip()
# Confidence approximation via softmax
start_prob = float(F.softmax(start_logits, dim=0)[start_idx])
end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1])
score = round((start_prob + end_prob) / 2, 4)
return {"answer": answer, "score": score}
def predict(context: str, question: str) -> dict:
"""
Run QA inference.
Returns:
{
"answer": str,
"score": float (0.0–1.0),
"model": "BERT",
"model_id": "bert"
}
"""
if _model is None or _tokenizer is None:
return {
"answer": "BERT model is not loaded. Please check server logs.",
"score": 0.0,
"model": "BERT",
"model_id": "bert",
"error": True,
}
if not context or not question:
return {
"answer": "Context and question must not be empty.",
"score": 0.0,
"model": "BERT",
"model_id": "bert",
"error": True,
}
try:
result = _run_qa_inference(context=context, question=question)
score = result["score"]
answer = result["answer"]
if score < 0.05 or "[CLS]" in answer or not answer:
answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context."
score = 0.0
return {
"answer": answer,
"score": score,
"model": "BERT",
"model_id": "bert",
"error": False,
}
except Exception as exc:
logger.error(f"[BERT] Inference error: {exc}")
return {
"answer": f"Inference error: {exc}",
"score": 0.0,
"model": "BERT",
"model_id": "bert",
"error": True,
}