""" model3.py — Integration for BiLSTM Model. """ import logging import torch import os from models.qa_model import QAModel # Import vocab utilities and preprocess utilities from utils.preprocess import tokenize from utils.vocab import encode logger = logging.getLogger(__name__) model = None vocab = None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def init_model3(): global model, vocab logger.info("[Model3] Initialising BiLSTM from qa_model.pth...") # Assumes qa_model.pth is at the root of the backend directory model_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "qa_model.pth") if not os.path.exists(model_path): logger.warning(f"[Model3] qa_model.pth not found at {model_path}! Model 3 inference will fail.") return try: checkpoint = torch.load(model_path, map_location=device) vocab = checkpoint["vocab"] model = QAModel(len(vocab)) model.load_state_dict(checkpoint["model_state"]) model.to(device) model.eval() logger.info("[Model3] BiLSTM successfully loaded.") except Exception as e: logger.error(f"[Model3] Failed to load BiLSTM model: {e}") def predict(context: str, question: str) -> dict: """Predict using the loaded BiLSTM.""" if model is None or vocab is None: return { "answer": "BiLSTM model weights (qa_model.pth) not found or failed to load. Please make sure the trained model is placed in the backend folder.", "score": 0.0, "model": "BiLSTM", "model_id": "model3", "error": True, "stub": False, } try: q_tokens = tokenize(question) c_tokens = tokenize(context) tokens = q_tokens + ["[SEP]"] + c_tokens encoded = encode(tokens, vocab) max_len = 300 if len(encoded) < max_len: encoded += [0] * (max_len - len(encoded)) else: encoded = encoded[:max_len] x = torch.tensor(encoded).unsqueeze(0).to(device) with torch.no_grad(): start_logits, end_logits = model(x) start = torch.argmax(start_logits, dim=1).item() end = torch.argmax(end_logits, dim=1).item() if start > end or start >= len(tokens): answer = "No answer found" score = 0.0 else: answer = " ".join(tokens[start:end+1]) # Extract basic score approximations from logits if needed, but returning dummy score for now. score = 0.85 return { "answer": answer, "score": score, "model": "BiLSTM", "model_id": "model3", "error": False, } except Exception as e: logger.error(f"[Model3] Inference error: {e}") return { "answer": "Inference error occurred.", "score": 0.0, "model": "BiLSTM", "model_id": "model3", "error": True, "stub": False, }