Spaces:
Sleeping
Sleeping
| """ | |
| bilstm_model.py — Integration for BiLSTM Model. | |
| """ | |
| import logging | |
| import torch | |
| import os | |
| from models.qa_model import QAModel | |
| # Import vocab utilities and preprocess utilities | |
| from utils.preprocess import tokenize | |
| from utils.vocab import encode | |
| logger = logging.getLogger(__name__) | |
| model = None | |
| vocab = None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def init_bilstm_model(): | |
| global model, vocab | |
| logger.info("[BiLSTM] Initialising BiLSTM from qa_model.pth...") | |
| # Assumes qa_model.pth is at the root of the backend directory | |
| model_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "qa_model.pth") | |
| if not os.path.exists(model_path): | |
| logger.warning(f"[BiLSTM] qa_model.pth not found at {model_path}! BiLSTM inference will fail.") | |
| return | |
| try: | |
| checkpoint = torch.load(model_path, map_location=device) | |
| vocab = checkpoint["vocab"] | |
| model = QAModel(len(vocab)) | |
| model.load_state_dict(checkpoint["model_state"]) | |
| model.to(device) | |
| model.eval() | |
| logger.info("[BiLSTM] BiLSTM successfully loaded.") | |
| except Exception as e: | |
| logger.error(f"[BiLSTM] Failed to load BiLSTM model: {e}") | |
| def predict(context: str, question: str) -> dict: | |
| """Predict using the loaded BiLSTM.""" | |
| if model is None or vocab is None: | |
| return { | |
| "answer": "BiLSTM model weights (qa_model.pth) not found or failed to load. Please make sure the trained model is placed in the backend folder.", | |
| "score": 0.0, | |
| "model": "BiLSTM", | |
| "model_id": "bilstm", | |
| "error": True, | |
| "stub": False, | |
| } | |
| try: | |
| q_tokens = tokenize(question) | |
| c_tokens = tokenize(context) | |
| tokens = q_tokens + ["[SEP]"] + c_tokens | |
| encoded = encode(tokens, vocab) | |
| max_len = 300 | |
| if len(encoded) < max_len: | |
| encoded += [0] * (max_len - len(encoded)) | |
| else: | |
| encoded = encoded[:max_len] | |
| x = torch.tensor(encoded).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| start_logits, end_logits = model(x) | |
| start = torch.argmax(start_logits, dim=1).item() | |
| end = torch.argmax(end_logits, dim=1).item() | |
| if start > end or start >= len(tokens): | |
| answer = "No answer found" | |
| score = 0.0 | |
| else: | |
| answer = " ".join(tokens[start:end+1]) | |
| # Dynamic confidence calculation based on logit softmax probabilities | |
| import torch.nn.functional as F | |
| import math | |
| s_prob = float(F.softmax(start_logits, dim=1).max()) | |
| e_prob = float(F.softmax(end_logits, dim=1).max()) | |
| avg_prob = (s_prob + e_prob) / 2 | |
| # Calibrate for a more "proper" human-readable percentage | |
| score = round(math.sqrt(avg_prob), 4) | |
| return { | |
| "answer": answer, | |
| "score": score, | |
| "model": "BiLSTM", | |
| "model_id": "bilstm", | |
| "error": False, | |
| } | |
| except Exception as e: | |
| logger.error(f"[BiLSTM] Inference error: {e}") | |
| return { | |
| "answer": "Inference error occurred.", | |
| "score": 0.0, | |
| "model": "BiLSTM", | |
| "model_id": "bilstm", | |
| "error": True, | |
| "stub": False, | |
| } | |