Spaces:
Sleeping
Sleeping
| """ | |
| BERT Question Answering — singleton model loader + inference. | |
| Key improvements over the original: | |
| * Thread-safe lazy singleton (load once per worker, guarded by a lock) | |
| * Uses AutoTokenizer/AutoModel so the model is swappable via env var | |
| * Optional warmup pass to eliminate first-request latency | |
| * No silent re-loading; failures surface as real exceptions | |
| """ | |
| import logging | |
| import threading | |
| import time | |
| from typing import Optional | |
| import torch | |
| from transformers import AutoModelForQuestionAnswering, AutoTokenizer | |
| from . import config | |
| logger = logging.getLogger(__name__) | |
| _model: Optional[AutoModelForQuestionAnswering] = None | |
| _tokenizer: Optional[AutoTokenizer] = None | |
| _load_lock = threading.Lock() | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def init_model(warmup: bool = None) -> None: | |
| """Load the model + tokenizer exactly once. Safe to call repeatedly.""" | |
| global _model, _tokenizer | |
| if _model is not None and _tokenizer is not None: | |
| return | |
| with _load_lock: | |
| if _model is not None and _tokenizer is not None: | |
| return # another thread won the race | |
| start = time.time() | |
| logger.info(f"Loading model '{config.MODEL_NAME}' on {_device}…") | |
| _tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME) | |
| _model = AutoModelForQuestionAnswering.from_pretrained(config.MODEL_NAME) | |
| _model.to(_device) | |
| _model.eval() | |
| logger.info(f"Model loaded in {time.time() - start:.1f}s") | |
| do_warmup = config.WARMUP_ON_START if warmup is None else warmup | |
| if do_warmup: | |
| try: | |
| _warmup() | |
| except Exception: | |
| logger.warning("Warmup inference failed", exc_info=True) | |
| def _warmup() -> None: | |
| """Run a tiny inference so the first real request isn't slow.""" | |
| t0 = time.time() | |
| predict_qa( | |
| question="What is this?", | |
| context="This is a warmup context used to prime the model cache.", | |
| _persist=False, | |
| ) | |
| logger.info(f"Warmup inference completed in {(time.time() - t0) * 1000:.0f}ms") | |
| def _require_model(): | |
| if _model is None or _tokenizer is None: | |
| init_model() | |
| def _truncate_context(context: str, max_chars: int) -> str: | |
| if len(context) <= max_chars: | |
| return context | |
| truncated = context[:max_chars] | |
| last_dot = truncated.rfind(".") | |
| if last_dot > max_chars * 0.7: | |
| truncated = truncated[: last_dot + 1] | |
| return truncated | |
| def predict_qa(question: str, context: str, _persist: bool = True) -> dict: | |
| """ | |
| Extract an answer span for `question` from `context` using BERT. | |
| Returns a dict with the answer, confidence, token breakdown, | |
| answer character offsets, and inference time — a superset of | |
| what the frontend needs to render the full UI. | |
| """ | |
| _require_model() | |
| question = (question or "").strip() | |
| context = (context or "").strip() | |
| if not question or not context: | |
| raise ValueError("Both question and context are required.") | |
| ctx = _truncate_context(context, config.MAX_CONTEXT_CHARS) | |
| inputs = _tokenizer( | |
| question, | |
| ctx, | |
| return_tensors="pt", | |
| max_length=config.MAX_SEQ_LENGTH, | |
| truncation="only_second", | |
| return_offsets_mapping=True, | |
| padding=False, | |
| ) | |
| offset_mapping = inputs.pop("offset_mapping")[0].tolist() | |
| inputs_on_device = {k: v.to(_device) for k, v in inputs.items()} | |
| input_ids = inputs["input_ids"] | |
| token_type_ids = inputs.get("token_type_ids") | |
| tokens_raw = _tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| t0 = time.time() | |
| with torch.no_grad(): | |
| outputs = _model(**inputs_on_device) | |
| inference_ms = int((time.time() - t0) * 1000) | |
| # Back to CPU for post-processing | |
| start_logits = outputs.start_logits[0].detach().cpu() | |
| end_logits = outputs.end_logits[0].detach().cpu() | |
| best_score, best_s, best_e = -float("inf"), 0, 0 | |
| k = min(5, start_logits.size(0)) | |
| top_starts = torch.topk(start_logits, k).indices.tolist() | |
| top_ends = torch.topk(end_logits, k).indices.tolist() | |
| for s in top_starts: | |
| for e in top_ends: | |
| if e < s or (e - s) >= 50: | |
| continue | |
| # Only accept spans that fall inside the context segment | |
| if token_type_ids is not None and token_type_ids[0][s].item() != 1: | |
| continue | |
| score = start_logits[s].item() + end_logits[e].item() | |
| if score > best_score: | |
| best_score, best_s, best_e = score, s, e | |
| if best_score == -float("inf"): | |
| best_s = int(torch.argmax(start_logits).item()) | |
| best_e = int(torch.argmax(end_logits).item()) | |
| if best_e < best_s: | |
| best_e = best_s | |
| answer_ids = input_ids[0][best_s : best_e + 1] | |
| answer = _tokenizer.decode(answer_ids, skip_special_tokens=True).strip() | |
| if not answer: | |
| answer = "(No answer found in the given context)" | |
| s_probs = torch.softmax(start_logits, dim=0) | |
| e_probs = torch.softmax(end_logits, dim=0) | |
| conf = float(s_probs[best_s] * e_probs[best_e]) | |
| conf_level = "high" if conf > 0.6 else ("medium" if conf > 0.2 else "low") | |
| tokens = [] | |
| for i, tok in enumerate(tokens_raw): | |
| if tok in ("[CLS]", "[SEP]", "[PAD]"): | |
| t = "special" | |
| elif token_type_ids is not None and token_type_ids[0][i].item() == 0: | |
| t = "question" | |
| else: | |
| t = "context" | |
| if best_s <= i <= best_e and t == "context": | |
| t = "answer" | |
| tokens.append({"text": tok.replace("##", ""), "type": t}) | |
| ans_start_char, ans_end_char = -1, -1 | |
| if best_s < len(offset_mapping) and best_e < len(offset_mapping): | |
| so, eo = offset_mapping[best_s], offset_mapping[best_e] | |
| if so and eo: | |
| ans_start_char, ans_end_char = int(so[0]), int(eo[1]) | |
| logger.info( | |
| f"QA: q={question[:60]!r} → answer={answer[:60]!r} " | |
| f"conf={conf:.3f} ({conf_level}) in {inference_ms}ms" | |
| ) | |
| return { | |
| "answer": answer, | |
| "confidence": round(conf, 4), | |
| "confidence_pct": f"{conf * 100:.1f}%", | |
| "confidence_level": conf_level, | |
| "answer_start_char": ans_start_char, | |
| "answer_end_char": ans_end_char, | |
| "context_used": ctx, | |
| "tokens": tokens, | |
| "num_tokens": len(tokens_raw), | |
| "inference_time_ms": inference_ms, | |
| } | |