Spaces:
Sleeping
Sleeping
File size: 5,973 Bytes
202ae51 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """
FILE 2: src/model.py β BERT QA Engine
=======================================
IMPORTED BY: app.py (calls init_model at startup, predict_qa per request)
IMPORTS: transformers (BertForQuestionAnswering, BertTokenizer)
torch, sklearn (not used here but available)
Functions:
init_model() β loads BERT into memory (called once at startup)
predict_qa(q, ctx) β runs extractive QA, returns answer dict
"""
import os
import torch
from transformers import BertForQuestionAnswering, BertTokenizer, BertConfig
import logging
import time
logger = logging.getLogger(__name__)
# ββ Global state (loaded once, reused for every request) ββ
model = None
tokenizer = None
MODEL_NAME = "deepset/bert-base-cased-squad2"
def init_model():
"""
Load BERT QA model + tokenizer into memory.
Called once by app.py at server startup.
First run downloads ~440MB from HuggingFace (cached after).
"""
global model, tokenizer
start = time.time()
logger.info(f"Loading model: {MODEL_NAME}")
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
config = BertConfig.from_pretrained(MODEL_NAME, output_hidden_states=False)
model = BertForQuestionAnswering.from_pretrained(MODEL_NAME, config=config)
model.eval() # Switch to inference mode (no dropout, no gradient tracking)
logger.info(f"Model loaded in {time.time() - start:.1f}s")
def predict_qa(question: str, context: str) -> dict:
"""
Run BERT extractive QA.
Called by: app.py β api_predict route
Input: question string + context string
Returns: {
"answer": "5000mAh",
"confidence": 0.912,
"confidence_pct": "91.2%",
"confidence_level": "high",
"answer_start_char": 562,
"answer_end_char": 569,
"context_used": "...",
"tokens": [{"text": "what", "type": "question"}, ...],
"num_tokens": 156,
"inference_time_ms": 287
}
This dict is sent as JSON to main.js which renders it in the UI.
"""
# ββ Truncate long context (BERT max = 512 tokens β 2500 chars) ββ
max_chars = 2500
ctx = context[:max_chars]
if len(context) > max_chars:
last_dot = ctx.rfind(".")
if last_dot > max_chars * 0.7:
ctx = ctx[:last_dot + 1]
# ββ Tokenize: [CLS] question [SEP] context [SEP] ββ
inputs = tokenizer(
question, ctx,
return_tensors="pt",
max_length=512,
truncation=True,
return_offsets_mapping=True,
)
offset_mapping = inputs.pop("offset_mapping")[0].tolist()
input_ids = inputs["input_ids"]
token_type_ids = inputs.get("token_type_ids")
tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0])
# ββ Run BERT forward pass ββ
t0 = time.time()
with torch.no_grad():
outputs = model(**inputs)
inference_ms = int((time.time() - t0) * 1000)
logger.info(f"Inference: {inference_ms}ms")
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
# ββ Find best valid answer span ββ
# Check top-5 start Γ top-5 end combinations
top_starts = torch.topk(start_logits, 5).indices.tolist()
top_ends = torch.topk(end_logits, 5).indices.tolist()
best_score, best_s, best_e = -float("inf"), 0, 0
for s in top_starts:
for e in top_ends:
if e >= s and (e - s) < 50:
# Must be in context segment (token_type_id == 1)
if token_type_ids is not None and token_type_ids[0][s].item() == 1:
score = start_logits[s].item() + end_logits[e].item()
if score > best_score:
best_score, best_s, best_e = score, s, e
# Fallback to raw argmax
if best_score == -float("inf"):
best_s = torch.argmax(start_logits).item()
best_e = torch.argmax(end_logits).item()
if best_e < best_s:
best_e = best_s
# ββ Decode answer text ββ
answer_ids = input_ids[0][best_s:best_e + 1]
answer = tokenizer.decode(answer_ids, skip_special_tokens=True).strip()
if not answer:
answer = "(No answer found in the given context)"
# ββ Confidence score ββ
s_probs = torch.softmax(start_logits, dim=0)
e_probs = torch.softmax(end_logits, dim=0)
conf = (s_probs[best_s] * e_probs[best_e]).item()
conf_level = "high" if conf > 0.6 else ("medium" if conf > 0.2 else "low")
# ββ Classify each token (question / context / answer / special) ββ
tokens = []
for i, tok in enumerate(tokens_raw):
if tok in ("[CLS]", "[SEP]", "[PAD]"):
t = "special"
elif token_type_ids is not None and token_type_ids[0][i].item() == 0:
t = "question"
else:
t = "context"
if best_s <= i <= best_e and t == "context":
t = "answer"
tokens.append({"text": tok.replace("##", ""), "type": t})
# ββ Character-level answer position (for highlighting in context) ββ
ans_start_char, ans_end_char = -1, -1
if best_s < len(offset_mapping) and best_e < len(offset_mapping):
so, eo = offset_mapping[best_s], offset_mapping[best_e]
if so and eo:
ans_start_char, ans_end_char = so[0], eo[1]
logger.info(f"Answer: '{answer}' | Confidence: {conf:.3f} ({conf_level})")
return {
"answer": answer,
"confidence": round(conf, 4),
"confidence_pct": f"{conf * 100:.1f}%",
"confidence_level": conf_level,
"answer_start_char": ans_start_char,
"answer_end_char": ans_end_char,
"context_used": ctx,
"tokens": tokens,
"num_tokens": len(tokens_raw),
"inference_time_ms": inference_ms,
}
|