File size: 3,471 Bytes
09daf0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
qa_engine.py β€” Model router.

Routes inference requests to the correct model module based on model_id.
Initialises all models at startup.
"""

import logging
from models import bert_model, model2, model3
from utils.db import settings_col

logger = logging.getLogger(__name__)

# ─── Registry ────────────────────────────────────────────────────────────────

MODELS = {
    "bert": {
        "id": "bert",
        "name": "BERT",
        "description": "",
        "status": "ready",
        "module": bert_model,
    },
    "model2": {
        "id": "model2",
        "name": "DistilBERT",
        "description": "",
        "status": "coming_soon",
        "module": model2,
    },
    "model3": {
        "id": "model3",
        "name": "BiLSTM",
        "description": "",
        "status": "ready",
        "module": model3,
    },
}


def init_all_models():
    """Initialise all models at application startup."""
    logger.info("[QAEngine] Initialising models...")
    bert_model.init_bert_model()
    model2.init_model2()
    model3.init_model3()
    logger.info("[QAEngine] All models initialised.")


def get_models_info() -> list:
    """Return metadata list for all models (used by /api/models endpoint)."""
    try:
        sys_conf = settings_col().find_one({"_id": "system_config"}) or {}
        model_status_overrides = sys_conf.get("model_status", {})
    except Exception:
        model_status_overrides = {}

    return [
        {
            "id": m["id"],
            "name": m["name"],
            "description": m["description"],
            "status": model_status_overrides.get(m["id"], m["status"]),
        }
        for m in MODELS.values()
    ]


def run_inference(model_id: str, context: str, question: str) -> dict:
    """
    Route a QA request to the appropriate model.

    Args:
        model_id: One of "bert", "model2", "model3"
        context:  The passage/document text
        question: The question to answer

    Returns:
        dict with keys: answer, score, model, model_id, error
    """
    if model_id not in MODELS:
        return {
            "answer": f"Unknown model '{model_id}'. Available: {list(MODELS.keys())}",
            "score": 0.0,
            "model": "Unknown",
            "model_id": model_id,
            "error": True,
        }

    try:
        sys_conf = settings_col().find_one({"_id": "system_config"}) or {}
        if sys_conf.get("maintenance_mode", False):
            return {
                "answer": "System is currently under maintenance. Please try again later.", 
                "score": 0.0, 
                "model": "System", 
                "model_id": model_id, 
                "error": True
            }
        
        status_override = sys_conf.get("model_status", {}).get(model_id)
        current_status = status_override if status_override else MODELS[model_id]["status"]
        if current_status != "ready":
             return {
                 "answer": "This model is currently disabled by an administrator.", 
                 "score": 0.0, 
                 "model": MODELS[model_id]["name"], 
                 "model_id": model_id, 
                 "error": True
             }
    except Exception:
        pass

    module = MODELS[model_id]["module"]
    return module.predict(context=context, question=question)