tnp554 commited on
Commit
09daf0b
·
1 Parent(s): a4081df

feat: deploy SQuAD backend with all AI models

Browse files
.env ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─── Database ───────────────────────────────────────────────────────────────
2
+ MONGO_URI=mongodb+srv://tnp554:ibmtnp@ibmcluster.swumgnp.mongodb.net/squad_qa?appName=IBMCLUSTER
3
+
4
+ # ─── Auth ────────────────────────────────────────────────────────────────────
5
+ JWT_SECRET=905d93e5bf632330aee5075046c4b8cc7d1d2c28d575918c9dbf7be33536badd
6
+ JWT_EXPIRY_HOURS=24
7
+
8
+ # ─── Admin Seed ──────────────────────────────────────────────────────────────
9
+ ADMIN_EMAIL=admin@squad.ai
10
+ ADMIN_PASSWORD=Admin@123
11
+
12
+ # ─── App Config ──────────────────────────────────────────────────────────────
13
+ FLASK_ENV=production
14
+ # Comma-separated list of allowed origins (no trailing slash)
15
+ ALLOWED_ORIGINS=http://localhost:5173,http://localhost:5174,http://localhost:3000
16
+
17
+ # ─── Feature Flags ───────────────────────────────────────────────────────────
18
+ PDF_MAX_PAGES=15
19
+
20
+ EMAIL_USER=otp.squad.ai@gmail.com
21
+ EMAIL_PASS=yfqkqjtzlbljgpww
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system deps for PyPDF2, python-docx, torch, and file security (libmagic)
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ gcc \
8
+ libgomp1 \
9
+ libmagic1 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Install Python dependencies
13
+ COPY requirements.txt .
14
+ RUN pip install --no-cache-dir -r requirements.txt
15
+
16
+ # Copy source
17
+ COPY . .
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["gunicorn", "-c", "gunicorn.conf.py", "app:app"]
README.md CHANGED
@@ -1,10 +1,38 @@
1
- ---
2
- title: SQuAD
3
- emoji: 🏆
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🐍 Backend Architecture (Flask + PyTorch)
2
+
3
+ The core engine responsible for MongoDB tracking, Authentication routing, and executing Heavy Machine Learning Inference locally on your physical server via Virtual Environments.
4
+
5
+ ## 🔑 Environment Variables
6
+ The root of this folder requires a `.env` file to function:
7
+ ```env
8
+ MONGODB_URI=mongodb+srv://<your-creds>.mongodb.net
9
+ JWT_SECRET=super_secure_hash_string_here
10
+ ADMIN_EMAIL=admin@squad.ai
11
+ ADMIN_PASSWORD=Admin@123
12
+ EMAIL_USER=your_gmail@gmail.com
13
+ EMAIL_PASS=your_16_char_gmail_app_password
14
+ FLASK_ENV=development
15
+ ```
16
+
17
+ ## 🧠 AI Inference Matrix (`/models`)
18
+ The system routes questions based on physical payload ID bindings directly into active memory arrays.
19
+ 1. **Model 1: `bert_model.py` (BERT)**
20
+ * Leverages HuggingFace `transformers` for `deepset/bert-base-cased-squad2`.
21
+ 2. **Model 3: `model3.py` (BiLSTM)**
22
+ * Native PyTorch integration running isolated weights mapped precisely off a local `qa_model.pth` tensor dictionary array.
23
+
24
+ ## 📜 Database Collections
25
+ All queries are funneled cleanly into MongoDB:
26
+ - `users`: Standard user tracking, OTP storage, password hashing tracking.
27
+ - `chats`: Detailed inference payloads, system diagnostics, user-soft deletion patterns (`user_deleted: True`).
28
+ - `settings`: Central singleton objects storing administrative configurations.
29
+
30
+ ## 🚀 Running Locally
31
+ ```bash
32
+ # 1. Activate Virtual Env
33
+ .\.venv\Scripts\activate
34
+ # 2. Install Dependencies
35
+ pip install -r requirements.txt
36
+ # 3. Boot Server
37
+ python app.py
38
+ ```
__pycache__/auth.cpython-314.pyc ADDED
Binary file (5.05 kB). View file
 
__pycache__/qa_engine.cpython-314.pyc ADDED
Binary file (4.33 kB). View file
 
