""" bert_model.py — HuggingFace BERT Question Answering Model. Model: deepset/bert-base-cased-squad2 Uses direct PyTorch inference (compatible with transformers 5.x). """ import logging logger = logging.getLogger(__name__) _tokenizer = None _model = None MODEL_NAME = "deepset/bert-base-cased-squad2" def init_bert_model(): """Load the BERT QA model. Called once at app startup.""" global _tokenizer, _model try: from transformers import AutoTokenizer, AutoModelForQuestionAnswering logger.info(f"[BERT] Loading model '{MODEL_NAME}' ...") _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME) _model.eval() logger.info("[BERT] Model loaded and ready.") except Exception as exc: logger.error(f"[BERT] Failed to load model: {exc}") _tokenizer = None _model = None def _run_qa_inference(context: str, question: str) -> dict: """Direct PyTorch inference — works with any transformers version.""" import torch import torch.nn.functional as F inputs = _tokenizer( question, context, return_tensors="pt", truncation=True, max_length=512, ) with torch.no_grad(): outputs = _model(**inputs) start_logits = outputs.start_logits[0] end_logits = outputs.end_logits[0] start_idx = int(torch.argmax(start_logits)) end_idx = int(torch.argmax(end_logits)) + 1 if end_idx <= start_idx: end_idx = start_idx + 1 input_ids = inputs["input_ids"][0] answer_tokens = input_ids[start_idx:end_idx] answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip() # Confidence approximation via softmax start_prob = float(F.softmax(start_logits, dim=0)[start_idx]) end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1]) score = round((start_prob + end_prob) / 2, 4) return {"answer": answer, "score": score} def predict(context: str, question: str) -> dict: """ Run QA inference. Returns: { "answer": str, "score": float (0.0–1.0), "model": "BERT", "model_id": "bert" } """ if _model is None or _tokenizer is None: return { "answer": "BERT model is not loaded. Please check server logs.", "score": 0.0, "model": "BERT", "model_id": "bert", "error": True, } if not context or not question: return { "answer": "Context and question must not be empty.", "score": 0.0, "model": "BERT", "model_id": "bert", "error": True, } try: result = _run_qa_inference(context=context, question=question) score = result["score"] answer = result["answer"] if score < 0.05 or "[CLS]" in answer or not answer: answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context." score = 0.0 return { "answer": answer, "score": score, "model": "BERT", "model_id": "bert", "error": False, } except Exception as exc: logger.error(f"[BERT] Inference error: {exc}") return { "answer": f"Inference error: {exc}", "score": 0.0, "model": "BERT", "model_id": "bert", "error": True, }