Spaces:
Sleeping
Sleeping
File size: 4,573 Bytes
0a5d897 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """
bert_model.py — HuggingFace BERT Question Answering Model.
Model: deepset/bert-base-cased-squad2
Uses direct PyTorch inference (compatible with transformers 5.x).
"""
import logging
logger = logging.getLogger(__name__)
_tokenizer = None
_model = None
MODEL_NAME = "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad"
def init_bert_model():
"""Load the BERT QA model. Called once at app startup."""
global _tokenizer, _model
try:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
logger.info(f"[BERT] Loading model '{MODEL_NAME}' ...")
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
_model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
_model.eval()
logger.info("[BERT] Model loaded and ready.")
except Exception as exc:
logger.error(f"[BERT] Failed to load model: {exc}")
_tokenizer = None
_model = None
def _run_qa_inference(context: str, question: str) -> dict:
"""Direct PyTorch inference — works with any transformers version."""
import torch
import torch.nn.functional as F
inputs = _tokenizer(
question, context,
return_tensors="pt",
truncation=True,
max_length=512,
)
with torch.no_grad():
outputs = _model(**inputs)
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
start_idx = int(torch.argmax(start_logits))
end_idx = int(torch.argmax(end_logits)) + 1
if end_idx <= start_idx:
end_idx = start_idx + 1
input_ids = inputs["input_ids"][0]
answer_tokens = input_ids[start_idx:end_idx]
answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip()
# BERT-Large Premium Calibration: High baseline + aggressive scaling.
import math
start_prob = float(F.softmax(start_logits, dim=0)[start_idx])
end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1])
avg_prob = (start_prob + end_prob) / 2
calibrated_score = 0.15 + (avg_prob ** 0.3) * 0.85
score = round(min(max(calibrated_score, 0.0), 0.99), 4)
# --- Strong Collision Filter: Block question repetition using word overlap ---
q_words = set(question.lower().replace('?', '').split())
a_words = set(answer.lower().replace('?', '').split())
# Calculate how much of the answer is just the question
if q_words and a_words:
common = q_words.intersection(a_words)
overlap_ratio = len(common) / len(q_words)
else:
overlap_ratio = 0
# If overlap is high (>70%), or answer is empty, or score is suspiciously low
if overlap_ratio > 0.7 or len(answer.strip()) < 1 or "[cls]" in answer.lower():
return {
"answer": "I'm sorry, I couldn't find a specific answer to that in the provided document.",
"score": 0.0,
"error": False,
"not_found": True
}
return {"answer": answer, "score": score}
def predict(context: str, question: str) -> dict:
"""
Run QA inference.
Returns:
{
"answer": str,
"score": float (0.0–1.0),
"model": "BERT",
"model_id": "bert"
}
"""
if _model is None or _tokenizer is None:
return {
"answer": "BERT model is not loaded. Please check server logs.",
"score": 0.0,
"model": "BERT",
"model_id": "bert",
"error": True,
}
if not context or not question:
return {
"answer": "Context and question must not be empty.",
"score": 0.0,
"model": "BERT",
"model_id": "bert",
"error": True,
}
try:
result = _run_qa_inference(context=context, question=question)
score = result["score"]
answer = result["answer"]
if score < 0.05 or "[CLS]" in answer or not answer:
answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context."
score = 0.0
return {
"answer": answer,
"score": score,
"model": "BERT",
"model_id": "bert",
"error": False,
}
except Exception as exc:
logger.error(f"[BERT] Inference error: {exc}")
return {
"answer": f"Inference error: {exc}",
"score": 0.0,
"model": "BERT",
"model_id": "bert",
"error": True,
}
|