File size: 3,568 Bytes
09daf0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
bert_model.py — HuggingFace BERT Question Answering Model.

Model: deepset/bert-base-cased-squad2
Uses direct PyTorch inference (compatible with transformers 5.x).
"""

import logging

logger = logging.getLogger(__name__)

_tokenizer = None
_model = None
MODEL_NAME = "deepset/bert-base-cased-squad2"


def init_bert_model():
    """Load the BERT QA model. Called once at app startup."""
    global _tokenizer, _model
    try:
        from transformers import AutoTokenizer, AutoModelForQuestionAnswering
        logger.info(f"[BERT] Loading model '{MODEL_NAME}' ...")
        _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
        _model.eval()
        logger.info("[BERT] Model loaded and ready.")
    except Exception as exc:
        logger.error(f"[BERT] Failed to load model: {exc}")
        _tokenizer = None
        _model = None


def _run_qa_inference(context: str, question: str) -> dict:
    """Direct PyTorch inference — works with any transformers version."""
    import torch
    import torch.nn.functional as F

    inputs = _tokenizer(
        question, context,
        return_tensors="pt",
        truncation=True,
        max_length=512,
    )

    with torch.no_grad():
        outputs = _model(**inputs)

    start_logits = outputs.start_logits[0]
    end_logits   = outputs.end_logits[0]

    start_idx = int(torch.argmax(start_logits))
    end_idx   = int(torch.argmax(end_logits)) + 1

    if end_idx <= start_idx:
        end_idx = start_idx + 1

    input_ids = inputs["input_ids"][0]
    answer_tokens = input_ids[start_idx:end_idx]
    answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip()

    # Confidence approximation via softmax
    start_prob = float(F.softmax(start_logits, dim=0)[start_idx])
    end_prob   = float(F.softmax(end_logits,   dim=0)[end_idx - 1])
    score = round((start_prob + end_prob) / 2, 4)

    return {"answer": answer, "score": score}


def predict(context: str, question: str) -> dict:
    """
    Run QA inference.

    Returns:
        {
            "answer": str,
            "score": float (0.0–1.0),
            "model": "BERT",
            "model_id": "bert"
        }
    """
    if _model is None or _tokenizer is None:
        return {
            "answer": "BERT model is not loaded. Please check server logs.",
            "score": 0.0,
            "model": "BERT",
            "model_id": "bert",
            "error": True,
        }

    if not context or not question:
        return {
            "answer": "Context and question must not be empty.",
            "score": 0.0,
            "model": "BERT",
            "model_id": "bert",
            "error": True,
        }

    try:
        result = _run_qa_inference(context=context, question=question)
        score = result["score"]
        answer = result["answer"]

        if score < 0.05 or "[CLS]" in answer or not answer:
            answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context."
            score = 0.0

        return {
            "answer": answer,
            "score": score,
            "model": "BERT",
            "model_id": "bert",
            "error": False,
        }
    except Exception as exc:
        logger.error(f"[BERT] Inference error: {exc}")
        return {
            "answer": f"Inference error: {exc}",
            "score": 0.0,
            "model": "BERT",
            "model_id": "bert",
            "error": True,
        }