| """ |
| bert_model.py — HuggingFace BERT Question Answering Model. |
| |
| Model: deepset/bert-base-cased-squad2 |
| Uses direct PyTorch inference (compatible with transformers 5.x). |
| """ |
|
|
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| _tokenizer = None |
| _model = None |
| MODEL_NAME = "deepset/bert-base-cased-squad2" |
|
|
|
|
| def init_bert_model(): |
| """Load the BERT QA model. Called once at app startup.""" |
| global _tokenizer, _model |
| try: |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering |
| logger.info(f"[BERT] Loading model '{MODEL_NAME}' ...") |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME) |
| _model.eval() |
| logger.info("[BERT] Model loaded and ready.") |
| except Exception as exc: |
| logger.error(f"[BERT] Failed to load model: {exc}") |
| _tokenizer = None |
| _model = None |
|
|
|
|
| def _run_qa_inference(context: str, question: str) -> dict: |
| """Direct PyTorch inference — works with any transformers version.""" |
| import torch |
| import torch.nn.functional as F |
|
|
| inputs = _tokenizer( |
| question, context, |
| return_tensors="pt", |
| truncation=True, |
| max_length=512, |
| ) |
|
|
| with torch.no_grad(): |
| outputs = _model(**inputs) |
|
|
| start_logits = outputs.start_logits[0] |
| end_logits = outputs.end_logits[0] |
|
|
| start_idx = int(torch.argmax(start_logits)) |
| end_idx = int(torch.argmax(end_logits)) + 1 |
|
|
| if end_idx <= start_idx: |
| end_idx = start_idx + 1 |
|
|
| input_ids = inputs["input_ids"][0] |
| answer_tokens = input_ids[start_idx:end_idx] |
| answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip() |
|
|
| |
| start_prob = float(F.softmax(start_logits, dim=0)[start_idx]) |
| end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1]) |
| score = round((start_prob + end_prob) / 2, 4) |
|
|
| return {"answer": answer, "score": score} |
|
|
|
|
| def predict(context: str, question: str) -> dict: |
| """ |
| Run QA inference. |
| |
| Returns: |
| { |
| "answer": str, |
| "score": float (0.0–1.0), |
| "model": "BERT", |
| "model_id": "bert" |
| } |
| """ |
| if _model is None or _tokenizer is None: |
| return { |
| "answer": "BERT model is not loaded. Please check server logs.", |
| "score": 0.0, |
| "model": "BERT", |
| "model_id": "bert", |
| "error": True, |
| } |
|
|
| if not context or not question: |
| return { |
| "answer": "Context and question must not be empty.", |
| "score": 0.0, |
| "model": "BERT", |
| "model_id": "bert", |
| "error": True, |
| } |
|
|
| try: |
| result = _run_qa_inference(context=context, question=question) |
| score = result["score"] |
| answer = result["answer"] |
|
|
| if score < 0.05 or "[CLS]" in answer or not answer: |
| answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context." |
| score = 0.0 |
|
|
| return { |
| "answer": answer, |
| "score": score, |
| "model": "BERT", |
| "model_id": "bert", |
| "error": False, |
| } |
| except Exception as exc: |
| logger.error(f"[BERT] Inference error: {exc}") |
| return { |
| "answer": f"Inference error: {exc}", |
| "score": 0.0, |
| "model": "BERT", |
| "model_id": "bert", |
| "error": True, |
| } |
|
|