Spaces:
Sleeping
Sleeping
| """ | |
| bert_model.py — HuggingFace BERT Question Answering Model. | |
| Model: deepset/bert-base-cased-squad2 | |
| Uses direct PyTorch inference (compatible with transformers 5.x). | |
| """ | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| _tokenizer = None | |
| _model = None | |
| MODEL_NAME = "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad" | |
| def init_bert_model(): | |
| """Load the BERT QA model. Called once at app startup.""" | |
| global _tokenizer, _model | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
| logger.info(f"[BERT] Loading model '{MODEL_NAME}' ...") | |
| _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME) | |
| _model.eval() | |
| logger.info("[BERT] Model loaded and ready.") | |
| except Exception as exc: | |
| logger.error(f"[BERT] Failed to load model: {exc}") | |
| _tokenizer = None | |
| _model = None | |
| def _run_qa_inference(context: str, question: str) -> dict: | |
| """Direct PyTorch inference — works with any transformers version.""" | |
| import torch | |
| import torch.nn.functional as F | |
| inputs = _tokenizer( | |
| question, context, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| ) | |
| with torch.no_grad(): | |
| outputs = _model(**inputs) | |
| start_logits = outputs.start_logits[0] | |
| end_logits = outputs.end_logits[0] | |
| start_idx = int(torch.argmax(start_logits)) | |
| end_idx = int(torch.argmax(end_logits)) + 1 | |
| if end_idx <= start_idx: | |
| end_idx = start_idx + 1 | |
| input_ids = inputs["input_ids"][0] | |
| answer_tokens = input_ids[start_idx:end_idx] | |
| answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip() | |
| # BERT-Large Premium Calibration: High baseline + aggressive scaling. | |
| import math | |
| start_prob = float(F.softmax(start_logits, dim=0)[start_idx]) | |
| end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1]) | |
| avg_prob = (start_prob + end_prob) / 2 | |
| calibrated_score = 0.15 + (avg_prob ** 0.3) * 0.85 | |
| score = round(min(max(calibrated_score, 0.0), 0.99), 4) | |
| # --- Strong Collision Filter: Block question repetition using word overlap --- | |
| q_words = set(question.lower().replace('?', '').split()) | |
| a_words = set(answer.lower().replace('?', '').split()) | |
| # Calculate how much of the answer is just the question | |
| if q_words and a_words: | |
| common = q_words.intersection(a_words) | |
| overlap_ratio = len(common) / len(q_words) | |
| else: | |
| overlap_ratio = 0 | |
| # If overlap is high (>70%), or answer is empty, or score is suspiciously low | |
| if overlap_ratio > 0.7 or len(answer.strip()) < 1 or "[cls]" in answer.lower(): | |
| return { | |
| "answer": "I'm sorry, I couldn't find a specific answer to that in the provided document.", | |
| "score": 0.0, | |
| "error": False, | |
| "not_found": True | |
| } | |
| return {"answer": answer, "score": score} | |
| def predict(context: str, question: str) -> dict: | |
| """ | |
| Run QA inference. | |
| Returns: | |
| { | |
| "answer": str, | |
| "score": float (0.0–1.0), | |
| "model": "BERT", | |
| "model_id": "bert" | |
| } | |
| """ | |
| if _model is None or _tokenizer is None: | |
| return { | |
| "answer": "BERT model is not loaded. Please check server logs.", | |
| "score": 0.0, | |
| "model": "BERT", | |
| "model_id": "bert", | |
| "error": True, | |
| } | |
| if not context or not question: | |
| return { | |
| "answer": "Context and question must not be empty.", | |
| "score": 0.0, | |
| "model": "BERT", | |
| "model_id": "bert", | |
| "error": True, | |
| } | |
| try: | |
| result = _run_qa_inference(context=context, question=question) | |
| score = result["score"] | |
| answer = result["answer"] | |
| if score < 0.05 or "[CLS]" in answer or not answer: | |
| answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context." | |
| score = 0.0 | |
| return { | |
| "answer": answer, | |
| "score": score, | |
| "model": "BERT", | |
| "model_id": "bert", | |
| "error": False, | |
| } | |
| except Exception as exc: | |
| logger.error(f"[BERT] Inference error: {exc}") | |
| return { | |
| "answer": f"Inference error: {exc}", | |
| "score": 0.0, | |
| "model": "BERT", | |
| "model_id": "bert", | |
| "error": True, | |
| } | |