File size: 23,195 Bytes
09daf0b e4e93df 09daf0b a6fb0d0 09daf0b a6fb0d0 09daf0b a6fb0d0 09daf0b a6fb0d0 09daf0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 | """
app.py β Main Flask application for the SQuAD QA System.
Endpoints:
Public:
POST /api/auth/register
POST /api/auth/login
GET /api/health
Authenticated (any user):
GET /api/auth/me
GET /api/models
POST /api/ask
GET /api/history
DELETE /api/history/<chat_id>
DELETE /api/history
Admin only:
GET /api/admin/users
PUT /api/admin/users/<user_id>
DELETE /api/admin/users/<user_id>
GET /api/admin/stats
"""
import os
import sys
import logging
import re
from datetime import datetime, timezone, timedelta
from flask import Flask, request, jsonify, g
from flask_cors import CORS
from flask_bcrypt import Bcrypt
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from bson import ObjectId
from dotenv import load_dotenv
# βββ Load environment βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
load_dotenv()
# βββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
# βββ App init βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = Flask(__name__)
bcrypt = Bcrypt(app)
limiter = Limiter(
get_remote_address,
app=app,
default_limits=["1000 per day", "100 per hour"],
storage_uri="memory://"
)
app.config['MAX_CONTENT_LENGTH'] = 5 * 1024 * 1024 # 5 MB max constraint
# βββ CORS (reads from env for cloud safety) βββββββββββββββββββββββββββββββββββ
raw_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:5173,http://localhost:3000")
allowed_origins = [o.strip() for o in raw_origins.split(",") if o.strip()]
CORS(app, origins=allowed_origins, supports_credentials=True)
# βββ Internal imports (after app init) βββββββββββββββββββββββββββββββββββββββ
from auth import generate_token, require_auth, require_admin
from utils.db import users_col, chats_col, settings_col, is_using_mock
from utils.pdf_parser import extract_text
import qa_engine
@app.route("/")
def index():
return jsonify({
"message": "SQuAD QA System API is Live",
"status": "success",
"version": "1.1.0",
"models_loaded": ["BERT", "BiLSTM"],
"docs": "https://github.com/tnp554/squad-qa-system"
})
# βββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _serialize(doc: dict) -> dict:
"""Convert MongoDB ObjectId fields to strings for JSON serialization."""
if doc is None:
return None
doc = dict(doc)
if "_id" in doc:
doc["id"] = str(doc.pop("_id"))
return doc
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _future_iso(seconds: int) -> str:
return (datetime.now(timezone.utc) + timedelta(seconds=seconds)).isoformat()
def safe_str(val) -> str:
"""Ensure the input is strictly a string, preventing NoSQL injection dicts."""
if not isinstance(val, str):
return ""
return val.strip()
def send_otp_email(to_email, otp):
"""Sends OTP via real Gmail SMTP if ENV vars exist."""
email_user = os.getenv("EMAIL_USER")
email_pass = os.getenv("EMAIL_PASS")
if not email_user or not email_pass:
# Fallback to mock logging if user hasn't put in valid app passwords yet
logger.warning("=" * 60)
logger.warning(f" [MOCK EMAIL OTP] Verification code for {to_email}: {otp}")
logger.warning("=" * 60)
return False
try:
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
msg = MIMEMultipart()
msg['From'] = email_user
msg['To'] = to_email
msg['Subject'] = "SQuAD QA - Your Verification Code"
body = f"Welcome to SQuAD QA!!!\n\nYour 6-digit registration verification code is: {otp}\n\nPlease enter this code to complete your registration.\n\nThank you!!!"
msg.attach(MIMEText(body, 'plain'))
server = smtplib.SMTP_SSL('smtp.gmail.com', 465)
server.login(email_user, email_pass)
server.send_message(msg)
server.quit()
logger.info(f"[SMTP] Successfully dispatched OTP to {to_email}")
return True
except Exception as e:
logger.error(f"[SMTP ERROR] Failed to send actual email to {to_email}: {e}")
return False
# βββ Admin Seed βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _seed_admin():
"""Create the default admin user if it doesn't exist."""
admin_email = os.getenv("ADMIN_EMAIL", "admin@squad.ai")
admin_password = os.getenv("ADMIN_PASSWORD", "Admin@123")
col = users_col()
if col.find_one({"email": admin_email}):
logger.info(f"[Seed] Admin user '{admin_email}' already exists.")
return
hashed = bcrypt.generate_password_hash(admin_password).decode("utf-8")
col.insert_one({
"name": "Administrator",
"email": admin_email,
"password": hashed,
"role": "admin",
"is_active": True,
"created_at": _now_iso(),
"last_login": None,
})
logger.info(f"[Seed] Admin user '{admin_email}' created.")
# βββ Health βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.route("/api/health", methods=["GET"])
def health():
return jsonify({
"status": "ok",
"db_mode": "mock" if is_using_mock() else "atlas",
"timestamp": _now_iso(),
})
# βββ Auth Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.route("/api/auth/register", methods=["POST"])
@limiter.limit("10 per hour")
def register():
data = request.get_json(silent=True) or {}
name = safe_str(data.get("name"))
email = safe_str(data.get("email")).lower()
password = safe_str(data.get("password"))
if not name or not email or not password:
return jsonify({"error": "Name, email, and password are required."}), 400
password_regex = r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[@$!%*?&#^])[A-Za-z\d@$!%*?&#^]{8,}$"
if not re.match(password_regex, password):
return jsonify({"error": "Password must be at least 8 characters and include uppercase, lowercase, number, and a special character."}), 400
col = users_col()
sys_col = settings_col()
sys_conf = sys_col.find_one({"_id": "system_config"}) or {}
if sys_conf.get("disable_registrations", False):
return jsonify({"error": "New user registrations are currently disabled by the administrator."}), 403
if col.find_one({"email": email}):
return jsonify({"error": "An account with this email already exists."}), 409
hashed = bcrypt.generate_password_hash(password).decode("utf-8")
import random
otp = str(random.randint(100000, 999999))
send_otp_email(email, otp)
result = col.insert_one({
"name": name,
"email": email,
"password": hashed,
"role": "user",
"is_active": False,
"is_verified": False,
"otp": otp,
"otp_expires_at": _future_iso(60),
"created_at": _now_iso(),
"last_login": None,
})
return jsonify({
"message": "OTP sent to email. Please verify your account.",
"requires_otp": True
}), 201
@app.route("/api/auth/verify", methods=["POST"])
@limiter.limit("5 per minute")
def verify_otp():
data = request.get_json(silent=True) or {}
email = safe_str(data.get("email")).lower()
otp = safe_str(data.get("otp"))
if not email or not otp:
return jsonify({"error": "Email and OTP are required."}), 400
col = users_col()
user = col.find_one({"email": email})
if not user:
return jsonify({"error": "User not found."}), 404
if user.get("is_verified", False):
return jsonify({"error": "Account already verified."}), 400
expires_at = user.get("otp_expires_at")
if expires_at and _now_iso() > expires_at:
return jsonify({"error": "OTP has expired. Please request a new one."}), 400
if str(user.get("otp")) != str(otp):
return jsonify({"error": "Invalid verification code."}), 400
col.update_one({"_id": user["_id"]}, {"$set": {"is_verified": True, "is_active": True, "otp": None}})
user_id = str(user["_id"])
from auth import generate_token
role = user.get("role", "user")
token = generate_token(user_id, role)
col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}})
return jsonify({
"message": "Account verified successfully.",
"token": token,
"user": {"id": user_id, "name": user["name"], "email": user["email"], "role": role},
}), 200
@app.route("/api/auth/resend-otp", methods=["POST"])
@limiter.limit("3 per minute")
def resend_otp():
data = request.get_json(silent=True) or {}
email = safe_str(data.get("email")).lower()
if not email:
return jsonify({"error": "Email is required."}), 400
col = users_col()
user = col.find_one({"email": email})
if not user:
return jsonify({"error": "User not found."}), 404
if user.get("is_verified", False):
return jsonify({"error": "Account is already verified."}), 400
import random
new_otp = str(random.randint(100000, 999999))
col.update_one({"_id": user["_id"]}, {"$set": {"otp": new_otp, "otp_expires_at": _future_iso(60)}})
send_otp_email(email, new_otp)
return jsonify({"message": "A new OTP has been sent to your email."}), 200
@app.route("/api/auth/login", methods=["POST"])
@limiter.limit("15 per minute")
def login():
data = request.get_json(silent=True) or {}
email = safe_str(data.get("email")).lower()
password = safe_str(data.get("password"))
if not email or not password:
return jsonify({"error": "Email and password are required."}), 400
col = users_col()
user = col.find_one({"email": email})
if not user or not bcrypt.check_password_hash(user["password"], password):
return jsonify({"error": "Invalid email or password."}), 401
if not user.get("is_verified", True):
# We can trigger verify if they try to login while unverified, but for simplicity:
return jsonify({"error": "Your account is not verified. Please check your email for the OTP."}), 403
if not user.get("is_active", True):
return jsonify({"error": "Your account has been deactivated. Contact admin."}), 403
user_id = str(user["_id"])
role = user.get("role", "user")
token = generate_token(user_id, role)
# Update last_login
col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}})
return jsonify({
"message": "Login successful.",
"token": token,
"user": {
"id": user_id,
"name": user["name"],
"email": user["email"],
"role": role,
},
})
@app.route("/api/auth/me", methods=["GET"])
@require_auth
def me():
from bson import ObjectId as ObjId
col = users_col()
try:
user = col.find_one({"_id": ObjId(g.current_user["id"])})
except Exception:
user = col.find_one({"_id": g.current_user["id"]})
if not user:
return jsonify({"error": "User not found."}), 404
user = _serialize(user)
user.pop("password", None)
return jsonify({"user": user})
# βββ Models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.route("/api/models", methods=["GET"])
@require_auth
def get_models():
models_info = qa_engine.get_models_info()
ready_ids = [m["id"] for m in models_info if m.get("status") == "ready"]
pipeline = [
{"$match": {"model_id": {"$in": ready_ids}, "error": False}},
{"$group": {"_id": "$model_id", "avg_score": {"$avg": "$score"}, "count": {"$sum": 1}}}
]
try:
from utils.db import chats_col
stats = {doc["_id"]: doc for doc in chats_col().aggregate(pipeline)}
total_queries = sum(d["count"] for d in stats.values())
total_score = sum(d["avg_score"] * d["count"] for d in stats.values())
global_avg = (total_score / total_queries) if total_queries > 0 else 0
except Exception:
stats = {}
global_avg = 0
total_queries = 0
for m in models_info:
model_stat = stats.get(m["id"], {})
m["avg_score"] = model_stat.get("avg_score", 0.0)
m["query_count"] = model_stat.get("count", 0)
return jsonify({
"models": models_info,
"global_avg": global_avg,
"total_queries": total_queries
})
# βββ Ask (QA Inference) βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.route("/api/ask", methods=["POST"])
@require_auth
@limiter.limit("30 per minute")
def ask():
model_id = "bert"
context = ""
question = ""
# ββ File upload (multipart form) ββ
if request.content_type and "multipart/form-data" in request.content_type:
model_id = safe_str(request.form.get("model_id")) or "bert"
question = safe_str(request.form.get("question"))
file = request.files.get("file")
if file:
try:
import magic
buffer = file.read()
mime = magic.from_buffer(buffer, mime=True)
allowed_mimes = ["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]
if mime not in allowed_mimes:
return jsonify({"error": f"Security system rejected {mime}. Only true PDF/DOCX files permitted."}), 400
from utils.pdf_parser import extract_text
context = extract_text(buffer, file.filename)
except ValueError as exc:
return jsonify({"error": str(exc)}), 400
else:
context = safe_str(request.form.get("context"))
else:
# ββ JSON body ββ
data = request.get_json(silent=True) or {}
model_id = safe_str(data.get("model_id")) or "bert"
context = safe_str(data.get("context"))
question = safe_str(data.get("question"))
if not context:
return jsonify({"error": "Context (text or file) is required."}), 400
if not question:
return jsonify({"error": "Question is required."}), 400
# ββ Run inference ββ
result = qa_engine.run_inference(model_id, context, question)
# ββ Persist to DB ββ
chat_doc = {
"user_id": g.current_user["id"],
"model_id": model_id,
"model_name": result.get("model", model_id),
"context": context[:2000], # truncate for storage
"question": question,
"answer": result.get("answer", ""),
"score": result.get("score", 0.0),
"error": result.get("error", False),
"created_at": _now_iso(),
}
insert_result = chats_col().insert_one(chat_doc)
result["chat_id"] = str(insert_result.inserted_id)
return jsonify(result)
# βββ History ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.route("/api/history", methods=["GET"])
@require_auth
def get_history():
col = chats_col()
docs = list(col.find(
{"user_id": g.current_user["id"], "user_deleted": {"$ne": True}},
sort=[("created_at", -1)],
limit=50,
))
return jsonify({"history": [_serialize(d) for d in docs]})
@app.route("/api/history/<chat_id>", methods=["DELETE"])
@require_auth
def delete_chat(chat_id):
from bson import ObjectId as ObjId
col = chats_col()
try:
res = col.update_one(
{"_id": ObjId(chat_id), "user_id": g.current_user["id"]},
{"$set": {"user_deleted": True}}
)
except Exception:
return jsonify({"error": "Invalid chat ID."}), 400
if res.matched_count == 0:
return jsonify({"error": "Chat not found or not owned by you."}), 404
return jsonify({"message": "Chat deleted."})
@app.route("/api/history", methods=["DELETE"])
@require_auth
def clear_history():
col = chats_col()
res = col.update_many(
{"user_id": g.current_user["id"]},
{"$set": {"user_deleted": True}}
)
return jsonify({"message": f"Cleared {res.modified_count} chat(s)."})
# βββ Admin Routes βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.route("/api/admin/users", methods=["GET"])
@require_admin
def admin_list_users():
col = users_col()
users = list(col.find({}, sort=[("created_at", -1)]))
result = []
for u in users:
u = _serialize(u)
u.pop("password", None)
result.append(u)
return jsonify({"users": result, "total": len(result)})
@app.route("/api/admin/users/<user_id>", methods=["PUT"])
@require_admin
def admin_update_user(user_id):
from bson import ObjectId as ObjId
data = request.get_json(silent=True) or {}
allowed_fields = {"name", "role", "is_active"}
update = {k: v for k, v in data.items() if k in allowed_fields}
if not update:
return jsonify({"error": "No valid fields to update."}), 400
col = users_col()
try:
res = col.update_one({"_id": ObjId(user_id)}, {"$set": update})
except Exception:
return jsonify({"error": "Invalid user ID."}), 400
if res.matched_count == 0:
return jsonify({"error": "User not found."}), 404
return jsonify({"message": "User updated successfully."})
@app.route("/api/admin/users/<user_id>", methods=["DELETE"])
@require_admin
def admin_delete_user(user_id):
from bson import ObjectId as ObjId
# Prevent self-deletion
if user_id == g.current_user["id"]:
return jsonify({"error": "You cannot delete your own account."}), 400
col = users_col()
try:
res = col.delete_one({"_id": ObjId(user_id)})
except Exception:
return jsonify({"error": "Invalid user ID."}), 400
if res.deleted_count == 0:
return jsonify({"error": "User not found."}), 404
# Also logically remove their chat history
chats_col().update_many(
{"user_id": user_id},
{"$set": {"user_deleted": True, "admin_deleted_user": True}}
)
return jsonify({"message": "User and their history deleted."})
@app.route("/api/admin/stats", methods=["GET"])
@require_admin
def admin_stats():
users = users_col()
chats = chats_col()
total_users = users.count_documents({})
total_queries = chats.count_documents({})
# Model usage breakdown
pipeline = [
{"$group": {"_id": "$model_id", "count": {"$sum": 1}}}
]
try:
model_usage = {doc["_id"]: doc["count"] for doc in chats.aggregate(pipeline)}
except Exception:
model_usage = {}
# Timeseries data for graphs
ts_pipeline = [
{"$project": {"date": {"$substr": ["$created_at", 0, 10]}}},
{"$group": {"_id": "$date", "queries": {"$sum": 1}}},
{"$sort": {"_id": 1}},
{"$limit": 30}
]
try:
timeseries = [{"date": doc["_id"], "queries": doc["queries"]} for doc in chats.aggregate(ts_pipeline)]
except Exception:
timeseries = []
return jsonify({
"total_users": total_users,
"total_queries": total_queries,
"model_usage": model_usage,
"timeseries": timeseries,
"db_mode": "mock" if is_using_mock() else "atlas",
})
@app.route("/api/admin/settings", methods=["GET"])
@require_admin
def get_settings():
col = settings_col()
doc = col.find_one({"_id": "system_config"})
if not doc:
doc = {"_id": "system_config", "disable_registrations": False, "maintenance_mode": False}
col.insert_one(doc)
return jsonify({"settings": _serialize(doc)})
@app.route("/api/admin/settings", methods=["PUT"])
@require_admin
def update_settings():
data = request.get_json(silent=True) or {}
allowed = {"disable_registrations", "maintenance_mode"}
update = {k: v for k, v in data.items() if k in allowed}
if not update:
return jsonify({"error": "No valid settings provided."}), 400
col = settings_col()
col.update_one({"_id": "system_config"}, {"$set": update}, upsert=True)
return jsonify({"message": "Settings updated."})
@app.route("/api/admin/models/<model_id>", methods=["PUT"])
@require_admin
def toggle_model_status(model_id):
if model_id not in qa_engine.MODELS:
return jsonify({"error": "Invalid model ID."}), 404
data = request.get_json(silent=True) or {}
target_status = data.get("status")
if target_status not in ["ready", "maintenance"]:
return jsonify({"error": "Invalid status."}), 400
col = settings_col()
col.update_one({"_id": "system_config"}, {"$set": {f"model_status.{model_id}": target_status}}, upsert=True)
return jsonify({"message": f"Model {model_id} status updated to {target_status}."})
# βββ Entry Point ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
logger.info("=" * 60)
logger.info(" SQuAD QA System β Backend Starting (Production Mode)")
logger.info("=" * 60)
# Initialise AI models
qa_engine.init_all_models()
# Seed admin user
_seed_admin()
if __name__ == "__main__":
flask_env = os.getenv("FLASK_ENV", "development")
debug = flask_env == "development"
app.run(host="0.0.0.0", port=5000, debug=debug)
|