Spaces:
Sleeping
Sleeping
| import os | |
| import jwt | |
| import datetime | |
| import bcrypt | |
| from fastapi import FastAPI, HTTPException, Depends, Response, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from dotenv import load_dotenv | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from starlette.responses import JSONResponse | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Initialize rate limiter | |
| limiter = Limiter(key_func=get_remote_address) | |
| app.state.limiter = limiter | |
| # Custom error handler for too many requests | |
| async def ratelimit_handler(request: Request, exc: RateLimitExceeded): | |
| return JSONResponse( | |
| {"detail": "Too many login attempts. Try again later."}, | |
| status_code=429 | |
| ) | |
| # Load environment variables | |
| load_dotenv() | |
| # Get environment variables | |
| SECRET_KEY = os.getenv("SECRET_KEY") | |
| TOKEN_EXPIRATION_MINUTES = int(os.getenv("TOKEN_EXPIRATION_MINUTES", 30)) | |
| REFRESH_TOKEN_EXPIRATION_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRATION_DAYS", 7)) | |
| ALLOWED_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*").split(',') # Ensure it's a list | |
| # Load dummy user API key | |
| hashed_password = os.getenv("DUMMY_USER_KEY") | |
| if hashed_password: | |
| # Ensure it's stored as a hashed password (not plain text) | |
| hashed_password = bcrypt.hashpw(hashed_password.encode(), bcrypt.gensalt()).decode() | |
| # Fake database of API keys (hashed) | |
| API_KEYS_DB = {"user1": hashed_password} if hashed_password else {} | |
| def verify_api_key(api_key: str) -> bool: | |
| """Check if the provided API key is valid.""" | |
| for hashed_key in API_KEYS_DB.values(): | |
| if hashed_key and bcrypt.checkpw(api_key.encode(), hashed_key.encode()): | |
| return True | |
| return False | |
| # Configure CORS for security (allow only trusted frontend) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=ALLOWED_ORIGIN, | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["Authorization", "Content-Type"], | |
| ) | |
| def create_jwt_token(user_id: str, expiration_minutes: int): | |
| """Generate JWT Token with user ID.""" | |
| expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=expiration_minutes) | |
| payload = {"sub": user_id, "exp": expiration} | |
| return jwt.encode(payload, SECRET_KEY, algorithm="HS256") | |
| def verify_jwt_token(token: str): | |
| """Verify and decode JWT token.""" | |
| try: | |
| return jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) | |
| except jwt.ExpiredSignatureError: | |
| raise HTTPException(status_code=401, detail="Token expired") | |
| except jwt.InvalidTokenError: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| def get_sample_data(): | |
| """Public endpoint: No token required""" | |
| return {"message": "This is public sample data."} | |
| # Allow max 5 login attempts per minute | |
| def login(request: Request, response: Response, api_key: str): | |
| """User must provide a valid API key to obtain JWT tokens.""" | |
| if not verify_api_key(api_key): | |
| raise HTTPException(status_code=403, detail="Invalid API key") | |
| access_token = create_jwt_token("user1", TOKEN_EXPIRATION_MINUTES) | |
| refresh_token = create_jwt_token("user1", REFRESH_TOKEN_EXPIRATION_DAYS * 24 * 60) | |
| # Secure HTTP-only cookies (prevent XSS) | |
| response.set_cookie( | |
| key="access_token", | |
| value=access_token, | |
| httponly=True, | |
| secure=True, # Ensure HTTPS is used in production | |
| samesite="Lax" | |
| ) | |
| response.set_cookie( | |
| key="refresh_token", | |
| value=refresh_token, | |
| httponly=True, | |
| secure=True, | |
| samesite="Lax" | |
| ) | |
| return {"message": "Login successful"} | |
| def protected_data(request: Request): | |
| """Protected route: Requires valid access token in HTTP-only cookie.""" | |
| access_token = request.cookies.get("access_token") | |
| if not access_token or not verify_jwt_token(access_token): | |
| raise HTTPException(status_code=401, detail="Invalid or expired token") | |
| return {"message": "You are authenticated!"} | |
| def refresh_token(request: Request, response: Response): | |
| """Refresh the access token using the refresh token.""" | |
| refresh_token = request.cookies.get("refresh_token") | |
| if not refresh_token or not verify_jwt_token(refresh_token): | |
| raise HTTPException(status_code=401, detail="Invalid refresh token") | |
| # Issue new access token | |
| new_access_token = create_jwt_token("user1", TOKEN_EXPIRATION_MINUTES) | |
| response.set_cookie( | |
| key="access_token", | |
| value=new_access_token, | |
| httponly=True, | |
| secure=True, | |
| samesite="Lax" | |
| ) | |
| return {"message": "Token refreshed"} | |
| def logout(response: Response): | |
| """Clear the authentication cookies.""" | |
| response.delete_cookie("access_token") | |
| response.delete_cookie("refresh_token") | |
| return {"message": "Logged out"} |