| """ |
| qa_engine.py β Model router. |
| |
| Routes inference requests to the correct model module based on model_id. |
| Initialises all models at startup. |
| """ |
|
|
| import logging |
| from models import bert_model, model2, model3 |
| from utils.db import settings_col |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
|
|
| MODELS = { |
| "bert": { |
| "id": "bert", |
| "name": "BERT", |
| "description": "", |
| "status": "ready", |
| "module": bert_model, |
| }, |
| "model2": { |
| "id": "model2", |
| "name": "DistilBERT", |
| "description": "", |
| "status": "coming_soon", |
| "module": model2, |
| }, |
| "model3": { |
| "id": "model3", |
| "name": "BiLSTM", |
| "description": "", |
| "status": "ready", |
| "module": model3, |
| }, |
| } |
|
|
|
|
| def init_all_models(): |
| """Initialise all models at application startup.""" |
| logger.info("[QAEngine] Initialising models...") |
| bert_model.init_bert_model() |
| model2.init_model2() |
| model3.init_model3() |
| logger.info("[QAEngine] All models initialised.") |
|
|
|
|
| def get_models_info() -> list: |
| """Return metadata list for all models (used by /api/models endpoint).""" |
| try: |
| sys_conf = settings_col().find_one({"_id": "system_config"}) or {} |
| model_status_overrides = sys_conf.get("model_status", {}) |
| except Exception: |
| model_status_overrides = {} |
|
|
| return [ |
| { |
| "id": m["id"], |
| "name": m["name"], |
| "description": m["description"], |
| "status": model_status_overrides.get(m["id"], m["status"]), |
| } |
| for m in MODELS.values() |
| ] |
|
|
|
|
| def run_inference(model_id: str, context: str, question: str) -> dict: |
| """ |
| Route a QA request to the appropriate model. |
| |
| Args: |
| model_id: One of "bert", "model2", "model3" |
| context: The passage/document text |
| question: The question to answer |
| |
| Returns: |
| dict with keys: answer, score, model, model_id, error |
| """ |
| if model_id not in MODELS: |
| return { |
| "answer": f"Unknown model '{model_id}'. Available: {list(MODELS.keys())}", |
| "score": 0.0, |
| "model": "Unknown", |
| "model_id": model_id, |
| "error": True, |
| } |
|
|
| try: |
| sys_conf = settings_col().find_one({"_id": "system_config"}) or {} |
| if sys_conf.get("maintenance_mode", False): |
| return { |
| "answer": "System is currently under maintenance. Please try again later.", |
| "score": 0.0, |
| "model": "System", |
| "model_id": model_id, |
| "error": True |
| } |
| |
| status_override = sys_conf.get("model_status", {}).get(model_id) |
| current_status = status_override if status_override else MODELS[model_id]["status"] |
| if current_status != "ready": |
| return { |
| "answer": "This model is currently disabled by an administrator.", |
| "score": 0.0, |
| "model": MODELS[model_id]["name"], |
| "model_id": model_id, |
| "error": True |
| } |
| except Exception: |
| pass |
|
|
| module = MODELS[model_id]["module"] |
| return module.predict(context=context, question=question) |
|
|