banter-api / main.py
EbukaGaus's picture
push
d5a3ec4
import logging
from fastapi import FastAPI, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi.responses import JSONResponse
from fastapi.templating import Jinja2Templates
from datetime import datetime
import socketio
import jwt
from auth_utils import SECRET_KEY, ALGORITHM
from models import User
from database import SessionLocal
from routers import auth, chat, users, conversations, messages, teams, admin
# Import the Socket.IO server instance and user store from our new module
from global_chat import sio, active_users, create_message_html
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("comp_eb")
limiter = Limiter(key_func=get_remote_address)
templates = Jinja2Templates(directory="templates")
# --- FastAPI Application ---
app = FastAPI(
title="Ebuka AI API",
description="API for the Ebuka AI chatbot with user authentication.",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:8001",
"http://127.0.0.1:8001",
"http://localhost:8050",
"http://127.0.0.1:8050",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# --- Include REST Routers ---
app.include_router(auth.router)
app.include_router(chat.router)
app.include_router(users.router)
app.include_router(conversations.router)
app.include_router(messages.router)
app.include_router(teams.router)
app.include_router(admin.router)
# --- Socket.IO Authentication and Event Handling ---
@sio.on('connect')
async def connect(sid, environ, auth):
"""
This is the authentication middleware for Socket.IO.
It's called for every new connection attempt.
"""
token = auth.get('token')
if not token:
logger.warning(f"Connection rejected for sid {sid}: No token provided.")
raise socketio.exceptions.ConnectionRefusedError('authentication failed: no token')
db = SessionLocal()
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
email_from_token = payload.get("sub")
if not email_from_token:
raise socketio.exceptions.ConnectionRefusedError('authentication failed: invalid token payload')
user = db.query(User).filter(User.email == email_from_token).first()
if not user:
raise socketio.exceptions.ConnectionRefusedError('authentication failed: user not found')
# Authentication successful, attach user data to the session
display_name = user.display_name or user.username or user.email
user_data = {
"id": user.id,
"email": user.email or "",
"username": user.username or "",
"display_name": display_name or ""
}
# Save the user data in our active users store and the session
active_users[sid] = user_data
logger.info(f"User {display_name} connected with sid {sid}")
# Send welcome message to the newly connected user
welcome_message = {"type": "system", "message": f"Welcome, {display_name}!"}
html_content = create_message_html(welcome_message, user.id)
await sio.emit('htmx_message', html_content, to=sid)
# Broadcast to all other users that a new user has joined
online_count = len(set(u['id'] for u in active_users.values()))
join_message = {
"type": "user_joined",
"message": f"{display_name} has joined.",
"online_count": online_count
}
for other_sid, other_user_data in active_users.items():
if other_sid != sid:
html_content = create_message_html(join_message, other_user_data['id'])
await sio.emit('htmx_message', html_content, to=other_sid)
except jwt.PyJWTError as e:
logger.error(f"JWT Decode Error for sid {sid}: {e}")
raise socketio.exceptions.ConnectionRefusedError('authentication failed: invalid token')
finally:
db.close()
# --- Mount the Socket.IO app ---
# This combines your FastAPI app and the Socket.IO server
socket_app = socketio.ASGIApp(sio, other_asgi_app=app)
# --- Root Endpoint ---
@app.get("/", tags=["Root"])
async def read_root():
return {"message": "Welcome to the Ebuka AI API! Database is managed by Alembic."}