Spaces:
Sleeping
Sleeping
| """ | |
| qa_engine.py β Model router. | |
| Routes inference requests to the correct model module based on model_id. | |
| Initialises all models at startup. | |
| """ | |
| import logging | |
| from models import bert_model, distilbert_model, bilstm_model | |
| from utils.db import settings_col | |
| logger = logging.getLogger(__name__) | |
| # βββ Registry ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODELS = { | |
| "bert": { | |
| "id": "bert", | |
| "name": "BERT", | |
| "description": "", | |
| "status": "ready", | |
| "module": bert_model, | |
| }, | |
| "distilbert": { | |
| "id": "distilbert", | |
| "name": "DistilBERT", | |
| "description": "", | |
| "status": "ready", | |
| "module": distilbert_model, | |
| }, | |
| "bilstm": { | |
| "id": "bilstm", | |
| "name": "BiLSTM", | |
| "description": "", | |
| "status": "ready", | |
| "module": bilstm_model, | |
| }, | |
| } | |
| def init_all_models(): | |
| """Initialise all models at application startup.""" | |
| logger.info("[QAEngine] Initialising models...") | |
| bert_model.init_bert_model() | |
| distilbert_model.init_distilbert_model() | |
| bilstm_model.init_bilstm_model() | |
| logger.info("[QAEngine] All models initialised.") | |
| def get_models_info() -> list: | |
| """Return metadata list for all models (used by /api/models endpoint).""" | |
| try: | |
| sys_conf = settings_col().find_one({"_id": "system_config"}) or {} | |
| model_status_overrides = sys_conf.get("model_status", {}) | |
| except Exception: | |
| model_status_overrides = {} | |
| return [ | |
| { | |
| "id": m["id"], | |
| "name": m["name"], | |
| "description": "", | |
| "status": model_status_overrides.get(m["id"], m["status"]), | |
| } | |
| for m in MODELS.values() | |
| ] | |
| def run_inference(model_id: str, context: str, question: str) -> dict: | |
| """ | |
| Route a QA request to the appropriate model. | |
| Args: | |
| model_id: One of "bert", "distilbert", "bilstm" | |
| context: The passage/document text | |
| question: The question to answer | |
| Returns: | |
| dict with keys: answer, score, model, model_id, error | |
| """ | |
| if model_id not in MODELS: | |
| return { | |
| "answer": f"Unknown model '{model_id}'. Available: {list(MODELS.keys())}", | |
| "score": 0.0, | |
| "model": "Unknown", | |
| "model_id": model_id, | |
| "error": True, | |
| } | |
| try: | |
| sys_conf = settings_col().find_one({"_id": "system_config"}) or {} | |
| if sys_conf.get("maintenance_mode", False): | |
| return { | |
| "answer": "System is currently under maintenance. Please try again later.", | |
| "score": 0.0, | |
| "model": "System", | |
| "model_id": model_id, | |
| "error": True | |
| } | |
| status_override = sys_conf.get("model_status", {}).get(model_id) | |
| current_status = status_override if status_override else MODELS[model_id]["status"] | |
| if current_status != "ready": | |
| return { | |
| "answer": "This model is currently disabled by an administrator.", | |
| "score": 0.0, | |
| "model": MODELS[model_id]["name"], | |
| "model_id": model_id, | |
| "error": True | |
| } | |
| except Exception: | |
| pass | |
| module = MODELS[model_id]["module"] | |
| return module.predict(context=context, question=question) | |