feat: deploy SQuAD backend with all AI models
Browse files- .env +21 -0
- Dockerfile +21 -0
- README.md +38 -10
- __pycache__/auth.cpython-314.pyc +0 -0
- __pycache__/qa_engine.cpython-314.pyc +0 -0
- app.py +638 -0
- auth.py +92 -0
- data_loader/load_squad_json.py +25 -0
- gunicorn.conf.py +10 -0
- main.py +68 -0
- models/__init__.py +1 -0
- models/__pycache__/__init__.cpython-314.pyc +0 -0
- models/__pycache__/bert_model.cpython-314.pyc +0 -0
- models/__pycache__/model2.cpython-314.pyc +0 -0
- models/__pycache__/model3.cpython-314.pyc +0 -0
- models/__pycache__/qa_model.cpython-314.pyc +0 -0
- models/bert_model.py +123 -0
- models/model2.py +28 -0
- models/model3.py +100 -0
- models/qa_model.py +27 -0
- qa_engine.py +115 -0
- qa_model.pth +3 -0
- requirements.txt +16 -0
- train.py +100 -0
- utils/__init__.py +1 -0
- utils/__pycache__/__init__.cpython-314.pyc +0 -0
- utils/__pycache__/db.cpython-314.pyc +0 -0
- utils/__pycache__/pdf_parser.cpython-314.pyc +0 -0
- utils/__pycache__/preprocess.cpython-314.pyc +0 -0
- utils/__pycache__/vocab.cpython-314.pyc +0 -0
- utils/db.py +82 -0
- utils/file_loader.py +19 -0
- utils/pdf_parser.py +49 -0
- utils/preprocess.py +6 -0
- utils/squad_preprocess.py +26 -0
- utils/vocab.py +13 -0
.env
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ─── Database ───────────────────────────────────────────────────────────────
|
| 2 |
+
MONGO_URI=mongodb+srv://tnp554:ibmtnp@ibmcluster.swumgnp.mongodb.net/squad_qa?appName=IBMCLUSTER
|
| 3 |
+
|
| 4 |
+
# ─── Auth ────────────────────────────────────────────────────────────────────
|
| 5 |
+
JWT_SECRET=905d93e5bf632330aee5075046c4b8cc7d1d2c28d575918c9dbf7be33536badd
|
| 6 |
+
JWT_EXPIRY_HOURS=24
|
| 7 |
+
|
| 8 |
+
# ─── Admin Seed ──────────────────────────────────────────────────────────────
|
| 9 |
+
ADMIN_EMAIL=admin@squad.ai
|
| 10 |
+
ADMIN_PASSWORD=Admin@123
|
| 11 |
+
|
| 12 |
+
# ─── App Config ──────────────────────────────────────────────────────────────
|
| 13 |
+
FLASK_ENV=production
|
| 14 |
+
# Comma-separated list of allowed origins (no trailing slash)
|
| 15 |
+
ALLOWED_ORIGINS=http://localhost:5173,http://localhost:5174,http://localhost:3000
|
| 16 |
+
|
| 17 |
+
# ─── Feature Flags ───────────────────────────────────────────────────────────
|
| 18 |
+
PDF_MAX_PAGES=15
|
| 19 |
+
|
| 20 |
+
EMAIL_USER=otp.squad.ai@gmail.com
|
| 21 |
+
EMAIL_PASS=yfqkqjtzlbljgpww
|
Dockerfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system deps for PyPDF2, python-docx, torch, and file security (libmagic)
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
gcc \
|
| 8 |
+
libgomp1 \
|
| 9 |
+
libmagic1 \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
# Install Python dependencies
|
| 13 |
+
COPY requirements.txt .
|
| 14 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
+
|
| 16 |
+
# Copy source
|
| 17 |
+
COPY . .
|
| 18 |
+
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
CMD ["gunicorn", "-c", "gunicorn.conf.py", "app:app"]
|
README.md
CHANGED
|
@@ -1,10 +1,38 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🐍 Backend Architecture (Flask + PyTorch)
|
| 2 |
+
|
| 3 |
+
The core engine responsible for MongoDB tracking, Authentication routing, and executing Heavy Machine Learning Inference locally on your physical server via Virtual Environments.
|
| 4 |
+
|
| 5 |
+
## 🔑 Environment Variables
|
| 6 |
+
The root of this folder requires a `.env` file to function:
|
| 7 |
+
```env
|
| 8 |
+
MONGODB_URI=mongodb+srv://<your-creds>.mongodb.net
|
| 9 |
+
JWT_SECRET=super_secure_hash_string_here
|
| 10 |
+
ADMIN_EMAIL=admin@squad.ai
|
| 11 |
+
ADMIN_PASSWORD=Admin@123
|
| 12 |
+
EMAIL_USER=your_gmail@gmail.com
|
| 13 |
+
EMAIL_PASS=your_16_char_gmail_app_password
|
| 14 |
+
FLASK_ENV=development
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## 🧠 AI Inference Matrix (`/models`)
|
| 18 |
+
The system routes questions based on physical payload ID bindings directly into active memory arrays.
|
| 19 |
+
1. **Model 1: `bert_model.py` (BERT)**
|
| 20 |
+
* Leverages HuggingFace `transformers` for `deepset/bert-base-cased-squad2`.
|
| 21 |
+
2. **Model 3: `model3.py` (BiLSTM)**
|
| 22 |
+
* Native PyTorch integration running isolated weights mapped precisely off a local `qa_model.pth` tensor dictionary array.
|
| 23 |
+
|
| 24 |
+
## 📜 Database Collections
|
| 25 |
+
All queries are funneled cleanly into MongoDB:
|
| 26 |
+
- `users`: Standard user tracking, OTP storage, password hashing tracking.
|
| 27 |
+
- `chats`: Detailed inference payloads, system diagnostics, user-soft deletion patterns (`user_deleted: True`).
|
| 28 |
+
- `settings`: Central singleton objects storing administrative configurations.
|
| 29 |
+
|
| 30 |
+
## 🚀 Running Locally
|
| 31 |
+
```bash
|
| 32 |
+
# 1. Activate Virtual Env
|
| 33 |
+
.\.venv\Scripts\activate
|
| 34 |
+
# 2. Install Dependencies
|
| 35 |
+
pip install -r requirements.txt
|
| 36 |
+
# 3. Boot Server
|
| 37 |
+
python app.py
|
| 38 |
+
```
|
__pycache__/auth.cpython-314.pyc
ADDED
|
Binary file (5.05 kB). View file
|
|
|
__pycache__/qa_engine.cpython-314.pyc
ADDED
|
Binary file (4.33 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py — Main Flask application for the SQuAD QA System.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
Public:
|
| 6 |
+
POST /api/auth/register
|
| 7 |
+
POST /api/auth/login
|
| 8 |
+
GET /api/health
|
| 9 |
+
|
| 10 |
+
Authenticated (any user):
|
| 11 |
+
GET /api/auth/me
|
| 12 |
+
GET /api/models
|
| 13 |
+
POST /api/ask
|
| 14 |
+
GET /api/history
|
| 15 |
+
DELETE /api/history/<chat_id>
|
| 16 |
+
DELETE /api/history
|
| 17 |
+
|
| 18 |
+
Admin only:
|
| 19 |
+
GET /api/admin/users
|
| 20 |
+
PUT /api/admin/users/<user_id>
|
| 21 |
+
DELETE /api/admin/users/<user_id>
|
| 22 |
+
GET /api/admin/stats
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import os
|
| 26 |
+
import sys
|
| 27 |
+
import logging
|
| 28 |
+
import re
|
| 29 |
+
from datetime import datetime, timezone, timedelta
|
| 30 |
+
|
| 31 |
+
from flask import Flask, request, jsonify, g
|
| 32 |
+
from flask_cors import CORS
|
| 33 |
+
from flask_bcrypt import Bcrypt
|
| 34 |
+
from flask_limiter import Limiter
|
| 35 |
+
from flask_limiter.util import get_remote_address
|
| 36 |
+
from bson import ObjectId
|
| 37 |
+
from dotenv import load_dotenv
|
| 38 |
+
|
| 39 |
+
# ─── Load environment ─────────────────────────────────────────────────────────
|
| 40 |
+
load_dotenv()
|
| 41 |
+
|
| 42 |
+
# ─── Logging ─────────────────────────────────────────────────────────────────
|
| 43 |
+
logging.basicConfig(
|
| 44 |
+
level=logging.INFO,
|
| 45 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 46 |
+
stream=sys.stdout,
|
| 47 |
+
)
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
# ─── App init ─────────────────────────────────────────────────────────────────
|
| 51 |
+
app = Flask(__name__)
|
| 52 |
+
bcrypt = Bcrypt(app)
|
| 53 |
+
limiter = Limiter(
|
| 54 |
+
get_remote_address,
|
| 55 |
+
app=app,
|
| 56 |
+
default_limits=["1000 per day", "100 per hour"],
|
| 57 |
+
storage_uri="memory://"
|
| 58 |
+
)
|
| 59 |
+
app.config['MAX_CONTENT_LENGTH'] = 5 * 1024 * 1024 # 5 MB max constraint
|
| 60 |
+
|
| 61 |
+
# ─── CORS (reads from env for cloud safety) ───────────────────────────────────
|
| 62 |
+
raw_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:5173,http://localhost:3000")
|
| 63 |
+
allowed_origins = [o.strip() for o in raw_origins.split(",") if o.strip()]
|
| 64 |
+
CORS(app, origins=allowed_origins, supports_credentials=True)
|
| 65 |
+
|
| 66 |
+
# ─── Internal imports (after app init) ───────────────────────────────────────
|
| 67 |
+
from auth import generate_token, require_auth, require_admin
|
| 68 |
+
from utils.db import users_col, chats_col, settings_col, is_using_mock
|
| 69 |
+
from utils.pdf_parser import extract_text
|
| 70 |
+
import qa_engine
|
| 71 |
+
|
| 72 |
+
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
| 73 |
+
|
| 74 |
+
def _serialize(doc: dict) -> dict:
|
| 75 |
+
"""Convert MongoDB ObjectId fields to strings for JSON serialization."""
|
| 76 |
+
if doc is None:
|
| 77 |
+
return None
|
| 78 |
+
doc = dict(doc)
|
| 79 |
+
if "_id" in doc:
|
| 80 |
+
doc["id"] = str(doc.pop("_id"))
|
| 81 |
+
return doc
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _now_iso() -> str:
|
| 85 |
+
return datetime.now(timezone.utc).isoformat()
|
| 86 |
+
|
| 87 |
+
def _future_iso(seconds: int) -> str:
|
| 88 |
+
return (datetime.now(timezone.utc) + timedelta(seconds=seconds)).isoformat()
|
| 89 |
+
|
| 90 |
+
def safe_str(val) -> str:
|
| 91 |
+
"""Ensure the input is strictly a string, preventing NoSQL injection dicts."""
|
| 92 |
+
if not isinstance(val, str):
|
| 93 |
+
return ""
|
| 94 |
+
return val.strip()
|
| 95 |
+
|
| 96 |
+
def send_otp_email(to_email, otp):
|
| 97 |
+
"""Sends OTP via real Gmail SMTP if ENV vars exist."""
|
| 98 |
+
email_user = os.getenv("EMAIL_USER")
|
| 99 |
+
email_pass = os.getenv("EMAIL_PASS")
|
| 100 |
+
if not email_user or not email_pass:
|
| 101 |
+
# Fallback to mock logging if user hasn't put in valid app passwords yet
|
| 102 |
+
logger.warning("=" * 60)
|
| 103 |
+
logger.warning(f" [MOCK EMAIL OTP] Verification code for {to_email}: {otp}")
|
| 104 |
+
logger.warning("=" * 60)
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
import smtplib
|
| 109 |
+
from email.mime.text import MIMEText
|
| 110 |
+
from email.mime.multipart import MIMEMultipart
|
| 111 |
+
msg = MIMEMultipart()
|
| 112 |
+
msg['From'] = email_user
|
| 113 |
+
msg['To'] = to_email
|
| 114 |
+
msg['Subject'] = "SQuAD QA - Your Verification Code"
|
| 115 |
+
body = f"Welcome to SQuAD QA!!!\n\nYour 6-digit registration verification code is: {otp}\n\nPlease enter this code to complete your registration.\n\nThank you!!!"
|
| 116 |
+
msg.attach(MIMEText(body, 'plain'))
|
| 117 |
+
|
| 118 |
+
server = smtplib.SMTP_SSL('smtp.gmail.com', 465)
|
| 119 |
+
server.login(email_user, email_pass)
|
| 120 |
+
server.send_message(msg)
|
| 121 |
+
server.quit()
|
| 122 |
+
logger.info(f"[SMTP] Successfully dispatched OTP to {to_email}")
|
| 123 |
+
return True
|
| 124 |
+
except Exception as e:
|
| 125 |
+
logger.error(f"[SMTP ERROR] Failed to send actual email to {to_email}: {e}")
|
| 126 |
+
return False
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ─── Admin Seed ────────────────────────────��──────────────────────────────────
|
| 130 |
+
|
| 131 |
+
def _seed_admin():
|
| 132 |
+
"""Create the default admin user if it doesn't exist."""
|
| 133 |
+
admin_email = os.getenv("ADMIN_EMAIL", "admin@squad.ai")
|
| 134 |
+
admin_password = os.getenv("ADMIN_PASSWORD", "Admin@123")
|
| 135 |
+
|
| 136 |
+
col = users_col()
|
| 137 |
+
if col.find_one({"email": admin_email}):
|
| 138 |
+
logger.info(f"[Seed] Admin user '{admin_email}' already exists.")
|
| 139 |
+
return
|
| 140 |
+
|
| 141 |
+
hashed = bcrypt.generate_password_hash(admin_password).decode("utf-8")
|
| 142 |
+
col.insert_one({
|
| 143 |
+
"name": "Administrator",
|
| 144 |
+
"email": admin_email,
|
| 145 |
+
"password": hashed,
|
| 146 |
+
"role": "admin",
|
| 147 |
+
"is_active": True,
|
| 148 |
+
"created_at": _now_iso(),
|
| 149 |
+
"last_login": None,
|
| 150 |
+
})
|
| 151 |
+
logger.info(f"[Seed] Admin user '{admin_email}' created.")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ─── Health ───────────────────────────────────────────────────────────────────
|
| 155 |
+
|
| 156 |
+
@app.route("/api/health", methods=["GET"])
|
| 157 |
+
def health():
|
| 158 |
+
return jsonify({
|
| 159 |
+
"status": "ok",
|
| 160 |
+
"db_mode": "mock" if is_using_mock() else "atlas",
|
| 161 |
+
"timestamp": _now_iso(),
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ─── Auth Routes ──────────────────────────────────────────────────────────────
|
| 166 |
+
|
| 167 |
+
@app.route("/api/auth/register", methods=["POST"])
|
| 168 |
+
@limiter.limit("10 per hour")
|
| 169 |
+
def register():
|
| 170 |
+
data = request.get_json(silent=True) or {}
|
| 171 |
+
name = safe_str(data.get("name"))
|
| 172 |
+
email = safe_str(data.get("email")).lower()
|
| 173 |
+
password = safe_str(data.get("password"))
|
| 174 |
+
|
| 175 |
+
if not name or not email or not password:
|
| 176 |
+
return jsonify({"error": "Name, email, and password are required."}), 400
|
| 177 |
+
|
| 178 |
+
password_regex = r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[@$!%*?&#^])[A-Za-z\d@$!%*?&#^]{8,}$"
|
| 179 |
+
if not re.match(password_regex, password):
|
| 180 |
+
return jsonify({"error": "Password must be at least 8 characters and include uppercase, lowercase, number, and a special character."}), 400
|
| 181 |
+
|
| 182 |
+
col = users_col()
|
| 183 |
+
|
| 184 |
+
sys_col = settings_col()
|
| 185 |
+
sys_conf = sys_col.find_one({"_id": "system_config"}) or {}
|
| 186 |
+
if sys_conf.get("disable_registrations", False):
|
| 187 |
+
return jsonify({"error": "New user registrations are currently disabled by the administrator."}), 403
|
| 188 |
+
|
| 189 |
+
if col.find_one({"email": email}):
|
| 190 |
+
return jsonify({"error": "An account with this email already exists."}), 409
|
| 191 |
+
|
| 192 |
+
hashed = bcrypt.generate_password_hash(password).decode("utf-8")
|
| 193 |
+
import random
|
| 194 |
+
otp = str(random.randint(100000, 999999))
|
| 195 |
+
send_otp_email(email, otp)
|
| 196 |
+
|
| 197 |
+
result = col.insert_one({
|
| 198 |
+
"name": name,
|
| 199 |
+
"email": email,
|
| 200 |
+
"password": hashed,
|
| 201 |
+
"role": "user",
|
| 202 |
+
"is_active": False,
|
| 203 |
+
"is_verified": False,
|
| 204 |
+
"otp": otp,
|
| 205 |
+
"otp_expires_at": _future_iso(60),
|
| 206 |
+
"created_at": _now_iso(),
|
| 207 |
+
"last_login": None,
|
| 208 |
+
})
|
| 209 |
+
|
| 210 |
+
return jsonify({
|
| 211 |
+
"message": "OTP sent to email. Please verify your account.",
|
| 212 |
+
"requires_otp": True
|
| 213 |
+
}), 201
|
| 214 |
+
|
| 215 |
+
@app.route("/api/auth/verify", methods=["POST"])
|
| 216 |
+
@limiter.limit("5 per minute")
|
| 217 |
+
def verify_otp():
|
| 218 |
+
data = request.get_json(silent=True) or {}
|
| 219 |
+
email = safe_str(data.get("email")).lower()
|
| 220 |
+
otp = safe_str(data.get("otp"))
|
| 221 |
+
|
| 222 |
+
if not email or not otp:
|
| 223 |
+
return jsonify({"error": "Email and OTP are required."}), 400
|
| 224 |
+
|
| 225 |
+
col = users_col()
|
| 226 |
+
user = col.find_one({"email": email})
|
| 227 |
+
|
| 228 |
+
if not user:
|
| 229 |
+
return jsonify({"error": "User not found."}), 404
|
| 230 |
+
|
| 231 |
+
if user.get("is_verified", False):
|
| 232 |
+
return jsonify({"error": "Account already verified."}), 400
|
| 233 |
+
|
| 234 |
+
expires_at = user.get("otp_expires_at")
|
| 235 |
+
if expires_at and _now_iso() > expires_at:
|
| 236 |
+
return jsonify({"error": "OTP has expired. Please request a new one."}), 400
|
| 237 |
+
|
| 238 |
+
if str(user.get("otp")) != str(otp):
|
| 239 |
+
return jsonify({"error": "Invalid verification code."}), 400
|
| 240 |
+
|
| 241 |
+
col.update_one({"_id": user["_id"]}, {"$set": {"is_verified": True, "is_active": True, "otp": None}})
|
| 242 |
+
|
| 243 |
+
user_id = str(user["_id"])
|
| 244 |
+
from auth import generate_token
|
| 245 |
+
role = user.get("role", "user")
|
| 246 |
+
token = generate_token(user_id, role)
|
| 247 |
+
col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}})
|
| 248 |
+
|
| 249 |
+
return jsonify({
|
| 250 |
+
"message": "Account verified successfully.",
|
| 251 |
+
"token": token,
|
| 252 |
+
"user": {"id": user_id, "name": user["name"], "email": user["email"], "role": role},
|
| 253 |
+
}), 200
|
| 254 |
+
|
| 255 |
+
@app.route("/api/auth/resend-otp", methods=["POST"])
|
| 256 |
+
@limiter.limit("3 per minute")
|
| 257 |
+
def resend_otp():
|
| 258 |
+
data = request.get_json(silent=True) or {}
|
| 259 |
+
email = safe_str(data.get("email")).lower()
|
| 260 |
+
|
| 261 |
+
if not email:
|
| 262 |
+
return jsonify({"error": "Email is required."}), 400
|
| 263 |
+
|
| 264 |
+
col = users_col()
|
| 265 |
+
user = col.find_one({"email": email})
|
| 266 |
+
|
| 267 |
+
if not user:
|
| 268 |
+
return jsonify({"error": "User not found."}), 404
|
| 269 |
+
|
| 270 |
+
if user.get("is_verified", False):
|
| 271 |
+
return jsonify({"error": "Account is already verified."}), 400
|
| 272 |
+
|
| 273 |
+
import random
|
| 274 |
+
new_otp = str(random.randint(100000, 999999))
|
| 275 |
+
|
| 276 |
+
col.update_one({"_id": user["_id"]}, {"$set": {"otp": new_otp, "otp_expires_at": _future_iso(60)}})
|
| 277 |
+
|
| 278 |
+
send_otp_email(email, new_otp)
|
| 279 |
+
|
| 280 |
+
return jsonify({"message": "A new OTP has been sent to your email."}), 200
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@app.route("/api/auth/login", methods=["POST"])
|
| 284 |
+
@limiter.limit("15 per minute")
|
| 285 |
+
def login():
|
| 286 |
+
data = request.get_json(silent=True) or {}
|
| 287 |
+
email = safe_str(data.get("email")).lower()
|
| 288 |
+
password = safe_str(data.get("password"))
|
| 289 |
+
|
| 290 |
+
if not email or not password:
|
| 291 |
+
return jsonify({"error": "Email and password are required."}), 400
|
| 292 |
+
|
| 293 |
+
col = users_col()
|
| 294 |
+
user = col.find_one({"email": email})
|
| 295 |
+
|
| 296 |
+
if not user or not bcrypt.check_password_hash(user["password"], password):
|
| 297 |
+
return jsonify({"error": "Invalid email or password."}), 401
|
| 298 |
+
if not user.get("is_verified", True):
|
| 299 |
+
# We can trigger verify if they try to login while unverified, but for simplicity:
|
| 300 |
+
return jsonify({"error": "Your account is not verified. Please check your email for the OTP."}), 403
|
| 301 |
+
if not user.get("is_active", True):
|
| 302 |
+
return jsonify({"error": "Your account has been deactivated. Contact admin."}), 403
|
| 303 |
+
|
| 304 |
+
user_id = str(user["_id"])
|
| 305 |
+
role = user.get("role", "user")
|
| 306 |
+
token = generate_token(user_id, role)
|
| 307 |
+
|
| 308 |
+
# Update last_login
|
| 309 |
+
col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}})
|
| 310 |
+
|
| 311 |
+
return jsonify({
|
| 312 |
+
"message": "Login successful.",
|
| 313 |
+
"token": token,
|
| 314 |
+
"user": {
|
| 315 |
+
"id": user_id,
|
| 316 |
+
"name": user["name"],
|
| 317 |
+
"email": user["email"],
|
| 318 |
+
"role": role,
|
| 319 |
+
},
|
| 320 |
+
})
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@app.route("/api/auth/me", methods=["GET"])
|
| 324 |
+
@require_auth
|
| 325 |
+
def me():
|
| 326 |
+
from bson import ObjectId as ObjId
|
| 327 |
+
col = users_col()
|
| 328 |
+
try:
|
| 329 |
+
user = col.find_one({"_id": ObjId(g.current_user["id"])})
|
| 330 |
+
except Exception:
|
| 331 |
+
user = col.find_one({"_id": g.current_user["id"]})
|
| 332 |
+
|
| 333 |
+
if not user:
|
| 334 |
+
return jsonify({"error": "User not found."}), 404
|
| 335 |
+
|
| 336 |
+
user = _serialize(user)
|
| 337 |
+
user.pop("password", None)
|
| 338 |
+
return jsonify({"user": user})
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
# ─── Models ───────────────────────────────────────────────────────────────────
|
| 342 |
+
|
| 343 |
+
@app.route("/api/models", methods=["GET"])
|
| 344 |
+
@require_auth
|
| 345 |
+
def get_models():
|
| 346 |
+
models_info = qa_engine.get_models_info()
|
| 347 |
+
|
| 348 |
+
ready_ids = [m["id"] for m in models_info if m.get("status") == "ready"]
|
| 349 |
+
pipeline = [
|
| 350 |
+
{"$match": {"model_id": {"$in": ready_ids}, "error": False}},
|
| 351 |
+
{"$group": {"_id": "$model_id", "avg_score": {"$avg": "$score"}, "count": {"$sum": 1}}}
|
| 352 |
+
]
|
| 353 |
+
try:
|
| 354 |
+
from utils.db import chats_col
|
| 355 |
+
stats = {doc["_id"]: doc for doc in chats_col().aggregate(pipeline)}
|
| 356 |
+
total_queries = sum(d["count"] for d in stats.values())
|
| 357 |
+
total_score = sum(d["avg_score"] * d["count"] for d in stats.values())
|
| 358 |
+
global_avg = (total_score / total_queries) if total_queries > 0 else 0
|
| 359 |
+
except Exception:
|
| 360 |
+
stats = {}
|
| 361 |
+
global_avg = 0
|
| 362 |
+
total_queries = 0
|
| 363 |
+
|
| 364 |
+
for m in models_info:
|
| 365 |
+
model_stat = stats.get(m["id"], {})
|
| 366 |
+
m["avg_score"] = model_stat.get("avg_score", 0.0)
|
| 367 |
+
m["query_count"] = model_stat.get("count", 0)
|
| 368 |
+
|
| 369 |
+
return jsonify({
|
| 370 |
+
"models": models_info,
|
| 371 |
+
"global_avg": global_avg,
|
| 372 |
+
"total_queries": total_queries
|
| 373 |
+
})
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# ─── Ask (QA Inference) ───────────────────────────────────────────────────────
|
| 377 |
+
|
| 378 |
+
@app.route("/api/ask", methods=["POST"])
|
| 379 |
+
@require_auth
|
| 380 |
+
@limiter.limit("30 per minute")
|
| 381 |
+
def ask():
|
| 382 |
+
model_id = "bert"
|
| 383 |
+
context = ""
|
| 384 |
+
question = ""
|
| 385 |
+
|
| 386 |
+
# ── File upload (multipart form) ──
|
| 387 |
+
if request.content_type and "multipart/form-data" in request.content_type:
|
| 388 |
+
model_id = safe_str(request.form.get("model_id")) or "bert"
|
| 389 |
+
question = safe_str(request.form.get("question"))
|
| 390 |
+
file = request.files.get("file")
|
| 391 |
+
if file:
|
| 392 |
+
try:
|
| 393 |
+
import magic
|
| 394 |
+
buffer = file.read()
|
| 395 |
+
mime = magic.from_buffer(buffer, mime=True)
|
| 396 |
+
allowed_mimes = ["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]
|
| 397 |
+
if mime not in allowed_mimes:
|
| 398 |
+
return jsonify({"error": f"Security system rejected {mime}. Only true PDF/DOCX files permitted."}), 400
|
| 399 |
+
from utils.pdf_parser import extract_text
|
| 400 |
+
context = extract_text(buffer, file.filename)
|
| 401 |
+
except ValueError as exc:
|
| 402 |
+
return jsonify({"error": str(exc)}), 400
|
| 403 |
+
else:
|
| 404 |
+
context = safe_str(request.form.get("context"))
|
| 405 |
+
else:
|
| 406 |
+
# ── JSON body ──
|
| 407 |
+
data = request.get_json(silent=True) or {}
|
| 408 |
+
model_id = safe_str(data.get("model_id")) or "bert"
|
| 409 |
+
context = safe_str(data.get("context"))
|
| 410 |
+
question = safe_str(data.get("question"))
|
| 411 |
+
|
| 412 |
+
if not context:
|
| 413 |
+
return jsonify({"error": "Context (text or file) is required."}), 400
|
| 414 |
+
if not question:
|
| 415 |
+
return jsonify({"error": "Question is required."}), 400
|
| 416 |
+
|
| 417 |
+
# ── Run inference ──
|
| 418 |
+
result = qa_engine.run_inference(model_id, context, question)
|
| 419 |
+
|
| 420 |
+
# ── Persist to DB ──
|
| 421 |
+
chat_doc = {
|
| 422 |
+
"user_id": g.current_user["id"],
|
| 423 |
+
"model_id": model_id,
|
| 424 |
+
"model_name": result.get("model", model_id),
|
| 425 |
+
"context": context[:2000], # truncate for storage
|
| 426 |
+
"question": question,
|
| 427 |
+
"answer": result.get("answer", ""),
|
| 428 |
+
"score": result.get("score", 0.0),
|
| 429 |
+
"error": result.get("error", False),
|
| 430 |
+
"created_at": _now_iso(),
|
| 431 |
+
}
|
| 432 |
+
insert_result = chats_col().insert_one(chat_doc)
|
| 433 |
+
result["chat_id"] = str(insert_result.inserted_id)
|
| 434 |
+
|
| 435 |
+
return jsonify(result)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# ─── History ──────────────────────────────────────────────────────────────────
|
| 439 |
+
|
| 440 |
+
@app.route("/api/history", methods=["GET"])
|
| 441 |
+
@require_auth
|
| 442 |
+
def get_history():
|
| 443 |
+
col = chats_col()
|
| 444 |
+
docs = list(col.find(
|
| 445 |
+
{"user_id": g.current_user["id"], "user_deleted": {"$ne": True}},
|
| 446 |
+
sort=[("created_at", -1)],
|
| 447 |
+
limit=50,
|
| 448 |
+
))
|
| 449 |
+
return jsonify({"history": [_serialize(d) for d in docs]})
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
@app.route("/api/history/<chat_id>", methods=["DELETE"])
|
| 453 |
+
@require_auth
|
| 454 |
+
def delete_chat(chat_id):
|
| 455 |
+
from bson import ObjectId as ObjId
|
| 456 |
+
col = chats_col()
|
| 457 |
+
try:
|
| 458 |
+
res = col.update_one(
|
| 459 |
+
{"_id": ObjId(chat_id), "user_id": g.current_user["id"]},
|
| 460 |
+
{"$set": {"user_deleted": True}}
|
| 461 |
+
)
|
| 462 |
+
except Exception:
|
| 463 |
+
return jsonify({"error": "Invalid chat ID."}), 400
|
| 464 |
+
|
| 465 |
+
if res.matched_count == 0:
|
| 466 |
+
return jsonify({"error": "Chat not found or not owned by you."}), 404
|
| 467 |
+
return jsonify({"message": "Chat deleted."})
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
@app.route("/api/history", methods=["DELETE"])
|
| 471 |
+
@require_auth
|
| 472 |
+
def clear_history():
|
| 473 |
+
col = chats_col()
|
| 474 |
+
res = col.update_many(
|
| 475 |
+
{"user_id": g.current_user["id"]},
|
| 476 |
+
{"$set": {"user_deleted": True}}
|
| 477 |
+
)
|
| 478 |
+
return jsonify({"message": f"Cleared {res.modified_count} chat(s)."})
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# ─── Admin Routes ─────────────────────────────────────────────────────────────
|
| 482 |
+
|
| 483 |
+
@app.route("/api/admin/users", methods=["GET"])
|
| 484 |
+
@require_admin
|
| 485 |
+
def admin_list_users():
|
| 486 |
+
col = users_col()
|
| 487 |
+
users = list(col.find({}, sort=[("created_at", -1)]))
|
| 488 |
+
result = []
|
| 489 |
+
for u in users:
|
| 490 |
+
u = _serialize(u)
|
| 491 |
+
u.pop("password", None)
|
| 492 |
+
result.append(u)
|
| 493 |
+
return jsonify({"users": result, "total": len(result)})
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
@app.route("/api/admin/users/<user_id>", methods=["PUT"])
|
| 497 |
+
@require_admin
|
| 498 |
+
def admin_update_user(user_id):
|
| 499 |
+
from bson import ObjectId as ObjId
|
| 500 |
+
data = request.get_json(silent=True) or {}
|
| 501 |
+
allowed_fields = {"name", "role", "is_active"}
|
| 502 |
+
update = {k: v for k, v in data.items() if k in allowed_fields}
|
| 503 |
+
|
| 504 |
+
if not update:
|
| 505 |
+
return jsonify({"error": "No valid fields to update."}), 400
|
| 506 |
+
|
| 507 |
+
col = users_col()
|
| 508 |
+
try:
|
| 509 |
+
res = col.update_one({"_id": ObjId(user_id)}, {"$set": update})
|
| 510 |
+
except Exception:
|
| 511 |
+
return jsonify({"error": "Invalid user ID."}), 400
|
| 512 |
+
|
| 513 |
+
if res.matched_count == 0:
|
| 514 |
+
return jsonify({"error": "User not found."}), 404
|
| 515 |
+
return jsonify({"message": "User updated successfully."})
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
@app.route("/api/admin/users/<user_id>", methods=["DELETE"])
|
| 519 |
+
@require_admin
|
| 520 |
+
def admin_delete_user(user_id):
|
| 521 |
+
from bson import ObjectId as ObjId
|
| 522 |
+
# Prevent self-deletion
|
| 523 |
+
if user_id == g.current_user["id"]:
|
| 524 |
+
return jsonify({"error": "You cannot delete your own account."}), 400
|
| 525 |
+
|
| 526 |
+
col = users_col()
|
| 527 |
+
try:
|
| 528 |
+
res = col.delete_one({"_id": ObjId(user_id)})
|
| 529 |
+
except Exception:
|
| 530 |
+
return jsonify({"error": "Invalid user ID."}), 400
|
| 531 |
+
|
| 532 |
+
if res.deleted_count == 0:
|
| 533 |
+
return jsonify({"error": "User not found."}), 404
|
| 534 |
+
|
| 535 |
+
# Also logically remove their chat history
|
| 536 |
+
chats_col().update_many(
|
| 537 |
+
{"user_id": user_id},
|
| 538 |
+
{"$set": {"user_deleted": True, "admin_deleted_user": True}}
|
| 539 |
+
)
|
| 540 |
+
return jsonify({"message": "User and their history deleted."})
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
@app.route("/api/admin/stats", methods=["GET"])
|
| 544 |
+
@require_admin
|
| 545 |
+
def admin_stats():
|
| 546 |
+
users = users_col()
|
| 547 |
+
chats = chats_col()
|
| 548 |
+
|
| 549 |
+
total_users = users.count_documents({})
|
| 550 |
+
total_queries = chats.count_documents({})
|
| 551 |
+
|
| 552 |
+
# Model usage breakdown
|
| 553 |
+
pipeline = [
|
| 554 |
+
{"$group": {"_id": "$model_id", "count": {"$sum": 1}}}
|
| 555 |
+
]
|
| 556 |
+
try:
|
| 557 |
+
model_usage = {doc["_id"]: doc["count"] for doc in chats.aggregate(pipeline)}
|
| 558 |
+
except Exception:
|
| 559 |
+
model_usage = {}
|
| 560 |
+
|
| 561 |
+
# Timeseries data for graphs
|
| 562 |
+
ts_pipeline = [
|
| 563 |
+
{"$project": {"date": {"$substr": ["$created_at", 0, 10]}}},
|
| 564 |
+
{"$group": {"_id": "$date", "queries": {"$sum": 1}}},
|
| 565 |
+
{"$sort": {"_id": 1}},
|
| 566 |
+
{"$limit": 30}
|
| 567 |
+
]
|
| 568 |
+
try:
|
| 569 |
+
timeseries = [{"date": doc["_id"], "queries": doc["queries"]} for doc in chats.aggregate(ts_pipeline)]
|
| 570 |
+
except Exception:
|
| 571 |
+
timeseries = []
|
| 572 |
+
|
| 573 |
+
return jsonify({
|
| 574 |
+
"total_users": total_users,
|
| 575 |
+
"total_queries": total_queries,
|
| 576 |
+
"model_usage": model_usage,
|
| 577 |
+
"timeseries": timeseries,
|
| 578 |
+
"db_mode": "mock" if is_using_mock() else "atlas",
|
| 579 |
+
})
|
| 580 |
+
|
| 581 |
+
@app.route("/api/admin/settings", methods=["GET"])
|
| 582 |
+
@require_admin
|
| 583 |
+
def get_settings():
|
| 584 |
+
col = settings_col()
|
| 585 |
+
doc = col.find_one({"_id": "system_config"})
|
| 586 |
+
if not doc:
|
| 587 |
+
doc = {"_id": "system_config", "disable_registrations": False, "maintenance_mode": False}
|
| 588 |
+
col.insert_one(doc)
|
| 589 |
+
return jsonify({"settings": _serialize(doc)})
|
| 590 |
+
|
| 591 |
+
@app.route("/api/admin/settings", methods=["PUT"])
|
| 592 |
+
@require_admin
|
| 593 |
+
def update_settings():
|
| 594 |
+
data = request.get_json(silent=True) or {}
|
| 595 |
+
allowed = {"disable_registrations", "maintenance_mode"}
|
| 596 |
+
update = {k: v for k, v in data.items() if k in allowed}
|
| 597 |
+
if not update:
|
| 598 |
+
return jsonify({"error": "No valid settings provided."}), 400
|
| 599 |
+
|
| 600 |
+
col = settings_col()
|
| 601 |
+
col.update_one({"_id": "system_config"}, {"$set": update}, upsert=True)
|
| 602 |
+
return jsonify({"message": "Settings updated."})
|
| 603 |
+
|
| 604 |
+
@app.route("/api/admin/models/<model_id>", methods=["PUT"])
|
| 605 |
+
@require_admin
|
| 606 |
+
def toggle_model_status(model_id):
|
| 607 |
+
if model_id not in qa_engine.MODELS:
|
| 608 |
+
return jsonify({"error": "Invalid model ID."}), 404
|
| 609 |
+
|
| 610 |
+
data = request.get_json(silent=True) or {}
|
| 611 |
+
target_status = data.get("status")
|
| 612 |
+
if target_status not in ["ready", "maintenance"]:
|
| 613 |
+
return jsonify({"error": "Invalid status."}), 400
|
| 614 |
+
|
| 615 |
+
col = settings_col()
|
| 616 |
+
col.update_one({"_id": "system_config"}, {"$set": {f"model_status.{model_id}": target_status}}, upsert=True)
|
| 617 |
+
|
| 618 |
+
return jsonify({"message": f"Model {model_id} status updated to {target_status}."})
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
# ─── Entry Point ──────────────────────────────────────────────────────────────
|
| 623 |
+
|
| 624 |
+
if __name__ == "__main__":
|
| 625 |
+
logger.info("=" * 60)
|
| 626 |
+
logger.info(" SQuAD QA System — Backend Starting")
|
| 627 |
+
logger.info("=" * 60)
|
| 628 |
+
|
| 629 |
+
# Initialise AI models
|
| 630 |
+
qa_engine.init_all_models()
|
| 631 |
+
|
| 632 |
+
# Seed admin user
|
| 633 |
+
_seed_admin()
|
| 634 |
+
|
| 635 |
+
flask_env = os.getenv("FLASK_ENV", "development")
|
| 636 |
+
debug = flask_env == "development"
|
| 637 |
+
|
| 638 |
+
app.run(host="0.0.0.0", port=5000, debug=debug)
|
auth.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
auth.py — JWT-based authentication helpers.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- generate_token(user_id, role) → signed JWT string
|
| 6 |
+
- @require_auth → validates JWT, injects g.current_user
|
| 7 |
+
- @require_admin → same as @require_auth + checks admin role
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import jwt
|
| 12 |
+
import logging
|
| 13 |
+
from functools import wraps
|
| 14 |
+
from datetime import datetime, timedelta, timezone
|
| 15 |
+
|
| 16 |
+
from flask import request, jsonify, g
|
| 17 |
+
from dotenv import load_dotenv
|
| 18 |
+
|
| 19 |
+
load_dotenv()
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
JWT_SECRET = os.getenv("JWT_SECRET", "default-insecure-secret-change-me")
|
| 24 |
+
JWT_EXPIRY_HOURS = int(os.getenv("JWT_EXPIRY_HOURS", "24"))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ─── Token Generation ─────────────────────────────────────────────────────────
|
| 28 |
+
|
| 29 |
+
def generate_token(user_id: str, role: str) -> str:
|
| 30 |
+
"""Create a signed JWT valid for JWT_EXPIRY_HOURS hours."""
|
| 31 |
+
payload = {
|
| 32 |
+
"sub": str(user_id),
|
| 33 |
+
"role": role,
|
| 34 |
+
"iat": datetime.now(timezone.utc),
|
| 35 |
+
"exp": datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRY_HOURS),
|
| 36 |
+
}
|
| 37 |
+
return jwt.encode(payload, JWT_SECRET, algorithm="HS256")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def decode_token(token: str) -> dict:
|
| 41 |
+
"""Decode and verify a JWT. Raises jwt.exceptions on failure."""
|
| 42 |
+
return jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ─── Decorators ───────────────────────────────────────────────────────────────
|
| 46 |
+
|
| 47 |
+
def require_auth(f):
|
| 48 |
+
"""Decorator: validates Bearer JWT and populates g.current_user."""
|
| 49 |
+
@wraps(f)
|
| 50 |
+
def decorated(*args, **kwargs):
|
| 51 |
+
auth_header = request.headers.get("Authorization", "")
|
| 52 |
+
if not auth_header.startswith("Bearer "):
|
| 53 |
+
return jsonify({"error": "Authorization header missing or malformed."}), 401
|
| 54 |
+
|
| 55 |
+
token = auth_header.split(" ", 1)[1]
|
| 56 |
+
try:
|
| 57 |
+
payload = decode_token(token)
|
| 58 |
+
|
| 59 |
+
# Real-time suspension check
|
| 60 |
+
from utils.db import users_col
|
| 61 |
+
from bson import ObjectId as ObjId
|
| 62 |
+
col = users_col()
|
| 63 |
+
try:
|
| 64 |
+
user = col.find_one({"_id": ObjId(payload["sub"])})
|
| 65 |
+
except Exception:
|
| 66 |
+
user = col.find_one({"_id": payload["sub"]})
|
| 67 |
+
|
| 68 |
+
if not user or not user.get("is_active", True):
|
| 69 |
+
return jsonify({"error": "Your account has been suspended by an administrator."}), 403
|
| 70 |
+
|
| 71 |
+
g.current_user = {
|
| 72 |
+
"id": payload["sub"],
|
| 73 |
+
"role": payload["role"],
|
| 74 |
+
}
|
| 75 |
+
except jwt.ExpiredSignatureError:
|
| 76 |
+
return jsonify({"error": "Token expired. Please log in again."}), 401
|
| 77 |
+
except jwt.InvalidTokenError as exc:
|
| 78 |
+
return jsonify({"error": f"Invalid token: {exc}"}), 401
|
| 79 |
+
|
| 80 |
+
return f(*args, **kwargs)
|
| 81 |
+
return decorated
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def require_admin(f):
|
| 85 |
+
"""Decorator: validates JWT AND checks for admin role."""
|
| 86 |
+
@wraps(f)
|
| 87 |
+
@require_auth
|
| 88 |
+
def decorated(*args, **kwargs):
|
| 89 |
+
if g.current_user.get("role") != "admin":
|
| 90 |
+
return jsonify({"error": "Admin access required."}), 403
|
| 91 |
+
return f(*args, **kwargs)
|
| 92 |
+
return decorated
|
data_loader/load_squad_json.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
def load_squad_json(path):
|
| 4 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 5 |
+
data = json.load(f)
|
| 6 |
+
|
| 7 |
+
samples = []
|
| 8 |
+
|
| 9 |
+
for article in data["data"]:
|
| 10 |
+
for para in article["paragraphs"]:
|
| 11 |
+
context = para["context"]
|
| 12 |
+
|
| 13 |
+
for qa in para["qas"]:
|
| 14 |
+
if not qa["answers"]:
|
| 15 |
+
continue
|
| 16 |
+
|
| 17 |
+
ans = qa["answers"][0]
|
| 18 |
+
|
| 19 |
+
samples.append({
|
| 20 |
+
"context": context,
|
| 21 |
+
"question": qa["question"],
|
| 22 |
+
"answer_text": ans["text"]
|
| 23 |
+
})
|
| 24 |
+
|
| 25 |
+
return samples
|
gunicorn.conf.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# Gunicorn configuration for production deployment
|
| 3 |
+
port = os.environ.get("PORT", "5000")
|
| 4 |
+
bind = f"0.0.0.0:{port}"
|
| 5 |
+
workers = 2 # Keep low — each worker loads BERT (~400MB RAM)
|
| 6 |
+
timeout = 120 # BERT inference can take a few seconds
|
| 7 |
+
accesslog = "-" # stdout
|
| 8 |
+
errorlog = "-" # stdout
|
| 9 |
+
loglevel = "info"
|
| 10 |
+
preload_app = True # Load model once, share across workers
|
main.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from utils.file_loader import load_txt, load_pdf, load_docx
|
| 3 |
+
from models.qa_model import QAModel
|
| 4 |
+
from utils.vocab import encode
|
| 5 |
+
from utils.preprocess import tokenize
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
checkpoint = torch.load("qa_model.pth", map_location="cpu")
|
| 9 |
+
vocab = checkpoint["vocab"]
|
| 10 |
+
|
| 11 |
+
model = QAModel(len(vocab))
|
| 12 |
+
model.load_state_dict(checkpoint["model_state"])
|
| 13 |
+
model.eval()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_context(path):
|
| 17 |
+
if path.endswith(".txt"):
|
| 18 |
+
return load_txt(path)
|
| 19 |
+
elif path.endswith(".pdf"):
|
| 20 |
+
return load_pdf(path)
|
| 21 |
+
elif path.endswith(".docx"):
|
| 22 |
+
return load_docx(path)
|
| 23 |
+
else:
|
| 24 |
+
raise ValueError("Unsupported file format")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def extract_answer(question, context):
|
| 28 |
+
q_tokens = tokenize(question)
|
| 29 |
+
c_tokens = tokenize(context)
|
| 30 |
+
|
| 31 |
+
tokens = q_tokens + ["[SEP]"] + c_tokens
|
| 32 |
+
encoded = encode(tokens, vocab)
|
| 33 |
+
|
| 34 |
+
max_len = 300
|
| 35 |
+
if len(encoded) < max_len:
|
| 36 |
+
encoded += [0] * (max_len - len(encoded))
|
| 37 |
+
else:
|
| 38 |
+
encoded = encoded[:max_len]
|
| 39 |
+
|
| 40 |
+
x = torch.tensor(encoded).unsqueeze(0)
|
| 41 |
+
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
start_logits, end_logits = model(x)
|
| 44 |
+
|
| 45 |
+
start = torch.argmax(start_logits, dim=1).item()
|
| 46 |
+
end = torch.argmax(end_logits, dim=1).item()
|
| 47 |
+
|
| 48 |
+
if start > end or start >= len(tokens):
|
| 49 |
+
return "No answer found"
|
| 50 |
+
|
| 51 |
+
return " ".join(tokens[start:end+1])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
print("===== BiLSTM QA (Fixed) =====\n")
|
| 56 |
+
|
| 57 |
+
path = input("Enter file path: ")
|
| 58 |
+
context = load_context(path)
|
| 59 |
+
|
| 60 |
+
question = input("Enter question: ")
|
| 61 |
+
|
| 62 |
+
answer = extract_answer(question, context)
|
| 63 |
+
|
| 64 |
+
print("\nAnswer:", answer)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
main()
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# models package
|
models/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (144 Bytes). View file
|
|
|
models/__pycache__/bert_model.cpython-314.pyc
ADDED
|
Binary file (4.78 kB). View file
|
|
|
models/__pycache__/model2.cpython-314.pyc
ADDED
|
Binary file (1.31 kB). View file
|
|
|
models/__pycache__/model3.cpython-314.pyc
ADDED
|
Binary file (4.6 kB). View file
|
|
|
models/__pycache__/qa_model.cpython-314.pyc
ADDED
|
Binary file (1.7 kB). View file
|
|
|
models/bert_model.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
bert_model.py — HuggingFace BERT Question Answering Model.
|
| 3 |
+
|
| 4 |
+
Model: deepset/bert-base-cased-squad2
|
| 5 |
+
Uses direct PyTorch inference (compatible with transformers 5.x).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
_tokenizer = None
|
| 13 |
+
_model = None
|
| 14 |
+
MODEL_NAME = "deepset/bert-base-cased-squad2"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def init_bert_model():
|
| 18 |
+
"""Load the BERT QA model. Called once at app startup."""
|
| 19 |
+
global _tokenizer, _model
|
| 20 |
+
try:
|
| 21 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
| 22 |
+
logger.info(f"[BERT] Loading model '{MODEL_NAME}' ...")
|
| 23 |
+
_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 24 |
+
_model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
|
| 25 |
+
_model.eval()
|
| 26 |
+
logger.info("[BERT] Model loaded and ready.")
|
| 27 |
+
except Exception as exc:
|
| 28 |
+
logger.error(f"[BERT] Failed to load model: {exc}")
|
| 29 |
+
_tokenizer = None
|
| 30 |
+
_model = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _run_qa_inference(context: str, question: str) -> dict:
|
| 34 |
+
"""Direct PyTorch inference — works with any transformers version."""
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
|
| 38 |
+
inputs = _tokenizer(
|
| 39 |
+
question, context,
|
| 40 |
+
return_tensors="pt",
|
| 41 |
+
truncation=True,
|
| 42 |
+
max_length=512,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
outputs = _model(**inputs)
|
| 47 |
+
|
| 48 |
+
start_logits = outputs.start_logits[0]
|
| 49 |
+
end_logits = outputs.end_logits[0]
|
| 50 |
+
|
| 51 |
+
start_idx = int(torch.argmax(start_logits))
|
| 52 |
+
end_idx = int(torch.argmax(end_logits)) + 1
|
| 53 |
+
|
| 54 |
+
if end_idx <= start_idx:
|
| 55 |
+
end_idx = start_idx + 1
|
| 56 |
+
|
| 57 |
+
input_ids = inputs["input_ids"][0]
|
| 58 |
+
answer_tokens = input_ids[start_idx:end_idx]
|
| 59 |
+
answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip()
|
| 60 |
+
|
| 61 |
+
# Confidence approximation via softmax
|
| 62 |
+
start_prob = float(F.softmax(start_logits, dim=0)[start_idx])
|
| 63 |
+
end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1])
|
| 64 |
+
score = round((start_prob + end_prob) / 2, 4)
|
| 65 |
+
|
| 66 |
+
return {"answer": answer, "score": score}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def predict(context: str, question: str) -> dict:
|
| 70 |
+
"""
|
| 71 |
+
Run QA inference.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
{
|
| 75 |
+
"answer": str,
|
| 76 |
+
"score": float (0.0–1.0),
|
| 77 |
+
"model": "BERT",
|
| 78 |
+
"model_id": "bert"
|
| 79 |
+
}
|
| 80 |
+
"""
|
| 81 |
+
if _model is None or _tokenizer is None:
|
| 82 |
+
return {
|
| 83 |
+
"answer": "BERT model is not loaded. Please check server logs.",
|
| 84 |
+
"score": 0.0,
|
| 85 |
+
"model": "BERT",
|
| 86 |
+
"model_id": "bert",
|
| 87 |
+
"error": True,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
if not context or not question:
|
| 91 |
+
return {
|
| 92 |
+
"answer": "Context and question must not be empty.",
|
| 93 |
+
"score": 0.0,
|
| 94 |
+
"model": "BERT",
|
| 95 |
+
"model_id": "bert",
|
| 96 |
+
"error": True,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
result = _run_qa_inference(context=context, question=question)
|
| 101 |
+
score = result["score"]
|
| 102 |
+
answer = result["answer"]
|
| 103 |
+
|
| 104 |
+
if score < 0.05 or "[CLS]" in answer or not answer:
|
| 105 |
+
answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context."
|
| 106 |
+
score = 0.0
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
"answer": answer,
|
| 110 |
+
"score": score,
|
| 111 |
+
"model": "BERT",
|
| 112 |
+
"model_id": "bert",
|
| 113 |
+
"error": False,
|
| 114 |
+
}
|
| 115 |
+
except Exception as exc:
|
| 116 |
+
logger.error(f"[BERT] Inference error: {exc}")
|
| 117 |
+
return {
|
| 118 |
+
"answer": f"Inference error: {exc}",
|
| 119 |
+
"score": 0.0,
|
| 120 |
+
"model": "BERT",
|
| 121 |
+
"model_id": "bert",
|
| 122 |
+
"error": True,
|
| 123 |
+
}
|
models/model2.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model2.py — Placeholder for Model 2.
|
| 3 |
+
|
| 4 |
+
Replace this file with your actual model implementation.
|
| 5 |
+
The predict() function signature must match:
|
| 6 |
+
predict(context: str, question: str) -> dict
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def init_model2():
|
| 15 |
+
"""Called at startup. No-op until model is integrated."""
|
| 16 |
+
logger.info("[Model2] Placeholder — not yet integrated.")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def predict(context: str, question: str) -> dict:
|
| 20 |
+
"""Stub: returns a friendly 'coming soon' response."""
|
| 21 |
+
return {
|
| 22 |
+
"answer": "Model 2 is not yet integrated. Please use BERT for now.",
|
| 23 |
+
"score": 0.0,
|
| 24 |
+
"model": "Model 2",
|
| 25 |
+
"model_id": "model2",
|
| 26 |
+
"error": False,
|
| 27 |
+
"stub": True,
|
| 28 |
+
}
|
models/model3.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
model3.py — Integration for BiLSTM Model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
from models.qa_model import QAModel
|
| 9 |
+
|
| 10 |
+
# Import vocab utilities and preprocess utilities
|
| 11 |
+
from utils.preprocess import tokenize
|
| 12 |
+
from utils.vocab import encode
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
model = None
|
| 17 |
+
vocab = None
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
def init_model3():
|
| 21 |
+
global model, vocab
|
| 22 |
+
logger.info("[Model3] Initialising BiLSTM from qa_model.pth...")
|
| 23 |
+
|
| 24 |
+
# Assumes qa_model.pth is at the root of the backend directory
|
| 25 |
+
model_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "qa_model.pth")
|
| 26 |
+
if not os.path.exists(model_path):
|
| 27 |
+
logger.warning(f"[Model3] qa_model.pth not found at {model_path}! Model 3 inference will fail.")
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 32 |
+
vocab = checkpoint["vocab"]
|
| 33 |
+
|
| 34 |
+
model = QAModel(len(vocab))
|
| 35 |
+
model.load_state_dict(checkpoint["model_state"])
|
| 36 |
+
model.to(device)
|
| 37 |
+
model.eval()
|
| 38 |
+
logger.info("[Model3] BiLSTM successfully loaded.")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"[Model3] Failed to load BiLSTM model: {e}")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def predict(context: str, question: str) -> dict:
|
| 44 |
+
"""Predict using the loaded BiLSTM."""
|
| 45 |
+
if model is None or vocab is None:
|
| 46 |
+
return {
|
| 47 |
+
"answer": "BiLSTM model weights (qa_model.pth) not found or failed to load. Please make sure the trained model is placed in the backend folder.",
|
| 48 |
+
"score": 0.0,
|
| 49 |
+
"model": "BiLSTM",
|
| 50 |
+
"model_id": "model3",
|
| 51 |
+
"error": True,
|
| 52 |
+
"stub": False,
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
q_tokens = tokenize(question)
|
| 57 |
+
c_tokens = tokenize(context)
|
| 58 |
+
|
| 59 |
+
tokens = q_tokens + ["[SEP]"] + c_tokens
|
| 60 |
+
encoded = encode(tokens, vocab)
|
| 61 |
+
|
| 62 |
+
max_len = 300
|
| 63 |
+
if len(encoded) < max_len:
|
| 64 |
+
encoded += [0] * (max_len - len(encoded))
|
| 65 |
+
else:
|
| 66 |
+
encoded = encoded[:max_len]
|
| 67 |
+
|
| 68 |
+
x = torch.tensor(encoded).unsqueeze(0).to(device)
|
| 69 |
+
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
start_logits, end_logits = model(x)
|
| 72 |
+
|
| 73 |
+
start = torch.argmax(start_logits, dim=1).item()
|
| 74 |
+
end = torch.argmax(end_logits, dim=1).item()
|
| 75 |
+
|
| 76 |
+
if start > end or start >= len(tokens):
|
| 77 |
+
answer = "No answer found"
|
| 78 |
+
score = 0.0
|
| 79 |
+
else:
|
| 80 |
+
answer = " ".join(tokens[start:end+1])
|
| 81 |
+
# Extract basic score approximations from logits if needed, but returning dummy score for now.
|
| 82 |
+
score = 0.85
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"answer": answer,
|
| 86 |
+
"score": score,
|
| 87 |
+
"model": "BiLSTM",
|
| 88 |
+
"model_id": "model3",
|
| 89 |
+
"error": False,
|
| 90 |
+
}
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logger.error(f"[Model3] Inference error: {e}")
|
| 93 |
+
return {
|
| 94 |
+
"answer": "Inference error occurred.",
|
| 95 |
+
"score": 0.0,
|
| 96 |
+
"model": "BiLSTM",
|
| 97 |
+
"model_id": "model3",
|
| 98 |
+
"error": True,
|
| 99 |
+
"stub": False,
|
| 100 |
+
}
|
models/qa_model.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
class QAModel(nn.Module):
|
| 5 |
+
def __init__(self, vocab_size, embed_dim=200, hidden_dim=256):
|
| 6 |
+
super().__init__()
|
| 7 |
+
|
| 8 |
+
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
| 9 |
+
|
| 10 |
+
self.lstm = nn.LSTM(
|
| 11 |
+
embed_dim,
|
| 12 |
+
hidden_dim,
|
| 13 |
+
batch_first=True,
|
| 14 |
+
bidirectional=True
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
self.fc_start = nn.Linear(hidden_dim*2, 1)
|
| 18 |
+
self.fc_end = nn.Linear(hidden_dim*2, 1)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
x = self.embedding(x)
|
| 22 |
+
out, _ = self.lstm(x)
|
| 23 |
+
|
| 24 |
+
start = self.fc_start(out).squeeze(-1)
|
| 25 |
+
end = self.fc_end(out).squeeze(-1)
|
| 26 |
+
|
| 27 |
+
return start, end
|
qa_engine.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
qa_engine.py — Model router.
|
| 3 |
+
|
| 4 |
+
Routes inference requests to the correct model module based on model_id.
|
| 5 |
+
Initialises all models at startup.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from models import bert_model, model2, model3
|
| 10 |
+
from utils.db import settings_col
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
# ─── Registry ────────────────────────────────────────────────────────────────
|
| 15 |
+
|
| 16 |
+
MODELS = {
|
| 17 |
+
"bert": {
|
| 18 |
+
"id": "bert",
|
| 19 |
+
"name": "BERT",
|
| 20 |
+
"description": "",
|
| 21 |
+
"status": "ready",
|
| 22 |
+
"module": bert_model,
|
| 23 |
+
},
|
| 24 |
+
"model2": {
|
| 25 |
+
"id": "model2",
|
| 26 |
+
"name": "DistilBERT",
|
| 27 |
+
"description": "",
|
| 28 |
+
"status": "coming_soon",
|
| 29 |
+
"module": model2,
|
| 30 |
+
},
|
| 31 |
+
"model3": {
|
| 32 |
+
"id": "model3",
|
| 33 |
+
"name": "BiLSTM",
|
| 34 |
+
"description": "",
|
| 35 |
+
"status": "ready",
|
| 36 |
+
"module": model3,
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def init_all_models():
|
| 42 |
+
"""Initialise all models at application startup."""
|
| 43 |
+
logger.info("[QAEngine] Initialising models...")
|
| 44 |
+
bert_model.init_bert_model()
|
| 45 |
+
model2.init_model2()
|
| 46 |
+
model3.init_model3()
|
| 47 |
+
logger.info("[QAEngine] All models initialised.")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_models_info() -> list:
|
| 51 |
+
"""Return metadata list for all models (used by /api/models endpoint)."""
|
| 52 |
+
try:
|
| 53 |
+
sys_conf = settings_col().find_one({"_id": "system_config"}) or {}
|
| 54 |
+
model_status_overrides = sys_conf.get("model_status", {})
|
| 55 |
+
except Exception:
|
| 56 |
+
model_status_overrides = {}
|
| 57 |
+
|
| 58 |
+
return [
|
| 59 |
+
{
|
| 60 |
+
"id": m["id"],
|
| 61 |
+
"name": m["name"],
|
| 62 |
+
"description": m["description"],
|
| 63 |
+
"status": model_status_overrides.get(m["id"], m["status"]),
|
| 64 |
+
}
|
| 65 |
+
for m in MODELS.values()
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def run_inference(model_id: str, context: str, question: str) -> dict:
|
| 70 |
+
"""
|
| 71 |
+
Route a QA request to the appropriate model.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model_id: One of "bert", "model2", "model3"
|
| 75 |
+
context: The passage/document text
|
| 76 |
+
question: The question to answer
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
dict with keys: answer, score, model, model_id, error
|
| 80 |
+
"""
|
| 81 |
+
if model_id not in MODELS:
|
| 82 |
+
return {
|
| 83 |
+
"answer": f"Unknown model '{model_id}'. Available: {list(MODELS.keys())}",
|
| 84 |
+
"score": 0.0,
|
| 85 |
+
"model": "Unknown",
|
| 86 |
+
"model_id": model_id,
|
| 87 |
+
"error": True,
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
sys_conf = settings_col().find_one({"_id": "system_config"}) or {}
|
| 92 |
+
if sys_conf.get("maintenance_mode", False):
|
| 93 |
+
return {
|
| 94 |
+
"answer": "System is currently under maintenance. Please try again later.",
|
| 95 |
+
"score": 0.0,
|
| 96 |
+
"model": "System",
|
| 97 |
+
"model_id": model_id,
|
| 98 |
+
"error": True
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
status_override = sys_conf.get("model_status", {}).get(model_id)
|
| 102 |
+
current_status = status_override if status_override else MODELS[model_id]["status"]
|
| 103 |
+
if current_status != "ready":
|
| 104 |
+
return {
|
| 105 |
+
"answer": "This model is currently disabled by an administrator.",
|
| 106 |
+
"score": 0.0,
|
| 107 |
+
"model": MODELS[model_id]["name"],
|
| 108 |
+
"model_id": model_id,
|
| 109 |
+
"error": True
|
| 110 |
+
}
|
| 111 |
+
except Exception:
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
module = MODELS[model_id]["module"]
|
| 115 |
+
return module.predict(context=context, question=question)
|
qa_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5ff35d1b92957d46df75fa375df83cf39c8998e51d4098cdb061a8b7fa7d028
|
| 3 |
+
size 43858657
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flask==3.0.3
|
| 2 |
+
flask-cors==4.0.1
|
| 3 |
+
flask-bcrypt==1.0.1
|
| 4 |
+
pymongo==4.7.3
|
| 5 |
+
dnspython==2.6.1
|
| 6 |
+
pyjwt==2.8.0
|
| 7 |
+
python-dotenv==1.0.1
|
| 8 |
+
transformers>=4.40.0
|
| 9 |
+
torch>=2.0.0
|
| 10 |
+
PyPDF2==3.0.1
|
| 11 |
+
python-docx==1.1.2
|
| 12 |
+
gunicorn==22.0.0
|
| 13 |
+
Werkzeug==3.0.3
|
| 14 |
+
mongomock==4.1.2
|
| 15 |
+
python-magic==0.4.27
|
| 16 |
+
flask-limiter==3.7.0
|
train.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader
|
| 4 |
+
|
| 5 |
+
from data_loader.load_squad_json import load_squad_json
|
| 6 |
+
from utils.squad_preprocess import process_sample
|
| 7 |
+
from utils.vocab import build_vocab, encode
|
| 8 |
+
from models.qa_model import QAModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class QADataset(Dataset):
|
| 12 |
+
def __init__(self, samples, vocab, max_len=300):
|
| 13 |
+
self.data = []
|
| 14 |
+
|
| 15 |
+
for s in samples:
|
| 16 |
+
item = process_sample(s)
|
| 17 |
+
if not item:
|
| 18 |
+
continue
|
| 19 |
+
|
| 20 |
+
tokens = item["tokens"]
|
| 21 |
+
encoded = encode(tokens, vocab)
|
| 22 |
+
|
| 23 |
+
if len(encoded) < max_len:
|
| 24 |
+
encoded += [0] * (max_len - len(encoded))
|
| 25 |
+
else:
|
| 26 |
+
encoded = encoded[:max_len]
|
| 27 |
+
|
| 28 |
+
start = item["start"]
|
| 29 |
+
end = item["end"]
|
| 30 |
+
|
| 31 |
+
if start >= max_len or end >= max_len:
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
self.data.append((encoded, start, end))
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.data)
|
| 38 |
+
|
| 39 |
+
def __getitem__(self, idx):
|
| 40 |
+
x, s, e = self.data[idx]
|
| 41 |
+
return torch.tensor(x), torch.tensor(s), torch.tensor(e)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def train():
|
| 45 |
+
print("Loading data...")
|
| 46 |
+
raw = load_squad_json("data/train-v2.0.json")[:30000]
|
| 47 |
+
|
| 48 |
+
print("Building vocab...")
|
| 49 |
+
all_tokens = []
|
| 50 |
+
for s in raw:
|
| 51 |
+
item = process_sample(s)
|
| 52 |
+
if item:
|
| 53 |
+
all_tokens += item["tokens"]
|
| 54 |
+
|
| 55 |
+
vocab = build_vocab(all_tokens)
|
| 56 |
+
|
| 57 |
+
print("Preparing dataset...")
|
| 58 |
+
dataset = QADataset(raw, vocab)
|
| 59 |
+
loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
| 60 |
+
|
| 61 |
+
print("Initializing model...")
|
| 62 |
+
model = QAModel(len(vocab))
|
| 63 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 64 |
+
loss_fn = nn.CrossEntropyLoss()
|
| 65 |
+
|
| 66 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 67 |
+
model.to(device)
|
| 68 |
+
|
| 69 |
+
print("Training...\n")
|
| 70 |
+
|
| 71 |
+
for epoch in range(5):
|
| 72 |
+
total_loss = 0
|
| 73 |
+
|
| 74 |
+
for x, start, end in loader:
|
| 75 |
+
x = x.to(device)
|
| 76 |
+
start = start.to(device)
|
| 77 |
+
end = end.to(device)
|
| 78 |
+
|
| 79 |
+
pred_start, pred_end = model(x)
|
| 80 |
+
|
| 81 |
+
loss = loss_fn(pred_start, start) + loss_fn(pred_end, end)
|
| 82 |
+
|
| 83 |
+
optimizer.zero_grad()
|
| 84 |
+
loss.backward()
|
| 85 |
+
optimizer.step()
|
| 86 |
+
|
| 87 |
+
total_loss += loss.item()
|
| 88 |
+
|
| 89 |
+
print(f"Epoch {epoch+1} Loss: {total_loss:.2f}")
|
| 90 |
+
|
| 91 |
+
torch.save({
|
| 92 |
+
"model_state": model.state_dict(),
|
| 93 |
+
"vocab": vocab
|
| 94 |
+
}, "qa_model.pth")
|
| 95 |
+
|
| 96 |
+
print("\n✅ Model trained and saved!")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
train()
|
utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# utils package
|
utils/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
utils/__pycache__/db.cpython-314.pyc
ADDED
|
Binary file (3.57 kB). View file
|
|
|
utils/__pycache__/pdf_parser.cpython-314.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
utils/__pycache__/preprocess.cpython-314.pyc
ADDED
|
Binary file (448 Bytes). View file
|
|
|
utils/__pycache__/vocab.cpython-314.pyc
ADDED
|
Binary file (713 Bytes). View file
|
|
|
utils/db.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
db.py — MongoDB Atlas connection with mongomock fallback.
|
| 3 |
+
If MONGO_URI is not set or the connection fails, the app runs on an
|
| 4 |
+
in-memory mock store so development works without any database.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import logging
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
MONGO_URI = os.getenv("MONGO_URI") or os.getenv("MONGODB_URI") or ""
|
| 16 |
+
DB_NAME = "squad_qa"
|
| 17 |
+
|
| 18 |
+
_client = None
|
| 19 |
+
_db = None
|
| 20 |
+
_using_mock = False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _connect_atlas():
|
| 24 |
+
"""Attempt to connect to MongoDB Atlas (or local Mongo)."""
|
| 25 |
+
global _client, _db, _using_mock
|
| 26 |
+
try:
|
| 27 |
+
from pymongo import MongoClient
|
| 28 |
+
from pymongo.errors import ConnectionFailure, ConfigurationError, ServerSelectionTimeoutError
|
| 29 |
+
|
| 30 |
+
if not MONGO_URI or "username:password" in MONGO_URI:
|
| 31 |
+
raise ValueError("MONGO_URI not configured — falling back to mock.")
|
| 32 |
+
|
| 33 |
+
_client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000, tls=True, tlsAllowInvalidCertificates=True)
|
| 34 |
+
# Trigger actual connection check
|
| 35 |
+
_client.admin.command("ping")
|
| 36 |
+
_db = _client[DB_NAME]
|
| 37 |
+
_using_mock = False
|
| 38 |
+
logger.info("[DB] Connected to MongoDB Atlas successfully.")
|
| 39 |
+
except Exception as exc:
|
| 40 |
+
logger.warning(f"[DB] MongoDB connection failed: {exc}")
|
| 41 |
+
logger.warning("[DB] Falling back to in-memory mongomock.")
|
| 42 |
+
_connect_mock()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _connect_mock():
|
| 46 |
+
"""Fall back to mongomock (in-memory, no persistence)."""
|
| 47 |
+
global _client, _db, _using_mock
|
| 48 |
+
try:
|
| 49 |
+
import mongomock
|
| 50 |
+
_client = mongomock.MongoClient()
|
| 51 |
+
_db = _client[DB_NAME]
|
| 52 |
+
_using_mock = True
|
| 53 |
+
logger.warning("[DB] Running on mongomock — data will NOT persist across restarts.")
|
| 54 |
+
except ImportError:
|
| 55 |
+
logger.error("[DB] mongomock not installed. Database unavailable.")
|
| 56 |
+
_db = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_db():
|
| 60 |
+
"""Return the active database handle (Atlas or mock)."""
|
| 61 |
+
global _db
|
| 62 |
+
if _db is None:
|
| 63 |
+
_connect_atlas()
|
| 64 |
+
return _db
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def is_using_mock():
|
| 68 |
+
return _using_mock
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Initialise on import
|
| 72 |
+
_connect_atlas()
|
| 73 |
+
|
| 74 |
+
# Convenience collection accessors
|
| 75 |
+
def users_col():
|
| 76 |
+
return get_db()["users"]
|
| 77 |
+
|
| 78 |
+
def chats_col():
|
| 79 |
+
return get_db()["chats"]
|
| 80 |
+
|
| 81 |
+
def settings_col():
|
| 82 |
+
return get_db()["settings"]
|
utils/file_loader.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import PyPDF2
|
| 2 |
+
import docx
|
| 3 |
+
|
| 4 |
+
def load_txt(file_path):
|
| 5 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 6 |
+
return f.read()
|
| 7 |
+
|
| 8 |
+
def load_pdf(file_path):
|
| 9 |
+
text = ""
|
| 10 |
+
with open(file_path, "rb") as f:
|
| 11 |
+
reader = PyPDF2.PdfReader(f)
|
| 12 |
+
for page in reader.pages:
|
| 13 |
+
if page.extract_text():
|
| 14 |
+
text += page.extract_text()
|
| 15 |
+
return text
|
| 16 |
+
|
| 17 |
+
def load_docx(file_path):
|
| 18 |
+
doc = docx.Document(file_path)
|
| 19 |
+
return "\n".join([p.text for p in doc.paragraphs])
|
utils/pdf_parser.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
pdf_parser.py — Extract plain text from PDF, DOCX, and TXT files.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
PDF_MAX_PAGES = int(os.getenv("PDF_MAX_PAGES", "15"))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def extract_text_from_pdf(file_bytes: bytes) -> str:
|
| 15 |
+
"""Extract text from a PDF byte stream (up to PDF_MAX_PAGES pages)."""
|
| 16 |
+
try:
|
| 17 |
+
import PyPDF2
|
| 18 |
+
reader = PyPDF2.PdfReader(BytesIO(file_bytes))
|
| 19 |
+
pages = reader.pages[:PDF_MAX_PAGES]
|
| 20 |
+
text = "\n".join(page.extract_text() or "" for page in pages)
|
| 21 |
+
return text.strip()
|
| 22 |
+
except Exception as exc:
|
| 23 |
+
logger.error(f"[PDF] Extraction failed: {exc}")
|
| 24 |
+
return ""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def extract_text_from_docx(file_bytes: bytes) -> str:
|
| 28 |
+
"""Extract text from a DOCX byte stream."""
|
| 29 |
+
try:
|
| 30 |
+
import docx
|
| 31 |
+
from io import BytesIO as _BytesIO
|
| 32 |
+
doc = docx.Document(_BytesIO(file_bytes))
|
| 33 |
+
return "\n".join(para.text for para in doc.paragraphs).strip()
|
| 34 |
+
except Exception as exc:
|
| 35 |
+
logger.error(f"[DOCX] Extraction failed: {exc}")
|
| 36 |
+
return ""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def extract_text(file_bytes: bytes, filename: str) -> str:
|
| 40 |
+
"""Dispatch extraction based on file extension."""
|
| 41 |
+
ext = os.path.splitext(filename.lower())[1]
|
| 42 |
+
if ext == ".pdf":
|
| 43 |
+
return extract_text_from_pdf(file_bytes)
|
| 44 |
+
elif ext in (".docx", ".doc"):
|
| 45 |
+
return extract_text_from_docx(file_bytes)
|
| 46 |
+
elif ext == ".txt":
|
| 47 |
+
return file_bytes.decode("utf-8", errors="ignore").strip()
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError(f"Unsupported file type: {ext}. Allowed: PDF, DOCX, TXT.")
|
utils/preprocess.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
def tokenize(text):
|
| 4 |
+
text = text.lower()
|
| 5 |
+
text = re.sub(r"[^\w\s]", "", text)
|
| 6 |
+
return text.split()
|
utils/squad_preprocess.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.preprocess import tokenize
|
| 2 |
+
|
| 3 |
+
def process_sample(sample):
|
| 4 |
+
context_tokens = tokenize(sample["context"])
|
| 5 |
+
question_tokens = tokenize(sample["question"])
|
| 6 |
+
answer_tokens = tokenize(sample["answer_text"])
|
| 7 |
+
|
| 8 |
+
# 🔥 Combine question + context
|
| 9 |
+
tokens = question_tokens + ["[SEP]"] + context_tokens
|
| 10 |
+
|
| 11 |
+
start = -1
|
| 12 |
+
for i in range(len(context_tokens)):
|
| 13 |
+
if context_tokens[i:i+len(answer_tokens)] == answer_tokens:
|
| 14 |
+
start = i + len(question_tokens) + 1
|
| 15 |
+
break
|
| 16 |
+
|
| 17 |
+
if start == -1:
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
end = start + len(answer_tokens) - 1
|
| 21 |
+
|
| 22 |
+
return {
|
| 23 |
+
"tokens": tokens,
|
| 24 |
+
"start": start,
|
| 25 |
+
"end": end
|
| 26 |
+
}
|
utils/vocab.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter
|
| 2 |
+
|
| 3 |
+
def build_vocab(tokens):
|
| 4 |
+
vocab = {"<PAD>":0, "<UNK>":1}
|
| 5 |
+
counter = Counter(tokens)
|
| 6 |
+
|
| 7 |
+
for word in counter:
|
| 8 |
+
vocab[word] = len(vocab)
|
| 9 |
+
|
| 10 |
+
return vocab
|
| 11 |
+
|
| 12 |
+
def encode(tokens, vocab):
|
| 13 |
+
return [vocab.get(t, vocab["<UNK>"]) for t in tokens]
|