ecom-qa-bert_f / src /model.py
rnyx's picture
Upload 2 files
202ae51 verified
"""
FILE 2: src/model.py β€” BERT QA Engine
=======================================
IMPORTED BY: app.py (calls init_model at startup, predict_qa per request)
IMPORTS: transformers (BertForQuestionAnswering, BertTokenizer)
torch, sklearn (not used here but available)
Functions:
init_model() β†’ loads BERT into memory (called once at startup)
predict_qa(q, ctx) β†’ runs extractive QA, returns answer dict
"""
import os
import torch
from transformers import BertForQuestionAnswering, BertTokenizer, BertConfig
import logging
import time
logger = logging.getLogger(__name__)
# ── Global state (loaded once, reused for every request) ──
model = None
tokenizer = None
MODEL_NAME = "deepset/bert-base-cased-squad2"
def init_model():
"""
Load BERT QA model + tokenizer into memory.
Called once by app.py at server startup.
First run downloads ~440MB from HuggingFace (cached after).
"""
global model, tokenizer
start = time.time()
logger.info(f"Loading model: {MODEL_NAME}")
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
config = BertConfig.from_pretrained(MODEL_NAME, output_hidden_states=False)
model = BertForQuestionAnswering.from_pretrained(MODEL_NAME, config=config)
model.eval() # Switch to inference mode (no dropout, no gradient tracking)
logger.info(f"Model loaded in {time.time() - start:.1f}s")
def predict_qa(question: str, context: str) -> dict:
"""
Run BERT extractive QA.
Called by: app.py β†’ api_predict route
Input: question string + context string
Returns: {
"answer": "5000mAh",
"confidence": 0.912,
"confidence_pct": "91.2%",
"confidence_level": "high",
"answer_start_char": 562,
"answer_end_char": 569,
"context_used": "...",
"tokens": [{"text": "what", "type": "question"}, ...],
"num_tokens": 156,
"inference_time_ms": 287
}
This dict is sent as JSON to main.js which renders it in the UI.
"""
# ── Truncate long context (BERT max = 512 tokens β‰ˆ 2500 chars) ──
max_chars = 2500
ctx = context[:max_chars]
if len(context) > max_chars:
last_dot = ctx.rfind(".")
if last_dot > max_chars * 0.7:
ctx = ctx[:last_dot + 1]
# ── Tokenize: [CLS] question [SEP] context [SEP] ──
inputs = tokenizer(
question, ctx,
return_tensors="pt",
max_length=512,
truncation=True,
return_offsets_mapping=True,
)
offset_mapping = inputs.pop("offset_mapping")[0].tolist()
input_ids = inputs["input_ids"]
token_type_ids = inputs.get("token_type_ids")
tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0])
# ── Run BERT forward pass ──
t0 = time.time()
with torch.no_grad():
outputs = model(**inputs)
inference_ms = int((time.time() - t0) * 1000)
logger.info(f"Inference: {inference_ms}ms")
start_logits = outputs.start_logits[0]
end_logits = outputs.end_logits[0]
# ── Find best valid answer span ──
# Check top-5 start Γ— top-5 end combinations
top_starts = torch.topk(start_logits, 5).indices.tolist()
top_ends = torch.topk(end_logits, 5).indices.tolist()
best_score, best_s, best_e = -float("inf"), 0, 0
for s in top_starts:
for e in top_ends:
if e >= s and (e - s) < 50:
# Must be in context segment (token_type_id == 1)
if token_type_ids is not None and token_type_ids[0][s].item() == 1:
score = start_logits[s].item() + end_logits[e].item()
if score > best_score:
best_score, best_s, best_e = score, s, e
# Fallback to raw argmax
if best_score == -float("inf"):
best_s = torch.argmax(start_logits).item()
best_e = torch.argmax(end_logits).item()
if best_e < best_s:
best_e = best_s
# ── Decode answer text ──
answer_ids = input_ids[0][best_s:best_e + 1]
answer = tokenizer.decode(answer_ids, skip_special_tokens=True).strip()
if not answer:
answer = "(No answer found in the given context)"
# ── Confidence score ──
s_probs = torch.softmax(start_logits, dim=0)
e_probs = torch.softmax(end_logits, dim=0)
conf = (s_probs[best_s] * e_probs[best_e]).item()
conf_level = "high" if conf > 0.6 else ("medium" if conf > 0.2 else "low")
# ── Classify each token (question / context / answer / special) ──
tokens = []
for i, tok in enumerate(tokens_raw):
if tok in ("[CLS]", "[SEP]", "[PAD]"):
t = "special"
elif token_type_ids is not None and token_type_ids[0][i].item() == 0:
t = "question"
else:
t = "context"
if best_s <= i <= best_e and t == "context":
t = "answer"
tokens.append({"text": tok.replace("##", ""), "type": t})
# ── Character-level answer position (for highlighting in context) ──
ans_start_char, ans_end_char = -1, -1
if best_s < len(offset_mapping) and best_e < len(offset_mapping):
so, eo = offset_mapping[best_s], offset_mapping[best_e]
if so and eo:
ans_start_char, ans_end_char = so[0], eo[1]
logger.info(f"Answer: '{answer}' | Confidence: {conf:.3f} ({conf_level})")
return {
"answer": answer,
"confidence": round(conf, 4),
"confidence_pct": f"{conf * 100:.1f}%",
"confidence_level": conf_level,
"answer_start_char": ans_start_char,
"answer_end_char": ans_end_char,
"context_used": ctx,
"tokens": tokens,
"num_tokens": len(tokens_raw),
"inference_time_ms": inference_ms,
}