""" bert_model.py — HuggingFace BERT Question Answering Model. Model: deepset/bert-base-cased-squad2 Uses direct PyTorch inference (compatible with transformers 5.x). """ import logging logger = logging.getLogger(__name__) _tokenizer = None _model = None MODEL_NAME = "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad" def init_bert_model(): """Load the BERT QA model. Called once at app startup.""" global _tokenizer, _model try: from transformers import AutoTokenizer, AutoModelForQuestionAnswering logger.info(f"[BERT] Loading model '{MODEL_NAME}' ...") _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME) _model.eval() logger.info("[BERT] Model loaded and ready.") except Exception as exc: logger.error(f"[BERT] Failed to load model: {exc}") _tokenizer = None _model = None def _run_qa_inference(context: str, question: str) -> dict: """Direct PyTorch inference — works with any transformers version.""" import torch import torch.nn.functional as F inputs = _tokenizer( question, context, return_tensors="pt", truncation=True, max_length=512, ) with torch.no_grad(): outputs = _model(**inputs) start_logits = outputs.start_logits[0] end_logits = outputs.end_logits[0] start_idx = int(torch.argmax(start_logits)) end_idx = int(torch.argmax(end_logits)) + 1 if end_idx <= start_idx: end_idx = start_idx + 1 input_ids = inputs["input_ids"][0] answer_tokens = input_ids[start_idx:end_idx] answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip() # BERT-Large Premium Calibration: High baseline + aggressive scaling. import math start_prob = float(F.softmax(start_logits, dim=0)[start_idx]) end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1]) avg_prob = (start_prob + end_prob) / 2 calibrated_score = 0.15 + (avg_prob ** 0.3) * 0.85 score = round(min(max(calibrated_score, 0.0), 0.99), 4) # --- Strong Collision Filter: Block question repetition using word overlap --- q_words = set(question.lower().replace('?', '').split()) a_words = set(answer.lower().replace('?', '').split()) # Calculate how much of the answer is just the question if q_words and a_words: common = q_words.intersection(a_words) overlap_ratio = len(common) / len(q_words) else: overlap_ratio = 0 # If overlap is high (>70%), or answer is empty, or score is suspiciously low if overlap_ratio > 0.7 or len(answer.strip()) < 1 or "[cls]" in answer.lower(): return { "answer": "I'm sorry, I couldn't find a specific answer to that in the provided document.", "score": 0.0, "error": False, "not_found": True } return {"answer": answer, "score": score} def predict(context: str, question: str) -> dict: """ Run QA inference. Returns: { "answer": str, "score": float (0.0–1.0), "model": "BERT", "model_id": "bert" } """ if _model is None or _tokenizer is None: return { "answer": "BERT model is not loaded. Please check server logs.", "score": 0.0, "model": "BERT", "model_id": "bert", "error": True, } if not context or not question: return { "answer": "Context and question must not be empty.", "score": 0.0, "model": "BERT", "model_id": "bert", "error": True, } try: result = _run_qa_inference(context=context, question=question) score = result["score"] answer = result["answer"] if score < 0.05 or "[CLS]" in answer or not answer: answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context." score = 0.0 return { "answer": answer, "score": score, "model": "BERT", "model_id": "bert", "error": False, } except Exception as exc: logger.error(f"[BERT] Inference error: {exc}") return { "answer": f"Inference error: {exc}", "score": 0.0, "model": "BERT", "model_id": "bert", "error": True, }