nexus-relay / relay.py
ChandimaPrabath's picture
update. added session token
fab9207 verified
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import json
from datetime import datetime
import httpx
from dotenv import load_dotenv
import os
load_dotenv()
SERVICES_URL = os.getenv("SERVICES_URL")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins; use specific domains for security
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Dictionary to hold active connections
active_connections = {}
# Dictionary to store undelivered messages
message_store = {}
# Store session tokens (for simplicity, in-memory storage)
session_tokens = {}
async def register_client(websocket: WebSocket, username: str, token: str):
"""Register a new client with a session token."""
active_connections[username] = websocket
session_tokens[username] = token # Store the session token
print(f"DEBUG: {username} connected with token. {token}")
# Deliver undelivered messages if any
if username in message_store:
for message in message_store[username]:
print(f"DEBUG: Sending undelivered message to {username}: {message}")
await websocket.send_text(json.dumps(message))
del message_store[username] # Clear delivered messages
print(f"DEBUG: Cleared stored messages for {username}")
async def unregister_client(username: str):
"""Unregister a client and delete their session token."""
if username in active_connections:
del active_connections[username]
del session_tokens[username] # Delete session token on disconnect
print(f"DEBUG: {username} disconnected.")
@app.websocket("/ws")
async def relay_server(websocket: WebSocket):
"""Relay server handling WebSocket connections."""
username = None
token = None
try:
await websocket.accept()
print("DEBUG: WebSocket connection accepted.")
# Initial login or signup
login_data = await websocket.receive_text()
print(f"DEBUG: Received initial data: {login_data}")
user_data = json.loads(login_data)
action = user_data.get("action") # 'login' or 'signup'
username = user_data.get("username")
password = user_data.get("password")
if action == "signup":
# Attempt user registration
print(f"DEBUG: Registering new user: {username}")
async with httpx.AsyncClient() as client:
response = await client.post(f"{SERVICES_URL}/register", json={"username": username, "password": password})
if response.status_code == 200:
print(f"DEBUG: Registration successful for {username}")
await websocket.send_text(json.dumps({"status": "success", "message": "Registration successful"}))
# Create session token upon registration
token_response = await client.post(f"{SERVICES_URL}/login", json={"username": username, "password": password})
if token_response.status_code == 200:
token = token_response.json().get("token")
await websocket.send_text(json.dumps({"status": "success", "token": token}))
# Register the client after successful signup and token retrieval
await register_client(websocket, username, token)
return
else:
await websocket.send_text(json.dumps({"status": "error", "message": "Token retrieval failed"}))
else:
error_message = response.json().get("detail", "Registration failed")
print(f"DEBUG: Registration failed for {username}: {error_message}")
await websocket.send_text(json.dumps({"status": "error", "message": error_message}))
return
elif action == "login":
# Proceed with login if not a signup
print(f"DEBUG: Authenticating user: {username}")
async with httpx.AsyncClient() as client:
response = await client.post(f"{SERVICES_URL}/login", json={"username": username, "password": password})
if response.status_code == 200:
print(f"DEBUG: Authentication successful for {username}")
token = response.json().get("token")
await websocket.send_text(json.dumps({"status": "success", "message": "Authenticated", "token": token}))
await register_client(websocket, username, token)
else:
error_message = response.json().get("detail", "Invalid credentials")
print(f"DEBUG: Authentication failed for {username}: {error_message}")
await websocket.send_text(json.dumps({"status": "error", "message": error_message}))
return
else:
await websocket.send_text(json.dumps({"status": "error", "message": "Invalid action. Use 'login' or 'signup'."}))
return
# Relay messages
while True:
try:
message = await websocket.receive_text()
print(f"DEBUG: Received message: {message}")
msg_data = json.loads(message)
recipient = msg_data.get("recipient")
msg_content = msg_data.get("message")
timestamp = datetime.now().isoformat()
# Check session token for message validation
if recipient in session_tokens:
received_token = msg_data.get("token")
if received_token != session_tokens[recipient]:
await websocket.send_text(json.dumps({"status": "error", "message": "Invalid session token"}))
continue
# Create message object
message_obj = {
"from": username,
"message": msg_content,
"timestamp": timestamp
}
# Validate recipient and send/deliver message
print(f"DEBUG: Sending message to recipient: {recipient}")
if recipient in active_connections:
recipient_socket = active_connections[recipient]
await recipient_socket.send_text(json.dumps(message_obj))
print(f"DEBUG: Message sent to {recipient}: {message_obj}")
else:
# Store undelivered message
if recipient not in message_store:
message_store[recipient] = []
message_store[recipient].append(message_obj)
print(f"DEBUG: Message stored for {recipient}: {message_obj}")
await websocket.send_text(json.dumps({"status": "success", "message": "Message stored for delivery"}))
except Exception as e:
# Handle errors gracefully
await websocket.send_text(json.dumps({"status": "error", "message": "Message processing error"}))
print(f"DEBUG: Error processing message: {e}")
except WebSocketDisconnect:
print(f"DEBUG: Connection with {username} closed.")
finally:
if username:
await unregister_client(username)