alho94's picture
Update app.py
3d963b9 verified
import os
import jwt
import datetime
import bcrypt
from fastapi import FastAPI, HTTPException, Depends, Response, Request
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from starlette.responses import JSONResponse
# Initialize FastAPI app
app = FastAPI()
# Initialize rate limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
# Custom error handler for too many requests
@app.exception_handler(RateLimitExceeded)
async def ratelimit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
{"detail": "Too many login attempts. Try again later."},
status_code=429
)
# Load environment variables
load_dotenv()
# Get environment variables
SECRET_KEY = os.getenv("SECRET_KEY")
TOKEN_EXPIRATION_MINUTES = int(os.getenv("TOKEN_EXPIRATION_MINUTES", 30))
REFRESH_TOKEN_EXPIRATION_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRATION_DAYS", 7))
ALLOWED_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*").split(',') # Ensure it's a list
# Load dummy user API key
hashed_password = os.getenv("DUMMY_USER_KEY")
if hashed_password:
# Ensure it's stored as a hashed password (not plain text)
hashed_password = bcrypt.hashpw(hashed_password.encode(), bcrypt.gensalt()).decode()
# Fake database of API keys (hashed)
API_KEYS_DB = {"user1": hashed_password} if hashed_password else {}
def verify_api_key(api_key: str) -> bool:
"""Check if the provided API key is valid."""
for hashed_key in API_KEYS_DB.values():
if hashed_key and bcrypt.checkpw(api_key.encode(), hashed_key.encode()):
return True
return False
# Configure CORS for security (allow only trusted frontend)
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGIN,
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["Authorization", "Content-Type"],
)
def create_jwt_token(user_id: str, expiration_minutes: int):
"""Generate JWT Token with user ID."""
expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=expiration_minutes)
payload = {"sub": user_id, "exp": expiration}
return jwt.encode(payload, SECRET_KEY, algorithm="HS256")
def verify_jwt_token(token: str):
"""Verify and decode JWT token."""
try:
return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
@app.get("/sample-data")
def get_sample_data():
"""Public endpoint: No token required"""
return {"message": "This is public sample data."}
@app.post("/login")
@limiter.limit("5/minute") # Allow max 5 login attempts per minute
def login(request: Request, response: Response, api_key: str):
"""User must provide a valid API key to obtain JWT tokens."""
if not verify_api_key(api_key):
raise HTTPException(status_code=403, detail="Invalid API key")
access_token = create_jwt_token("user1", TOKEN_EXPIRATION_MINUTES)
refresh_token = create_jwt_token("user1", REFRESH_TOKEN_EXPIRATION_DAYS * 24 * 60)
# Secure HTTP-only cookies (prevent XSS)
response.set_cookie(
key="access_token",
value=access_token,
httponly=True,
secure=True, # Ensure HTTPS is used in production
samesite="Lax"
)
response.set_cookie(
key="refresh_token",
value=refresh_token,
httponly=True,
secure=True,
samesite="Lax"
)
return {"message": "Login successful"}
@app.get("/protected-data")
def protected_data(request: Request):
"""Protected route: Requires valid access token in HTTP-only cookie."""
access_token = request.cookies.get("access_token")
if not access_token or not verify_jwt_token(access_token):
raise HTTPException(status_code=401, detail="Invalid or expired token")
return {"message": "You are authenticated!"}
@app.post("/refresh-token")
def refresh_token(request: Request, response: Response):
"""Refresh the access token using the refresh token."""
refresh_token = request.cookies.get("refresh_token")
if not refresh_token or not verify_jwt_token(refresh_token):
raise HTTPException(status_code=401, detail="Invalid refresh token")
# Issue new access token
new_access_token = create_jwt_token("user1", TOKEN_EXPIRATION_MINUTES)
response.set_cookie(
key="access_token",
value=new_access_token,
httponly=True,
secure=True,
samesite="Lax"
)
return {"message": "Token refreshed"}
@app.get("/logout")
def logout(response: Response):
"""Clear the authentication cookies."""
response.delete_cookie("access_token")
response.delete_cookie("refresh_token")
return {"message": "Logged out"}