|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.responses import HTMLResponse |
|
|
from fastapi.templating import Jinja2Templates |
|
|
import json |
|
|
import asyncio |
|
|
from datetime import datetime |
|
|
from typing import Dict, List, Set |
|
|
import base64 |
|
|
import mimetypes |
|
|
import os |
|
|
import uvicorn |
|
|
app = FastAPI(title="Tri-Chat API", description="Real-time chat with WebSocket support") |
|
|
|
|
|
|
|
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
|
|
|
class ConnectionManager: |
|
|
def __init__(self): |
|
|
|
|
|
self.active_connections: Dict[str, List[Dict]] = {} |
|
|
|
|
|
self.message_history: Dict[str, List[Dict]] = {} |
|
|
|
|
|
async def connect(self, websocket: WebSocket, room: str, username: str): |
|
|
await websocket.accept() |
|
|
|
|
|
|
|
|
if room not in self.active_connections: |
|
|
self.active_connections[room] = [] |
|
|
self.message_history[room] = [] |
|
|
|
|
|
|
|
|
connection_info = { |
|
|
"websocket": websocket, |
|
|
"username": username, |
|
|
"joined_at": datetime.now().isoformat() |
|
|
} |
|
|
self.active_connections[room].append(connection_info) |
|
|
|
|
|
|
|
|
join_message = { |
|
|
"type": "system", |
|
|
"message": f"{username} joined the room", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"room": room |
|
|
} |
|
|
await self.broadcast_to_room(room, join_message) |
|
|
|
|
|
|
|
|
for message in self.message_history[room]: |
|
|
await websocket.send_text(json.dumps(message)) |
|
|
|
|
|
def disconnect(self, websocket: WebSocket, room: str): |
|
|
if room in self.active_connections: |
|
|
|
|
|
for conn in self.active_connections[room]: |
|
|
if conn["websocket"] == websocket: |
|
|
self.active_connections[room].remove(conn) |
|
|
return conn["username"] |
|
|
return None |
|
|
|
|
|
async def broadcast_to_room(self, room: str, message: dict): |
|
|
if room not in self.active_connections: |
|
|
return |
|
|
|
|
|
|
|
|
self.message_history[room].append(message) |
|
|
|
|
|
|
|
|
if len(self.message_history[room]) > 100: |
|
|
self.message_history[room] = self.message_history[room][-100:] |
|
|
|
|
|
|
|
|
disconnected = [] |
|
|
for connection_info in self.active_connections[room]: |
|
|
try: |
|
|
await connection_info["websocket"].send_text(json.dumps(message)) |
|
|
except: |
|
|
disconnected.append(connection_info) |
|
|
|
|
|
|
|
|
for conn in disconnected: |
|
|
self.active_connections[room].remove(conn) |
|
|
|
|
|
def get_room_users(self, room: str) -> List[str]: |
|
|
if room not in self.active_connections: |
|
|
return [] |
|
|
return [conn["username"] for conn in self.active_connections[room]] |
|
|
|
|
|
|
|
|
manager = ConnectionManager() |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def get_chat_page(): |
|
|
"""Serve the chat HTML page""" |
|
|
try: |
|
|
with open("templates/index.html", "r", encoding="utf-8") as f: |
|
|
html_content = f.read() |
|
|
return HTMLResponse(content=html_content) |
|
|
except FileNotFoundError: |
|
|
return HTMLResponse( |
|
|
content="<h1>Error: templates/index.html not found</h1><p>Please make sure the templates directory exists with index.html</p>", |
|
|
status_code=404 |
|
|
) |
|
|
|
|
|
@app.websocket("/ws/{room}") |
|
|
async def websocket_endpoint(websocket: WebSocket, room: str, username: str): |
|
|
"""WebSocket endpoint for real-time chat""" |
|
|
|
|
|
|
|
|
if not username or len(username.strip()) == 0: |
|
|
await websocket.close(code=1008, reason="Username is required") |
|
|
return |
|
|
|
|
|
if not room or len(room.strip()) == 0: |
|
|
room = "global" |
|
|
|
|
|
|
|
|
username = username.strip()[:20] |
|
|
room = room.strip()[:30] |
|
|
|
|
|
await manager.connect(websocket, room, username) |
|
|
|
|
|
try: |
|
|
while True: |
|
|
|
|
|
data = await websocket.receive_text() |
|
|
message_data = json.loads(data) |
|
|
|
|
|
|
|
|
if message_data.get("type") not in ["text", "file"]: |
|
|
continue |
|
|
|
|
|
|
|
|
if message_data["type"] == "text": |
|
|
text_content = message_data.get("text", "").strip() |
|
|
if len(text_content) == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
text_content = text_content[:500] |
|
|
|
|
|
message = { |
|
|
"type": "text", |
|
|
"username": username, |
|
|
"text": text_content, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"room": room |
|
|
} |
|
|
|
|
|
await manager.broadcast_to_room(room, message) |
|
|
|
|
|
|
|
|
elif message_data["type"] == "file": |
|
|
file_name = message_data.get("fileName", "unknown")[:100] |
|
|
file_type = message_data.get("fileType", "application/octet-stream") |
|
|
file_size = message_data.get("fileSize", 0) |
|
|
file_data = message_data.get("fileData", "") |
|
|
|
|
|
|
|
|
if file_size > 5 * 1024 * 1024: |
|
|
await websocket.send_text(json.dumps({ |
|
|
"type": "error", |
|
|
"message": "File size exceeds 5MB limit" |
|
|
})) |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
base64.b64decode(file_data) |
|
|
except Exception: |
|
|
await websocket.send_text(json.dumps({ |
|
|
"type": "error", |
|
|
"message": "Invalid file data" |
|
|
})) |
|
|
continue |
|
|
|
|
|
message = { |
|
|
"type": "file", |
|
|
"username": username, |
|
|
"fileName": file_name, |
|
|
"fileType": file_type, |
|
|
"fileSize": file_size, |
|
|
"fileData": file_data, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"room": room |
|
|
} |
|
|
|
|
|
await manager.broadcast_to_room(room, message) |
|
|
|
|
|
except WebSocketDisconnect: |
|
|
disconnected_username = manager.disconnect(websocket, room) |
|
|
if disconnected_username: |
|
|
leave_message = { |
|
|
"type": "system", |
|
|
"message": f"{disconnected_username} left the room", |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"room": room |
|
|
} |
|
|
await manager.broadcast_to_room(room, leave_message) |
|
|
|
|
|
@app.get("/api/rooms") |
|
|
async def get_active_rooms(): |
|
|
"""Get list of active chat rooms""" |
|
|
rooms = [] |
|
|
for room_name, connections in manager.active_connections.items(): |
|
|
if connections: |
|
|
rooms.append({ |
|
|
"name": room_name, |
|
|
"user_count": len(connections), |
|
|
"users": [conn["username"] for conn in connections] |
|
|
}) |
|
|
return {"rooms": rooms} |
|
|
|
|
|
@app.get("/api/rooms/{room}/users") |
|
|
async def get_room_users(room: str): |
|
|
"""Get list of users in a specific room""" |
|
|
users = manager.get_room_users(room) |
|
|
return { |
|
|
"room": room, |
|
|
"users": users, |
|
|
"user_count": len(users) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
port = int(os.getenv("PORT", 7860)) |
|
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|
|