Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from typing import List | |
| import jwt | |
| import datetime | |
| from dotenv import load_dotenv | |
| import os | |
| load_dotenv() | |
| # FastAPI app setup | |
| app = FastAPI() | |
| # CORS setup to allow cross-origin requests | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins; use specific domains for security | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Secret key for JWT encoding/decoding | |
| SECRET_KEY = os.getenv("SECRET_KEY", "secret") # Use a more secure key in production | |
| ALGORITHM = "HS256" | |
| ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Token expiration time in minutes | |
| # Temporary user dictionary for storing credentials | |
| USERS = { | |
| "user1": "password1", | |
| "user2": "password2", | |
| } | |
| # Pydantic model to define the structure of user data | |
| class User(BaseModel): | |
| username: str | |
| password: str | |
| # JWT token generation function | |
| def create_access_token(data: dict, expires_delta: datetime.timedelta = None): | |
| """Create a new JWT token.""" | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.datetime.utcnow() + datetime.timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| to_encode.update({"exp": expire}) | |
| return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| # JWT token verification function | |
| def verify_access_token(token: str): | |
| """Verify and decode the JWT token.""" | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| return payload | |
| except jwt.ExpiredSignatureError: | |
| raise HTTPException(status_code=401, detail="Token has expired") | |
| except jwt.DecodeError: | |
| raise HTTPException(status_code=401, detail="Token is invalid") | |
| except Exception as e: | |
| # Catch any other exception and raise a generic HTTP 400 error | |
| raise HTTPException(status_code=400, detail=f"An error occurred: {str(e)}") | |
| # Register a new user | |
| async def register_user(user: User): | |
| """Register a new user.""" | |
| if user.username in USERS: | |
| raise HTTPException(status_code=400, detail="Username already exists") | |
| USERS[user.username] = user.password | |
| return {"message": "User registered successfully"} | |
| # Login a user and generate a JWT session token | |
| async def login_user(user: User): | |
| """Authenticate a user and return JWT token.""" | |
| if user.username not in USERS or USERS[user.username] != user.password: | |
| raise HTTPException(status_code=401, detail="Invalid credentials") | |
| # Create JWT token | |
| access_token = create_access_token(data={"sub": user.username}) | |
| return {"token": access_token} | |
| # API to validate the session token | |
| async def validate_token(token: str): | |
| """Validate the JWT token.""" | |
| payload = verify_access_token(token) | |
| return {"message": "Token is valid", "username": payload.get("sub")} | |
| # Search for users by username | |
| async def search_users(query: str) -> List[str]: | |
| """Search for users by username.""" | |
| matching_users = [username for username in USERS if query.lower() in username.lower()] | |
| if not matching_users: | |
| raise HTTPException(status_code=404, detail="No users found matching the query") | |
| return matching_users | |
| # API routes for CRUD operations | |
| async def get_all_users() -> List[str]: | |
| """Get a list of all users.""" | |
| return list(USERS.keys()) | |
| async def get_user(username: str): | |
| """Get details of a specific user.""" | |
| if username not in USERS: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| return {"username": username, "password": USERS[username]} | |
| async def delete_user(username: str): | |
| """Delete a specific user.""" | |
| if username not in USERS: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| del USERS[username] | |
| return {"message": f"User {username} deleted successfully"} | |
| async def update_user(username: str, user: User): | |
| """Update a user's password.""" | |
| if username not in USERS: | |
| raise HTTPException(status_code=404, detail="User not found") | |
| USERS[username] = user.password | |
| return {"message": f"User {username} password updated successfully"} | |
| async def user_exists(username: str): | |
| """Check if a user exists based on the username.""" | |
| if username in USERS: | |
| return {"message": f"User '{username}' exists"} | |
| raise HTTPException(status_code=404, detail="User not found") | |
| # Global error handling middleware | |
| async def custom_error_handler(request: Request, call_next): | |
| try: | |
| response = await call_next(request) | |
| return response | |
| except HTTPException as exc: | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"detail": exc.detail} | |
| ) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": "Internal Server Error"} | |
| ) | |