Spaces:
Runtime error
Runtime error
| import logging | |
| from fastapi import FastAPI, Depends, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from slowapi import Limiter, _rate_limit_exceeded_handler | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from fastapi.responses import JSONResponse | |
| from fastapi.templating import Jinja2Templates | |
| from datetime import datetime | |
| import socketio | |
| import jwt | |
| from auth_utils import SECRET_KEY, ALGORITHM | |
| from models import User | |
| from database import SessionLocal | |
| from routers import auth, chat, users, conversations, messages, teams, admin | |
| # Import the Socket.IO server instance and user store from our new module | |
| from global_chat import sio, active_users, create_message_html | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger("comp_eb") | |
| limiter = Limiter(key_func=get_remote_address) | |
| templates = Jinja2Templates(directory="templates") | |
| # --- FastAPI Application --- | |
| app = FastAPI( | |
| title="Ebuka AI API", | |
| description="API for the Ebuka AI chatbot with user authentication.", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:8001", | |
| "http://127.0.0.1:8001", | |
| "http://localhost:8050", | |
| "http://127.0.0.1:8050", | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.state.limiter = limiter | |
| app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
| # --- Include REST Routers --- | |
| app.include_router(auth.router) | |
| app.include_router(chat.router) | |
| app.include_router(users.router) | |
| app.include_router(conversations.router) | |
| app.include_router(messages.router) | |
| app.include_router(teams.router) | |
| app.include_router(admin.router) | |
| # --- Socket.IO Authentication and Event Handling --- | |
| async def connect(sid, environ, auth): | |
| """ | |
| This is the authentication middleware for Socket.IO. | |
| It's called for every new connection attempt. | |
| """ | |
| token = auth.get('token') | |
| if not token: | |
| logger.warning(f"Connection rejected for sid {sid}: No token provided.") | |
| raise socketio.exceptions.ConnectionRefusedError('authentication failed: no token') | |
| db = SessionLocal() | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| email_from_token = payload.get("sub") | |
| if not email_from_token: | |
| raise socketio.exceptions.ConnectionRefusedError('authentication failed: invalid token payload') | |
| user = db.query(User).filter(User.email == email_from_token).first() | |
| if not user: | |
| raise socketio.exceptions.ConnectionRefusedError('authentication failed: user not found') | |
| # Authentication successful, attach user data to the session | |
| display_name = user.display_name or user.username or user.email | |
| user_data = { | |
| "id": user.id, | |
| "email": user.email or "", | |
| "username": user.username or "", | |
| "display_name": display_name or "" | |
| } | |
| # Save the user data in our active users store and the session | |
| active_users[sid] = user_data | |
| logger.info(f"User {display_name} connected with sid {sid}") | |
| # Send welcome message to the newly connected user | |
| welcome_message = {"type": "system", "message": f"Welcome, {display_name}!"} | |
| html_content = create_message_html(welcome_message, user.id) | |
| await sio.emit('htmx_message', html_content, to=sid) | |
| # Broadcast to all other users that a new user has joined | |
| online_count = len(set(u['id'] for u in active_users.values())) | |
| join_message = { | |
| "type": "user_joined", | |
| "message": f"{display_name} has joined.", | |
| "online_count": online_count | |
| } | |
| for other_sid, other_user_data in active_users.items(): | |
| if other_sid != sid: | |
| html_content = create_message_html(join_message, other_user_data['id']) | |
| await sio.emit('htmx_message', html_content, to=other_sid) | |
| except jwt.PyJWTError as e: | |
| logger.error(f"JWT Decode Error for sid {sid}: {e}") | |
| raise socketio.exceptions.ConnectionRefusedError('authentication failed: invalid token') | |
| finally: | |
| db.close() | |
| # --- Mount the Socket.IO app --- | |
| # This combines your FastAPI app and the Socket.IO server | |
| socket_app = socketio.ASGIApp(sio, other_asgi_app=app) | |
| # --- Root Endpoint --- | |
| async def read_root(): | |
| return {"message": "Welcome to the Ebuka AI API! Database is managed by Alembic."} | |