| """ |
| app.py β Main Flask application for the SQuAD QA System. |
| |
| Endpoints: |
| Public: |
| POST /api/auth/register |
| POST /api/auth/login |
| GET /api/health |
| |
| Authenticated (any user): |
| GET /api/auth/me |
| GET /api/models |
| POST /api/ask |
| GET /api/history |
| DELETE /api/history/<chat_id> |
| DELETE /api/history |
| |
| Admin only: |
| GET /api/admin/users |
| PUT /api/admin/users/<user_id> |
| DELETE /api/admin/users/<user_id> |
| GET /api/admin/stats |
| """ |
|
|
| import os |
| import sys |
| import logging |
| import re |
| from datetime import datetime, timezone, timedelta |
|
|
| from flask import Flask, request, jsonify, g |
| from flask_cors import CORS |
| from flask_bcrypt import Bcrypt |
| from flask_limiter import Limiter |
| from flask_limiter.util import get_remote_address |
| from bson import ObjectId |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| stream=sys.stdout, |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = Flask(__name__) |
| bcrypt = Bcrypt(app) |
| limiter = Limiter( |
| get_remote_address, |
| app=app, |
| default_limits=["1000 per day", "100 per hour"], |
| storage_uri="memory://" |
| ) |
| app.config['MAX_CONTENT_LENGTH'] = 5 * 1024 * 1024 |
|
|
| |
| raw_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:5173,http://localhost:3000") |
| allowed_origins = [o.strip() for o in raw_origins.split(",") if o.strip()] |
| CORS(app, origins=allowed_origins, supports_credentials=True) |
|
|
| |
| from auth import generate_token, require_auth, require_admin |
| from utils.db import users_col, chats_col, settings_col, is_using_mock |
| from utils.pdf_parser import extract_text |
| import qa_engine |
|
|
| @app.route("/") |
| def index(): |
| return jsonify({ |
| "message": "SQuAD QA System API is Live", |
| "status": "success", |
| "version": "1.1.0", |
| "models_loaded": ["BERT", "BiLSTM"], |
| "docs": "https://github.com/tnp554/squad-qa-system" |
| }) |
|
|
| |
|
|
| def _serialize(doc: dict) -> dict: |
| """Convert MongoDB ObjectId fields to strings for JSON serialization.""" |
| if doc is None: |
| return None |
| doc = dict(doc) |
| if "_id" in doc: |
| doc["id"] = str(doc.pop("_id")) |
| return doc |
|
|
|
|
| def _now_iso() -> str: |
| return datetime.now(timezone.utc).isoformat() |
|
|
| def _future_iso(seconds: int) -> str: |
| return (datetime.now(timezone.utc) + timedelta(seconds=seconds)).isoformat() |
|
|
| def safe_str(val) -> str: |
| """Ensure the input is strictly a string, preventing NoSQL injection dicts.""" |
| if not isinstance(val, str): |
| return "" |
| return val.strip() |
|
|
| def send_otp_email(to_email, otp): |
| """Sends OTP via real Gmail SMTP if ENV vars exist.""" |
| email_user = os.getenv("EMAIL_USER") |
| email_pass = os.getenv("EMAIL_PASS") |
| if not email_user or not email_pass: |
| |
| logger.warning("=" * 60) |
| logger.warning(f" [MOCK EMAIL OTP] Verification code for {to_email}: {otp}") |
| logger.warning("=" * 60) |
| return False |
| |
| try: |
| import smtplib |
| from email.mime.text import MIMEText |
| from email.mime.multipart import MIMEMultipart |
| msg = MIMEMultipart() |
| msg['From'] = email_user |
| msg['To'] = to_email |
| msg['Subject'] = "SQuAD QA - Your Verification Code" |
| body = f"Welcome to SQuAD QA!!!\n\nYour 6-digit registration verification code is: {otp}\n\nPlease enter this code to complete your registration.\n\nThank you!!!" |
| msg.attach(MIMEText(body, 'plain')) |
| |
| server = smtplib.SMTP_SSL('smtp.gmail.com', 465) |
| server.login(email_user, email_pass) |
| server.send_message(msg) |
| server.quit() |
| logger.info(f"[SMTP] Successfully dispatched OTP to {to_email}") |
| return True |
| except Exception as e: |
| logger.error(f"[SMTP ERROR] Failed to send actual email to {to_email}: {e}") |
| return False |
|
|
|
|
| |
|
|
| def _seed_admin(): |
| """Create the default admin user if it doesn't exist.""" |
| admin_email = os.getenv("ADMIN_EMAIL", "admin@squad.ai") |
| admin_password = os.getenv("ADMIN_PASSWORD", "Admin@123") |
|
|
| col = users_col() |
| if col.find_one({"email": admin_email}): |
| logger.info(f"[Seed] Admin user '{admin_email}' already exists.") |
| return |
|
|
| hashed = bcrypt.generate_password_hash(admin_password).decode("utf-8") |
| col.insert_one({ |
| "name": "Administrator", |
| "email": admin_email, |
| "password": hashed, |
| "role": "admin", |
| "is_active": True, |
| "created_at": _now_iso(), |
| "last_login": None, |
| }) |
| logger.info(f"[Seed] Admin user '{admin_email}' created.") |
|
|
|
|
| |
|
|
| @app.route("/api/health", methods=["GET"]) |
| def health(): |
| return jsonify({ |
| "status": "ok", |
| "db_mode": "mock" if is_using_mock() else "atlas", |
| "timestamp": _now_iso(), |
| }) |
|
|
|
|
| |
|
|
| @app.route("/api/auth/register", methods=["POST"]) |
| @limiter.limit("10 per hour") |
| def register(): |
| data = request.get_json(silent=True) or {} |
| name = safe_str(data.get("name")) |
| email = safe_str(data.get("email")).lower() |
| password = safe_str(data.get("password")) |
|
|
| if not name or not email or not password: |
| return jsonify({"error": "Name, email, and password are required."}), 400 |
| |
| password_regex = r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[@$!%*?&#^])[A-Za-z\d@$!%*?&#^]{8,}$" |
| if not re.match(password_regex, password): |
| return jsonify({"error": "Password must be at least 8 characters and include uppercase, lowercase, number, and a special character."}), 400 |
|
|
| col = users_col() |
| |
| sys_col = settings_col() |
| sys_conf = sys_col.find_one({"_id": "system_config"}) or {} |
| if sys_conf.get("disable_registrations", False): |
| return jsonify({"error": "New user registrations are currently disabled by the administrator."}), 403 |
|
|
| if col.find_one({"email": email}): |
| return jsonify({"error": "An account with this email already exists."}), 409 |
|
|
| hashed = bcrypt.generate_password_hash(password).decode("utf-8") |
| import random |
| otp = str(random.randint(100000, 999999)) |
| send_otp_email(email, otp) |
|
|
| result = col.insert_one({ |
| "name": name, |
| "email": email, |
| "password": hashed, |
| "role": "user", |
| "is_active": False, |
| "is_verified": False, |
| "otp": otp, |
| "otp_expires_at": _future_iso(60), |
| "created_at": _now_iso(), |
| "last_login": None, |
| }) |
|
|
| return jsonify({ |
| "message": "OTP sent to email. Please verify your account.", |
| "requires_otp": True |
| }), 201 |
|
|
| @app.route("/api/auth/verify", methods=["POST"]) |
| @limiter.limit("5 per minute") |
| def verify_otp(): |
| data = request.get_json(silent=True) or {} |
| email = safe_str(data.get("email")).lower() |
| otp = safe_str(data.get("otp")) |
|
|
| if not email or not otp: |
| return jsonify({"error": "Email and OTP are required."}), 400 |
|
|
| col = users_col() |
| user = col.find_one({"email": email}) |
|
|
| if not user: |
| return jsonify({"error": "User not found."}), 404 |
| |
| if user.get("is_verified", False): |
| return jsonify({"error": "Account already verified."}), 400 |
|
|
| expires_at = user.get("otp_expires_at") |
| if expires_at and _now_iso() > expires_at: |
| return jsonify({"error": "OTP has expired. Please request a new one."}), 400 |
|
|
| if str(user.get("otp")) != str(otp): |
| return jsonify({"error": "Invalid verification code."}), 400 |
|
|
| col.update_one({"_id": user["_id"]}, {"$set": {"is_verified": True, "is_active": True, "otp": None}}) |
| |
| user_id = str(user["_id"]) |
| from auth import generate_token |
| role = user.get("role", "user") |
| token = generate_token(user_id, role) |
| col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}}) |
|
|
| return jsonify({ |
| "message": "Account verified successfully.", |
| "token": token, |
| "user": {"id": user_id, "name": user["name"], "email": user["email"], "role": role}, |
| }), 200 |
|
|
| @app.route("/api/auth/resend-otp", methods=["POST"]) |
| @limiter.limit("3 per minute") |
| def resend_otp(): |
| data = request.get_json(silent=True) or {} |
| email = safe_str(data.get("email")).lower() |
|
|
| if not email: |
| return jsonify({"error": "Email is required."}), 400 |
|
|
| col = users_col() |
| user = col.find_one({"email": email}) |
|
|
| if not user: |
| return jsonify({"error": "User not found."}), 404 |
| |
| if user.get("is_verified", False): |
| return jsonify({"error": "Account is already verified."}), 400 |
|
|
| import random |
| new_otp = str(random.randint(100000, 999999)) |
| |
| col.update_one({"_id": user["_id"]}, {"$set": {"otp": new_otp, "otp_expires_at": _future_iso(60)}}) |
| |
| send_otp_email(email, new_otp) |
|
|
| return jsonify({"message": "A new OTP has been sent to your email."}), 200 |
|
|
|
|
| @app.route("/api/auth/login", methods=["POST"]) |
| @limiter.limit("15 per minute") |
| def login(): |
| data = request.get_json(silent=True) or {} |
| email = safe_str(data.get("email")).lower() |
| password = safe_str(data.get("password")) |
|
|
| if not email or not password: |
| return jsonify({"error": "Email and password are required."}), 400 |
|
|
| col = users_col() |
| user = col.find_one({"email": email}) |
|
|
| if not user or not bcrypt.check_password_hash(user["password"], password): |
| return jsonify({"error": "Invalid email or password."}), 401 |
| if not user.get("is_verified", True): |
| |
| return jsonify({"error": "Your account is not verified. Please check your email for the OTP."}), 403 |
| if not user.get("is_active", True): |
| return jsonify({"error": "Your account has been deactivated. Contact admin."}), 403 |
|
|
| user_id = str(user["_id"]) |
| role = user.get("role", "user") |
| token = generate_token(user_id, role) |
|
|
| |
| col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}}) |
|
|
| return jsonify({ |
| "message": "Login successful.", |
| "token": token, |
| "user": { |
| "id": user_id, |
| "name": user["name"], |
| "email": user["email"], |
| "role": role, |
| }, |
| }) |
|
|
|
|
| @app.route("/api/auth/me", methods=["GET"]) |
| @require_auth |
| def me(): |
| from bson import ObjectId as ObjId |
| col = users_col() |
| try: |
| user = col.find_one({"_id": ObjId(g.current_user["id"])}) |
| except Exception: |
| user = col.find_one({"_id": g.current_user["id"]}) |
|
|
| if not user: |
| return jsonify({"error": "User not found."}), 404 |
|
|
| user = _serialize(user) |
| user.pop("password", None) |
| return jsonify({"user": user}) |
|
|
|
|
| |
|
|
| @app.route("/api/models", methods=["GET"]) |
| @require_auth |
| def get_models(): |
| models_info = qa_engine.get_models_info() |
| |
| ready_ids = [m["id"] for m in models_info if m.get("status") == "ready"] |
| pipeline = [ |
| {"$match": {"model_id": {"$in": ready_ids}, "error": False}}, |
| {"$group": {"_id": "$model_id", "avg_score": {"$avg": "$score"}, "count": {"$sum": 1}}} |
| ] |
| try: |
| from utils.db import chats_col |
| stats = {doc["_id"]: doc for doc in chats_col().aggregate(pipeline)} |
| total_queries = sum(d["count"] for d in stats.values()) |
| total_score = sum(d["avg_score"] * d["count"] for d in stats.values()) |
| global_avg = (total_score / total_queries) if total_queries > 0 else 0 |
| except Exception: |
| stats = {} |
| global_avg = 0 |
| total_queries = 0 |
|
|
| for m in models_info: |
| model_stat = stats.get(m["id"], {}) |
| m["avg_score"] = model_stat.get("avg_score", 0.0) |
| m["query_count"] = model_stat.get("count", 0) |
| |
| return jsonify({ |
| "models": models_info, |
| "global_avg": global_avg, |
| "total_queries": total_queries |
| }) |
|
|
|
|
| |
|
|
| @app.route("/api/ask", methods=["POST"]) |
| @require_auth |
| @limiter.limit("30 per minute") |
| def ask(): |
| model_id = "bert" |
| context = "" |
| question = "" |
|
|
| |
| if request.content_type and "multipart/form-data" in request.content_type: |
| model_id = safe_str(request.form.get("model_id")) or "bert" |
| question = safe_str(request.form.get("question")) |
| file = request.files.get("file") |
| if file: |
| try: |
| import magic |
| buffer = file.read() |
| mime = magic.from_buffer(buffer, mime=True) |
| allowed_mimes = ["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"] |
| if mime not in allowed_mimes: |
| return jsonify({"error": f"Security system rejected {mime}. Only true PDF/DOCX files permitted."}), 400 |
| from utils.pdf_parser import extract_text |
| context = extract_text(buffer, file.filename) |
| except ValueError as exc: |
| return jsonify({"error": str(exc)}), 400 |
| else: |
| context = safe_str(request.form.get("context")) |
| else: |
| |
| data = request.get_json(silent=True) or {} |
| model_id = safe_str(data.get("model_id")) or "bert" |
| context = safe_str(data.get("context")) |
| question = safe_str(data.get("question")) |
|
|
| if not context: |
| return jsonify({"error": "Context (text or file) is required."}), 400 |
| if not question: |
| return jsonify({"error": "Question is required."}), 400 |
|
|
| |
| result = qa_engine.run_inference(model_id, context, question) |
|
|
| |
| chat_doc = { |
| "user_id": g.current_user["id"], |
| "model_id": model_id, |
| "model_name": result.get("model", model_id), |
| "context": context[:2000], |
| "question": question, |
| "answer": result.get("answer", ""), |
| "score": result.get("score", 0.0), |
| "error": result.get("error", False), |
| "created_at": _now_iso(), |
| } |
| insert_result = chats_col().insert_one(chat_doc) |
| result["chat_id"] = str(insert_result.inserted_id) |
|
|
| return jsonify(result) |
|
|
|
|
| |
|
|
| @app.route("/api/history", methods=["GET"]) |
| @require_auth |
| def get_history(): |
| col = chats_col() |
| docs = list(col.find( |
| {"user_id": g.current_user["id"], "user_deleted": {"$ne": True}}, |
| sort=[("created_at", -1)], |
| limit=50, |
| )) |
| return jsonify({"history": [_serialize(d) for d in docs]}) |
|
|
|
|
| @app.route("/api/history/<chat_id>", methods=["DELETE"]) |
| @require_auth |
| def delete_chat(chat_id): |
| from bson import ObjectId as ObjId |
| col = chats_col() |
| try: |
| res = col.update_one( |
| {"_id": ObjId(chat_id), "user_id": g.current_user["id"]}, |
| {"$set": {"user_deleted": True}} |
| ) |
| except Exception: |
| return jsonify({"error": "Invalid chat ID."}), 400 |
|
|
| if res.matched_count == 0: |
| return jsonify({"error": "Chat not found or not owned by you."}), 404 |
| return jsonify({"message": "Chat deleted."}) |
|
|
|
|
| @app.route("/api/history", methods=["DELETE"]) |
| @require_auth |
| def clear_history(): |
| col = chats_col() |
| res = col.update_many( |
| {"user_id": g.current_user["id"]}, |
| {"$set": {"user_deleted": True}} |
| ) |
| return jsonify({"message": f"Cleared {res.modified_count} chat(s)."}) |
|
|
|
|
| |
|
|
| @app.route("/api/admin/users", methods=["GET"]) |
| @require_admin |
| def admin_list_users(): |
| col = users_col() |
| users = list(col.find({}, sort=[("created_at", -1)])) |
| result = [] |
| for u in users: |
| u = _serialize(u) |
| u.pop("password", None) |
| result.append(u) |
| return jsonify({"users": result, "total": len(result)}) |
|
|
|
|
| @app.route("/api/admin/users/<user_id>", methods=["PUT"]) |
| @require_admin |
| def admin_update_user(user_id): |
| from bson import ObjectId as ObjId |
| data = request.get_json(silent=True) or {} |
| allowed_fields = {"name", "role", "is_active"} |
| update = {k: v for k, v in data.items() if k in allowed_fields} |
|
|
| if not update: |
| return jsonify({"error": "No valid fields to update."}), 400 |
|
|
| col = users_col() |
| try: |
| res = col.update_one({"_id": ObjId(user_id)}, {"$set": update}) |
| except Exception: |
| return jsonify({"error": "Invalid user ID."}), 400 |
|
|
| if res.matched_count == 0: |
| return jsonify({"error": "User not found."}), 404 |
| return jsonify({"message": "User updated successfully."}) |
|
|
|
|
| @app.route("/api/admin/users/<user_id>", methods=["DELETE"]) |
| @require_admin |
| def admin_delete_user(user_id): |
| from bson import ObjectId as ObjId |
| |
| if user_id == g.current_user["id"]: |
| return jsonify({"error": "You cannot delete your own account."}), 400 |
|
|
| col = users_col() |
| try: |
| res = col.delete_one({"_id": ObjId(user_id)}) |
| except Exception: |
| return jsonify({"error": "Invalid user ID."}), 400 |
|
|
| if res.deleted_count == 0: |
| return jsonify({"error": "User not found."}), 404 |
|
|
| |
| chats_col().update_many( |
| {"user_id": user_id}, |
| {"$set": {"user_deleted": True, "admin_deleted_user": True}} |
| ) |
| return jsonify({"message": "User and their history deleted."}) |
|
|
|
|
| @app.route("/api/admin/stats", methods=["GET"]) |
| @require_admin |
| def admin_stats(): |
| users = users_col() |
| chats = chats_col() |
|
|
| total_users = users.count_documents({}) |
| total_queries = chats.count_documents({}) |
|
|
| |
| pipeline = [ |
| {"$group": {"_id": "$model_id", "count": {"$sum": 1}}} |
| ] |
| try: |
| model_usage = {doc["_id"]: doc["count"] for doc in chats.aggregate(pipeline)} |
| except Exception: |
| model_usage = {} |
|
|
| |
| ts_pipeline = [ |
| {"$project": {"date": {"$substr": ["$created_at", 0, 10]}}}, |
| {"$group": {"_id": "$date", "queries": {"$sum": 1}}}, |
| {"$sort": {"_id": 1}}, |
| {"$limit": 30} |
| ] |
| try: |
| timeseries = [{"date": doc["_id"], "queries": doc["queries"]} for doc in chats.aggregate(ts_pipeline)] |
| except Exception: |
| timeseries = [] |
|
|
| return jsonify({ |
| "total_users": total_users, |
| "total_queries": total_queries, |
| "model_usage": model_usage, |
| "timeseries": timeseries, |
| "db_mode": "mock" if is_using_mock() else "atlas", |
| }) |
|
|
| @app.route("/api/admin/settings", methods=["GET"]) |
| @require_admin |
| def get_settings(): |
| col = settings_col() |
| doc = col.find_one({"_id": "system_config"}) |
| if not doc: |
| doc = {"_id": "system_config", "disable_registrations": False, "maintenance_mode": False} |
| col.insert_one(doc) |
| return jsonify({"settings": _serialize(doc)}) |
|
|
| @app.route("/api/admin/settings", methods=["PUT"]) |
| @require_admin |
| def update_settings(): |
| data = request.get_json(silent=True) or {} |
| allowed = {"disable_registrations", "maintenance_mode"} |
| update = {k: v for k, v in data.items() if k in allowed} |
| if not update: |
| return jsonify({"error": "No valid settings provided."}), 400 |
| |
| col = settings_col() |
| col.update_one({"_id": "system_config"}, {"$set": update}, upsert=True) |
| return jsonify({"message": "Settings updated."}) |
|
|
| @app.route("/api/admin/models/<model_id>", methods=["PUT"]) |
| @require_admin |
| def toggle_model_status(model_id): |
| if model_id not in qa_engine.MODELS: |
| return jsonify({"error": "Invalid model ID."}), 404 |
| |
| data = request.get_json(silent=True) or {} |
| target_status = data.get("status") |
| if target_status not in ["ready", "maintenance"]: |
| return jsonify({"error": "Invalid status."}), 400 |
| |
| col = settings_col() |
| col.update_one({"_id": "system_config"}, {"$set": {f"model_status.{model_id}": target_status}}, upsert=True) |
| |
| return jsonify({"message": f"Model {model_id} status updated to {target_status}."}) |
|
|
|
|
|
|
| |
|
|
| logger.info("=" * 60) |
| logger.info(" SQuAD QA System β Backend Starting (Production Mode)") |
| logger.info("=" * 60) |
|
|
| |
| qa_engine.init_all_models() |
|
|
| |
| _seed_admin() |
|
|
| if __name__ == "__main__": |
| flask_env = os.getenv("FLASK_ENV", "development") |
| debug = flask_env == "development" |
| app.run(host="0.0.0.0", port=5000, debug=debug) |
|
|