File size: 5,973 Bytes
202ae51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""

FILE 2: src/model.py β€” BERT QA Engine

=======================================

IMPORTED BY: app.py (calls init_model at startup, predict_qa per request)

IMPORTS:     transformers (BertForQuestionAnswering, BertTokenizer)

             torch, sklearn (not used here but available)



Functions:

  init_model()          β†’ loads BERT into memory (called once at startup)

  predict_qa(q, ctx)    β†’ runs extractive QA, returns answer dict

"""

import os
import torch
from transformers import BertForQuestionAnswering, BertTokenizer, BertConfig
import logging
import time

logger = logging.getLogger(__name__)

# ── Global state (loaded once, reused for every request) ──
model = None
tokenizer = None
MODEL_NAME = "deepset/bert-base-cased-squad2"


def init_model():
    """

    Load BERT QA model + tokenizer into memory.

    Called once by app.py at server startup.

    First run downloads ~440MB from HuggingFace (cached after).

    """
    global model, tokenizer
    start = time.time()
    logger.info(f"Loading model: {MODEL_NAME}")

    tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
    config = BertConfig.from_pretrained(MODEL_NAME, output_hidden_states=False)
    model = BertForQuestionAnswering.from_pretrained(MODEL_NAME, config=config)
    model.eval()  # Switch to inference mode (no dropout, no gradient tracking)

    logger.info(f"Model loaded in {time.time() - start:.1f}s")


def predict_qa(question: str, context: str) -> dict:
    """

    Run BERT extractive QA.



    Called by: app.py β†’ api_predict route

    Input:    question string + context string

    Returns:  {

        "answer": "5000mAh",

        "confidence": 0.912,

        "confidence_pct": "91.2%",

        "confidence_level": "high",

        "answer_start_char": 562,

        "answer_end_char": 569,

        "context_used": "...",

        "tokens": [{"text": "what", "type": "question"}, ...],

        "num_tokens": 156,

        "inference_time_ms": 287

    }



    This dict is sent as JSON to main.js which renders it in the UI.

    """
    # ── Truncate long context (BERT max = 512 tokens β‰ˆ 2500 chars) ──
    max_chars = 2500
    ctx = context[:max_chars]
    if len(context) > max_chars:
        last_dot = ctx.rfind(".")
        if last_dot > max_chars * 0.7:
            ctx = ctx[:last_dot + 1]

    # ── Tokenize: [CLS] question [SEP] context [SEP] ──
    inputs = tokenizer(
        question, ctx,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        return_offsets_mapping=True,
    )
    offset_mapping = inputs.pop("offset_mapping")[0].tolist()
    input_ids = inputs["input_ids"]
    token_type_ids = inputs.get("token_type_ids")
    tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0])

    # ── Run BERT forward pass ──
    t0 = time.time()
    with torch.no_grad():
        outputs = model(**inputs)
    inference_ms = int((time.time() - t0) * 1000)
    logger.info(f"Inference: {inference_ms}ms")

    start_logits = outputs.start_logits[0]
    end_logits = outputs.end_logits[0]

    # ── Find best valid answer span ──
    # Check top-5 start Γ— top-5 end combinations
    top_starts = torch.topk(start_logits, 5).indices.tolist()
    top_ends = torch.topk(end_logits, 5).indices.tolist()

    best_score, best_s, best_e = -float("inf"), 0, 0
    for s in top_starts:
        for e in top_ends:
            if e >= s and (e - s) < 50:
                # Must be in context segment (token_type_id == 1)
                if token_type_ids is not None and token_type_ids[0][s].item() == 1:
                    score = start_logits[s].item() + end_logits[e].item()
                    if score > best_score:
                        best_score, best_s, best_e = score, s, e

    # Fallback to raw argmax
    if best_score == -float("inf"):
        best_s = torch.argmax(start_logits).item()
        best_e = torch.argmax(end_logits).item()
        if best_e < best_s:
            best_e = best_s

    # ── Decode answer text ──
    answer_ids = input_ids[0][best_s:best_e + 1]
    answer = tokenizer.decode(answer_ids, skip_special_tokens=True).strip()
    if not answer:
        answer = "(No answer found in the given context)"

    # ── Confidence score ──
    s_probs = torch.softmax(start_logits, dim=0)
    e_probs = torch.softmax(end_logits, dim=0)
    conf = (s_probs[best_s] * e_probs[best_e]).item()
    conf_level = "high" if conf > 0.6 else ("medium" if conf > 0.2 else "low")

    # ── Classify each token (question / context / answer / special) ──
    tokens = []
    for i, tok in enumerate(tokens_raw):
        if tok in ("[CLS]", "[SEP]", "[PAD]"):
            t = "special"
        elif token_type_ids is not None and token_type_ids[0][i].item() == 0:
            t = "question"
        else:
            t = "context"
        if best_s <= i <= best_e and t == "context":
            t = "answer"
        tokens.append({"text": tok.replace("##", ""), "type": t})

    # ── Character-level answer position (for highlighting in context) ──
    ans_start_char, ans_end_char = -1, -1
    if best_s < len(offset_mapping) and best_e < len(offset_mapping):
        so, eo = offset_mapping[best_s], offset_mapping[best_e]
        if so and eo:
            ans_start_char, ans_end_char = so[0], eo[1]

    logger.info(f"Answer: '{answer}' | Confidence: {conf:.3f} ({conf_level})")

    return {
        "answer": answer,
        "confidence": round(conf, 4),
        "confidence_pct": f"{conf * 100:.1f}%",
        "confidence_level": conf_level,
        "answer_start_char": ans_start_char,
        "answer_end_char": ans_end_char,
        "context_used": ctx,
        "tokens": tokens,
        "num_tokens": len(tokens_raw),
        "inference_time_ms": inference_ms,
    }