ecom-qa-bert / src /model.py
rnyx's picture
Initial deploy: BERT QA app
3338b6d
"""
BERT Question Answering — singleton model loader + inference.
Key improvements over the original:
* Thread-safe lazy singleton (load once per worker, guarded by a lock)
* Uses AutoTokenizer/AutoModel so the model is swappable via env var
* Optional warmup pass to eliminate first-request latency
* No silent re-loading; failures surface as real exceptions
"""
import logging
import threading
import time
from typing import Optional
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from . import config
logger = logging.getLogger(__name__)
_model: Optional[AutoModelForQuestionAnswering] = None
_tokenizer: Optional[AutoTokenizer] = None
_load_lock = threading.Lock()
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def init_model(warmup: bool = None) -> None:
"""Load the model + tokenizer exactly once. Safe to call repeatedly."""
global _model, _tokenizer
if _model is not None and _tokenizer is not None:
return
with _load_lock:
if _model is not None and _tokenizer is not None:
return # another thread won the race
start = time.time()
logger.info(f"Loading model '{config.MODEL_NAME}' on {_device}…")
_tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
_model = AutoModelForQuestionAnswering.from_pretrained(config.MODEL_NAME)
_model.to(_device)
_model.eval()
logger.info(f"Model loaded in {time.time() - start:.1f}s")
do_warmup = config.WARMUP_ON_START if warmup is None else warmup
if do_warmup:
try:
_warmup()
except Exception:
logger.warning("Warmup inference failed", exc_info=True)
def _warmup() -> None:
"""Run a tiny inference so the first real request isn't slow."""
t0 = time.time()
predict_qa(
question="What is this?",
context="This is a warmup context used to prime the model cache.",
_persist=False,
)
logger.info(f"Warmup inference completed in {(time.time() - t0) * 1000:.0f}ms")
def _require_model():
if _model is None or _tokenizer is None:
init_model()
def _truncate_context(context: str, max_chars: int) -> str:
if len(context) <= max_chars:
return context
truncated = context[:max_chars]
last_dot = truncated.rfind(".")
if last_dot > max_chars * 0.7:
truncated = truncated[: last_dot + 1]
return truncated
def predict_qa(question: str, context: str, _persist: bool = True) -> dict:
"""
Extract an answer span for `question` from `context` using BERT.
Returns a dict with the answer, confidence, token breakdown,
answer character offsets, and inference time — a superset of
what the frontend needs to render the full UI.
"""
_require_model()
question = (question or "").strip()
context = (context or "").strip()
if not question or not context:
raise ValueError("Both question and context are required.")
ctx = _truncate_context(context, config.MAX_CONTEXT_CHARS)
inputs = _tokenizer(
question,
ctx,
return_tensors="pt",
max_length=config.MAX_SEQ_LENGTH,
truncation="only_second",
return_offsets_mapping=True,
padding=False,
)
offset_mapping = inputs.pop("offset_mapping")[0].tolist()
inputs_on_device = {k: v.to(_device) for k, v in inputs.items()}
input_ids = inputs["input_ids"]
token_type_ids = inputs.get("token_type_ids")
tokens_raw = _tokenizer.convert_ids_to_tokens(input_ids[0])
t0 = time.time()
with torch.no_grad():
outputs = _model(**inputs_on_device)
inference_ms = int((time.time() - t0) * 1000)
# Back to CPU for post-processing
start_logits = outputs.start_logits[0].detach().cpu()
end_logits = outputs.end_logits[0].detach().cpu()
best_score, best_s, best_e = -float("inf"), 0, 0
k = min(5, start_logits.size(0))
top_starts = torch.topk(start_logits, k).indices.tolist()
top_ends = torch.topk(end_logits, k).indices.tolist()
for s in top_starts:
for e in top_ends:
if e < s or (e - s) >= 50:
continue
# Only accept spans that fall inside the context segment
if token_type_ids is not None and token_type_ids[0][s].item() != 1:
continue
score = start_logits[s].item() + end_logits[e].item()
if score > best_score:
best_score, best_s, best_e = score, s, e
if best_score == -float("inf"):
best_s = int(torch.argmax(start_logits).item())
best_e = int(torch.argmax(end_logits).item())
if best_e < best_s:
best_e = best_s
answer_ids = input_ids[0][best_s : best_e + 1]
answer = _tokenizer.decode(answer_ids, skip_special_tokens=True).strip()
if not answer:
answer = "(No answer found in the given context)"
s_probs = torch.softmax(start_logits, dim=0)
e_probs = torch.softmax(end_logits, dim=0)
conf = float(s_probs[best_s] * e_probs[best_e])
conf_level = "high" if conf > 0.6 else ("medium" if conf > 0.2 else "low")
tokens = []
for i, tok in enumerate(tokens_raw):
if tok in ("[CLS]", "[SEP]", "[PAD]"):
t = "special"
elif token_type_ids is not None and token_type_ids[0][i].item() == 0:
t = "question"
else:
t = "context"
if best_s <= i <= best_e and t == "context":
t = "answer"
tokens.append({"text": tok.replace("##", ""), "type": t})
ans_start_char, ans_end_char = -1, -1
if best_s < len(offset_mapping) and best_e < len(offset_mapping):
so, eo = offset_mapping[best_s], offset_mapping[best_e]
if so and eo:
ans_start_char, ans_end_char = int(so[0]), int(eo[1])
logger.info(
f"QA: q={question[:60]!r} → answer={answer[:60]!r} "
f"conf={conf:.3f} ({conf_level}) in {inference_ms}ms"
)
return {
"answer": answer,
"confidence": round(conf, 4),
"confidence_pct": f"{conf * 100:.1f}%",
"confidence_level": conf_level,
"answer_start_char": ans_start_char,
"answer_end_char": ans_end_char,
"context_used": ctx,
"tokens": tokens,
"num_tokens": len(tokens_raw),
"inference_time_ms": inference_ms,
}