| import asyncio |
| import json |
| import logging |
| import os |
| import sys |
| import time |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
|
|
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.middleware.gzip import GZipMiddleware |
| from fastapi.responses import JSONResponse |
| from pydantic import BaseModel |
| from typing import List, Optional |
| from types import SimpleNamespace |
| from motor.motor_asyncio import AsyncIOMotorClient |
| from datetime import datetime |
|
|
| import scanner |
| import dlp_scanner |
| import deep_scanner |
| import auth |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| ) |
| logger = logging.getLogger("s3shastra") |
|
|
| |
| if sys.platform == "win32": |
| asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) |
|
|
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| await _startup_db_client() |
| yield |
| |
| await _shutdown_db_client() |
|
|
|
|
| app = FastAPI( |
| title="S3Shastra API", |
| version="1.3.0", |
| lifespan=lifespan, |
| ) |
|
|
| |
| |
| app.add_middleware(GZipMiddleware, minimum_size=500) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| MAX_REQUEST_BODY_BYTES = 10 * 1024 * 1024 |
|
|
|
|
| @app.middleware("http") |
| async def request_lifecycle_middleware(request: Request, call_next): |
| """Add timing header and reject oversized requests.""" |
| |
| if request.headers.get("upgrade", "").lower() != "websocket": |
| content_length = request.headers.get("content-length") |
| if content_length and int(content_length) > MAX_REQUEST_BODY_BYTES: |
| return JSONResponse( |
| {"error": "Request body too large"}, |
| status_code=413, |
| ) |
| start = time.perf_counter() |
| response = await call_next(request) |
| elapsed_ms = round((time.perf_counter() - start) * 1000, 2) |
| response.headers["X-Process-Time-Ms"] = str(elapsed_ms) |
| |
| response.headers.setdefault("X-Content-Type-Options", "nosniff") |
| response.headers.setdefault("X-Frame-Options", "DENY") |
| response.headers.setdefault("Referrer-Policy", "strict-origin-when-cross-origin") |
| return response |
|
|
|
|
| |
| @app.exception_handler(Exception) |
| async def global_exception_handler(request: Request, exc: Exception): |
| logger.error("Unhandled error on %s %s: %s", request.method, request.url.path, exc, exc_info=True) |
| return JSONResponse( |
| {"error": "Internal server error"}, |
| status_code=500, |
| ) |
|
|
| |
| MONGODB_URL = os.getenv("MONGODB_URL", "mongodb://localhost:27017") |
| DATABASE_NAME = "s3shastra" |
| COLLECTION_NAME = "scan_history" |
|
|
| |
| mongo_client: Optional[AsyncIOMotorClient] = None |
| database = None |
| history_collection = None |
| users_collection = None |
| otps_collection = None |
| sessions_collection = None |
|
|
|
|
| async def _startup_db_client(): |
| global mongo_client, database, history_collection, users_collection, otps_collection, sessions_collection |
| try: |
| mongo_client = AsyncIOMotorClient( |
| MONGODB_URL, |
| maxPoolSize=50, |
| minPoolSize=5, |
| maxIdleTimeMS=30000, |
| serverSelectionTimeoutMS=5000, |
| ) |
| database = mongo_client[DATABASE_NAME] |
| history_collection = database[COLLECTION_NAME] |
| users_collection = database["users"] |
| otps_collection = database["otps"] |
| sessions_collection = database["sessions"] |
| |
| await mongo_client.admin.command('ping') |
| logger.info("Connected to MongoDB database: %s", DATABASE_NAME) |
|
|
| |
| await history_collection.create_index("timestamp", background=True) |
| await history_collection.create_index("domain", background=True) |
| await otps_collection.create_index("email", background=True) |
| await otps_collection.create_index("expires_at", expireAfterSeconds=0, background=True) |
| await sessions_collection.create_index("session_id", unique=True, background=True) |
| await sessions_collection.create_index("expires_at", expireAfterSeconds=0, background=True) |
| await users_collection.create_index("email", unique=True, background=True) |
| await users_collection.create_index("user_id", unique=True, background=True) |
| logger.info("MongoDB indexes ensured") |
| except Exception as e: |
| logger.warning("MongoDB connection failed: %s", e) |
| logger.info("Will continue without MongoDB (history won't persist)") |
|
|
|
|
| async def _shutdown_db_client(): |
| global mongo_client |
| if mongo_client: |
| mongo_client.close() |
| logger.info("MongoDB connection closed") |
|
|
| class HistoryData(BaseModel): |
| history: List[dict] |
|
|
| @app.websocket("/ws/scan") |
| async def websocket_endpoint(websocket: WebSocket): |
| await websocket.accept() |
| try: |
| while True: |
| |
| data = await websocket.receive_json() |
| |
| |
| if "domain" not in data or not data["domain"]: |
| await websocket.send_json({"type": "error", "message": "Domain is required."}) |
| continue |
|
|
| domain = data["domain"] |
| |
| |
| |
| req_threads = int(data.get("threads", 500)) |
| req_per_host = int(data.get("per_host", 100)) |
| if sys.platform == "win32": |
| WIN_MAX_THREADS = 80 |
| WIN_MAX_PER_HOST = 30 |
| if req_threads > WIN_MAX_THREADS: |
| print(f"β οΈ Windows detected: Clamping threads from {req_threads} to {WIN_MAX_THREADS} (select() FD limit).") |
| req_threads = WIN_MAX_THREADS |
| if req_per_host > WIN_MAX_PER_HOST: |
| req_per_host = WIN_MAX_PER_HOST |
|
|
| |
| args = SimpleNamespace( |
| threads=req_threads, |
| timeout=float(data.get("timeout", 30.0)), |
| connect_timeout=float(data.get("connect_timeout", 10.0)), |
| read_timeout=float(data.get("read_timeout", 20.0)), |
| subdomain_timeout=float(data.get("subdomain_timeout", 15.0)), |
| per_host=req_per_host, |
| insecure=bool(data.get("insecure", False)), |
| no_takeover=bool(data.get("no_takeover", False)), |
| no_bucket_checks=bool(data.get("no_bucket_checks", False)), |
| include=data.get("include"), |
| exclude=data.get("exclude"), |
| providers=data.get("providers", []), |
| user_agent=scanner.DEFAULT_UA, |
| proxy=None |
| ) |
|
|
| |
| await scanner.entrypoint_scan(domain, args, websocket) |
| |
| |
| await websocket.send_json({"type": "finished", "message": f"Scan finished for {domain}."}) |
|
|
| except WebSocketDisconnect: |
| print(f"Client disconnected: {websocket.client}") |
| except Exception as e: |
| logger.error("WebSocket scan error: %s", e, exc_info=True) |
| try: |
| await websocket.send_json({"type": "error", "message": "An unexpected server error occurred."}) |
| except Exception: |
| pass |
| finally: |
| try: |
| if not websocket.client_state.value == 3: |
| await websocket.close() |
| logger.info("Connection closed for: %s", websocket.client) |
| except Exception: |
| pass |
|
|
| @app.get("/buckets/public") |
| async def get_public_buckets(): |
| """Fetch all unique public buckets from history""" |
| try: |
| if history_collection is not None: |
| |
| |
| |
| |
| pipeline = [ |
| {"$unwind": "$results"}, |
| {"$match": { |
| "$or": [ |
| {"results.status": "Found"}, |
| {"results.status": "Public Listing"}, |
| {"results.status": "Exists (200)"} |
| ] |
| }}, |
| {"$group": { |
| "_id": "$results.reference", |
| "provider": {"$first": "$results.provider"}, |
| "server": {"$first": "$results.server"} |
| }}, |
| {"$sort": {"_id": 1}} |
| ] |
| |
| buckets = await history_collection.aggregate(pipeline).to_list(length=1000) |
| return JSONResponse(content={"buckets": [{"name": b["_id"], "provider": b["provider"]} for b in buckets]}) |
| else: |
| return JSONResponse(content={"buckets": []}) |
| except Exception as e: |
| logger.error("Error fetching public buckets: %s", e) |
| return JSONResponse(content={"buckets": []}) |
|
|
| @app.websocket("/ws/dlp") |
| async def websocket_dlp_endpoint(websocket: WebSocket): |
| await websocket.accept() |
| try: |
| while True: |
| data = await websocket.receive_json() |
| |
| bucket_name = data.get("bucket_name") |
| if not bucket_name: |
| await websocket.send_json({"type": "error", "message": "Bucket name is required."}) |
| continue |
| |
| year_filter = data.get("year_filter") |
| recent_years_filter = data.get("recent_years_filter") |
| timeout = data.get("timeout", 10) |
| |
| |
| |
| auditor = dlp_scanner.S3DLPAuditor( |
| bucket_name=bucket_name, |
| timeout=int(timeout), |
| year_filter=int(year_filter) if year_filter else None, |
| recent_years_filter=int(recent_years_filter) if recent_years_filter else None |
| ) |
| |
| |
| await auditor.audit_bucket(websocket) |
| |
| except WebSocketDisconnect: |
| logger.info("DLP Client disconnected: %s", websocket.client) |
| except Exception as e: |
| logger.error("DLP Error: %s", e, exc_info=True) |
| try: |
| await websocket.send_json({"type": "error", "message": "DLP scan encountered an error."}) |
| except Exception: |
| pass |
|
|
| @app.websocket("/ws/deep_scan") |
| async def websocket_deep_scan_endpoint(websocket: WebSocket): |
| await websocket.accept() |
| try: |
| while True: |
| data = await websocket.receive_json() |
| |
| bucket_name = data.get("bucket_name") |
| if not bucket_name: |
| await websocket.send_json({"type": "error", "message": "Bucket name is required."}) |
| continue |
| |
| timeout = data.get("timeout", 15) |
| |
| |
| auditor = deep_scanner.S3DeepAuditor( |
| bucket_name=bucket_name, |
| timeout=int(timeout) |
| ) |
| |
| |
| await auditor.audit_bucket(websocket) |
| |
| except WebSocketDisconnect: |
| logger.info("Deep Scan Client disconnected: %s", websocket.client) |
| except Exception as e: |
| logger.error("Deep Scan Error: %s", e, exc_info=True) |
| try: |
| await websocket.send_json({"type": "error", "message": "Deep scan encountered an error."}) |
| except Exception: |
| pass |
|
|
| @app.get("/") |
| async def read_root(): |
| return {"message": "S3Shastra Backend is running. Connect via WebSocket."} |
|
|
| |
| _exposure_cache: dict = {} |
| EXPOSURE_CACHE_TTL = 300 |
|
|
| @app.get("/exposure-timeline") |
| async def exposure_timeline(bucket_ref: str = ""): |
| """Query Wayback Machine for historical bucket exposure data.""" |
| if not bucket_ref: |
| return JSONResponse(content={"error": "bucket_ref query parameter required"}, status_code=400) |
|
|
| |
| cached = _exposure_cache.get(bucket_ref) |
| if cached: |
| cache_ts, cache_data = cached |
| if time.time() - cache_ts < EXPOSURE_CACHE_TTL: |
| resp = JSONResponse(content=cache_data) |
| resp.headers["X-Cache"] = "HIT" |
| resp.headers["Cache-Control"] = "public, max-age=300" |
| return resp |
|
|
| try: |
| import aiohttp |
| timeout = aiohttp.ClientTimeout(total=35) |
| async with aiohttp.ClientSession(timeout=timeout) as sess: |
| result = await scanner.get_exposure_timeline(bucket_ref, sess) |
| _exposure_cache[bucket_ref] = (time.time(), result) |
| resp = JSONResponse(content=result) |
| resp.headers["Cache-Control"] = "public, max-age=300" |
| return resp |
| except Exception as e: |
| logger.error("Exposure timeline error for %s: %s", bucket_ref, e) |
| return JSONResponse(content={"error": "Failed to fetch timeline data"}, status_code=500) |
|
|
|
|
| @app.post("/compliance-posture") |
| async def compliance_posture(data: dict): |
| """Compute compliance posture against CIS, NIST, SOC2, ISO27001 from scan results.""" |
| results = data.get("results", []) |
| if not results: |
| return JSONResponse(content={"error": "No results provided"}, status_code=400) |
|
|
| |
| cis_controls = [ |
| {"id": "2.1.4", "title": "S3 Block Public Access", |
| "description": "Ensure S3 Buckets are configured with Block Public Access (bucket settings)", |
| "check": lambda r: not (r.get("permissions", {}).get("read") or r.get("permissions", {}).get("write") or r.get("permissions", {}).get("list")), |
| "severity": "CRITICAL", "remediation": "aws s3api put-public-access-block --bucket {bucket} --public-access-block-configuration BlockPublicAcls=true,IgnorePublicAcls=true,BlockPublicPolicy=true,RestrictPublicBuckets=true"}, |
| {"id": "2.1.1", "title": "Deny HTTP Requests (Enforce TLS)", |
| "description": "Ensure S3 Bucket Policy is set to deny HTTP requests (enforce encryption in transit)", |
| |
| "check": lambda r: r.get("securityChecks", {}).get("encryption") in ("enabled", "denied", "unknown"), |
| "severity": "HIGH", "remediation": "aws s3api put-bucket-encryption --bucket {bucket} --server-side-encryption-configuration '{\"Rules\":[{\"ApplyServerSideEncryptionByDefault\":{\"SSEAlgorithm\":\"AES256\"}}]}'"}, |
| {"id": "2.1.2", "title": "MFA Delete Enabled", |
| "description": "Ensure MFA Delete is enabled on S3 buckets (requires credentials to verify)", |
| "check": lambda r: True, |
| "severity": "MEDIUM", "remediation": "aws s3api put-bucket-versioning --bucket {bucket} --versioning-configuration Status=Enabled,MFADelete=Enabled --mfa 'arn:aws:iam:::<account>:mfa/<device> <code>'"}, |
| {"id": "2.1.3", "title": "S3 Bucket Versioning", |
| "description": "Ensure all data in Amazon S3 has been discovered, classified and secured when required", |
| "check": lambda r: True, |
| "severity": "MEDIUM", "remediation": "aws s3api put-bucket-versioning --bucket {bucket} --versioning-configuration Status=Enabled"}, |
| {"id": "S3.1", "title": "S3 ACL Not Public", |
| "description": "Ensure S3 bucket ACL does not grant access to AllUsers or AuthenticatedUsers", |
| "check": lambda r: not r.get("permissions", {}).get("aclReadable"), |
| "severity": "HIGH", "remediation": "aws s3api put-bucket-acl --bucket {bucket} --acl private"}, |
| {"id": "S3.2", "title": "No Public Write Access", |
| "description": "Ensure no S3 bucket allows public write (PUT/DELETE) access", |
| "check": lambda r: not r.get("permissions", {}).get("write"), |
| "severity": "CRITICAL", "remediation": "aws s3api put-bucket-policy --bucket {bucket} --policy '{\"Statement\":[{\"Effect\":\"Deny\",\"Principal\":\"*\",\"Action\":[\"s3:PutObject\",\"s3:DeleteObject\"],\"Resource\":\"arn:aws:s3:::{bucket}/*\"}]}'"}, |
| {"id": "S3.3", "title": "No Public Listing", |
| "description": "Ensure S3 bucket does not allow public listing of objects", |
| "check": lambda r: not r.get("permissions", {}).get("list"), |
| "severity": "HIGH", "remediation": "aws s3api put-bucket-policy --bucket {bucket} --policy '{\"Statement\":[{\"Effect\":\"Deny\",\"Principal\":\"*\",\"Action\":\"s3:ListBucket\",\"Resource\":\"arn:aws:s3:::{bucket}\"}]}'"}, |
| {"id": "S3.4", "title": "No CORS Wildcard", |
| "description": "Ensure S3 bucket CORS does not allow wildcard (*) origins", |
| "check": lambda r: r.get("securityChecks", {}).get("corsOpen") not in ("wildcard", "reflects_origin"), |
| "severity": "MEDIUM", "remediation": "aws s3api delete-bucket-cors --bucket {bucket} # Then reconfigure with specific origins only"}, |
| {"id": "S3.5", "title": "No Presigned URL Bypass", |
| "description": "Ensure bucket does not serve content for invalid/expired presigned URLs", |
| "check": lambda r: not r.get("securityChecks", {}).get("presignedBypass"), |
| "severity": "CRITICAL", "remediation": "aws s3api put-bucket-policy --bucket {bucket} --policy '{\"Statement\":[{\"Effect\":\"Deny\",\"Principal\":\"*\",\"Action\":\"s3:GetObject\",\"Resource\":\"arn:aws:s3:::{bucket}/*\",\"Condition\":{\"StringNotEquals\":{\"s3:authType\":\"REST-QUERY-STRING\"}}}]}'"}, |
| ] |
|
|
| |
| nist_controls = [ |
| {"id": "AC-3", "title": "Access Enforcement", |
| "description": "Enforce approved authorizations for logical access to information and resources", |
| "check": lambda r: not (r.get("permissions", {}).get("read") or r.get("permissions", {}).get("write") or r.get("permissions", {}).get("list")), |
| "severity": "HIGH", "remediation": "Restrict all public access; apply IAM policies with least-privilege."}, |
| {"id": "AC-6", "title": "Least Privilege", |
| "description": "Employ principle of least privilege for access to resources", |
| "check": lambda r: not r.get("permissions", {}).get("write"), |
| "severity": "HIGH", "remediation": "Remove public write access; grant only required PUT permissions to specific IAM roles."}, |
| {"id": "SC-28", "title": "Protection of Information at Rest", |
| "description": "Protect confidentiality of information at rest using encryption", |
| "check": lambda r: r.get("securityChecks", {}).get("encryption") in ("enabled", "denied", "unknown"), |
| "severity": "HIGH", "remediation": "Enable SSE-S3 (AES-256) or SSE-KMS encryption on all buckets."}, |
| {"id": "AU-2", "title": "Audit Events", |
| "description": "Identify events that require auditing (access logging)", |
| "check": lambda r: True, |
| "severity": "MEDIUM", "remediation": "Enable S3 server access logging: aws s3api put-bucket-logging --bucket {bucket} --bucket-logging-status '{\"LoggingEnabled\":{\"TargetBucket\":\"my-log-bucket\",\"TargetPrefix\":\"logs/\"}}'"}, |
| {"id": "SI-4", "title": "System Monitoring", |
| "description": "Monitor information system to detect unauthorized access", |
| "check": lambda r: True, |
| "severity": "MEDIUM", "remediation": "Enable AWS CloudTrail data events for S3."}, |
| ] |
|
|
| |
| soc2_controls = [ |
| {"id": "CC6.1", "title": "Logical Access Security", |
| "description": "Restrict logical access to information assets", |
| "check": lambda r: not (r.get("permissions", {}).get("read") or r.get("permissions", {}).get("write") or r.get("permissions", {}).get("list")), |
| "severity": "CRITICAL", "remediation": "Implement S3 Block Public Access and IAM policies."}, |
| {"id": "CC6.3", "title": "Access to Data at Rest", |
| "description": "Apply authorized access controls and encryption to data stored on systems", |
| "check": lambda r: r.get("securityChecks", {}).get("encryption") in ("enabled", "denied", "unknown"), |
| "severity": "HIGH", "remediation": "Enable bucket encryption with SSE-S3 or SSE-KMS."}, |
| {"id": "CC6.6", "title": "Protection Against External Threats", |
| "description": "Protect information assets against threats from sources outside the entity", |
| "check": lambda r: not r.get("permissions", {}).get("write") and not r.get("securityChecks", {}).get("presignedBypass"), |
| "severity": "CRITICAL", "remediation": "Block public write access and presigned URL bypass vulnerabilities."}, |
| {"id": "CC7.2", "title": "Monitoring of Systems", |
| "description": "Monitor system components for anomalies and evaluate for indicators of compromise", |
| "check": lambda r: True, |
| "severity": "MEDIUM", "remediation": "Enable CloudTrail and S3 access logging."}, |
| ] |
|
|
| |
| iso_controls = [ |
| {"id": "A.8.3", "title": "Information Access Restriction", |
| "description": "Prevent unauthorized access to information stored in cloud resources", |
| "check": lambda r: not (r.get("permissions", {}).get("read") or r.get("permissions", {}).get("write") or r.get("permissions", {}).get("list")), |
| "severity": "CRITICAL", "remediation": "Apply S3 Block Public Access settings and IAM least-privilege policies."}, |
| {"id": "A.8.24", "title": "Use of Cryptography", |
| "description": "Ensure proper use of cryptography to protect confidentiality and integrity of information", |
| "check": lambda r: r.get("securityChecks", {}).get("encryption") in ("enabled", "denied", "unknown"), |
| "severity": "HIGH", "remediation": "Enable SSE-S3 (AES-256) or SSE-KMS encryption on all S3 buckets."}, |
| {"id": "A.8.15", "title": "Logging", |
| "description": "Produce, store, protect and analyse logs that record activities, exceptions and events", |
| "check": lambda r: True, |
| "severity": "MEDIUM", "remediation": "Enable S3 server access logging and CloudTrail data events."}, |
| {"id": "A.8.20", "title": "Networks Security", |
| "description": "Secure networks and network services to protect information in systems and applications", |
| "check": lambda r: r.get("securityChecks", {}).get("corsOpen") not in ("wildcard", "reflects_origin"), |
| "severity": "MEDIUM", "remediation": "Restrict CORS origin configuration; remove wildcard (*) origins."}, |
| ] |
|
|
| def evaluate_framework(controls, results_list): |
| findings = [] |
| total = len(controls) * max(len(results_list), 1) |
| passed = 0 |
| for ctrl in controls: |
| ctrl_failing = [] |
| ctrl_passed = 0 |
| for r in results_list: |
| try: |
| if ctrl["check"](r): |
| ctrl_passed += 1 |
| else: |
| bucket_name = r.get("reference", "unknown") |
| ctrl_failing.append(bucket_name) |
| except Exception: |
| |
| bucket_name = r.get("reference", "unknown") |
| ctrl_failing.append(bucket_name) |
| passed += ctrl_passed |
| findings.append({ |
| "id": ctrl["id"], |
| "title": ctrl["title"], |
| "description": ctrl["description"], |
| "severity": ctrl["severity"], |
| "status": "PASS" if not ctrl_failing else "FAIL", |
| "failingBuckets": ctrl_failing, |
| "remediation": ctrl["remediation"], |
| }) |
| score = round((passed / total) * 100) if total > 0 else 100 |
| return {"score": score, "controls": findings} |
|
|
| return JSONResponse(content={ |
| "cis": evaluate_framework(cis_controls, results), |
| "nist": evaluate_framework(nist_controls, results), |
| "soc2": evaluate_framework(soc2_controls, results), |
| "iso27001": evaluate_framework(iso_controls, results), |
| }) |
|
|
| @app.get("/history") |
| async def get_history(): |
| """Load scan history from MongoDB""" |
| try: |
| if history_collection is not None: |
| |
| cursor = history_collection.find({}).sort("timestamp", -1) |
| history_items = await cursor.to_list(length=1000) |
| |
| |
| formatted_history = [] |
| for item in history_items: |
| formatted_history.append({ |
| "id": item.get("id", item.get("_id")), |
| "domain": item.get("domain", ""), |
| "bucketsFound": item.get("bucketsFound", 0), |
| "date": item.get("date", ""), |
| "results": item.get("results", []) |
| }) |
| |
| return JSONResponse(content={"history": formatted_history}) |
| else: |
| return JSONResponse(content={"history": []}) |
| except Exception as e: |
| logger.error("Error loading history from MongoDB: %s", e) |
| return JSONResponse(content={"history": []}) |
|
|
| @app.post("/history") |
| async def save_history(data: HistoryData): |
| """Save scan history to MongoDB""" |
| try: |
| if history_collection is not None: |
| |
| await history_collection.delete_many({}) |
| |
| |
| history_items = [] |
| for item in data.history: |
| history_item = { |
| "id": item.get("id"), |
| "domain": item.get("domain"), |
| "bucketsFound": item.get("bucketsFound"), |
| "date": item.get("date"), |
| "results": item.get("results", []), |
| "timestamp": datetime.utcnow() |
| } |
| history_items.append(history_item) |
| |
| if history_items: |
| await history_collection.insert_many(history_items) |
| |
| return JSONResponse(content={"message": "History saved successfully"}) |
| else: |
| return JSONResponse(content={"error": "Database not connected"}, status_code=503) |
| except Exception as e: |
| logger.error("Error saving history: %s", e) |
| return JSONResponse(content={"error": "Failed to save history"}, status_code=500) |
|
|
|
|
| |
| |
| |
|
|
| from fastapi import Depends, HTTPException, Header |
|
|
| @app.post("/auth/request-otp") |
| async def request_otp(data: auth.OTPRequest): |
| """Generate a 6-digit OTP, store it in DB with 30s expiry, return it""" |
| if database is None: |
| raise HTTPException(status_code=503, detail="Database not connected") |
| |
| email = data.email.strip().lower() |
| if not email: |
| raise HTTPException(status_code=400, detail="Email is required") |
| |
| if not await auth.check_otp_rate_limit(database, email): |
| raise HTTPException(status_code=429, detail="Too many OTP requests. Please wait a few minutes.") |
| |
| otp_code = auth.generate_otp_code() |
| |
| await auth.store_otp(database, email, otp_code) |
| |
| email_sent = await auth.send_otp_email(email, otp_code) |
| |
| return { |
| "message": "OTP sent successfully" if email_sent else "OTP generated (Check console/demo mode)", |
| "demo_otp": otp_code if not email_sent else None, |
| "expires_in": auth.OTP_EXPIRY_SECONDS |
| } |
|
|
| @app.post("/auth/verify-otp") |
| async def verify_otp(data: auth.OTPVerify): |
| """Verify OTP (must be within 30s, single-use) and generate UUID session""" |
| if database is None: |
| raise HTTPException(status_code=503, detail="Database not connected") |
| |
| email = data.email.strip().lower() |
| otp = data.otp.strip() |
| |
| if not await auth.check_verification_rate_limit(database, email): |
| raise HTTPException(status_code=429, detail="Too many failed attempts. Please wait a few minutes.") |
| |
| is_valid = await auth.verify_stored_otp(database, email, otp) |
| if not is_valid: |
| await auth.record_failed_verification(database, email) |
| raise HTTPException(status_code=401, detail="Invalid, used, or expired OTP") |
| |
| |
| user = await auth.get_user_by_email(database, email) |
| if not user: |
| user = await auth.create_user(database, email) |
| |
| |
| session_id = await auth.create_user_session(database, user["user_id"]) |
| |
| return { |
| "message": "Authentication successful", |
| "session_id": session_id, |
| "user": { |
| "email": user["email"], |
| "name": user.get("name") |
| } |
| } |
|
|
| async def get_current_user(authorization: str = Header(None)): |
| """Dependency to validate session UUID and return current user""" |
| if not authorization or not authorization.startswith("Bearer "): |
| raise HTTPException(status_code=401, detail="Invalid or missing session token") |
| |
| session_id = authorization.replace("Bearer ", "") |
| user = await auth.get_session_user(database, session_id) |
| |
| if not user: |
| raise HTTPException(status_code=401, detail="Session expired or invalid") |
| |
| return user |
|
|
| @app.get("/auth/me") |
| async def get_me(current_user: dict = Depends(get_current_user)): |
| """Validates session and returns user object""" |
| return { |
| "user_id": current_user["user_id"], |
| "email": current_user["email"], |
| "name": current_user.get("name") |
| } |
|
|
|
|
|
|
| @app.delete("/history") |
| async def clear_history(): |
| """Clear scan history from MongoDB""" |
| try: |
| if history_collection is not None: |
| await history_collection.delete_many({}) |
| return JSONResponse(content={"status": "success", "message": "History cleared from MongoDB"}) |
| else: |
| return JSONResponse(content={"status": "warning", "message": "MongoDB not available"}) |
| except Exception as e: |
| logger.error("Error clearing history from MongoDB: %s", e) |
| return JSONResponse(content={"status": "error", "message": "Failed to clear history"}, status_code=500) |
|
|