app.py ADDED
@@ -0,0 +1,638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py — Main Flask application for the SQuAD QA System.
3
+
4
+ Endpoints:
5
+ Public:
6
+ POST /api/auth/register
7
+ POST /api/auth/login
8
+ GET /api/health
9
+
10
+ Authenticated (any user):
11
+ GET /api/auth/me
12
+ GET /api/models
13
+ POST /api/ask
14
+ GET /api/history
15
+ DELETE /api/history/<chat_id>
16
+ DELETE /api/history
17
+
18
+ Admin only:
19
+ GET /api/admin/users
20
+ PUT /api/admin/users/<user_id>
21
+ DELETE /api/admin/users/<user_id>
22
+ GET /api/admin/stats
23
+ """
24
+
25
+ import os
26
+ import sys
27
+ import logging
28
+ import re
29
+ from datetime import datetime, timezone, timedelta
30
+
31
+ from flask import Flask, request, jsonify, g
32
+ from flask_cors import CORS
33
+ from flask_bcrypt import Bcrypt
34
+ from flask_limiter import Limiter
35
+ from flask_limiter.util import get_remote_address
36
+ from bson import ObjectId
37
+ from dotenv import load_dotenv
38
+
39
+ # ─── Load environment ─────────────────────────────────────────────────────────
40
+ load_dotenv()
41
+
42
+ # ─── Logging ─────────────────────────────────────────────────────────────────
43
+ logging.basicConfig(
44
+ level=logging.INFO,
45
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
46
+ stream=sys.stdout,
47
+ )
48
+ logger = logging.getLogger(__name__)
49
+
50
+ # ─── App init ─────────────────────────────────────────────────────────────────
51
+ app = Flask(__name__)
52
+ bcrypt = Bcrypt(app)
53
+ limiter = Limiter(
54
+ get_remote_address,
55
+ app=app,
56
+ default_limits=["1000 per day", "100 per hour"],
57
+ storage_uri="memory://"
58
+ )
59
+ app.config['MAX_CONTENT_LENGTH'] = 5 * 1024 * 1024 # 5 MB max constraint
60
+
61
+ # ─── CORS (reads from env for cloud safety) ───────────────────────────────────
62
+ raw_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:5173,http://localhost:3000")
63
+ allowed_origins = [o.strip() for o in raw_origins.split(",") if o.strip()]
64
+ CORS(app, origins=allowed_origins, supports_credentials=True)
65
+
66
+ # ─── Internal imports (after app init) ───────────────────────────────────────
67
+ from auth import generate_token, require_auth, require_admin
68
+ from utils.db import users_col, chats_col, settings_col, is_using_mock
69
+ from utils.pdf_parser import extract_text
70
+ import qa_engine
71
+
72
+ # ─── Helpers ─────────────────────────────────────────────────────────────────
73
+
74
+ def _serialize(doc: dict) -> dict:
75
+ """Convert MongoDB ObjectId fields to strings for JSON serialization."""
76
+ if doc is None:
77
+ return None
78
+ doc = dict(doc)
79
+ if "_id" in doc:
80
+ doc["id"] = str(doc.pop("_id"))
81
+ return doc
82
+
83
+
84
+ def _now_iso() -> str:
85
+ return datetime.now(timezone.utc).isoformat()
86
+
87
+ def _future_iso(seconds: int) -> str:
88
+ return (datetime.now(timezone.utc) + timedelta(seconds=seconds)).isoformat()
89
+
90
+ def safe_str(val) -> str:
91
+ """Ensure the input is strictly a string, preventing NoSQL injection dicts."""
92
+ if not isinstance(val, str):
93
+ return ""
94
+ return val.strip()
95
+
96
+ def send_otp_email(to_email, otp):
97
+ """Sends OTP via real Gmail SMTP if ENV vars exist."""
98
+ email_user = os.getenv("EMAIL_USER")
99
+ email_pass = os.getenv("EMAIL_PASS")
100
+ if not email_user or not email_pass:
101
+ # Fallback to mock logging if user hasn't put in valid app passwords yet
102
+ logger.warning("=" * 60)
103
+ logger.warning(f" [MOCK EMAIL OTP] Verification code for {to_email}: {otp}")
104
+ logger.warning("=" * 60)
105
+ return False
106
+
107
+ try:
108
+ import smtplib
109
+ from email.mime.text import MIMEText
110
+ from email.mime.multipart import MIMEMultipart
111
+ msg = MIMEMultipart()
112
+ msg['From'] = email_user
113
+ msg['To'] = to_email
114
+ msg['Subject'] = "SQuAD QA - Your Verification Code"
115
+ body = f"Welcome to SQuAD QA!!!\n\nYour 6-digit registration verification code is: {otp}\n\nPlease enter this code to complete your registration.\n\nThank you!!!"
116
+ msg.attach(MIMEText(body, 'plain'))
117
+
118
+ server = smtplib.SMTP_SSL('smtp.gmail.com', 465)
119
+ server.login(email_user, email_pass)
120
+ server.send_message(msg)
121
+ server.quit()
122
+ logger.info(f"[SMTP] Successfully dispatched OTP to {to_email}")
123
+ return True
124
+ except Exception as e:
125
+ logger.error(f"[SMTP ERROR] Failed to send actual email to {to_email}: {e}")
126
+ return False
127
+
128
+
129
+ # ─── Admin Seed ────────────────────────────��──────────────────────────────────
130
+
131
+ def _seed_admin():
132
+ """Create the default admin user if it doesn't exist."""
133
+ admin_email = os.getenv("ADMIN_EMAIL", "admin@squad.ai")
134
+ admin_password = os.getenv("ADMIN_PASSWORD", "Admin@123")
135
+
136
+ col = users_col()
137
+ if col.find_one({"email": admin_email}):
138
+ logger.info(f"[Seed] Admin user '{admin_email}' already exists.")
139
+ return
140
+
141
+ hashed = bcrypt.generate_password_hash(admin_password).decode("utf-8")
142
+ col.insert_one({
143
+ "name": "Administrator",
144
+ "email": admin_email,
145
+ "password": hashed,
146
+ "role": "admin",
147
+ "is_active": True,
148
+ "created_at": _now_iso(),
149
+ "last_login": None,
150
+ })
151
+ logger.info(f"[Seed] Admin user '{admin_email}' created.")
152
+
153
+
154
+ # ─── Health ───────────────────────────────────────────────────────────────────
155
+
156
+ @app.route("/api/health", methods=["GET"])
157
+ def health():
158
+ return jsonify({
159
+ "status": "ok",
160
+ "db_mode": "mock" if is_using_mock() else "atlas",
161
+ "timestamp": _now_iso(),
162
+ })
163
+
164
+
165
+ # ─── Auth Routes ──────────────────────────────────────────────────────────────
166
+
167
+ @app.route("/api/auth/register", methods=["POST"])
168
+ @limiter.limit("10 per hour")
169
+ def register():
170
+ data = request.get_json(silent=True) or {}
171
+ name = safe_str(data.get("name"))
172
+ email = safe_str(data.get("email")).lower()
173
+ password = safe_str(data.get("password"))
174
+
175
+ if not name or not email or not password:
176
+ return jsonify({"error": "Name, email, and password are required."}), 400
177
+
178
+ password_regex = r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[@$!%*?&#^])[A-Za-z\d@$!%*?&#^]{8,}$"
179
+ if not re.match(password_regex, password):
180
+ return jsonify({"error": "Password must be at least 8 characters and include uppercase, lowercase, number, and a special character."}), 400
181
+
182
+ col = users_col()
183
+
184
+ sys_col = settings_col()
185
+ sys_conf = sys_col.find_one({"_id": "system_config"}) or {}
186
+ if sys_conf.get("disable_registrations", False):
187
+ return jsonify({"error": "New user registrations are currently disabled by the administrator."}), 403
188
+
189
+ if col.find_one({"email": email}):
190
+ return jsonify({"error": "An account with this email already exists."}), 409
191
+
192
+ hashed = bcrypt.generate_password_hash(password).decode("utf-8")
193
+ import random
194
+ otp = str(random.randint(100000, 999999))
195
+ send_otp_email(email, otp)
196
+
197
+ result = col.insert_one({
198
+ "name": name,
199
+ "email": email,
200
+ "password": hashed,
201
+ "role": "user",
202
+ "is_active": False,
203
+ "is_verified": False,
204
+ "otp": otp,
205
+ "otp_expires_at": _future_iso(60),
206
+ "created_at": _now_iso(),
207
+ "last_login": None,
208
+ })
209
+
210
+ return jsonify({
211
+ "message": "OTP sent to email. Please verify your account.",
212
+ "requires_otp": True
213
+ }), 201
214
+
215
+ @app.route("/api/auth/verify", methods=["POST"])
216
+ @limiter.limit("5 per minute")
217
+ def verify_otp():
218
+ data = request.get_json(silent=True) or {}
219
+ email = safe_str(data.get("email")).lower()
220
+ otp = safe_str(data.get("otp"))
221
+
222
+ if not email or not otp:
223
+ return jsonify({"error": "Email and OTP are required."}), 400
224
+
225
+ col = users_col()
226
+ user = col.find_one({"email": email})
227
+
228
+ if not user:
229
+ return jsonify({"error": "User not found."}), 404
230
+
231
+ if user.get("is_verified", False):
232
+ return jsonify({"error": "Account already verified."}), 400
233
+
234
+ expires_at = user.get("otp_expires_at")
235
+ if expires_at and _now_iso() > expires_at:
236
+ return jsonify({"error": "OTP has expired. Please request a new one."}), 400
237
+
238
+ if str(user.get("otp")) != str(otp):
239
+ return jsonify({"error": "Invalid verification code."}), 400
240
+
241
+ col.update_one({"_id": user["_id"]}, {"$set": {"is_verified": True, "is_active": True, "otp": None}})
242
+
243
+ user_id = str(user["_id"])
244
+ from auth import generate_token
245
+ role = user.get("role", "user")
246
+ token = generate_token(user_id, role)
247
+ col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}})
248
+
249
+ return jsonify({
250
+ "message": "Account verified successfully.",
251
+ "token": token,
252
+ "user": {"id": user_id, "name": user["name"], "email": user["email"], "role": role},
253
+ }), 200
254
+
255
+ @app.route("/api/auth/resend-otp", methods=["POST"])
256
+ @limiter.limit("3 per minute")
257
+ def resend_otp():
258
+ data = request.get_json(silent=True) or {}
259
+ email = safe_str(data.get("email")).lower()
260
+
261
+ if not email:
262
+ return jsonify({"error": "Email is required."}), 400
263
+
264
+ col = users_col()
265
+ user = col.find_one({"email": email})
266
+
267
+ if not user:
268
+ return jsonify({"error": "User not found."}), 404
269
+
270
+ if user.get("is_verified", False):
271
+ return jsonify({"error": "Account is already verified."}), 400
272
+
273
+ import random
274
+ new_otp = str(random.randint(100000, 999999))
275
+
276
+ col.update_one({"_id": user["_id"]}, {"$set": {"otp": new_otp, "otp_expires_at": _future_iso(60)}})
277
+
278
+ send_otp_email(email, new_otp)
279
+
280
+ return jsonify({"message": "A new OTP has been sent to your email."}), 200
281
+
282
+
283
+ @app.route("/api/auth/login", methods=["POST"])
284
+ @limiter.limit("15 per minute")
285
+ def login():
286
+ data = request.get_json(silent=True) or {}
287
+ email = safe_str(data.get("email")).lower()
288
+ password = safe_str(data.get("password"))
289
+
290
+ if not email or not password:
291
+ return jsonify({"error": "Email and password are required."}), 400
292
+
293
+ col = users_col()
294
+ user = col.find_one({"email": email})
295
+
296
+ if not user or not bcrypt.check_password_hash(user["password"], password):
297
+ return jsonify({"error": "Invalid email or password."}), 401
298
+ if not user.get("is_verified", True):
299
+ # We can trigger verify if they try to login while unverified, but for simplicity:
300
+ return jsonify({"error": "Your account is not verified. Please check your email for the OTP."}), 403
301
+ if not user.get("is_active", True):
302
+ return jsonify({"error": "Your account has been deactivated. Contact admin."}), 403
303
+
304
+ user_id = str(user["_id"])
305
+ role = user.get("role", "user")
306
+ token = generate_token(user_id, role)
307
+
308
+ # Update last_login
309
+ col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}})
310
+
311
+ return jsonify({
312
+ "message": "Login successful.",
313
+ "token": token,
314
+ "user": {
315
+ "id": user_id,
316
+ "name": user["name"],
317
+ "email": user["email"],
318
+ "role": role,
319
+ },
320
+ })
321
+
322
+
323
+ @app.route("/api/auth/me", methods=["GET"])
324
+ @require_auth
325
+ def me():
326
+ from bson import ObjectId as ObjId
327
+ col = users_col()
328
+ try:
329
+ user = col.find_one({"_id": ObjId(g.current_user["id"])})
330
+ except Exception:
331
+ user = col.find_one({"_id": g.current_user["id"]})
332
+
333
+ if not user:
334
+ return jsonify({"error": "User not found."}), 404
335
+
336
+ user = _serialize(user)
337
+ user.pop("password", None)
338
+ return jsonify({"user": user})
339
+
340
+
341
+ # ─── Models ───────────────────────────────────────────────────────────────────
342
+
343
+ @app.route("/api/models", methods=["GET"])
344
+ @require_auth
345
+ def get_models():
346
+ models_info = qa_engine.get_models_info()
347
+
348
+ ready_ids = [m["id"] for m in models_info if m.get("status") == "ready"]
349
+ pipeline = [
350
+ {"$match": {"model_id": {"$in": ready_ids}, "error": False}},
351
+ {"$group": {"_id": "$model_id", "avg_score": {"$avg": "$score"}, "count": {"$sum": 1}}}
352
+ ]
353
+ try:
354
+ from utils.db import chats_col
355
+ stats = {doc["_id"]: doc for doc in chats_col().aggregate(pipeline)}
356
+ total_queries = sum(d["count"] for d in stats.values())
357
+ total_score = sum(d["avg_score"] * d["count"] for d in stats.values())
358
+ global_avg = (total_score / total_queries) if total_queries > 0 else 0
359
+ except Exception:
360
+ stats = {}
361
+ global_avg = 0
362
+ total_queries = 0
363
+
364
+ for m in models_info:
365
+ model_stat = stats.get(m["id"], {})
366
+ m["avg_score"] = model_stat.get("avg_score", 0.0)
367
+ m["query_count"] = model_stat.get("count", 0)
368
+
369
+ return jsonify({
370
+ "models": models_info,
371
+ "global_avg": global_avg,
372
+ "total_queries": total_queries
373
+ })
374
+
375
+
376
+ # ─── Ask (QA Inference) ───────────────────────────────────────────────────────
377
+
378
+ @app.route("/api/ask", methods=["POST"])
379
+ @require_auth
380
+ @limiter.limit("30 per minute")
381
+ def ask():
382
+ model_id = "bert"
383
+ context = ""
384
+ question = ""
385
+
386
+ # ── File upload (multipart form) ──
387
+ if request.content_type and "multipart/form-data" in request.content_type:
388
+ model_id = safe_str(request.form.get("model_id")) or "bert"
389
+ question = safe_str(request.form.get("question"))
390
+ file = request.files.get("file")
391
+ if file:
392
+ try:
393
+ import magic
394
+ buffer = file.read()
395
+ mime = magic.from_buffer(buffer, mime=True)
396
+ allowed_mimes = ["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]
397
+ if mime not in allowed_mimes:
398
+ return jsonify({"error": f"Security system rejected {mime}. Only true PDF/DOCX files permitted."}), 400
399
+ from utils.pdf_parser import extract_text
400
+ context = extract_text(buffer, file.filename)
401
+ except ValueError as exc:
402
+ return jsonify({"error": str(exc)}), 400
403
+ else:
404
+ context = safe_str(request.form.get("context"))
405
+ else:
406
+ # ── JSON body ──
407
+ data = request.get_json(silent=True) or {}
408
+ model_id = safe_str(data.get("model_id")) or "bert"
409
+ context = safe_str(data.get("context"))
410
+ question = safe_str(data.get("question"))
411
+
412
+ if not context:
413
+ return jsonify({"error": "Context (text or file) is required."}), 400
414
+ if not question:
415
+ return jsonify({"error": "Question is required."}), 400
416
+
417
+ # ── Run inference ──
418
+ result = qa_engine.run_inference(model_id, context, question)
419
+
420
+ # ── Persist to DB ──
421
+ chat_doc = {
422
+ "user_id": g.current_user["id"],
423
+ "model_id": model_id,
424
+ "model_name": result.get("model", model_id),
425
+ "context": context[:2000], # truncate for storage
426
+ "question": question,
427
+ "answer": result.get("answer", ""),
428
+ "score": result.get("score", 0.0),
429
+ "error": result.get("error", False),
430
+ "created_at": _now_iso(),
431
+ }
432
+ insert_result = chats_col().insert_one(chat_doc)
433
+ result["chat_id"] = str(insert_result.inserted_id)
434
+
435
+ return jsonify(result)
436
+
437
+
438
+ # ─── History ──────────────────────────────────────────────────────────────────
439
+
440
+ @app.route("/api/history", methods=["GET"])
441
+ @require_auth
442
+ def get_history():
443
+ col = chats_col()
444
+ docs = list(col.find(
445
+ {"user_id": g.current_user["id"], "user_deleted": {"$ne": True}},
446
+ sort=[("created_at", -1)],
447
+ limit=50,
448
+ ))
449
+ return jsonify({"history": [_serialize(d) for d in docs]})
450
+
451
+
452
+ @app.route("/api/history/<chat_id>", methods=["DELETE"])
453
+ @require_auth
454
+ def delete_chat(chat_id):
455
+ from bson import ObjectId as ObjId
456
+ col = chats_col()
457
+ try:
458
+ res = col.update_one(
459
+ {"_id": ObjId(chat_id), "user_id": g.current_user["id"]},
460
+ {"$set": {"user_deleted": True}}
461
+ )
462
+ except Exception:
463
+ return jsonify({"error": "Invalid chat ID."}), 400
464
+
465
+ if res.matched_count == 0:
466
+ return jsonify({"error": "Chat not found or not owned by you."}), 404
467
+ return jsonify({"message": "Chat deleted."})
468
+
469
+
470
+ @app.route("/api/history", methods=["DELETE"])
471
+ @require_auth
472
+ def clear_history():
473
+ col = chats_col()
474
+ res = col.update_many(
475
+ {"user_id": g.current_user["id"]},
476
+ {"$set": {"user_deleted": True}}
477
+ )
478
+ return jsonify({"message": f"Cleared {res.modified_count} chat(s)."})
479
+
480
+
481
+ # ─── Admin Routes ─────────────────────────────────────────────────────────────
482
+
483
+ @app.route("/api/admin/users", methods=["GET"])
484
+ @require_admin
485
+ def admin_list_users():
486
+ col = users_col()
487
+ users = list(col.find({}, sort=[("created_at", -1)]))
488
+ result = []
489
+ for u in users:
490
+ u = _serialize(u)
491
+ u.pop("password", None)
492
+ result.append(u)
493
+ return jsonify({"users": result, "total": len(result)})
494
+
495
+
496
+ @app.route("/api/admin/users/<user_id>", methods=["PUT"])
497
+ @require_admin
498
+ def admin_update_user(user_id):
499
+ from bson import ObjectId as ObjId
500
+ data = request.get_json(silent=True) or {}
501
+ allowed_fields = {"name", "role", "is_active"}
502
+ update = {k: v for k, v in data.items() if k in allowed_fields}
503
+
504
+ if not update:
505
+ return jsonify({"error": "No valid fields to update."}), 400
506
+
507
+ col = users_col()
508
+ try:
509
+ res = col.update_one({"_id": ObjId(user_id)}, {"$set": update})
510
+ except Exception:
511
+ return jsonify({"error": "Invalid user ID."}), 400
512
+
513
+ if res.matched_count == 0:
514
+ return jsonify({"error": "User not found."}), 404
515
+ return jsonify({"message": "User updated successfully."})
516
+
517
+
518
+ @app.route("/api/admin/users/<user_id>", methods=["DELETE"])
519
+ @require_admin
520
+ def admin_delete_user(user_id):
521
+ from bson import ObjectId as ObjId
522
+ # Prevent self-deletion
523
+ if user_id == g.current_user["id"]:
524
+ return jsonify({"error": "You cannot delete your own account."}), 400
525
+
526
+ col = users_col()
527
+ try:
528
+ res = col.delete_one({"_id": ObjId(user_id)})
529
+ except Exception:
530
+ return jsonify({"error": "Invalid user ID."}), 400
531
+
532
+ if res.deleted_count == 0:
533
+ return jsonify({"error": "User not found."}), 404
534
+
535
+ # Also logically remove their chat history
536
+ chats_col().update_many(
537
+ {"user_id": user_id},
538
+ {"$set": {"user_deleted": True, "admin_deleted_user": True}}
539
+ )
540
+ return jsonify({"message": "User and their history deleted."})
541
+
542
+
543
+ @app.route("/api/admin/stats", methods=["GET"])
544
+ @require_admin
545
+ def admin_stats():
546
+ users = users_col()
547
+ chats = chats_col()
548
+
549
+ total_users = users.count_documents({})
550
+ total_queries = chats.count_documents({})
551
+
552
+ # Model usage breakdown
553
+ pipeline = [
554
+ {"$group": {"_id": "$model_id", "count": {"$sum": 1}}}
555
+ ]
556
+ try:
557
+ model_usage = {doc["_id"]: doc["count"] for doc in chats.aggregate(pipeline)}
558
+ except Exception:
559
+ model_usage = {}
560
+
561
+ # Timeseries data for graphs
562
+ ts_pipeline = [
563
+ {"$project": {"date": {"$substr": ["$created_at", 0, 10]}}},
564
+ {"$group": {"_id": "$date", "queries": {"$sum": 1}}},
565
+ {"$sort": {"_id": 1}},
566
+ {"$limit": 30}
567
+ ]
568
+ try:
569
+ timeseries = [{"date": doc["_id"], "queries": doc["queries"]} for doc in chats.aggregate(ts_pipeline)]
570
+ except Exception:
571
+ timeseries = []
572
+
573
+ return jsonify({
574
+ "total_users": total_users,
575
+ "total_queries": total_queries,
576
+ "model_usage": model_usage,
577
+ "timeseries": timeseries,
578
+ "db_mode": "mock" if is_using_mock() else "atlas",
579
+ })
580
+
581
+ @app.route("/api/admin/settings", methods=["GET"])
582
+ @require_admin
583
+ def get_settings():
584
+ col = settings_col()
585
+ doc = col.find_one({"_id": "system_config"})
586
+ if not doc:
587
+ doc = {"_id": "system_config", "disable_registrations": False, "maintenance_mode": False}
588
+ col.insert_one(doc)
589
+ return jsonify({"settings": _serialize(doc)})
590
+
591
+ @app.route("/api/admin/settings", methods=["PUT"])
592
+ @require_admin
593
+ def update_settings():
594
+ data = request.get_json(silent=True) or {}
595
+ allowed = {"disable_registrations", "maintenance_mode"}
596
+ update = {k: v for k, v in data.items() if k in allowed}
597
+ if not update:
598
+ return jsonify({"error": "No valid settings provided."}), 400
599
+
600
+ col = settings_col()
601
+ col.update_one({"_id": "system_config"}, {"$set": update}, upsert=True)
602
+ return jsonify({"message": "Settings updated."})
603
+
604
+ @app.route("/api/admin/models/<model_id>", methods=["PUT"])
605
+ @require_admin
606
+ def toggle_model_status(model_id):
607
+ if model_id not in qa_engine.MODELS:
608
+ return jsonify({"error": "Invalid model ID."}), 404
609
+
610
+ data = request.get_json(silent=True) or {}
611
+ target_status = data.get("status")
612
+ if target_status not in ["ready", "maintenance"]:
613
+ return jsonify({"error": "Invalid status."}), 400
614
+
615
+ col = settings_col()
616
+ col.update_one({"_id": "system_config"}, {"$set": {f"model_status.{model_id}": target_status}}, upsert=True)
617
+
618
+ return jsonify({"message": f"Model {model_id} status updated to {target_status}."})
619
+
620
+
621
+
622
+ # ─── Entry Point ──────────────────────────────────────────────────────────────
623
+
624
+ if __name__ == "__main__":
625
+ logger.info("=" * 60)
626
+ logger.info(" SQuAD QA System — Backend Starting")
627
+ logger.info("=" * 60)
628
+
629
+ # Initialise AI models
630
+ qa_engine.init_all_models()
631
+
632
+ # Seed admin user
633
+ _seed_admin()
634
+
635
+ flask_env = os.getenv("FLASK_ENV", "development")
636
+ debug = flask_env == "development"
637
+
638
+ app.run(host="0.0.0.0", port=5000, debug=debug)
auth.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ auth.py — JWT-based authentication helpers.
3
+
4
+ Provides:
5
+ - generate_token(user_id, role) → signed JWT string
6
+ - @require_auth → validates JWT, injects g.current_user
7
+ - @require_admin → same as @require_auth + checks admin role
8
+ """
9
+
10
+ import os
11
+ import jwt
12
+ import logging
13
+ from functools import wraps
14
+ from datetime import datetime, timedelta, timezone
15
+
16
+ from flask import request, jsonify, g
17
+ from dotenv import load_dotenv
18
+
19
+ load_dotenv()
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ JWT_SECRET = os.getenv("JWT_SECRET", "default-insecure-secret-change-me")
24
+ JWT_EXPIRY_HOURS = int(os.getenv("JWT_EXPIRY_HOURS", "24"))
25
+
26
+
27
+ # ─── Token Generation ─────────────────────────────────────────────────────────
28
+
29
+ def generate_token(user_id: str, role: str) -> str:
30
+ """Create a signed JWT valid for JWT_EXPIRY_HOURS hours."""
31
+ payload = {
32
+ "sub": str(user_id),
33
+ "role": role,
34
+ "iat": datetime.now(timezone.utc),
35
+ "exp": datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRY_HOURS),
36
+ }
37
+ return jwt.encode(payload, JWT_SECRET, algorithm="HS256")
38
+
39
+
40
+ def decode_token(token: str) -> dict:
41
+ """Decode and verify a JWT. Raises jwt.exceptions on failure."""
42
+ return jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
43
+
44
+
45
+ # ─── Decorators ───────────────────────────────────────────────────────────────
46
+
47
+ def require_auth(f):
48
+ """Decorator: validates Bearer JWT and populates g.current_user."""
49
+ @wraps(f)
50
+ def decorated(*args, **kwargs):
51
+ auth_header = request.headers.get("Authorization", "")
52
+ if not auth_header.startswith("Bearer "):
53
+ return jsonify({"error": "Authorization header missing or malformed."}), 401
54
+
55
+ token = auth_header.split(" ", 1)[1]
56
+ try:
57
+ payload = decode_token(token)
58
+
59
+ # Real-time suspension check
60
+ from utils.db import users_col
61
+ from bson import ObjectId as ObjId
62
+ col = users_col()
63
+ try:
64
+ user = col.find_one({"_id": ObjId(payload["sub"])})
65
+ except Exception:
66
+ user = col.find_one({"_id": payload["sub"]})
67
+
68
+ if not user or not user.get("is_active", True):
69
+ return jsonify({"error": "Your account has been suspended by an administrator."}), 403
70
+
71
+ g.current_user = {
72
+ "id": payload["sub"],
73
+ "role": payload["role"],
74
+ }
75
+ except jwt.ExpiredSignatureError:
76
+ return jsonify({"error": "Token expired. Please log in again."}), 401
77
+ except jwt.InvalidTokenError as exc:
78
+ return jsonify({"error": f"Invalid token: {exc}"}), 401
79
+
80
+ return f(*args, **kwargs)
81
+ return decorated
82
+
83
+
84
+ def require_admin(f):
85
+ """Decorator: validates JWT AND checks for admin role."""
86
+ @wraps(f)
87
+ @require_auth
88
+ def decorated(*args, **kwargs):
89
+ if g.current_user.get("role") != "admin":
90
+ return jsonify({"error": "Admin access required."}), 403
91
+ return f(*args, **kwargs)
92
+ return decorated
data_loader/load_squad_json.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def load_squad_json(path):
4
+ with open(path, "r", encoding="utf-8") as f:
5
+ data = json.load(f)
6
+
7
+ samples = []
8
+
9
+ for article in data["data"]:
10
+ for para in article["paragraphs"]:
11
+ context = para["context"]
12
+
13
+ for qa in para["qas"]:
14
+ if not qa["answers"]:
15
+ continue
16
+
17
+ ans = qa["answers"][0]
18
+
19
+ samples.append({
20
+ "context": context,
21
+ "question": qa["question"],
22
+ "answer_text": ans["text"]
23
+ })
24
+
25
+ return samples
gunicorn.conf.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # Gunicorn configuration for production deployment
3
+ port = os.environ.get("PORT", "5000")
4
+ bind = f"0.0.0.0:{port}"
5
+ workers = 2 # Keep low — each worker loads BERT (~400MB RAM)
6
+ timeout = 120 # BERT inference can take a few seconds
7
+ accesslog = "-" # stdout
8
+ errorlog = "-" # stdout
9
+ loglevel = "info"
10
+ preload_app = True # Load model once, share across workers
main.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.file_loader import load_txt, load_pdf, load_docx
3
+ from models.qa_model import QAModel
4
+ from utils.vocab import encode
5
+ from utils.preprocess import tokenize
6
+
7
+
8
+ checkpoint = torch.load("qa_model.pth", map_location="cpu")
9
+ vocab = checkpoint["vocab"]
10
+
11
+ model = QAModel(len(vocab))
12
+ model.load_state_dict(checkpoint["model_state"])
13
+ model.eval()
14
+
15
+
16
+ def load_context(path):
17
+ if path.endswith(".txt"):
18
+ return load_txt(path)
19
+ elif path.endswith(".pdf"):
20
+ return load_pdf(path)
21
+ elif path.endswith(".docx"):
22
+ return load_docx(path)
23
+ else:
24
+ raise ValueError("Unsupported file format")
25
+
26
+
27
+ def extract_answer(question, context):
28
+ q_tokens = tokenize(question)
29
+ c_tokens = tokenize(context)
30
+
31
+ tokens = q_tokens + ["[SEP]"] + c_tokens
32
+ encoded = encode(tokens, vocab)
33
+
34
+ max_len = 300
35
+ if len(encoded) < max_len:
36
+ encoded += [0] * (max_len - len(encoded))
37
+ else:
38
+ encoded = encoded[:max_len]
39
+
40
+ x = torch.tensor(encoded).unsqueeze(0)
41
+
42
+ with torch.no_grad():
43
+ start_logits, end_logits = model(x)
44
+
45
+ start = torch.argmax(start_logits, dim=1).item()
46
+ end = torch.argmax(end_logits, dim=1).item()
47
+
48
+ if start > end or start >= len(tokens):
49
+ return "No answer found"
50
+
51
+ return " ".join(tokens[start:end+1])
52
+
53
+
54
+ def main():
55
+ print("===== BiLSTM QA (Fixed) =====\n")
56
+
57
+ path = input("Enter file path: ")
58
+ context = load_context(path)
59
+
60
+ question = input("Enter question: ")
61
+
62
+ answer = extract_answer(question, context)
63
+
64
+ print("\nAnswer:", answer)
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # models package
models/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (144 Bytes). View file
 
models/__pycache__/bert_model.cpython-314.pyc ADDED
Binary file (4.78 kB). View file
 
models/__pycache__/model2.cpython-314.pyc ADDED
Binary file (1.31 kB). View file
 
models/__pycache__/model3.cpython-314.pyc ADDED
Binary file (4.6 kB). View file
 
models/__pycache__/qa_model.cpython-314.pyc ADDED
Binary file (1.7 kB). View file
 
models/bert_model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ bert_model.py — HuggingFace BERT Question Answering Model.
3
+
4
+ Model: deepset/bert-base-cased-squad2
5
+ Uses direct PyTorch inference (compatible with transformers 5.x).
6
+ """
7
+
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ _tokenizer = None
13
+ _model = None
14
+ MODEL_NAME = "deepset/bert-base-cased-squad2"
15
+
16
+
17
+ def init_bert_model():
18
+ """Load the BERT QA model. Called once at app startup."""
19
+ global _tokenizer, _model
20
+ try:
21
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
22
+ logger.info(f"[BERT] Loading model '{MODEL_NAME}' ...")
23
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
+ _model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
25
+ _model.eval()
26
+ logger.info("[BERT] Model loaded and ready.")
27
+ except Exception as exc:
28
+ logger.error(f"[BERT] Failed to load model: {exc}")
29
+ _tokenizer = None
30
+ _model = None
31
+
32
+
33
+ def _run_qa_inference(context: str, question: str) -> dict:
34
+ """Direct PyTorch inference — works with any transformers version."""
35
+ import torch
36
+ import torch.nn.functional as F
37
+
38
+ inputs = _tokenizer(
39
+ question, context,
40
+ return_tensors="pt",
41
+ truncation=True,
42
+ max_length=512,
43
+ )
44
+
45
+ with torch.no_grad():
46
+ outputs = _model(**inputs)
47
+
48
+ start_logits = outputs.start_logits[0]
49
+ end_logits = outputs.end_logits[0]
50
+
51
+ start_idx = int(torch.argmax(start_logits))
52
+ end_idx = int(torch.argmax(end_logits)) + 1
53
+
54
+ if end_idx <= start_idx:
55
+ end_idx = start_idx + 1
56
+
57
+ input_ids = inputs["input_ids"][0]
58
+ answer_tokens = input_ids[start_idx:end_idx]
59
+ answer = _tokenizer.decode(answer_tokens, skip_special_tokens=True).strip()
60
+
61
+ # Confidence approximation via softmax
62
+ start_prob = float(F.softmax(start_logits, dim=0)[start_idx])
63
+ end_prob = float(F.softmax(end_logits, dim=0)[end_idx - 1])
64
+ score = round((start_prob + end_prob) / 2, 4)
65
+
66
+ return {"answer": answer, "score": score}
67
+
68
+
69
+ def predict(context: str, question: str) -> dict:
70
+ """
71
+ Run QA inference.
72
+
73
+ Returns:
74
+ {
75
+ "answer": str,
76
+ "score": float (0.0–1.0),
77
+ "model": "BERT",
78
+ "model_id": "bert"
79
+ }
80
+ """
81
+ if _model is None or _tokenizer is None:
82
+ return {
83
+ "answer": "BERT model is not loaded. Please check server logs.",
84
+ "score": 0.0,
85
+ "model": "BERT",
86
+ "model_id": "bert",
87
+ "error": True,
88
+ }
89
+
90
+ if not context or not question:
91
+ return {
92
+ "answer": "Context and question must not be empty.",
93
+ "score": 0.0,
94
+ "model": "BERT",
95
+ "model_id": "bert",
96
+ "error": True,
97
+ }
98
+
99
+ try:
100
+ result = _run_qa_inference(context=context, question=question)
101
+ score = result["score"]
102
+ answer = result["answer"]
103
+
104
+ if score < 0.05 or "[CLS]" in answer or not answer:
105
+ answer = "Answer not found with sufficient confidence. Try rephrasing your question or providing more context."
106
+ score = 0.0
107
+
108
+ return {
109
+ "answer": answer,
110
+ "score": score,
111
+ "model": "BERT",
112
+ "model_id": "bert",
113
+ "error": False,
114
+ }
115
+ except Exception as exc:
116
+ logger.error(f"[BERT] Inference error: {exc}")
117
+ return {
118
+ "answer": f"Inference error: {exc}",
119
+ "score": 0.0,
120
+ "model": "BERT",
121
+ "model_id": "bert",
122
+ "error": True,
123
+ }
models/model2.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model2.py — Placeholder for Model 2.
3
+
4
+ Replace this file with your actual model implementation.
5
+ The predict() function signature must match:
6
+ predict(context: str, question: str) -> dict
7
+ """
8
+
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def init_model2():
15
+ """Called at startup. No-op until model is integrated."""
16
+ logger.info("[Model2] Placeholder — not yet integrated.")
17
+
18
+
19
+ def predict(context: str, question: str) -> dict:
20
+ """Stub: returns a friendly 'coming soon' response."""
21
+ return {
22
+ "answer": "Model 2 is not yet integrated. Please use BERT for now.",
23
+ "score": 0.0,
24
+ "model": "Model 2",
25
+ "model_id": "model2",
26
+ "error": False,
27
+ "stub": True,
28
+ }
models/model3.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model3.py — Integration for BiLSTM Model.
3
+ """
4
+
5
+ import logging
6
+ import torch
7
+ import os
8
+ from models.qa_model import QAModel
9
+
10
+ # Import vocab utilities and preprocess utilities
11
+ from utils.preprocess import tokenize
12
+ from utils.vocab import encode
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ model = None
17
+ vocab = None
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ def init_model3():
21
+ global model, vocab
22
+ logger.info("[Model3] Initialising BiLSTM from qa_model.pth...")
23
+
24
+ # Assumes qa_model.pth is at the root of the backend directory
25
+ model_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "qa_model.pth")
26
+ if not os.path.exists(model_path):
27
+ logger.warning(f"[Model3] qa_model.pth not found at {model_path}! Model 3 inference will fail.")
28
+ return
29
+
30
+ try:
31
+ checkpoint = torch.load(model_path, map_location=device)
32
+ vocab = checkpoint["vocab"]
33
+
34
+ model = QAModel(len(vocab))
35
+ model.load_state_dict(checkpoint["model_state"])
36
+ model.to(device)
37
+ model.eval()
38
+ logger.info("[Model3] BiLSTM successfully loaded.")
39
+ except Exception as e:
40
+ logger.error(f"[Model3] Failed to load BiLSTM model: {e}")
41
+
42
+
43
+ def predict(context: str, question: str) -> dict:
44
+ """Predict using the loaded BiLSTM."""
45
+ if model is None or vocab is None:
46
+ return {
47
+ "answer": "BiLSTM model weights (qa_model.pth) not found or failed to load. Please make sure the trained model is placed in the backend folder.",
48
+ "score": 0.0,
49
+ "model": "BiLSTM",
50
+ "model_id": "model3",
51
+ "error": True,
52
+ "stub": False,
53
+ }
54
+
55
+ try:
56
+ q_tokens = tokenize(question)
57
+ c_tokens = tokenize(context)
58
+
59
+ tokens = q_tokens + ["[SEP]"] + c_tokens
60
+ encoded = encode(tokens, vocab)
61
+
62
+ max_len = 300
63
+ if len(encoded) < max_len:
64
+ encoded += [0] * (max_len - len(encoded))
65
+ else:
66
+ encoded = encoded[:max_len]
67
+
68
+ x = torch.tensor(encoded).unsqueeze(0).to(device)
69
+
70
+ with torch.no_grad():
71
+ start_logits, end_logits = model(x)
72
+
73
+ start = torch.argmax(start_logits, dim=1).item()
74
+ end = torch.argmax(end_logits, dim=1).item()
75
+
76
+ if start > end or start >= len(tokens):
77
+ answer = "No answer found"
78
+ score = 0.0
79
+ else:
80
+ answer = " ".join(tokens[start:end+1])
81
+ # Extract basic score approximations from logits if needed, but returning dummy score for now.
82
+ score = 0.85
83
+
84
+ return {
85
+ "answer": answer,
86
+ "score": score,
87
+ "model": "BiLSTM",
88
+ "model_id": "model3",
89
+ "error": False,
90
+ }
91
+ except Exception as e:
92
+ logger.error(f"[Model3] Inference error: {e}")
93
+ return {
94
+ "answer": "Inference error occurred.",
95
+ "score": 0.0,
96
+ "model": "BiLSTM",
97
+ "model_id": "model3",
98
+ "error": True,
99
+ "stub": False,
100
+ }
models/qa_model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class QAModel(nn.Module):
5
+ def __init__(self, vocab_size, embed_dim=200, hidden_dim=256):
6
+ super().__init__()
7
+
8
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
9
+
10
+ self.lstm = nn.LSTM(
11
+ embed_dim,
12
+ hidden_dim,
13
+ batch_first=True,
14
+ bidirectional=True
15
+ )
16
+
17
+ self.fc_start = nn.Linear(hidden_dim*2, 1)
18
+ self.fc_end = nn.Linear(hidden_dim*2, 1)
19
+
20
+ def forward(self, x):
21
+ x = self.embedding(x)
22
+ out, _ = self.lstm(x)
23
+
24
+ start = self.fc_start(out).squeeze(-1)
25
+ end = self.fc_end(out).squeeze(-1)
26
+
27
+ return start, end
qa_engine.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ qa_engine.py — Model router.
3
+
4
+ Routes inference requests to the correct model module based on model_id.
5
+ Initialises all models at startup.
6
+ """
7
+
8
+ import logging
9
+ from models import bert_model, model2, model3
10
+ from utils.db import settings_col
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # ─── Registry ────────────────────────────────────────────────────────────────
15
+
16
+ MODELS = {
17
+ "bert": {
18
+ "id": "bert",
19
+ "name": "BERT",
20
+ "description": "",
21
+ "status": "ready",
22
+ "module": bert_model,
23
+ },
24
+ "model2": {
25
+ "id": "model2",
26
+ "name": "DistilBERT",
27
+ "description": "",
28
+ "status": "coming_soon",
29
+ "module": model2,
30
+ },
31
+ "model3": {
32
+ "id": "model3",
33
+ "name": "BiLSTM",
34
+ "description": "",
35
+ "status": "ready",
36
+ "module": model3,
37
+ },
38
+ }
39
+
40
+
41
+ def init_all_models():
42
+ """Initialise all models at application startup."""
43
+ logger.info("[QAEngine] Initialising models...")
44
+ bert_model.init_bert_model()
45
+ model2.init_model2()
46
+ model3.init_model3()
47
+ logger.info("[QAEngine] All models initialised.")
48
+
49
+
50
+ def get_models_info() -> list:
51
+ """Return metadata list for all models (used by /api/models endpoint)."""
52
+ try:
53
+ sys_conf = settings_col().find_one({"_id": "system_config"}) or {}
54
+ model_status_overrides = sys_conf.get("model_status", {})
55
+ except Exception:
56
+ model_status_overrides = {}
57
+
58
+ return [
59
+ {
60
+ "id": m["id"],
61
+ "name": m["name"],
62
+ "description": m["description"],
63
+ "status": model_status_overrides.get(m["id"], m["status"]),
64
+ }
65
+ for m in MODELS.values()
66
+ ]
67
+
68
+
69
+ def run_inference(model_id: str, context: str, question: str) -> dict:
70
+ """
71
+ Route a QA request to the appropriate model.
72
+
73
+ Args:
74
+ model_id: One of "bert", "model2", "model3"
75
+ context: The passage/document text
76
+ question: The question to answer
77
+
78
+ Returns:
79
+ dict with keys: answer, score, model, model_id, error
80
+ """
81
+ if model_id not in MODELS:
82
+ return {
83
+ "answer": f"Unknown model '{model_id}'. Available: {list(MODELS.keys())}",
84
+ "score": 0.0,
85
+ "model": "Unknown",
86
+ "model_id": model_id,
87
+ "error": True,
88
+ }
89
+
90
+ try:
91
+ sys_conf = settings_col().find_one({"_id": "system_config"}) or {}
92
+ if sys_conf.get("maintenance_mode", False):
93
+ return {
94
+ "answer": "System is currently under maintenance. Please try again later.",
95
+ "score": 0.0,
96
+ "model": "System",
97
+ "model_id": model_id,
98
+ "error": True
99
+ }
100
+
101
+ status_override = sys_conf.get("model_status", {}).get(model_id)
102
+ current_status = status_override if status_override else MODELS[model_id]["status"]
103
+ if current_status != "ready":
104
+ return {
105
+ "answer": "This model is currently disabled by an administrator.",
106
+ "score": 0.0,
107
+ "model": MODELS[model_id]["name"],
108
+ "model_id": model_id,
109
+ "error": True
110
+ }
111
+ except Exception:
112
+ pass
113
+
114
+ module = MODELS[model_id]["module"]
115
+ return module.predict(context=context, question=question)
qa_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5ff35d1b92957d46df75fa375df83cf39c8998e51d4098cdb061a8b7fa7d028
3
+ size 43858657
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask==3.0.3
2
+ flask-cors==4.0.1
3
+ flask-bcrypt==1.0.1
4
+ pymongo==4.7.3
5
+ dnspython==2.6.1
6
+ pyjwt==2.8.0
7
+ python-dotenv==1.0.1
8
+ transformers>=4.40.0
9
+ torch>=2.0.0
10
+ PyPDF2==3.0.1
11
+ python-docx==1.1.2
12
+ gunicorn==22.0.0
13
+ Werkzeug==3.0.3
14
+ mongomock==4.1.2
15
+ python-magic==0.4.27
16
+ flask-limiter==3.7.0
train.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+
5
+ from data_loader.load_squad_json import load_squad_json
6
+ from utils.squad_preprocess import process_sample
7
+ from utils.vocab import build_vocab, encode
8
+ from models.qa_model import QAModel
9
+
10
+
11
+ class QADataset(Dataset):
12
+ def __init__(self, samples, vocab, max_len=300):
13
+ self.data = []
14
+
15
+ for s in samples:
16
+ item = process_sample(s)
17
+ if not item:
18
+ continue
19
+
20
+ tokens = item["tokens"]
21
+ encoded = encode(tokens, vocab)
22
+
23
+ if len(encoded) < max_len:
24
+ encoded += [0] * (max_len - len(encoded))
25
+ else:
26
+ encoded = encoded[:max_len]
27
+
28
+ start = item["start"]
29
+ end = item["end"]
30
+
31
+ if start >= max_len or end >= max_len:
32
+ continue
33
+
34
+ self.data.append((encoded, start, end))
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, idx):
40
+ x, s, e = self.data[idx]
41
+ return torch.tensor(x), torch.tensor(s), torch.tensor(e)
42
+
43
+
44
+ def train():
45
+ print("Loading data...")
46
+ raw = load_squad_json("data/train-v2.0.json")[:30000]
47
+
48
+ print("Building vocab...")
49
+ all_tokens = []
50
+ for s in raw:
51
+ item = process_sample(s)
52
+ if item:
53
+ all_tokens += item["tokens"]
54
+
55
+ vocab = build_vocab(all_tokens)
56
+
57
+ print("Preparing dataset...")
58
+ dataset = QADataset(raw, vocab)
59
+ loader = DataLoader(dataset, batch_size=32, shuffle=True)
60
+
61
+ print("Initializing model...")
62
+ model = QAModel(len(vocab))
63
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
64
+ loss_fn = nn.CrossEntropyLoss()
65
+
66
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
+ model.to(device)
68
+
69
+ print("Training...\n")
70
+
71
+ for epoch in range(5):
72
+ total_loss = 0
73
+
74
+ for x, start, end in loader:
75
+ x = x.to(device)
76
+ start = start.to(device)
77
+ end = end.to(device)
78
+
79
+ pred_start, pred_end = model(x)
80
+
81
+ loss = loss_fn(pred_start, start) + loss_fn(pred_end, end)
82
+
83
+ optimizer.zero_grad()
84
+ loss.backward()
85
+ optimizer.step()
86
+
87
+ total_loss += loss.item()
88
+
89
+ print(f"Epoch {epoch+1} Loss: {total_loss:.2f}")
90
+
91
+ torch.save({
92
+ "model_state": model.state_dict(),
93
+ "vocab": vocab
94
+ }, "qa_model.pth")
95
+
96
+ print("\n✅ Model trained and saved!")
97
+
98
+
99
+ if __name__ == "__main__":
100
+ train()
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # utils package
utils/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (143 Bytes). View file
 
utils/__pycache__/db.cpython-314.pyc ADDED
Binary file (3.57 kB). View file
 
utils/__pycache__/pdf_parser.cpython-314.pyc ADDED
Binary file (3.67 kB). View file
 
utils/__pycache__/preprocess.cpython-314.pyc ADDED
Binary file (448 Bytes). View file
 
utils/__pycache__/vocab.cpython-314.pyc ADDED
Binary file (713 Bytes). View file
 
utils/db.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ db.py — MongoDB Atlas connection with mongomock fallback.
3
+ If MONGO_URI is not set or the connection fails, the app runs on an
4
+ in-memory mock store so development works without any database.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv()
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ MONGO_URI = os.getenv("MONGO_URI") or os.getenv("MONGODB_URI") or ""
16
+ DB_NAME = "squad_qa"
17
+
18
+ _client = None
19
+ _db = None
20
+ _using_mock = False
21
+
22
+
23
+ def _connect_atlas():
24
+ """Attempt to connect to MongoDB Atlas (or local Mongo)."""
25
+ global _client, _db, _using_mock
26
+ try:
27
+ from pymongo import MongoClient
28
+ from pymongo.errors import ConnectionFailure, ConfigurationError, ServerSelectionTimeoutError
29
+
30
+ if not MONGO_URI or "username:password" in MONGO_URI:
31
+ raise ValueError("MONGO_URI not configured — falling back to mock.")
32
+
33
+ _client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000, tls=True, tlsAllowInvalidCertificates=True)
34
+ # Trigger actual connection check
35
+ _client.admin.command("ping")
36
+ _db = _client[DB_NAME]
37
+ _using_mock = False
38
+ logger.info("[DB] Connected to MongoDB Atlas successfully.")
39
+ except Exception as exc:
40
+ logger.warning(f"[DB] MongoDB connection failed: {exc}")
41
+ logger.warning("[DB] Falling back to in-memory mongomock.")
42
+ _connect_mock()
43
+
44
+
45
+ def _connect_mock():
46
+ """Fall back to mongomock (in-memory, no persistence)."""
47
+ global _client, _db, _using_mock
48
+ try:
49
+ import mongomock
50
+ _client = mongomock.MongoClient()
51
+ _db = _client[DB_NAME]
52
+ _using_mock = True
53
+ logger.warning("[DB] Running on mongomock — data will NOT persist across restarts.")
54
+ except ImportError:
55
+ logger.error("[DB] mongomock not installed. Database unavailable.")
56
+ _db = None
57
+
58
+
59
+ def get_db():
60
+ """Return the active database handle (Atlas or mock)."""
61
+ global _db
62
+ if _db is None:
63
+ _connect_atlas()
64
+ return _db
65
+
66
+
67
+ def is_using_mock():
68
+ return _using_mock
69
+
70
+
71
+ # Initialise on import
72
+ _connect_atlas()
73
+
74
+ # Convenience collection accessors
75
+ def users_col():
76
+ return get_db()["users"]
77
+
78
+ def chats_col():
79
+ return get_db()["chats"]
80
+
81
+ def settings_col():
82
+ return get_db()["settings"]
utils/file_loader.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PyPDF2
2
+ import docx
3
+
4
+ def load_txt(file_path):
5
+ with open(file_path, "r", encoding="utf-8") as f:
6
+ return f.read()
7
+
8
+ def load_pdf(file_path):
9
+ text = ""
10
+ with open(file_path, "rb") as f:
11
+ reader = PyPDF2.PdfReader(f)
12
+ for page in reader.pages:
13
+ if page.extract_text():
14
+ text += page.extract_text()
15
+ return text
16
+
17
+ def load_docx(file_path):
18
+ doc = docx.Document(file_path)
19
+ return "\n".join([p.text for p in doc.paragraphs])
utils/pdf_parser.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ pdf_parser.py — Extract plain text from PDF, DOCX, and TXT files.
3
+ """
4
+
5
+ import os
6
+ import logging
7
+ from io import BytesIO
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ PDF_MAX_PAGES = int(os.getenv("PDF_MAX_PAGES", "15"))
12
+
13
+
14
+ def extract_text_from_pdf(file_bytes: bytes) -> str:
15
+ """Extract text from a PDF byte stream (up to PDF_MAX_PAGES pages)."""
16
+ try:
17
+ import PyPDF2
18
+ reader = PyPDF2.PdfReader(BytesIO(file_bytes))
19
+ pages = reader.pages[:PDF_MAX_PAGES]
20
+ text = "\n".join(page.extract_text() or "" for page in pages)
21
+ return text.strip()
22
+ except Exception as exc:
23
+ logger.error(f"[PDF] Extraction failed: {exc}")
24
+ return ""
25
+
26
+
27
+ def extract_text_from_docx(file_bytes: bytes) -> str:
28
+ """Extract text from a DOCX byte stream."""
29
+ try:
30
+ import docx
31
+ from io import BytesIO as _BytesIO
32
+ doc = docx.Document(_BytesIO(file_bytes))
33
+ return "\n".join(para.text for para in doc.paragraphs).strip()
34
+ except Exception as exc:
35
+ logger.error(f"[DOCX] Extraction failed: {exc}")
36
+ return ""
37
+
38
+
39
+ def extract_text(file_bytes: bytes, filename: str) -> str:
40
+ """Dispatch extraction based on file extension."""
41
+ ext = os.path.splitext(filename.lower())[1]
42
+ if ext == ".pdf":
43
+ return extract_text_from_pdf(file_bytes)
44
+ elif ext in (".docx", ".doc"):
45
+ return extract_text_from_docx(file_bytes)
46
+ elif ext == ".txt":
47
+ return file_bytes.decode("utf-8", errors="ignore").strip()
48
+ else:
49
+ raise ValueError(f"Unsupported file type: {ext}. Allowed: PDF, DOCX, TXT.")
utils/preprocess.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def tokenize(text):
4
+ text = text.lower()
5
+ text = re.sub(r"[^\w\s]", "", text)
6
+ return text.split()
utils/squad_preprocess.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.preprocess import tokenize
2
+
3
+ def process_sample(sample):
4
+ context_tokens = tokenize(sample["context"])
5
+ question_tokens = tokenize(sample["question"])
6
+ answer_tokens = tokenize(sample["answer_text"])
7
+
8
+ # 🔥 Combine question + context
9
+ tokens = question_tokens + ["[SEP]"] + context_tokens
10
+
11
+ start = -1
12
+ for i in range(len(context_tokens)):
13
+ if context_tokens[i:i+len(answer_tokens)] == answer_tokens:
14
+ start = i + len(question_tokens) + 1
15
+ break
16
+
17
+ if start == -1:
18
+ return None
19
+
20
+ end = start + len(answer_tokens) - 1
21
+
22
+ return {
23
+ "tokens": tokens,
24
+ "start": start,
25
+ "end": end
26
+ }
utils/vocab.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+
3
+ def build_vocab(tokens):
4
+ vocab = {"<PAD>":0, "<UNK>":1}
5
+ counter = Counter(tokens)
6
+
7
+ for word in counter:
8
+ vocab[word] = len(vocab)
9
+
10
+ return vocab
11
+
12
+ def encode(tokens, vocab):
13
+ return [vocab.get(t, vocab["<UNK>"]) for t in tokens]