File size: 5,010 Bytes
920cf43
 
 
2702f19
920cf43
 
 
3d963b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920cf43
 
 
 
 
 
 
 
3d963b9
920cf43
3d963b9
40e25df
ef05b24
 
 
 
2702f19
3d963b9
2702f19
ef05b24
 
 
 
 
 
 
920cf43
 
 
3d963b9
920cf43
 
 
 
 
3d963b9
 
920cf43
3d963b9
920cf43
 
 
 
 
3d963b9
920cf43
 
 
 
 
 
 
 
 
 
 
3d963b9
 
2702f19
 
 
 
3d963b9
 
920cf43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d963b9
920cf43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import jwt
import datetime
import bcrypt
from fastapi import FastAPI, HTTPException, Depends, Response, Request
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from starlette.responses import JSONResponse

# Initialize FastAPI app
app = FastAPI()

# Initialize rate limiter
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter

# Custom error handler for too many requests
@app.exception_handler(RateLimitExceeded)
async def ratelimit_handler(request: Request, exc: RateLimitExceeded):
    return JSONResponse(
        {"detail": "Too many login attempts. Try again later."},
        status_code=429
    )

# Load environment variables
load_dotenv()

# Get environment variables
SECRET_KEY = os.getenv("SECRET_KEY")
TOKEN_EXPIRATION_MINUTES = int(os.getenv("TOKEN_EXPIRATION_MINUTES", 30))
REFRESH_TOKEN_EXPIRATION_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRATION_DAYS", 7))
ALLOWED_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*").split(',')  # Ensure it's a list

# Load dummy user API key
hashed_password = os.getenv("DUMMY_USER_KEY")
if hashed_password:
    # Ensure it's stored as a hashed password (not plain text)
    hashed_password = bcrypt.hashpw(hashed_password.encode(), bcrypt.gensalt()).decode()

# Fake database of API keys (hashed)
API_KEYS_DB = {"user1": hashed_password} if hashed_password else {}

def verify_api_key(api_key: str) -> bool:
    """Check if the provided API key is valid."""
    for hashed_key in API_KEYS_DB.values():
        if hashed_key and bcrypt.checkpw(api_key.encode(), hashed_key.encode()):
            return True
    return False

# Configure CORS for security (allow only trusted frontend)
app.add_middleware(
    CORSMiddleware,
    allow_origins=ALLOWED_ORIGIN,  
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["Authorization", "Content-Type"],
)

def create_jwt_token(user_id: str, expiration_minutes: int):
    """Generate JWT Token with user ID."""
    expiration = datetime.datetime.utcnow() + datetime.timedelta(minutes=expiration_minutes)
    payload = {"sub": user_id, "exp": expiration}
    return jwt.encode(payload, SECRET_KEY, algorithm="HS256")

def verify_jwt_token(token: str):
    """Verify and decode JWT token."""
    try:
        return jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
    except jwt.ExpiredSignatureError:
        raise HTTPException(status_code=401, detail="Token expired")
    except jwt.InvalidTokenError:
        raise HTTPException(status_code=401, detail="Invalid token")

@app.get("/sample-data")
def get_sample_data():
    """Public endpoint: No token required"""
    return {"message": "This is public sample data."}

@app.post("/login")
@limiter.limit("5/minute")  # Allow max 5 login attempts per minute
def login(request: Request, response: Response, api_key: str):
    """User must provide a valid API key to obtain JWT tokens."""
    if not verify_api_key(api_key):
        raise HTTPException(status_code=403, detail="Invalid API key")

    access_token = create_jwt_token("user1", TOKEN_EXPIRATION_MINUTES)
    refresh_token = create_jwt_token("user1", REFRESH_TOKEN_EXPIRATION_DAYS * 24 * 60)

    # Secure HTTP-only cookies (prevent XSS)
    response.set_cookie(
        key="access_token",
        value=access_token,
        httponly=True,
        secure=True,  # Ensure HTTPS is used in production
        samesite="Lax"
    )
    response.set_cookie(
        key="refresh_token",
        value=refresh_token,
        httponly=True,
        secure=True,
        samesite="Lax"
    )
    
    return {"message": "Login successful"}

@app.get("/protected-data")
def protected_data(request: Request):
    """Protected route: Requires valid access token in HTTP-only cookie."""
    access_token = request.cookies.get("access_token")
    if not access_token or not verify_jwt_token(access_token):
        raise HTTPException(status_code=401, detail="Invalid or expired token")

    return {"message": "You are authenticated!"}

@app.post("/refresh-token")
def refresh_token(request: Request, response: Response):
    """Refresh the access token using the refresh token."""
    refresh_token = request.cookies.get("refresh_token")
    if not refresh_token or not verify_jwt_token(refresh_token):
        raise HTTPException(status_code=401, detail="Invalid refresh token")

    # Issue new access token
    new_access_token = create_jwt_token("user1", TOKEN_EXPIRATION_MINUTES)
    response.set_cookie(
        key="access_token",
        value=new_access_token,
        httponly=True,
        secure=True,
        samesite="Lax"
    )

    return {"message": "Token refreshed"}

@app.get("/logout")
def logout(response: Response):
    """Clear the authentication cookies."""
    response.delete_cookie("access_token")
    response.delete_cookie("refresh_token")
    return {"message": "Logged out"}