Spaces:
Running
Running
File size: 6,416 Bytes
3338b6d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """
BERT Question Answering — singleton model loader + inference.
Key improvements over the original:
* Thread-safe lazy singleton (load once per worker, guarded by a lock)
* Uses AutoTokenizer/AutoModel so the model is swappable via env var
* Optional warmup pass to eliminate first-request latency
* No silent re-loading; failures surface as real exceptions
"""
import logging
import threading
import time
from typing import Optional
import torch
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
from . import config
logger = logging.getLogger(__name__)
_model: Optional[AutoModelForQuestionAnswering] = None
_tokenizer: Optional[AutoTokenizer] = None
_load_lock = threading.Lock()
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def init_model(warmup: bool = None) -> None:
"""Load the model + tokenizer exactly once. Safe to call repeatedly."""
global _model, _tokenizer
if _model is not None and _tokenizer is not None:
return
with _load_lock:
if _model is not None and _tokenizer is not None:
return # another thread won the race
start = time.time()
logger.info(f"Loading model '{config.MODEL_NAME}' on {_device}…")
_tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
_model = AutoModelForQuestionAnswering.from_pretrained(config.MODEL_NAME)
_model.to(_device)
_model.eval()
logger.info(f"Model loaded in {time.time() - start:.1f}s")
do_warmup = config.WARMUP_ON_START if warmup is None else warmup
if do_warmup:
try:
_warmup()
except Exception:
logger.warning("Warmup inference failed", exc_info=True)
def _warmup() -> None:
"""Run a tiny inference so the first real request isn't slow."""
t0 = time.time()
predict_qa(
question="What is this?",
context="This is a warmup context used to prime the model cache.",
_persist=False,
)
logger.info(f"Warmup inference completed in {(time.time() - t0) * 1000:.0f}ms")
def _require_model():
if _model is None or _tokenizer is None:
init_model()
def _truncate_context(context: str, max_chars: int) -> str:
if len(context) <= max_chars:
return context
truncated = context[:max_chars]
last_dot = truncated.rfind(".")
if last_dot > max_chars * 0.7:
truncated = truncated[: last_dot + 1]
return truncated
def predict_qa(question: str, context: str, _persist: bool = True) -> dict:
"""
Extract an answer span for `question` from `context` using BERT.
Returns a dict with the answer, confidence, token breakdown,
answer character offsets, and inference time — a superset of
what the frontend needs to render the full UI.
"""
_require_model()
question = (question or "").strip()
context = (context or "").strip()
if not question or not context:
raise ValueError("Both question and context are required.")
ctx = _truncate_context(context, config.MAX_CONTEXT_CHARS)
inputs = _tokenizer(
question,
ctx,
return_tensors="pt",
max_length=config.MAX_SEQ_LENGTH,
truncation="only_second",
return_offsets_mapping=True,
padding=False,
)
offset_mapping = inputs.pop("offset_mapping")[0].tolist()
inputs_on_device = {k: v.to(_device) for k, v in inputs.items()}
input_ids = inputs["input_ids"]
token_type_ids = inputs.get("token_type_ids")
tokens_raw = _tokenizer.convert_ids_to_tokens(input_ids[0])
t0 = time.time()
with torch.no_grad():
outputs = _model(**inputs_on_device)
inference_ms = int((time.time() - t0) * 1000)
# Back to CPU for post-processing
start_logits = outputs.start_logits[0].detach().cpu()
end_logits = outputs.end_logits[0].detach().cpu()
best_score, best_s, best_e = -float("inf"), 0, 0
k = min(5, start_logits.size(0))
top_starts = torch.topk(start_logits, k).indices.tolist()
top_ends = torch.topk(end_logits, k).indices.tolist()
for s in top_starts:
for e in top_ends:
if e < s or (e - s) >= 50:
continue
# Only accept spans that fall inside the context segment
if token_type_ids is not None and token_type_ids[0][s].item() != 1:
continue
score = start_logits[s].item() + end_logits[e].item()
if score > best_score:
best_score, best_s, best_e = score, s, e
if best_score == -float("inf"):
best_s = int(torch.argmax(start_logits).item())
best_e = int(torch.argmax(end_logits).item())
if best_e < best_s:
best_e = best_s
answer_ids = input_ids[0][best_s : best_e + 1]
answer = _tokenizer.decode(answer_ids, skip_special_tokens=True).strip()
if not answer:
answer = "(No answer found in the given context)"
s_probs = torch.softmax(start_logits, dim=0)
e_probs = torch.softmax(end_logits, dim=0)
conf = float(s_probs[best_s] * e_probs[best_e])
conf_level = "high" if conf > 0.6 else ("medium" if conf > 0.2 else "low")
tokens = []
for i, tok in enumerate(tokens_raw):
if tok in ("[CLS]", "[SEP]", "[PAD]"):
t = "special"
elif token_type_ids is not None and token_type_ids[0][i].item() == 0:
t = "question"
else:
t = "context"
if best_s <= i <= best_e and t == "context":
t = "answer"
tokens.append({"text": tok.replace("##", ""), "type": t})
ans_start_char, ans_end_char = -1, -1
if best_s < len(offset_mapping) and best_e < len(offset_mapping):
so, eo = offset_mapping[best_s], offset_mapping[best_e]
if so and eo:
ans_start_char, ans_end_char = int(so[0]), int(eo[1])
logger.info(
f"QA: q={question[:60]!r} → answer={answer[:60]!r} "
f"conf={conf:.3f} ({conf_level}) in {inference_ms}ms"
)
return {
"answer": answer,
"confidence": round(conf, 4),
"confidence_pct": f"{conf * 100:.1f}%",
"confidence_level": conf_level,
"answer_start_char": ans_start_char,
"answer_end_char": ans_end_char,
"context_used": ctx,
"tokens": tokens,
"num_tokens": len(tokens_raw),
"inference_time_ms": inference_ms,
}
|