from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware import json from datetime import datetime import httpx from dotenv import load_dotenv import os load_dotenv() SERVICES_URL = os.getenv("SERVICES_URL") app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins; use specific domains for security allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Dictionary to hold active connections active_connections = {} # Dictionary to store undelivered messages message_store = {} # Store session tokens (for simplicity, in-memory storage) session_tokens = {} async def register_client(websocket: WebSocket, username: str, token: str): """Register a new client with a session token.""" active_connections[username] = websocket session_tokens[username] = token # Store the session token print(f"DEBUG: {username} connected with token. {token}") # Deliver undelivered messages if any if username in message_store: for message in message_store[username]: print(f"DEBUG: Sending undelivered message to {username}: {message}") await websocket.send_text(json.dumps(message)) del message_store[username] # Clear delivered messages print(f"DEBUG: Cleared stored messages for {username}") async def unregister_client(username: str): """Unregister a client and delete their session token.""" if username in active_connections: del active_connections[username] del session_tokens[username] # Delete session token on disconnect print(f"DEBUG: {username} disconnected.") @app.websocket("/ws") async def relay_server(websocket: WebSocket): """Relay server handling WebSocket connections.""" username = None token = None try: await websocket.accept() print("DEBUG: WebSocket connection accepted.") # Initial login or signup login_data = await websocket.receive_text() print(f"DEBUG: Received initial data: {login_data}") user_data = json.loads(login_data) action = user_data.get("action") # 'login' or 'signup' username = user_data.get("username") password = user_data.get("password") if action == "signup": # Attempt user registration print(f"DEBUG: Registering new user: {username}") async with httpx.AsyncClient() as client: response = await client.post(f"{SERVICES_URL}/register", json={"username": username, "password": password}) if response.status_code == 200: print(f"DEBUG: Registration successful for {username}") await websocket.send_text(json.dumps({"status": "success", "message": "Registration successful"})) # Create session token upon registration token_response = await client.post(f"{SERVICES_URL}/login", json={"username": username, "password": password}) if token_response.status_code == 200: token = token_response.json().get("token") await websocket.send_text(json.dumps({"status": "success", "token": token})) # Register the client after successful signup and token retrieval await register_client(websocket, username, token) return else: await websocket.send_text(json.dumps({"status": "error", "message": "Token retrieval failed"})) else: error_message = response.json().get("detail", "Registration failed") print(f"DEBUG: Registration failed for {username}: {error_message}") await websocket.send_text(json.dumps({"status": "error", "message": error_message})) return elif action == "login": # Proceed with login if not a signup print(f"DEBUG: Authenticating user: {username}") async with httpx.AsyncClient() as client: response = await client.post(f"{SERVICES_URL}/login", json={"username": username, "password": password}) if response.status_code == 200: print(f"DEBUG: Authentication successful for {username}") token = response.json().get("token") await websocket.send_text(json.dumps({"status": "success", "message": "Authenticated", "token": token})) await register_client(websocket, username, token) else: error_message = response.json().get("detail", "Invalid credentials") print(f"DEBUG: Authentication failed for {username}: {error_message}") await websocket.send_text(json.dumps({"status": "error", "message": error_message})) return else: await websocket.send_text(json.dumps({"status": "error", "message": "Invalid action. Use 'login' or 'signup'."})) return # Relay messages while True: try: message = await websocket.receive_text() print(f"DEBUG: Received message: {message}") msg_data = json.loads(message) recipient = msg_data.get("recipient") msg_content = msg_data.get("message") timestamp = datetime.now().isoformat() # Check session token for message validation if recipient in session_tokens: received_token = msg_data.get("token") if received_token != session_tokens[recipient]: await websocket.send_text(json.dumps({"status": "error", "message": "Invalid session token"})) continue # Create message object message_obj = { "from": username, "message": msg_content, "timestamp": timestamp } # Validate recipient and send/deliver message print(f"DEBUG: Sending message to recipient: {recipient}") if recipient in active_connections: recipient_socket = active_connections[recipient] await recipient_socket.send_text(json.dumps(message_obj)) print(f"DEBUG: Message sent to {recipient}: {message_obj}") else: # Store undelivered message if recipient not in message_store: message_store[recipient] = [] message_store[recipient].append(message_obj) print(f"DEBUG: Message stored for {recipient}: {message_obj}") await websocket.send_text(json.dumps({"status": "success", "message": "Message stored for delivery"})) except Exception as e: # Handle errors gracefully await websocket.send_text(json.dumps({"status": "error", "message": "Message processing error"})) print(f"DEBUG: Error processing message: {e}") except WebSocketDisconnect: print(f"DEBUG: Connection with {username} closed.") finally: if username: await unregister_client(username)