""" distilbert_model.py — HuggingFace DistilBERT Question Answering Model. Model: distilbert-base-cased-distilled-squad 40% smaller and 60% faster than BERT with ~97% of its performance. Uses direct PyTorch inference (compatible with transformers 5.x). """ import logging logger = logging.getLogger(__name__) _tokenizer = None _model = None MODEL_NAME = "distilbert-base-cased-distilled-squad" def init_distilbert_model(): """Load the DistilBERT QA model. Called once at app startup.""" global _tokenizer, _model try: from transformers import AutoTokenizer, AutoModelForQuestionAnswering logger.info(f"[DistilBERT] Loading model '{MODEL_NAME}' ...") _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME) _model.eval() logger.info("[DistilBERT] Model loaded and ready.") except Exception as exc: logger.error(f"[DistilBERT] Failed to load model: {exc}") _tokenizer = None _model = None def _run_qa_inference(context: str, question: str) -> dict: """Direct PyTorch inference — works with any transformers version.""" import torch import torch.nn.functional as F inputs = _tokenizer( question, context, return_tensors="pt", truncation=True, max_length=512, ) with torch.no_grad(): outputs = _model(**inputs) start_logits = outputs.start_logits[0] end_logits = outputs.end_logits[0] start_idx = int(torch.argmax(start_logits)) end_idx = int(torch.argmax(end_logits)) + 1 if end_idx <= start_idx: end_idx = start_idx + 1 input_ids = inputs["input_ids"][0] answer_tokens = input_ids[start_idx:end_idx] answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip() # DistilBERT Normalized Calibration: Softer scaling to respect BERT-Large's superiority. import math start_prob = float(F.softmax(start_logits, dim=0)[start_idx]) end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1]) avg_prob = (start_prob + end_prob) / 2 calibrated_score = avg_prob ** 0.8 # Softer scaling score = round(min(max(calibrated_score, 0.0), 1.0), 4) # --- Strong Collision Filter: Block question repetition using word overlap --- q_words = set(question.lower().replace('?', '').split()) a_words = set(answer.lower().replace('?', '').split()) if q_words and a_words: common = q_words.intersection(a_words) overlap_ratio = len(common) / len(q_words) else: overlap_ratio = 0 if overlap_ratio > 0.7 or len(answer.strip()) < 1 or "[cls]" in answer.lower(): return { "answer": "I'm sorry, I couldn't find a specific answer to that in the provided document.", "score": 0.0, "error": False, "not_found": True } return {"answer": answer, "score": score} def predict(context: str, question: str) -> dict: """ Run QA inference. Returns: { "answer": str, "score": float (0.0–1.0), "model": "DistilBERT", "model_id": "distilbert" } """ if _model is None or _tokenizer is None: return { "answer": "DistilBERT model is not loaded. Please check server logs.", "score": 0.0, "model": "DistilBERT", "model_id": "distilbert", "error": True, } if not context or not question: return { "answer": "Context and question must not be empty.", "score": 0.0, "model": "DistilBERT", "model_id": "distilbert", "error": True, } try: result = _run_qa_inference(context=context, question=question) score = result["score"] answer = result["answer"] if score < 0.05 or "[CLS]" in answer or not answer: answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context." score = 0.0 return { "answer": answer, "score": score, "model": "DistilBERT", "model_id": "distilbert", "error": False, } except Exception as exc: logger.error(f"[DistilBERT] Inference error: {exc}") return { "answer": f"Inference error: {exc}", "score": 0.0, "model": "DistilBERT", "model_id": "distilbert", "error": True, }