SQuAD / app.py
tnp554's picture
feat: add landing page API info
e4e93df
"""
app.py β€” Main Flask application for the SQuAD QA System.
Endpoints:
Public:
POST /api/auth/register
POST /api/auth/login
GET /api/health
Authenticated (any user):
GET /api/auth/me
GET /api/models
POST /api/ask
GET /api/history
DELETE /api/history/<chat_id>
DELETE /api/history
Admin only:
GET /api/admin/users
PUT /api/admin/users/<user_id>
DELETE /api/admin/users/<user_id>
GET /api/admin/stats
"""
import os
import sys
import logging
import re
from datetime import datetime, timezone, timedelta
from flask import Flask, request, jsonify, g
from flask_cors import CORS
from flask_bcrypt import Bcrypt
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from bson import ObjectId
from dotenv import load_dotenv
# ─── Load environment ─────────────────────────────────────────────────────────
load_dotenv()
# ─── Logging ─────────────────────────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
# ─── App init ─────────────────────────────────────────────────────────────────
app = Flask(__name__)
bcrypt = Bcrypt(app)
limiter = Limiter(
get_remote_address,
app=app,
default_limits=["1000 per day", "100 per hour"],
storage_uri="memory://"
)
app.config['MAX_CONTENT_LENGTH'] = 5 * 1024 * 1024 # 5 MB max constraint
# ─── CORS (reads from env for cloud safety) ───────────────────────────────────
raw_origins = os.getenv("ALLOWED_ORIGINS", "http://localhost:5173,http://localhost:3000")
allowed_origins = [o.strip() for o in raw_origins.split(",") if o.strip()]
CORS(app, origins=allowed_origins, supports_credentials=True)
# ─── Internal imports (after app init) ───────────────────────────────────────
from auth import generate_token, require_auth, require_admin
from utils.db import users_col, chats_col, settings_col, is_using_mock
from utils.pdf_parser import extract_text
import qa_engine
@app.route("/")
def index():
return jsonify({
"message": "SQuAD QA System API is Live",
"status": "success",
"version": "1.1.0",
"models_loaded": ["BERT", "BiLSTM"],
"docs": "https://github.com/tnp554/squad-qa-system"
})
# ─── Helpers ─────────────────────────────────────────────────────────────────
def _serialize(doc: dict) -> dict:
"""Convert MongoDB ObjectId fields to strings for JSON serialization."""
if doc is None:
return None
doc = dict(doc)
if "_id" in doc:
doc["id"] = str(doc.pop("_id"))
return doc
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _future_iso(seconds: int) -> str:
return (datetime.now(timezone.utc) + timedelta(seconds=seconds)).isoformat()
def safe_str(val) -> str:
"""Ensure the input is strictly a string, preventing NoSQL injection dicts."""
if not isinstance(val, str):
return ""
return val.strip()
def send_otp_email(to_email, otp):
"""Sends OTP via real Gmail SMTP if ENV vars exist."""
email_user = os.getenv("EMAIL_USER")
email_pass = os.getenv("EMAIL_PASS")
if not email_user or not email_pass:
# Fallback to mock logging if user hasn't put in valid app passwords yet
logger.warning("=" * 60)
logger.warning(f" [MOCK EMAIL OTP] Verification code for {to_email}: {otp}")
logger.warning("=" * 60)
return False
try:
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
msg = MIMEMultipart()
msg['From'] = email_user
msg['To'] = to_email
msg['Subject'] = "SQuAD QA - Your Verification Code"
body = f"Welcome to SQuAD QA!!!\n\nYour 6-digit registration verification code is: {otp}\n\nPlease enter this code to complete your registration.\n\nThank you!!!"
msg.attach(MIMEText(body, 'plain'))
server = smtplib.SMTP_SSL('smtp.gmail.com', 465)
server.login(email_user, email_pass)
server.send_message(msg)
server.quit()
logger.info(f"[SMTP] Successfully dispatched OTP to {to_email}")
return True
except Exception as e:
logger.error(f"[SMTP ERROR] Failed to send actual email to {to_email}: {e}")
return False
# ─── Admin Seed ───────────────────────────────────────────────────────────────
def _seed_admin():
"""Create the default admin user if it doesn't exist."""
admin_email = os.getenv("ADMIN_EMAIL", "admin@squad.ai")
admin_password = os.getenv("ADMIN_PASSWORD", "Admin@123")
col = users_col()
if col.find_one({"email": admin_email}):
logger.info(f"[Seed] Admin user '{admin_email}' already exists.")
return
hashed = bcrypt.generate_password_hash(admin_password).decode("utf-8")
col.insert_one({
"name": "Administrator",
"email": admin_email,
"password": hashed,
"role": "admin",
"is_active": True,
"created_at": _now_iso(),
"last_login": None,
})
logger.info(f"[Seed] Admin user '{admin_email}' created.")
# ─── Health ───────────────────────────────────────────────────────────────────
@app.route("/api/health", methods=["GET"])
def health():
return jsonify({
"status": "ok",
"db_mode": "mock" if is_using_mock() else "atlas",
"timestamp": _now_iso(),
})
# ─── Auth Routes ──────────────────────────────────────────────────────────────
@app.route("/api/auth/register", methods=["POST"])
@limiter.limit("10 per hour")
def register():
data = request.get_json(silent=True) or {}
name = safe_str(data.get("name"))
email = safe_str(data.get("email")).lower()
password = safe_str(data.get("password"))
if not name or not email or not password:
return jsonify({"error": "Name, email, and password are required."}), 400
password_regex = r"^(?=.*[a-z])(?=.*[A-Z])(?=.*\d)(?=.*[@$!%*?&#^])[A-Za-z\d@$!%*?&#^]{8,}$"
if not re.match(password_regex, password):
return jsonify({"error": "Password must be at least 8 characters and include uppercase, lowercase, number, and a special character."}), 400
col = users_col()
sys_col = settings_col()
sys_conf = sys_col.find_one({"_id": "system_config"}) or {}
if sys_conf.get("disable_registrations", False):
return jsonify({"error": "New user registrations are currently disabled by the administrator."}), 403
if col.find_one({"email": email}):
return jsonify({"error": "An account with this email already exists."}), 409
hashed = bcrypt.generate_password_hash(password).decode("utf-8")
import random
otp = str(random.randint(100000, 999999))
send_otp_email(email, otp)
result = col.insert_one({
"name": name,
"email": email,
"password": hashed,
"role": "user",
"is_active": False,
"is_verified": False,
"otp": otp,
"otp_expires_at": _future_iso(60),
"created_at": _now_iso(),
"last_login": None,
})
return jsonify({
"message": "OTP sent to email. Please verify your account.",
"requires_otp": True
}), 201
@app.route("/api/auth/verify", methods=["POST"])
@limiter.limit("5 per minute")
def verify_otp():
data = request.get_json(silent=True) or {}
email = safe_str(data.get("email")).lower()
otp = safe_str(data.get("otp"))
if not email or not otp:
return jsonify({"error": "Email and OTP are required."}), 400
col = users_col()
user = col.find_one({"email": email})
if not user:
return jsonify({"error": "User not found."}), 404
if user.get("is_verified", False):
return jsonify({"error": "Account already verified."}), 400
expires_at = user.get("otp_expires_at")
if expires_at and _now_iso() > expires_at:
return jsonify({"error": "OTP has expired. Please request a new one."}), 400
if str(user.get("otp")) != str(otp):
return jsonify({"error": "Invalid verification code."}), 400
col.update_one({"_id": user["_id"]}, {"$set": {"is_verified": True, "is_active": True, "otp": None}})
user_id = str(user["_id"])
from auth import generate_token
role = user.get("role", "user")
token = generate_token(user_id, role)
col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}})
return jsonify({
"message": "Account verified successfully.",
"token": token,
"user": {"id": user_id, "name": user["name"], "email": user["email"], "role": role},
}), 200
@app.route("/api/auth/resend-otp", methods=["POST"])
@limiter.limit("3 per minute")
def resend_otp():
data = request.get_json(silent=True) or {}
email = safe_str(data.get("email")).lower()
if not email:
return jsonify({"error": "Email is required."}), 400
col = users_col()
user = col.find_one({"email": email})
if not user:
return jsonify({"error": "User not found."}), 404
if user.get("is_verified", False):
return jsonify({"error": "Account is already verified."}), 400
import random
new_otp = str(random.randint(100000, 999999))
col.update_one({"_id": user["_id"]}, {"$set": {"otp": new_otp, "otp_expires_at": _future_iso(60)}})
send_otp_email(email, new_otp)
return jsonify({"message": "A new OTP has been sent to your email."}), 200
@app.route("/api/auth/login", methods=["POST"])
@limiter.limit("15 per minute")
def login():
data = request.get_json(silent=True) or {}
email = safe_str(data.get("email")).lower()
password = safe_str(data.get("password"))
if not email or not password:
return jsonify({"error": "Email and password are required."}), 400
col = users_col()
user = col.find_one({"email": email})
if not user or not bcrypt.check_password_hash(user["password"], password):
return jsonify({"error": "Invalid email or password."}), 401
if not user.get("is_verified", True):
# We can trigger verify if they try to login while unverified, but for simplicity:
return jsonify({"error": "Your account is not verified. Please check your email for the OTP."}), 403
if not user.get("is_active", True):
return jsonify({"error": "Your account has been deactivated. Contact admin."}), 403
user_id = str(user["_id"])
role = user.get("role", "user")
token = generate_token(user_id, role)
# Update last_login
col.update_one({"_id": user["_id"]}, {"$set": {"last_login": _now_iso()}})
return jsonify({
"message": "Login successful.",
"token": token,
"user": {
"id": user_id,
"name": user["name"],
"email": user["email"],
"role": role,
},
})
@app.route("/api/auth/me", methods=["GET"])
@require_auth
def me():
from bson import ObjectId as ObjId
col = users_col()
try:
user = col.find_one({"_id": ObjId(g.current_user["id"])})
except Exception:
user = col.find_one({"_id": g.current_user["id"]})
if not user:
return jsonify({"error": "User not found."}), 404
user = _serialize(user)
user.pop("password", None)
return jsonify({"user": user})
# ─── Models ───────────────────────────────────────────────────────────────────
@app.route("/api/models", methods=["GET"])
@require_auth
def get_models():
models_info = qa_engine.get_models_info()
ready_ids = [m["id"] for m in models_info if m.get("status") == "ready"]
pipeline = [
{"$match": {"model_id": {"$in": ready_ids}, "error": False}},
{"$group": {"_id": "$model_id", "avg_score": {"$avg": "$score"}, "count": {"$sum": 1}}}
]
try:
from utils.db import chats_col
stats = {doc["_id"]: doc for doc in chats_col().aggregate(pipeline)}
total_queries = sum(d["count"] for d in stats.values())
total_score = sum(d["avg_score"] * d["count"] for d in stats.values())
global_avg = (total_score / total_queries) if total_queries > 0 else 0
except Exception:
stats = {}
global_avg = 0
total_queries = 0
for m in models_info:
model_stat = stats.get(m["id"], {})
m["avg_score"] = model_stat.get("avg_score", 0.0)
m["query_count"] = model_stat.get("count", 0)
return jsonify({
"models": models_info,
"global_avg": global_avg,
"total_queries": total_queries
})
# ─── Ask (QA Inference) ───────────────────────────────────────────────────────
@app.route("/api/ask", methods=["POST"])
@require_auth
@limiter.limit("30 per minute")
def ask():
model_id = "bert"
context = ""
question = ""
# ── File upload (multipart form) ──
if request.content_type and "multipart/form-data" in request.content_type:
model_id = safe_str(request.form.get("model_id")) or "bert"
question = safe_str(request.form.get("question"))
file = request.files.get("file")
if file:
try:
import magic
buffer = file.read()
mime = magic.from_buffer(buffer, mime=True)
allowed_mimes = ["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]
if mime not in allowed_mimes:
return jsonify({"error": f"Security system rejected {mime}. Only true PDF/DOCX files permitted."}), 400
from utils.pdf_parser import extract_text
context = extract_text(buffer, file.filename)
except ValueError as exc:
return jsonify({"error": str(exc)}), 400
else:
context = safe_str(request.form.get("context"))
else:
# ── JSON body ──
data = request.get_json(silent=True) or {}
model_id = safe_str(data.get("model_id")) or "bert"
context = safe_str(data.get("context"))
question = safe_str(data.get("question"))
if not context:
return jsonify({"error": "Context (text or file) is required."}), 400
if not question:
return jsonify({"error": "Question is required."}), 400
# ── Run inference ──
result = qa_engine.run_inference(model_id, context, question)
# ── Persist to DB ──
chat_doc = {
"user_id": g.current_user["id"],
"model_id": model_id,
"model_name": result.get("model", model_id),
"context": context[:2000], # truncate for storage
"question": question,
"answer": result.get("answer", ""),
"score": result.get("score", 0.0),
"error": result.get("error", False),
"created_at": _now_iso(),
}
insert_result = chats_col().insert_one(chat_doc)
result["chat_id"] = str(insert_result.inserted_id)
return jsonify(result)
# ─── History ──────────────────────────────────────────────────────────────────
@app.route("/api/history", methods=["GET"])
@require_auth
def get_history():
col = chats_col()
docs = list(col.find(
{"user_id": g.current_user["id"], "user_deleted": {"$ne": True}},
sort=[("created_at", -1)],
limit=50,
))
return jsonify({"history": [_serialize(d) for d in docs]})
@app.route("/api/history/<chat_id>", methods=["DELETE"])
@require_auth
def delete_chat(chat_id):
from bson import ObjectId as ObjId
col = chats_col()
try:
res = col.update_one(
{"_id": ObjId(chat_id), "user_id": g.current_user["id"]},
{"$set": {"user_deleted": True}}
)
except Exception:
return jsonify({"error": "Invalid chat ID."}), 400
if res.matched_count == 0:
return jsonify({"error": "Chat not found or not owned by you."}), 404
return jsonify({"message": "Chat deleted."})
@app.route("/api/history", methods=["DELETE"])
@require_auth
def clear_history():
col = chats_col()
res = col.update_many(
{"user_id": g.current_user["id"]},
{"$set": {"user_deleted": True}}
)
return jsonify({"message": f"Cleared {res.modified_count} chat(s)."})
# ─── Admin Routes ─────────────────────────────────────────────────────────────
@app.route("/api/admin/users", methods=["GET"])
@require_admin
def admin_list_users():
col = users_col()
users = list(col.find({}, sort=[("created_at", -1)]))
result = []
for u in users:
u = _serialize(u)
u.pop("password", None)
result.append(u)
return jsonify({"users": result, "total": len(result)})
@app.route("/api/admin/users/<user_id>", methods=["PUT"])
@require_admin
def admin_update_user(user_id):
from bson import ObjectId as ObjId
data = request.get_json(silent=True) or {}
allowed_fields = {"name", "role", "is_active"}
update = {k: v for k, v in data.items() if k in allowed_fields}
if not update:
return jsonify({"error": "No valid fields to update."}), 400
col = users_col()
try:
res = col.update_one({"_id": ObjId(user_id)}, {"$set": update})
except Exception:
return jsonify({"error": "Invalid user ID."}), 400
if res.matched_count == 0:
return jsonify({"error": "User not found."}), 404
return jsonify({"message": "User updated successfully."})
@app.route("/api/admin/users/<user_id>", methods=["DELETE"])
@require_admin
def admin_delete_user(user_id):
from bson import ObjectId as ObjId
# Prevent self-deletion
if user_id == g.current_user["id"]:
return jsonify({"error": "You cannot delete your own account."}), 400
col = users_col()
try:
res = col.delete_one({"_id": ObjId(user_id)})
except Exception:
return jsonify({"error": "Invalid user ID."}), 400
if res.deleted_count == 0:
return jsonify({"error": "User not found."}), 404
# Also logically remove their chat history
chats_col().update_many(
{"user_id": user_id},
{"$set": {"user_deleted": True, "admin_deleted_user": True}}
)
return jsonify({"message": "User and their history deleted."})
@app.route("/api/admin/stats", methods=["GET"])
@require_admin
def admin_stats():
users = users_col()
chats = chats_col()
total_users = users.count_documents({})
total_queries = chats.count_documents({})
# Model usage breakdown
pipeline = [
{"$group": {"_id": "$model_id", "count": {"$sum": 1}}}
]
try:
model_usage = {doc["_id"]: doc["count"] for doc in chats.aggregate(pipeline)}
except Exception:
model_usage = {}
# Timeseries data for graphs
ts_pipeline = [
{"$project": {"date": {"$substr": ["$created_at", 0, 10]}}},
{"$group": {"_id": "$date", "queries": {"$sum": 1}}},
{"$sort": {"_id": 1}},
{"$limit": 30}
]
try:
timeseries = [{"date": doc["_id"], "queries": doc["queries"]} for doc in chats.aggregate(ts_pipeline)]
except Exception:
timeseries = []
return jsonify({
"total_users": total_users,
"total_queries": total_queries,
"model_usage": model_usage,
"timeseries": timeseries,
"db_mode": "mock" if is_using_mock() else "atlas",
})
@app.route("/api/admin/settings", methods=["GET"])
@require_admin
def get_settings():
col = settings_col()
doc = col.find_one({"_id": "system_config"})
if not doc:
doc = {"_id": "system_config", "disable_registrations": False, "maintenance_mode": False}
col.insert_one(doc)
return jsonify({"settings": _serialize(doc)})
@app.route("/api/admin/settings", methods=["PUT"])
@require_admin
def update_settings():
data = request.get_json(silent=True) or {}
allowed = {"disable_registrations", "maintenance_mode"}
update = {k: v for k, v in data.items() if k in allowed}
if not update:
return jsonify({"error": "No valid settings provided."}), 400
col = settings_col()
col.update_one({"_id": "system_config"}, {"$set": update}, upsert=True)
return jsonify({"message": "Settings updated."})
@app.route("/api/admin/models/<model_id>", methods=["PUT"])
@require_admin
def toggle_model_status(model_id):
if model_id not in qa_engine.MODELS:
return jsonify({"error": "Invalid model ID."}), 404
data = request.get_json(silent=True) or {}
target_status = data.get("status")
if target_status not in ["ready", "maintenance"]:
return jsonify({"error": "Invalid status."}), 400
col = settings_col()
col.update_one({"_id": "system_config"}, {"$set": {f"model_status.{model_id}": target_status}}, upsert=True)
return jsonify({"message": f"Model {model_id} status updated to {target_status}."})
# ─── Entry Point ──────────────────────────────────────────────────────────────
logger.info("=" * 60)
logger.info(" SQuAD QA System β€” Backend Starting (Production Mode)")
logger.info("=" * 60)
# Initialise AI models
qa_engine.init_all_models()
# Seed admin user
_seed_admin()
if __name__ == "__main__":
flask_env = os.getenv("FLASK_ENV", "development")
debug = flask_env == "development"
app.run(host="0.0.0.0", port=5000, debug=debug)