| """ |
| model3.py — Integration for BiLSTM Model. |
| """ |
|
|
| import logging |
| import torch |
| import os |
| from models.qa_model import QAModel |
|
|
| |
| from utils.preprocess import tokenize |
| from utils.vocab import encode |
|
|
| logger = logging.getLogger(__name__) |
|
|
| model = None |
| vocab = None |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def init_model3(): |
| global model, vocab |
| logger.info("[Model3] Initialising BiLSTM from qa_model.pth...") |
| |
| |
| model_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "qa_model.pth") |
| if not os.path.exists(model_path): |
| logger.warning(f"[Model3] qa_model.pth not found at {model_path}! Model 3 inference will fail.") |
| return |
|
|
| try: |
| checkpoint = torch.load(model_path, map_location=device) |
| vocab = checkpoint["vocab"] |
| |
| model = QAModel(len(vocab)) |
| model.load_state_dict(checkpoint["model_state"]) |
| model.to(device) |
| model.eval() |
| logger.info("[Model3] BiLSTM successfully loaded.") |
| except Exception as e: |
| logger.error(f"[Model3] Failed to load BiLSTM model: {e}") |
|
|
|
|
| def predict(context: str, question: str) -> dict: |
| """Predict using the loaded BiLSTM.""" |
| if model is None or vocab is None: |
| return { |
| "answer": "BiLSTM model weights (qa_model.pth) not found or failed to load. Please make sure the trained model is placed in the backend folder.", |
| "score": 0.0, |
| "model": "BiLSTM", |
| "model_id": "model3", |
| "error": True, |
| "stub": False, |
| } |
| |
| try: |
| q_tokens = tokenize(question) |
| c_tokens = tokenize(context) |
|
|
| tokens = q_tokens + ["[SEP]"] + c_tokens |
| encoded = encode(tokens, vocab) |
|
|
| max_len = 300 |
| if len(encoded) < max_len: |
| encoded += [0] * (max_len - len(encoded)) |
| else: |
| encoded = encoded[:max_len] |
|
|
| x = torch.tensor(encoded).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| start_logits, end_logits = model(x) |
| |
| start = torch.argmax(start_logits, dim=1).item() |
| end = torch.argmax(end_logits, dim=1).item() |
|
|
| if start > end or start >= len(tokens): |
| answer = "No answer found" |
| score = 0.0 |
| else: |
| answer = " ".join(tokens[start:end+1]) |
| |
| score = 0.85 |
|
|
| return { |
| "answer": answer, |
| "score": score, |
| "model": "BiLSTM", |
| "model_id": "model3", |
| "error": False, |
| } |
| except Exception as e: |
| logger.error(f"[Model3] Inference error: {e}") |
| return { |
| "answer": "Inference error occurred.", |
| "score": 0.0, |
| "model": "BiLSTM", |
| "model_id": "model3", |
| "error": True, |
| "stub": False, |
| } |
|
|