Spaces:
Sleeping
Sleeping
File size: 5,010 Bytes
920cf43 2702f19 920cf43 3d963b9 920cf43 3d963b9 920cf43 3d963b9 40e25df ef05b24 2702f19 3d963b9 2702f19 ef05b24 920cf43 3d963b9 920cf43 3d963b9 920cf43 3d963b9 920cf43 3d963b9 920cf43 3d963b9 2702f19 3d963b9 920cf43 3d963b9 920cf43 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | import os
import jwt
import datetime
import bcrypt
from fastapi import FastAPI, HTTPException, Depends, Response, Request
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from starlette.responses import JSONResponse
# Initialize FastAPI app
app = FastAPI()
# Initialize rate limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
# Custom error handler for too many requests
@app.exception_handler(RateLimitExceeded)
async def ratelimit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
{"detail": "Too many login attempts. Try again later."},
status_code=429
)
# Load environment variables
load_dotenv()
# Get environment variables
SECRET_KEY = os.getenv("SECRET_KEY")
TOKEN_EXPIRATION_MINUTES = int(os.getenv("TOKEN_EXPIRATION_MINUTES", 30))
REFRESH_TOKEN_EXPIRATION_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRATION_DAYS", 7))
ALLOWED_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*").split(',') # Ensure it's a list
# Load dummy user API key
hashed_password = os.getenv("DUMMY_USER_KEY")
if hashed_password:
# Ensure it's stored as a hashed password (not plain text)
hashed_password = bcrypt.hashpw(hashed_password.encode(), bcrypt.gensalt()).decode()
# Fake database of API keys (hashed)
API_KEYS_DB = {"user1": hashed_password} if hashed_password else {}
def verify_api_key(api_key: str) -> bool:
"""Check if the provided API key is valid."""
for hashed_key in API_KEYS_DB.values():
if hashed_key and bcrypt.checkpw(api_key.encode(), hashed_key.encode()):
return True
return False
# Configure CORS for security (allow only trusted frontend)
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGIN,
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["Authorization", "Content-Type"],
)
def create_jwt_token(user_id: str, expiration_minutes: int):
"""Generate JWT Token with user ID."""
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=expiration_minutes)
payload = {"sub": user_id, "exp": expiration}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")
def verify_jwt_token(token: str):
"""Verify and decode JWT token."""
try:
return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
@app.get("/sample-data")
def get_sample_data():
"""Public endpoint: No token required"""
return {"message": "This is public sample data."}
@app.post("/login")
@limiter.limit("5/minute") # Allow max 5 login attempts per minute
def login(request: Request, response: Response, api_key: str):
"""User must provide a valid API key to obtain JWT tokens."""
if not verify_api_key(api_key):
raise HTTPException(status_code=403, detail="Invalid API key")
access_token = create_jwt_token("user1", TOKEN_EXPIRATION_MINUTES)
refresh_token = create_jwt_token("user1", REFRESH_TOKEN_EXPIRATION_DAYS * 24 * 60)
# Secure HTTP-only cookies (prevent XSS)
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=True, # Ensure HTTPS is used in production
samesite="Lax"
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
httponly=True,
secure=True,
samesite="Lax"
)
return {"message": "Login successful"}
@app.get("/protected-data")
def protected_data(request: Request):
"""Protected route: Requires valid access token in HTTP-only cookie."""
access_token = request.cookies.get("access_token")
if not access_token or not verify_jwt_token(access_token):
raise HTTPException(status_code=401, detail="Invalid or expired token")
return {"message": "You are authenticated!"}
@app.post("/refresh-token")
def refresh_token(request: Request, response: Response):
"""Refresh the access token using the refresh token."""
refresh_token = request.cookies.get("refresh_token")
if not refresh_token or not verify_jwt_token(refresh_token):
raise HTTPException(status_code=401, detail="Invalid refresh token")
# Issue new access token
new_access_token = create_jwt_token("user1", TOKEN_EXPIRATION_MINUTES)
response.set_cookie(
key="access_token",
value=new_access_token,
httponly=True,
secure=True,
samesite="Lax"
)
return {"message": "Token refreshed"}
@app.get("/logout")
def logout(response: Response):
"""Clear the authentication cookies."""
response.delete_cookie("access_token")
response.delete_cookie("refresh_token")
return {"message": "Logged out"} |