Spaces:
Sleeping
Sleeping
| """ | |
| FILE 2: src/model.py β BERT QA Engine | |
| ======================================= | |
| IMPORTED BY: app.py (calls init_model at startup, predict_qa per request) | |
| IMPORTS: transformers (BertForQuestionAnswering, BertTokenizer) | |
| torch, sklearn (not used here but available) | |
| Functions: | |
| init_model() β loads BERT into memory (called once at startup) | |
| predict_qa(q, ctx) β runs extractive QA, returns answer dict | |
| """ | |
| import os | |
| import torch | |
| from transformers import BertForQuestionAnswering, BertTokenizer, BertConfig | |
| import logging | |
| import time | |
| logger = logging.getLogger(__name__) | |
| # ββ Global state (loaded once, reused for every request) ββ | |
| model = None | |
| tokenizer = None | |
| MODEL_NAME = "deepset/bert-base-cased-squad2" | |
| def init_model(): | |
| """ | |
| Load BERT QA model + tokenizer into memory. | |
| Called once by app.py at server startup. | |
| First run downloads ~440MB from HuggingFace (cached after). | |
| """ | |
| global model, tokenizer | |
| start = time.time() | |
| logger.info(f"Loading model: {MODEL_NAME}") | |
| tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) | |
| config = BertConfig.from_pretrained(MODEL_NAME, output_hidden_states=False) | |
| model = BertForQuestionAnswering.from_pretrained(MODEL_NAME, config=config) | |
| model.eval() # Switch to inference mode (no dropout, no gradient tracking) | |
| logger.info(f"Model loaded in {time.time() - start:.1f}s") | |
| def predict_qa(question: str, context: str) -> dict: | |
| """ | |
| Run BERT extractive QA. | |
| Called by: app.py β api_predict route | |
| Input: question string + context string | |
| Returns: { | |
| "answer": "5000mAh", | |
| "confidence": 0.912, | |
| "confidence_pct": "91.2%", | |
| "confidence_level": "high", | |
| "answer_start_char": 562, | |
| "answer_end_char": 569, | |
| "context_used": "...", | |
| "tokens": [{"text": "what", "type": "question"}, ...], | |
| "num_tokens": 156, | |
| "inference_time_ms": 287 | |
| } | |
| This dict is sent as JSON to main.js which renders it in the UI. | |
| """ | |
| # ββ Truncate long context (BERT max = 512 tokens β 2500 chars) ββ | |
| max_chars = 2500 | |
| ctx = context[:max_chars] | |
| if len(context) > max_chars: | |
| last_dot = ctx.rfind(".") | |
| if last_dot > max_chars * 0.7: | |
| ctx = ctx[:last_dot + 1] | |
| # ββ Tokenize: [CLS] question [SEP] context [SEP] ββ | |
| inputs = tokenizer( | |
| question, ctx, | |
| return_tensors="pt", | |
| max_length=512, | |
| truncation=True, | |
| return_offsets_mapping=True, | |
| ) | |
| offset_mapping = inputs.pop("offset_mapping")[0].tolist() | |
| input_ids = inputs["input_ids"] | |
| token_type_ids = inputs.get("token_type_ids") | |
| tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0]) | |
| # ββ Run BERT forward pass ββ | |
| t0 = time.time() | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| inference_ms = int((time.time() - t0) * 1000) | |
| logger.info(f"Inference: {inference_ms}ms") | |
| start_logits = outputs.start_logits[0] | |
| end_logits = outputs.end_logits[0] | |
| # ββ Find best valid answer span ββ | |
| # Check top-5 start Γ top-5 end combinations | |
| top_starts = torch.topk(start_logits, 5).indices.tolist() | |
| top_ends = torch.topk(end_logits, 5).indices.tolist() | |
| best_score, best_s, best_e = -float("inf"), 0, 0 | |
| for s in top_starts: | |
| for e in top_ends: | |
| if e >= s and (e - s) < 50: | |
| # Must be in context segment (token_type_id == 1) | |
| if token_type_ids is not None and token_type_ids[0][s].item() == 1: | |
| score = start_logits[s].item() + end_logits[e].item() | |
| if score > best_score: | |
| best_score, best_s, best_e = score, s, e | |
| # Fallback to raw argmax | |
| if best_score == -float("inf"): | |
| best_s = torch.argmax(start_logits).item() | |
| best_e = torch.argmax(end_logits).item() | |
| if best_e < best_s: | |
| best_e = best_s | |
| # ββ Decode answer text ββ | |
| answer_ids = input_ids[0][best_s:best_e + 1] | |
| answer = tokenizer.decode(answer_ids, skip_special_tokens=True).strip() | |
| if not answer: | |
| answer = "(No answer found in the given context)" | |
| # ββ Confidence score ββ | |
| s_probs = torch.softmax(start_logits, dim=0) | |
| e_probs = torch.softmax(end_logits, dim=0) | |
| conf = (s_probs[best_s] * e_probs[best_e]).item() | |
| conf_level = "high" if conf > 0.6 else ("medium" if conf > 0.2 else "low") | |
| # ββ Classify each token (question / context / answer / special) ββ | |
| tokens = [] | |
| for i, tok in enumerate(tokens_raw): | |
| if tok in ("[CLS]", "[SEP]", "[PAD]"): | |
| t = "special" | |
| elif token_type_ids is not None and token_type_ids[0][i].item() == 0: | |
| t = "question" | |
| else: | |
| t = "context" | |
| if best_s <= i <= best_e and t == "context": | |
| t = "answer" | |
| tokens.append({"text": tok.replace("##", ""), "type": t}) | |
| # ββ Character-level answer position (for highlighting in context) ββ | |
| ans_start_char, ans_end_char = -1, -1 | |
| if best_s < len(offset_mapping) and best_e < len(offset_mapping): | |
| so, eo = offset_mapping[best_s], offset_mapping[best_e] | |
| if so and eo: | |
| ans_start_char, ans_end_char = so[0], eo[1] | |
| logger.info(f"Answer: '{answer}' | Confidence: {conf:.3f} ({conf_level})") | |
| return { | |
| "answer": answer, | |
| "confidence": round(conf, 4), | |
| "confidence_pct": f"{conf * 100:.1f}%", | |
| "confidence_level": conf_level, | |
| "answer_start_char": ans_start_char, | |
| "answer_end_char": ans_end_char, | |
| "context_used": ctx, | |
| "tokens": tokens, | |
| "num_tokens": len(tokens_raw), | |
| "inference_time_ms": inference_ms, | |
| } | |