s3shastra / main.py
Atharv834
Deploy S3Shastra backend - FastAPI + scanners + ML models
6a4dcb6
import asyncio
import json
import logging
import os
import sys
import time
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List, Optional
from types import SimpleNamespace
from motor.motor_asyncio import AsyncIOMotorClient
from datetime import datetime
import scanner
import dlp_scanner
import deep_scanner
import auth
# ── Structured Logging ──────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("s3shastra")
# Fix for Windows Event Loop Policy (runtime error: too many file descriptors)
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
# ── Lifespan (startup/shutdown) ─────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
await _startup_db_client()
yield
# Shutdown
await _shutdown_db_client()
app = FastAPI(
title="S3Shastra API",
version="1.3.0",
lifespan=lifespan,
)
# ── Middleware Stack (order matters: outermost runs first) ──────
# 1. GZip compression for all responses > 500 bytes
app.add_middleware(GZipMiddleware, minimum_size=500)
# 2. CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 3. Request timing + size-limit middleware
MAX_REQUEST_BODY_BYTES = 10 * 1024 * 1024 # 10 MB
@app.middleware("http")
async def request_lifecycle_middleware(request: Request, call_next):
"""Add timing header and reject oversized requests."""
# Size guard (skip WebSocket upgrades)
if request.headers.get("upgrade", "").lower() != "websocket":
content_length = request.headers.get("content-length")
if content_length and int(content_length) > MAX_REQUEST_BODY_BYTES:
return JSONResponse(
{"error": "Request body too large"},
status_code=413,
)
start = time.perf_counter()
response = await call_next(request)
elapsed_ms = round((time.perf_counter() - start) * 1000, 2)
response.headers["X-Process-Time-Ms"] = str(elapsed_ms)
# Default security headers
response.headers.setdefault("X-Content-Type-Options", "nosniff")
response.headers.setdefault("X-Frame-Options", "DENY")
response.headers.setdefault("Referrer-Policy", "strict-origin-when-cross-origin")
return response
# ── Global unhandled-rejection guard ────────────────────────────
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error("Unhandled error on %s %s: %s", request.method, request.url.path, exc, exc_info=True)
return JSONResponse(
{"error": "Internal server error"},
status_code=500,
)
# ── MongoDB Configuration ───────────────────────────────────────
MONGODB_URL = os.getenv("MONGODB_URL", "mongodb://localhost:27017")
DATABASE_NAME = "s3shastra"
COLLECTION_NAME = "scan_history"
# MongoDB client (will be initialized on startup)
mongo_client: Optional[AsyncIOMotorClient] = None
database = None
history_collection = None
users_collection = None
otps_collection = None
sessions_collection = None
async def _startup_db_client():
global mongo_client, database, history_collection, users_collection, otps_collection, sessions_collection
try:
mongo_client = AsyncIOMotorClient(
MONGODB_URL,
maxPoolSize=50,
minPoolSize=5,
maxIdleTimeMS=30000,
serverSelectionTimeoutMS=5000,
)
database = mongo_client[DATABASE_NAME]
history_collection = database[COLLECTION_NAME]
users_collection = database["users"]
otps_collection = database["otps"]
sessions_collection = database["sessions"]
# Test connection
await mongo_client.admin.command('ping')
logger.info("Connected to MongoDB database: %s", DATABASE_NAME)
# ── Create indexes for faster queries ───────────────────
await history_collection.create_index("timestamp", background=True)
await history_collection.create_index("domain", background=True)
await otps_collection.create_index("email", background=True)
await otps_collection.create_index("expires_at", expireAfterSeconds=0, background=True) # TTL index
await sessions_collection.create_index("session_id", unique=True, background=True)
await sessions_collection.create_index("expires_at", expireAfterSeconds=0, background=True) # TTL index
await users_collection.create_index("email", unique=True, background=True)
await users_collection.create_index("user_id", unique=True, background=True)
logger.info("MongoDB indexes ensured")
except Exception as e:
logger.warning("MongoDB connection failed: %s", e)
logger.info("Will continue without MongoDB (history won't persist)")
async def _shutdown_db_client():
global mongo_client
if mongo_client:
mongo_client.close()
logger.info("MongoDB connection closed")
class HistoryData(BaseModel):
history: List[dict]
@app.websocket("/ws/scan")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Wait for scan parameters from the client
data = await websocket.receive_json()
# Basic validation
if "domain" not in data or not data["domain"]:
await websocket.send_json({"type": "error", "message": "Domain is required."})
continue
domain = data["domain"]
# Enforce safe thread limits on Windows to avoid "too many file descriptors" error
# Windows select() has a hard 512 FD limit; connections + overhead must stay under it
req_threads = int(data.get("threads", 500))
req_per_host = int(data.get("per_host", 100))
if sys.platform == "win32":
WIN_MAX_THREADS = 80
WIN_MAX_PER_HOST = 30
if req_threads > WIN_MAX_THREADS:
print(f"⚠️ Windows detected: Clamping threads from {req_threads} to {WIN_MAX_THREADS} (select() FD limit).")
req_threads = WIN_MAX_THREADS
if req_per_host > WIN_MAX_PER_HOST:
req_per_host = WIN_MAX_PER_HOST
# Create an args object similar to the original script's argparse
args = SimpleNamespace(
threads=req_threads,
timeout=float(data.get("timeout", 30.0)),
connect_timeout=float(data.get("connect_timeout", 10.0)),
read_timeout=float(data.get("read_timeout", 20.0)),
subdomain_timeout=float(data.get("subdomain_timeout", 15.0)),
per_host=req_per_host,
insecure=bool(data.get("insecure", False)),
no_takeover=bool(data.get("no_takeover", False)),
no_bucket_checks=bool(data.get("no_bucket_checks", False)),
include=data.get("include"),
exclude=data.get("exclude"),
providers=data.get("providers", []),
user_agent=scanner.DEFAULT_UA,
proxy=None # Not implemented in GUI
)
# Run the scan
await scanner.entrypoint_scan(domain, args, websocket)
# Signal completion for the current domain
await websocket.send_json({"type": "finished", "message": f"Scan finished for {domain}."})
except WebSocketDisconnect:
print(f"Client disconnected: {websocket.client}")
except Exception as e:
logger.error("WebSocket scan error: %s", e, exc_info=True)
try:
await websocket.send_json({"type": "error", "message": "An unexpected server error occurred."})
except Exception:
pass
finally:
try:
if not websocket.client_state.value == 3: # CLIENT_DISCONNECTED
await websocket.close()
logger.info("Connection closed for: %s", websocket.client)
except Exception:
pass
@app.get("/buckets/public")
async def get_public_buckets():
"""Fetch all unique public buckets from history"""
try:
if history_collection is not None:
# Aggregation pipeline to:
# 1. Unwind results array
# 2. Match only "Found" or "Public Listing" status
# 3. Group by bucket name (reference) to get unique buckets
pipeline = [
{"$unwind": "$results"},
{"$match": {
"$or": [
{"results.status": "Found"},
{"results.status": "Public Listing"},
{"results.status": "Exists (200)"} # Also include open buckets
]
}},
{"$group": {
"_id": "$results.reference",
"provider": {"$first": "$results.provider"},
"server": {"$first": "$results.server"}
}},
{"$sort": {"_id": 1}}
]
buckets = await history_collection.aggregate(pipeline).to_list(length=1000)
return JSONResponse(content={"buckets": [{"name": b["_id"], "provider": b["provider"]} for b in buckets]})
else:
return JSONResponse(content={"buckets": []})
except Exception as e:
logger.error("Error fetching public buckets: %s", e)
return JSONResponse(content={"buckets": []})
@app.websocket("/ws/dlp")
async def websocket_dlp_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
bucket_name = data.get("bucket_name")
if not bucket_name:
await websocket.send_json({"type": "error", "message": "Bucket name is required."})
continue
year_filter = data.get("year_filter")
recent_years_filter = data.get("recent_years_filter")
timeout = data.get("timeout", 10)
# Instantiate and run auditor
# Note: We create a new instance for each request
auditor = dlp_scanner.S3DLPAuditor(
bucket_name=bucket_name,
timeout=int(timeout),
year_filter=int(year_filter) if year_filter else None,
recent_years_filter=int(recent_years_filter) if recent_years_filter else None
)
# This will stream results back to the websocket
await auditor.audit_bucket(websocket)
except WebSocketDisconnect:
logger.info("DLP Client disconnected: %s", websocket.client)
except Exception as e:
logger.error("DLP Error: %s", e, exc_info=True)
try:
await websocket.send_json({"type": "error", "message": "DLP scan encountered an error."})
except Exception:
pass
@app.websocket("/ws/deep_scan")
async def websocket_deep_scan_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_json()
bucket_name = data.get("bucket_name")
if not bucket_name:
await websocket.send_json({"type": "error", "message": "Bucket name is required."})
continue
timeout = data.get("timeout", 15)
# Instantiate and run the deep auditor
auditor = deep_scanner.S3DeepAuditor(
bucket_name=bucket_name,
timeout=int(timeout)
)
# This will stream deep scan results back to the websocket
await auditor.audit_bucket(websocket)
except WebSocketDisconnect:
logger.info("Deep Scan Client disconnected: %s", websocket.client)
except Exception as e:
logger.error("Deep Scan Error: %s", e, exc_info=True)
try:
await websocket.send_json({"type": "error", "message": "Deep scan encountered an error."})
except Exception:
pass
@app.get("/")
async def read_root():
return {"message": "S3Shastra Backend is running. Connect via WebSocket."}
# ── In-memory cache for exposure timeline (non-sensitive, safe to cache) ──
_exposure_cache: dict = {} # key β†’ (timestamp, data)
EXPOSURE_CACHE_TTL = 300 # 5 minutes
@app.get("/exposure-timeline")
async def exposure_timeline(bucket_ref: str = ""):
"""Query Wayback Machine for historical bucket exposure data."""
if not bucket_ref:
return JSONResponse(content={"error": "bucket_ref query parameter required"}, status_code=400)
# Check cache first (Wayback data changes slowly)
cached = _exposure_cache.get(bucket_ref)
if cached:
cache_ts, cache_data = cached
if time.time() - cache_ts < EXPOSURE_CACHE_TTL:
resp = JSONResponse(content=cache_data)
resp.headers["X-Cache"] = "HIT"
resp.headers["Cache-Control"] = "public, max-age=300"
return resp
try:
import aiohttp
timeout = aiohttp.ClientTimeout(total=35)
async with aiohttp.ClientSession(timeout=timeout) as sess:
result = await scanner.get_exposure_timeline(bucket_ref, sess)
_exposure_cache[bucket_ref] = (time.time(), result)
resp = JSONResponse(content=result)
resp.headers["Cache-Control"] = "public, max-age=300"
return resp
except Exception as e:
logger.error("Exposure timeline error for %s: %s", bucket_ref, e)
return JSONResponse(content={"error": "Failed to fetch timeline data"}, status_code=500)
@app.post("/compliance-posture")
async def compliance_posture(data: dict):
"""Compute compliance posture against CIS, NIST, SOC2, ISO27001 from scan results."""
results = data.get("results", [])
if not results:
return JSONResponse(content={"error": "No results provided"}, status_code=400)
# CIS AWS Foundations Benchmark v3.0 β€” S3 controls
cis_controls = [
{"id": "2.1.4", "title": "S3 Block Public Access",
"description": "Ensure S3 Buckets are configured with Block Public Access (bucket settings)",
"check": lambda r: not (r.get("permissions", {}).get("read") or r.get("permissions", {}).get("write") or r.get("permissions", {}).get("list")),
"severity": "CRITICAL", "remediation": "aws s3api put-public-access-block --bucket {bucket} --public-access-block-configuration BlockPublicAcls=true,IgnorePublicAcls=true,BlockPublicPolicy=true,RestrictPublicBuckets=true"},
{"id": "2.1.1", "title": "Deny HTTP Requests (Enforce TLS)",
"description": "Ensure S3 Bucket Policy is set to deny HTTP requests (enforce encryption in transit)",
# Unauthenticated scanner cannot read bucket policies β€” check server-side encryption as proxy
"check": lambda r: r.get("securityChecks", {}).get("encryption") in ("enabled", "denied", "unknown"),
"severity": "HIGH", "remediation": "aws s3api put-bucket-encryption --bucket {bucket} --server-side-encryption-configuration '{\"Rules\":[{\"ApplyServerSideEncryptionByDefault\":{\"SSEAlgorithm\":\"AES256\"}}]}'"},
{"id": "2.1.2", "title": "MFA Delete Enabled",
"description": "Ensure MFA Delete is enabled on S3 buckets (requires credentials to verify)",
"check": lambda r: True, # Can't check without creds β€” informational
"severity": "MEDIUM", "remediation": "aws s3api put-bucket-versioning --bucket {bucket} --versioning-configuration Status=Enabled,MFADelete=Enabled --mfa 'arn:aws:iam:::<account>:mfa/<device> <code>'"},
{"id": "2.1.3", "title": "S3 Bucket Versioning",
"description": "Ensure all data in Amazon S3 has been discovered, classified and secured when required",
"check": lambda r: True, # Can't check without creds β€” informational
"severity": "MEDIUM", "remediation": "aws s3api put-bucket-versioning --bucket {bucket} --versioning-configuration Status=Enabled"},
{"id": "S3.1", "title": "S3 ACL Not Public",
"description": "Ensure S3 bucket ACL does not grant access to AllUsers or AuthenticatedUsers",
"check": lambda r: not r.get("permissions", {}).get("aclReadable"),
"severity": "HIGH", "remediation": "aws s3api put-bucket-acl --bucket {bucket} --acl private"},
{"id": "S3.2", "title": "No Public Write Access",
"description": "Ensure no S3 bucket allows public write (PUT/DELETE) access",
"check": lambda r: not r.get("permissions", {}).get("write"),
"severity": "CRITICAL", "remediation": "aws s3api put-bucket-policy --bucket {bucket} --policy '{\"Statement\":[{\"Effect\":\"Deny\",\"Principal\":\"*\",\"Action\":[\"s3:PutObject\",\"s3:DeleteObject\"],\"Resource\":\"arn:aws:s3:::{bucket}/*\"}]}'"},
{"id": "S3.3", "title": "No Public Listing",
"description": "Ensure S3 bucket does not allow public listing of objects",
"check": lambda r: not r.get("permissions", {}).get("list"),
"severity": "HIGH", "remediation": "aws s3api put-bucket-policy --bucket {bucket} --policy '{\"Statement\":[{\"Effect\":\"Deny\",\"Principal\":\"*\",\"Action\":\"s3:ListBucket\",\"Resource\":\"arn:aws:s3:::{bucket}\"}]}'"},
{"id": "S3.4", "title": "No CORS Wildcard",
"description": "Ensure S3 bucket CORS does not allow wildcard (*) origins",
"check": lambda r: r.get("securityChecks", {}).get("corsOpen") not in ("wildcard", "reflects_origin"),
"severity": "MEDIUM", "remediation": "aws s3api delete-bucket-cors --bucket {bucket} # Then reconfigure with specific origins only"},
{"id": "S3.5", "title": "No Presigned URL Bypass",
"description": "Ensure bucket does not serve content for invalid/expired presigned URLs",
"check": lambda r: not r.get("securityChecks", {}).get("presignedBypass"),
"severity": "CRITICAL", "remediation": "aws s3api put-bucket-policy --bucket {bucket} --policy '{\"Statement\":[{\"Effect\":\"Deny\",\"Principal\":\"*\",\"Action\":\"s3:GetObject\",\"Resource\":\"arn:aws:s3:::{bucket}/*\",\"Condition\":{\"StringNotEquals\":{\"s3:authType\":\"REST-QUERY-STRING\"}}}]}'"},
]
# NIST 800-53 mappings
nist_controls = [
{"id": "AC-3", "title": "Access Enforcement",
"description": "Enforce approved authorizations for logical access to information and resources",
"check": lambda r: not (r.get("permissions", {}).get("read") or r.get("permissions", {}).get("write") or r.get("permissions", {}).get("list")),
"severity": "HIGH", "remediation": "Restrict all public access; apply IAM policies with least-privilege."},
{"id": "AC-6", "title": "Least Privilege",
"description": "Employ principle of least privilege for access to resources",
"check": lambda r: not r.get("permissions", {}).get("write"),
"severity": "HIGH", "remediation": "Remove public write access; grant only required PUT permissions to specific IAM roles."},
{"id": "SC-28", "title": "Protection of Information at Rest",
"description": "Protect confidentiality of information at rest using encryption",
"check": lambda r: r.get("securityChecks", {}).get("encryption") in ("enabled", "denied", "unknown"),
"severity": "HIGH", "remediation": "Enable SSE-S3 (AES-256) or SSE-KMS encryption on all buckets."},
{"id": "AU-2", "title": "Audit Events",
"description": "Identify events that require auditing (access logging)",
"check": lambda r: True, # Can't check without creds β€” informational
"severity": "MEDIUM", "remediation": "Enable S3 server access logging: aws s3api put-bucket-logging --bucket {bucket} --bucket-logging-status '{\"LoggingEnabled\":{\"TargetBucket\":\"my-log-bucket\",\"TargetPrefix\":\"logs/\"}}'"},
{"id": "SI-4", "title": "System Monitoring",
"description": "Monitor information system to detect unauthorized access",
"check": lambda r: True,
"severity": "MEDIUM", "remediation": "Enable AWS CloudTrail data events for S3."},
]
# SOC 2 Type II controls
soc2_controls = [
{"id": "CC6.1", "title": "Logical Access Security",
"description": "Restrict logical access to information assets",
"check": lambda r: not (r.get("permissions", {}).get("read") or r.get("permissions", {}).get("write") or r.get("permissions", {}).get("list")),
"severity": "CRITICAL", "remediation": "Implement S3 Block Public Access and IAM policies."},
{"id": "CC6.3", "title": "Access to Data at Rest",
"description": "Apply authorized access controls and encryption to data stored on systems",
"check": lambda r: r.get("securityChecks", {}).get("encryption") in ("enabled", "denied", "unknown"),
"severity": "HIGH", "remediation": "Enable bucket encryption with SSE-S3 or SSE-KMS."},
{"id": "CC6.6", "title": "Protection Against External Threats",
"description": "Protect information assets against threats from sources outside the entity",
"check": lambda r: not r.get("permissions", {}).get("write") and not r.get("securityChecks", {}).get("presignedBypass"),
"severity": "CRITICAL", "remediation": "Block public write access and presigned URL bypass vulnerabilities."},
{"id": "CC7.2", "title": "Monitoring of Systems",
"description": "Monitor system components for anomalies and evaluate for indicators of compromise",
"check": lambda r: True,
"severity": "MEDIUM", "remediation": "Enable CloudTrail and S3 access logging."},
]
# ISO/IEC 27001:2022 controls (Annex A β€” updated numbering)
iso_controls = [
{"id": "A.8.3", "title": "Information Access Restriction",
"description": "Prevent unauthorized access to information stored in cloud resources",
"check": lambda r: not (r.get("permissions", {}).get("read") or r.get("permissions", {}).get("write") or r.get("permissions", {}).get("list")),
"severity": "CRITICAL", "remediation": "Apply S3 Block Public Access settings and IAM least-privilege policies."},
{"id": "A.8.24", "title": "Use of Cryptography",
"description": "Ensure proper use of cryptography to protect confidentiality and integrity of information",
"check": lambda r: r.get("securityChecks", {}).get("encryption") in ("enabled", "denied", "unknown"),
"severity": "HIGH", "remediation": "Enable SSE-S3 (AES-256) or SSE-KMS encryption on all S3 buckets."},
{"id": "A.8.15", "title": "Logging",
"description": "Produce, store, protect and analyse logs that record activities, exceptions and events",
"check": lambda r: True, # Requires credentials to verify β€” informational
"severity": "MEDIUM", "remediation": "Enable S3 server access logging and CloudTrail data events."},
{"id": "A.8.20", "title": "Networks Security",
"description": "Secure networks and network services to protect information in systems and applications",
"check": lambda r: r.get("securityChecks", {}).get("corsOpen") not in ("wildcard", "reflects_origin"),
"severity": "MEDIUM", "remediation": "Restrict CORS origin configuration; remove wildcard (*) origins."},
]
def evaluate_framework(controls, results_list):
findings = []
total = len(controls) * max(len(results_list), 1)
passed = 0
for ctrl in controls:
ctrl_failing = []
ctrl_passed = 0
for r in results_list:
try:
if ctrl["check"](r):
ctrl_passed += 1
else:
bucket_name = r.get("reference", "unknown")
ctrl_failing.append(bucket_name)
except Exception:
# Count evaluation errors as failures (conservative approach)
bucket_name = r.get("reference", "unknown")
ctrl_failing.append(bucket_name)
passed += ctrl_passed
findings.append({
"id": ctrl["id"],
"title": ctrl["title"],
"description": ctrl["description"],
"severity": ctrl["severity"],
"status": "PASS" if not ctrl_failing else "FAIL",
"failingBuckets": ctrl_failing,
"remediation": ctrl["remediation"],
})
score = round((passed / total) * 100) if total > 0 else 100
return {"score": score, "controls": findings}
return JSONResponse(content={
"cis": evaluate_framework(cis_controls, results),
"nist": evaluate_framework(nist_controls, results),
"soc2": evaluate_framework(soc2_controls, results),
"iso27001": evaluate_framework(iso_controls, results),
})
@app.get("/history")
async def get_history():
"""Load scan history from MongoDB"""
try:
if history_collection is not None:
# Fetch all history items from MongoDB, sorted by date (newest first)
cursor = history_collection.find({}).sort("timestamp", -1)
history_items = await cursor.to_list(length=1000)
# Convert MongoDB documents to the expected format
formatted_history = []
for item in history_items:
formatted_history.append({
"id": item.get("id", item.get("_id")),
"domain": item.get("domain", ""),
"bucketsFound": item.get("bucketsFound", 0),
"date": item.get("date", ""),
"results": item.get("results", [])
})
return JSONResponse(content={"history": formatted_history})
else:
return JSONResponse(content={"history": []})
except Exception as e:
logger.error("Error loading history from MongoDB: %s", e)
return JSONResponse(content={"history": []})
@app.post("/history")
async def save_history(data: HistoryData):
"""Save scan history to MongoDB"""
try:
if history_collection is not None:
# Clear existing history and insert new data
await history_collection.delete_many({})
# Add timestamp to each history item
history_items = []
for item in data.history:
history_item = {
"id": item.get("id"),
"domain": item.get("domain"),
"bucketsFound": item.get("bucketsFound"),
"date": item.get("date"),
"results": item.get("results", []),
"timestamp": datetime.utcnow()
}
history_items.append(history_item)
if history_items:
await history_collection.insert_many(history_items)
return JSONResponse(content={"message": "History saved successfully"})
else:
return JSONResponse(content={"error": "Database not connected"}, status_code=503)
except Exception as e:
logger.error("Error saving history: %s", e)
return JSONResponse(content={"error": "Failed to save history"}, status_code=500)
# ==========================================
# OTP AUTHENTICATION ENDPOINTS
# ==========================================
from fastapi import Depends, HTTPException, Header
@app.post("/auth/request-otp")
async def request_otp(data: auth.OTPRequest):
"""Generate a 6-digit OTP, store it in DB with 30s expiry, return it"""
if database is None:
raise HTTPException(status_code=503, detail="Database not connected")
email = data.email.strip().lower()
if not email:
raise HTTPException(status_code=400, detail="Email is required")
if not await auth.check_otp_rate_limit(database, email):
raise HTTPException(status_code=429, detail="Too many OTP requests. Please wait a few minutes.")
otp_code = auth.generate_otp_code()
await auth.store_otp(database, email, otp_code)
email_sent = await auth.send_otp_email(email, otp_code)
return {
"message": "OTP sent successfully" if email_sent else "OTP generated (Check console/demo mode)",
"demo_otp": otp_code if not email_sent else None, # FOR DEMONSTRATION PURPOSES ONLY
"expires_in": auth.OTP_EXPIRY_SECONDS
}
@app.post("/auth/verify-otp")
async def verify_otp(data: auth.OTPVerify):
"""Verify OTP (must be within 30s, single-use) and generate UUID session"""
if database is None:
raise HTTPException(status_code=503, detail="Database not connected")
email = data.email.strip().lower()
otp = data.otp.strip()
if not await auth.check_verification_rate_limit(database, email):
raise HTTPException(status_code=429, detail="Too many failed attempts. Please wait a few minutes.")
is_valid = await auth.verify_stored_otp(database, email, otp)
if not is_valid:
await auth.record_failed_verification(database, email)
raise HTTPException(status_code=401, detail="Invalid, used, or expired OTP")
# Valid OTP. Get or create user
user = await auth.get_user_by_email(database, email)
if not user:
user = await auth.create_user(database, email)
# Generate secure UUID Session Token
session_id = await auth.create_user_session(database, user["user_id"])
return {
"message": "Authentication successful",
"session_id": session_id,
"user": {
"email": user["email"],
"name": user.get("name")
}
}
async def get_current_user(authorization: str = Header(None)):
"""Dependency to validate session UUID and return current user"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid or missing session token")
session_id = authorization.replace("Bearer ", "")
user = await auth.get_session_user(database, session_id)
if not user:
raise HTTPException(status_code=401, detail="Session expired or invalid")
return user
@app.get("/auth/me")
async def get_me(current_user: dict = Depends(get_current_user)):
"""Validates session and returns user object"""
return {
"user_id": current_user["user_id"],
"email": current_user["email"],
"name": current_user.get("name")
}
@app.delete("/history")
async def clear_history():
"""Clear scan history from MongoDB"""
try:
if history_collection is not None:
await history_collection.delete_many({})
return JSONResponse(content={"status": "success", "message": "History cleared from MongoDB"})
else:
return JSONResponse(content={"status": "warning", "message": "MongoDB not available"})
except Exception as e:
logger.error("Error clearing history from MongoDB: %s", e)
return JSONResponse(content={"status": "error", "message": "Failed to clear history"}, status_code=500)