squad-backend / models /distilbert_model.py
tnp554's picture
feat: deploy squad backend to hugging face spaces
0a5d897
"""
distilbert_model.py — HuggingFace DistilBERT Question Answering Model.
Model: distilbert-base-cased-distilled-squad
40% smaller and 60% faster than BERT with ~97% of its performance.
Uses direct PyTorch inference (compatible with transformers 5.x).
"""
import logging
logger = logging.getLogger(__name__)
_tokenizer = None
_model = None
MODEL_NAME = "distilbert-base-cased-distilled-squad"
def init_distilbert_model():
"""Load the DistilBERT QA model. Called once at app startup."""
global _tokenizer, _model
try:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
logger.info(f"[DistilBERT] Loading model '{MODEL_NAME}' ...")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
_model.eval()
logger.info("[DistilBERT] Model loaded and ready.")
except Exception as exc:
logger.error(f"[DistilBERT] Failed to load model: {exc}")
_tokenizer = None
_model = None
def _run_qa_inference(context: str, question: str) -> dict:
"""Direct PyTorch inference — works with any transformers version."""
import torch
import torch.nn.functional as F
inputs = _tokenizer(
question, context,
return_tensors="pt",
truncation=True,
max_length=512,
)
with torch.no_grad():
outputs = _model(**inputs)
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
start_idx = int(torch.argmax(start_logits))
end_idx = int(torch.argmax(end_logits)) + 1
if end_idx <= start_idx:
end_idx = start_idx + 1
input_ids = inputs["input_ids"][0]
answer_tokens = input_ids[start_idx:end_idx]
answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip()
# DistilBERT Normalized Calibration: Softer scaling to respect BERT-Large's superiority.
import math
start_prob = float(F.softmax(start_logits, dim=0)[start_idx])
end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1])
avg_prob = (start_prob + end_prob) / 2
calibrated_score = avg_prob ** 0.8 # Softer scaling
score = round(min(max(calibrated_score, 0.0), 1.0), 4)
# --- Strong Collision Filter: Block question repetition using word overlap ---
q_words = set(question.lower().replace('?', '').split())
a_words = set(answer.lower().replace('?', '').split())
if q_words and a_words:
common = q_words.intersection(a_words)
overlap_ratio = len(common) / len(q_words)
else:
overlap_ratio = 0
if overlap_ratio > 0.7 or len(answer.strip()) < 1 or "[cls]" in answer.lower():
return {
"answer": "I'm sorry, I couldn't find a specific answer to that in the provided document.",
"score": 0.0,
"error": False,
"not_found": True
}
return {"answer": answer, "score": score}
def predict(context: str, question: str) -> dict:
"""
Run QA inference.
Returns:
{
"answer": str,
"score": float (0.0–1.0),
"model": "DistilBERT",
"model_id": "distilbert"
}
"""
if _model is None or _tokenizer is None:
return {
"answer": "DistilBERT model is not loaded. Please check server logs.",
"score": 0.0,
"model": "DistilBERT",
"model_id": "distilbert",
"error": True,
}
if not context or not question:
return {
"answer": "Context and question must not be empty.",
"score": 0.0,
"model": "DistilBERT",
"model_id": "distilbert",
"error": True,
}
try:
result = _run_qa_inference(context=context, question=question)
score = result["score"]
answer = result["answer"]
if score < 0.05 or "[CLS]" in answer or not answer:
answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context."
score = 0.0
return {
"answer": answer,
"score": score,
"model": "DistilBERT",
"model_id": "distilbert",
"error": False,
}
except Exception as exc:
logger.error(f"[DistilBERT] Inference error: {exc}")
return {
"answer": f"Inference error: {exc}",
"score": 0.0,
"model": "DistilBERT",
"model_id": "distilbert",
"error": True,
}