""" FILE 2: src/model.py — BERT QA Engine ======================================= IMPORTED BY: app.py (calls init_model at startup, predict_qa per request) IMPORTS: transformers (BertForQuestionAnswering, BertTokenizer) torch, sklearn (not used here but available) Functions: init_model() → loads BERT into memory (called once at startup) predict_qa(q, ctx) → runs extractive QA, returns answer dict """ import os import torch from transformers import BertForQuestionAnswering, BertTokenizer, BertConfig import logging import time logger = logging.getLogger(__name__) # ── Global state (loaded once, reused for every request) ── model = None tokenizer = None MODEL_NAME = "deepset/bert-base-cased-squad2" def init_model(): """ Load BERT QA model + tokenizer into memory. Called once by app.py at server startup. First run downloads ~440MB from HuggingFace (cached after). """ global model, tokenizer start = time.time() logger.info(f"Loading model: {MODEL_NAME}") tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) config = BertConfig.from_pretrained(MODEL_NAME, output_hidden_states=False) model = BertForQuestionAnswering.from_pretrained(MODEL_NAME, config=config) model.eval() # Switch to inference mode (no dropout, no gradient tracking) logger.info(f"Model loaded in {time.time() - start:.1f}s") def predict_qa(question: str, context: str) -> dict: """ Run BERT extractive QA. Called by: app.py → api_predict route Input: question string + context string Returns: { "answer": "5000mAh", "confidence": 0.912, "confidence_pct": "91.2%", "confidence_level": "high", "answer_start_char": 562, "answer_end_char": 569, "context_used": "...", "tokens": [{"text": "what", "type": "question"}, ...], "num_tokens": 156, "inference_time_ms": 287 } This dict is sent as JSON to main.js which renders it in the UI. """ # ── Truncate long context (BERT max = 512 tokens ≈ 2500 chars) ── max_chars = 2500 ctx = context[:max_chars] if len(context) > max_chars: last_dot = ctx.rfind(".") if last_dot > max_chars * 0.7: ctx = ctx[:last_dot + 1] # ── Tokenize: [CLS] question [SEP] context [SEP] ── inputs = tokenizer( question, ctx, return_tensors="pt", max_length=512, truncation=True, return_offsets_mapping=True, ) offset_mapping = inputs.pop("offset_mapping")[0].tolist() input_ids = inputs["input_ids"] token_type_ids = inputs.get("token_type_ids") tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0]) # ── Run BERT forward pass ── t0 = time.time() with torch.no_grad(): outputs = model(**inputs) inference_ms = int((time.time() - t0) * 1000) logger.info(f"Inference: {inference_ms}ms") start_logits = outputs.start_logits[0] end_logits = outputs.end_logits[0] # ── Find best valid answer span ── # Check top-5 start × top-5 end combinations top_starts = torch.topk(start_logits, 5).indices.tolist() top_ends = torch.topk(end_logits, 5).indices.tolist() best_score, best_s, best_e = -float("inf"), 0, 0 for s in top_starts: for e in top_ends: if e >= s and (e - s) < 50: # Must be in context segment (token_type_id == 1) if token_type_ids is not None and token_type_ids[0][s].item() == 1: score = start_logits[s].item() + end_logits[e].item() if score > best_score: best_score, best_s, best_e = score, s, e # Fallback to raw argmax if best_score == -float("inf"): best_s = torch.argmax(start_logits).item() best_e = torch.argmax(end_logits).item() if best_e < best_s: best_e = best_s # ── Decode answer text ── answer_ids = input_ids[0][best_s:best_e + 1] answer = tokenizer.decode(answer_ids, skip_special_tokens=True).strip() if not answer: answer = "(No answer found in the given context)" # ── Confidence score ── s_probs = torch.softmax(start_logits, dim=0) e_probs = torch.softmax(end_logits, dim=0) conf = (s_probs[best_s] * e_probs[best_e]).item() conf_level = "high" if conf > 0.6 else ("medium" if conf > 0.2 else "low") # ── Classify each token (question / context / answer / special) ── tokens = [] for i, tok in enumerate(tokens_raw): if tok in ("[CLS]", "[SEP]", "[PAD]"): t = "special" elif token_type_ids is not None and token_type_ids[0][i].item() == 0: t = "question" else: t = "context" if best_s <= i <= best_e and t == "context": t = "answer" tokens.append({"text": tok.replace("##", ""), "type": t}) # ── Character-level answer position (for highlighting in context) ── ans_start_char, ans_end_char = -1, -1 if best_s < len(offset_mapping) and best_e < len(offset_mapping): so, eo = offset_mapping[best_s], offset_mapping[best_e] if so and eo: ans_start_char, ans_end_char = so[0], eo[1] logger.info(f"Answer: '{answer}' | Confidence: {conf:.3f} ({conf_level})") return { "answer": answer, "confidence": round(conf, 4), "confidence_pct": f"{conf * 100:.1f}%", "confidence_level": conf_level, "answer_start_char": ans_start_char, "answer_end_char": ans_end_char, "context_used": ctx, "tokens": tokens, "num_tokens": len(tokens_raw), "inference_time_ms": inference_ms, }