Spaces:
Sleeping
Sleeping
| """ | |
| auth.py – Lightweight authentication micro-service | |
| ---------------------------------------------------- | |
| • Exposes /auth/verify_user (login check) | |
| • Optional /auth/register_user for sign-up | |
| • Uses MySQL, bcrypt, and Pydantic | |
| • No Streamlit, cookies, or front-end logic | |
| """ | |
| from fastapi import FastAPI, Request, Response, Cookie, HTTPException, Depends | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import OAuth2PasswordRequestForm | |
| from jose import jwt, JWTError | |
| from pydantic import BaseModel | |
| from dotenv import load_dotenv | |
| import os | |
| import time | |
| import requests | |
| import os | |
| import pymysql | |
| import bcrypt | |
| import sys | |
| import importlib.util | |
| from services.api.db.token_utils import create_token, decode_token | |
| from sqlalchemy.orm import Session | |
| from contextlib import asynccontextmanager | |
| # Add parent directory to path to enable imports | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # Define lifespan context manager for startup/shutdown events | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| print("Starting up FastAPI application...") | |
| # Start background tasks from upload endpoints | |
| try: | |
| # Import and start upload cleanup task | |
| from services.api.upload_endpoints import start_cleanup_task | |
| cleanup_task = start_cleanup_task() | |
| print("Upload cleanup task started successfully") | |
| except Exception as e: | |
| print(f"Failed to start upload cleanup task: {e}") | |
| yield | |
| # Shutdown | |
| print("Shutting down FastAPI application...") | |
| # Any cleanup can be done here | |
| # Create single FastAPI instance with lifespan | |
| app = FastAPI(lifespan=lifespan) | |
| # Configure CORS with all needed origins | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:8503", | |
| "http://localhost:3000", | |
| "http://127.0.0.1:3000", | |
| "https://tlong-ds.github.io", | |
| "https://tlong-ds.github.io/thelearninghouse/", | |
| "https://*.hf.space", | |
| "https://*.huggingface.co" | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| # Import API endpoints | |
| try: | |
| from services.api.api_endpoints import router as api_router | |
| except ImportError: | |
| # Try a different import strategy for direct module execution | |
| spec = importlib.util.spec_from_file_location( | |
| "api_endpoints", | |
| os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "api_endpoints.py") | |
| ) | |
| api_endpoints = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(api_endpoints) | |
| api_router = api_endpoints.router | |
| # Load environment variables | |
| load_dotenv() | |
| SECRET_KEY = os.getenv("SECRET_TOKEN", "dev-secret") | |
| ALGORITHM = os.getenv("ALGORITHM", "HS256") | |
| MYSQL_USER = os.getenv("MYSQL_USER") | |
| MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD") | |
| MYSQL_HOST = os.getenv("MYSQL_HOST", "localhost") | |
| MYSQL_DB = os.getenv("MYSQL_DB") | |
| MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306")) | |
| # Import and include chat router | |
| try: | |
| from services.api.chat_endpoints import router as chat_router | |
| except ImportError: | |
| spec = importlib.util.spec_from_file_location( | |
| "chat_endpoints", | |
| os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "chat_endpoints.py") | |
| ) | |
| chat_endpoints = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(chat_endpoints) | |
| chat_router = chat_endpoints.router | |
| # Import and include upload router | |
| try: | |
| from services.api.upload_endpoints import router as upload_router | |
| except ImportError: | |
| spec = importlib.util.spec_from_file_location( | |
| "upload_endpoints", | |
| os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "upload_endpoints.py") | |
| ) | |
| upload_endpoints = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(upload_endpoints) | |
| upload_router = upload_endpoints.router | |
| # Debug print of available routes | |
| print("DEBUG: Available routes in api_router:") | |
| for route in api_router.routes: | |
| print(f"Route: {route.path}, Methods: {route.methods}") | |
| # Include API routers | |
| app.include_router(api_router, prefix="/api") | |
| app.include_router(chat_router, prefix="/api") | |
| app.include_router(upload_router, prefix="/api") | |
| class LoginPayload(BaseModel): | |
| username: str | |
| password: str | |
| role: str | |
| class PasswordChangePayload(BaseModel): | |
| current_password: str | |
| new_password: str | |
| class PasswordChangePayload(BaseModel): | |
| current_password: str | |
| new_password: str | |
| # Token functions moved to token_utils.py | |
| async def login(response: Response, payload: LoginPayload): | |
| try: | |
| print(f"Login attempt: {payload.username}, role: {payload.role}") | |
| # Directly verify the user credentials | |
| table_map = { | |
| "Learner": "Learners", | |
| "Instructor": "Instructors" | |
| } | |
| table = table_map.get(payload.role) | |
| if not table: | |
| print(f"Invalid role: {payload.role}") | |
| raise HTTPException(status_code=400, detail="Invalid role") | |
| conn = connect_db() | |
| user_id = None | |
| full_name = None | |
| try: | |
| with conn.cursor() as cur: | |
| # Get password, user ID, and full name for token creation | |
| if payload.role == "Learner": | |
| query = f"SELECT Password, LearnerID, LearnerName FROM {table} WHERE AccountName=%s LIMIT 1" | |
| else: # Instructor | |
| query = f"SELECT Password, InstructorID, InstructorName FROM {table} WHERE AccountName=%s LIMIT 1" | |
| print(f"Executing query: {query} with username: {payload.username}") | |
| cur.execute(query, (payload.username,)) | |
| row = cur.fetchone() | |
| if not row: | |
| print(f"No user found with username: {payload.username} in table: {table}") | |
| raise HTTPException(status_code=401, detail="Incorrect username or password") | |
| password_valid = check_password(payload.password, row[0]) | |
| user_id = row[1] # Get the LearnerID or InstructorID | |
| full_name = row[2] # Get the LearnerName or InstructorName | |
| print(f"Password check result: {password_valid}, User ID: {user_id}, Full Name: {full_name}") | |
| if not password_valid: | |
| print(f"Authentication failed: Invalid password") | |
| raise HTTPException(status_code=401, detail="Incorrect username or password") | |
| except Exception as db_err: | |
| print(f"Database error: {str(db_err)}") | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(db_err)}") | |
| finally: | |
| conn.close() | |
| # User authenticated successfully - include user_id and full name in token | |
| user_data = { | |
| "username": payload.username, | |
| "role": payload.role, | |
| "user_id": user_id, | |
| "full_name": full_name | |
| } | |
| print(f"Authentication successful for: {payload.username} (ID: {user_id}, Name: {full_name})") | |
| token = create_token(user_data) | |
| # Set cookie with settings that work for both Chrome and Safari | |
| # For localhost development, we need different settings than production | |
| is_localhost = os.getenv("ENVIRONMENT", "development") == "development" | |
| response.set_cookie( | |
| key="auth_token", | |
| value=token, | |
| httponly=False, # Allow JavaScript access for localStorage fallback | |
| samesite="Lax" if is_localhost else "None", # Lax for localhost, None for cross-origin | |
| secure=False if is_localhost else True, # False for HTTP localhost, True for HTTPS production | |
| path="/", | |
| max_age=604800, # 7 days | |
| domain=None # Let browser set the domain automatically | |
| ) | |
| return { | |
| "message": f"Login successful for {user_data['username']}", | |
| "username": user_data["username"], | |
| "role": user_data["role"], | |
| "user_id": user_data["user_id"], | |
| "full_name": user_data["full_name"], | |
| "token": token | |
| } | |
| except Exception as e: | |
| print(f"Login exception: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Login error: {str(e)}") | |
| def logout(response: Response): | |
| response.delete_cookie("auth_token") | |
| return HTMLResponse("<h3>Logged out — cookie cleared!</h3>") | |
| async def whoami(auth_token: str = Cookie(None)): | |
| payload = decode_token(auth_token) | |
| return { | |
| "username": payload["username"], | |
| "role": payload["role"], | |
| "user_id": payload.get("user_id"), | |
| "full_name": payload.get("full_name") | |
| } | |
| def protected_route(auth_token: str = Cookie(None)): | |
| payload = decode_token(auth_token) | |
| return {"msg": "Access granted", "user": payload} | |
| async def change_password(payload: PasswordChangePayload, request: Request, auth_token: str = Cookie(None)): | |
| """Change user password endpoint.""" | |
| # Get token from cookie or Authorization header | |
| token = auth_token | |
| if not token and "Authorization" in request.headers: | |
| auth_header = request.headers["Authorization"] | |
| if auth_header.startswith("Bearer "): | |
| token = auth_header[7:] # Remove 'Bearer ' prefix | |
| if not token: | |
| raise HTTPException(status_code=401, detail="No authentication token provided") | |
| try: | |
| # Decode token to get user info | |
| user_data = decode_token(token) | |
| if not user_data: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| username = user_data["username"] | |
| role = user_data["role"] | |
| # Get the correct table based on role | |
| table_map = { | |
| "Learner": "Learners", | |
| "Instructor": "Instructors" | |
| } | |
| table = table_map.get(role) | |
| if not table: | |
| raise HTTPException(status_code=400, detail="Invalid role") | |
| conn = connect_db() | |
| try: | |
| with conn.cursor() as cur: | |
| # First verify current password | |
| query = f"SELECT Password FROM {table} WHERE AccountName=%s LIMIT 1" | |
| cur.execute(query, (username,)) | |
| row = cur.fetchone() | |
| if not row or not check_password(payload.current_password, row[0]): | |
| raise HTTPException(status_code=401, detail="Current password is incorrect") | |
| # Hash the new password | |
| new_password_hash = hash_password(payload.new_password) | |
| # Update the password | |
| update_query = f"UPDATE {table} SET Password=%s WHERE AccountName=%s" | |
| cur.execute(update_query, (new_password_hash, username)) | |
| conn.commit() | |
| return {"message": "Password updated successfully"} | |
| except Exception as db_err: | |
| conn.rollback() | |
| print(f"Database error: {str(db_err)}") | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(db_err)}") | |
| finally: | |
| conn.close() | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Unexpected error in change_password: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error: {str(e)}") | |
| def connect_db(): | |
| return pymysql.connect( | |
| host=MYSQL_HOST, | |
| user=MYSQL_USER, | |
| password=MYSQL_PASSWORD, | |
| database=MYSQL_DB, | |
| port=MYSQL_PORT | |
| ) | |
| def check_password(plain: str, hashed: str) -> bool: | |
| """Compare plaintext vs bcrypt hash stored in DB.""" | |
| return bcrypt.checkpw(plain.encode("utf-8"), hashed.encode("utf-8")) | |
| def hash_password(plain: str) -> str: | |
| """Return bcrypt hash for storage.""" | |
| return bcrypt.hashpw(plain.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") | |
| # ──────────────────────────────────────────────────────────────────────────────── | |
| # Pydantic models | |
| # ──────────────────────────────────────────────────────────────────────────────── | |
| class UserCredentials(BaseModel): | |
| username: str | |
| password: str | |
| role: str # "Learner" or "Instructor" | |
| class NewUser(UserCredentials): | |
| name: str | |
| email: str | |
| def verify_user(creds: UserCredentials): | |
| """ | |
| Validate user credentials. | |
| Returns {username, role} on success; 401 on failure. | |
| """ | |
| try: | |
| print(f"Verify user attempt for: {creds.username}, role: {creds.role}") | |
| table_map = { | |
| "Learner": "Learners", | |
| "Instructor": "Instructors" | |
| } | |
| table = table_map.get(creds.role) | |
| if not table: | |
| print(f"Invalid role: {creds.role}") | |
| raise HTTPException(status_code=400, detail="Invalid role") | |
| conn = connect_db() | |
| try: | |
| with conn.cursor() as cur: | |
| query = f"SELECT Password FROM {table} WHERE AccountName=%s LIMIT 1" | |
| print(f"Executing query: {query} with username: {creds.username}") | |
| cur.execute(query, (creds.username,)) | |
| row = cur.fetchone() | |
| if not row: | |
| print(f"No user found with username: {creds.username} in table: {table}") | |
| except Exception as db_err: | |
| print(f"Database error: {str(db_err)}") | |
| raise HTTPException(status_code=500, detail=f"Database error: {str(db_err)}") | |
| finally: | |
| conn.close() | |
| if not row: | |
| print(f"Authentication failed: User not found") | |
| raise HTTPException(status_code=401, detail="Incorrect username or password") | |
| password_valid = check_password(creds.password, row[0]) | |
| print(f"Password check result: {password_valid}") | |
| if not password_valid: | |
| print(f"Authentication failed: Invalid password") | |
| raise HTTPException(status_code=401, detail="Incorrect username or password") | |
| print(f"Authentication successful for: {creds.username}") | |
| return {"username": creds.username, "role": creds.role} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"Unexpected error in verify_user: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error: {str(e)}") | |
| def register_user(user: NewUser): | |
| """ | |
| Simple sign-up endpoint. | |
| Returns 201 Created on success, 409 if username/email exists. | |
| """ | |
| table_cfg = { | |
| "Learner": ("Learners", "LearnerName"), | |
| "Instructor": ("Instructors", "InstructorName") | |
| } | |
| cfg = table_cfg.get(user.role) | |
| if not cfg: | |
| raise HTTPException(status_code=400, detail="Invalid role") | |
| table, name_col = cfg | |
| hashed_pw = hash_password(user.password) | |
| conn = connect_db() | |
| try: | |
| with conn.cursor() as cur: | |
| # Uniqueness check | |
| cur.execute( | |
| f"SELECT 1 FROM {table} WHERE AccountName=%s OR Email=%s LIMIT 1", | |
| (user.username, user.email) | |
| ) | |
| if cur.fetchone(): | |
| raise HTTPException(status_code=409, detail="Username or email already exists") | |
| # Insert new record | |
| cur.execute( | |
| f"INSERT INTO {table} ({name_col}, Email, AccountName, Password) " | |
| f"VALUES (%s, %s, %s, %s)", | |
| (user.name, user.email, user.username, hashed_pw) | |
| ) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| return {"msg": "User created", "username": user.username, "role": user.role}, 201 | |
| async def logout(): | |
| response = JSONResponse(content={"message": "Logged out successfully"}) | |
| # Clear all auth-related cookies | |
| response.delete_cookie(key="auth_token", path="/") | |
| response.delete_cookie(key="session_id", path="/") | |
| return response | |
| async def get_users(role: str): | |
| table_map = { | |
| "Learner": "Learners", | |
| "Instructor": "Instructors" | |
| } | |
| table = table_map.get(role) | |
| if not table: | |
| raise HTTPException(status_code=400, detail="Invalid role") | |
| conn = connect_db() | |
| try: | |
| with conn.cursor() as cur: | |
| cur.execute(f"SELECT * FROM {table}") | |
| columns = [desc[0] for desc in cur.description] | |
| users = [dict(zip(columns, row)) for row in cur.fetchall()] | |
| return users | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| conn.close() | |
| async def get_user_count(role: str): | |
| table_map = { | |
| "Learner": "Learners", | |
| "Instructor": "Instructors" | |
| } | |
| table = table_map.get(role) | |
| if not table: | |
| raise HTTPException(status_code=400, detail="Invalid role") | |
| conn = connect_db() | |
| try: | |
| with conn.cursor() as cur: | |
| cur.execute(f"SELECT COUNT(*) FROM {table}") | |
| count = cur.fetchone()[0] | |
| return {"count": count} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| conn.close() |