File size: 7,449 Bytes
54ebe81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f59a9a4
 
 
 
 
54ebe81
f59a9a4
fab9207
54ebe81
 
 
 
 
 
 
 
 
 
f59a9a4
54ebe81
 
f59a9a4
54ebe81
 
 
 
 
 
f59a9a4
54ebe81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f59a9a4
 
 
 
 
 
 
 
 
 
 
54ebe81
 
 
 
 
 
018bd07
 
 
 
 
 
 
f59a9a4
 
 
018bd07
 
 
 
 
 
 
 
54ebe81
 
 
 
 
 
 
 
 
 
 
f59a9a4
 
 
 
 
 
 
54ebe81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import json
from datetime import datetime
import httpx
from dotenv import load_dotenv
import os

load_dotenv()

SERVICES_URL = os.getenv("SERVICES_URL")

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins; use specific domains for security
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Dictionary to hold active connections
active_connections = {}

# Dictionary to store undelivered messages
message_store = {}

# Store session tokens (for simplicity, in-memory storage)
session_tokens = {}

async def register_client(websocket: WebSocket, username: str, token: str):
    """Register a new client with a session token."""
    active_connections[username] = websocket
    session_tokens[username] = token  # Store the session token
    print(f"DEBUG: {username} connected with token. {token}")

    # Deliver undelivered messages if any
    if username in message_store:
        for message in message_store[username]:
            print(f"DEBUG: Sending undelivered message to {username}: {message}")
            await websocket.send_text(json.dumps(message))
        del message_store[username]  # Clear delivered messages
        print(f"DEBUG: Cleared stored messages for {username}")

async def unregister_client(username: str):
    """Unregister a client and delete their session token."""
    if username in active_connections:
        del active_connections[username]
        del session_tokens[username]  # Delete session token on disconnect
        print(f"DEBUG: {username} disconnected.")

@app.websocket("/ws")
async def relay_server(websocket: WebSocket):
    """Relay server handling WebSocket connections."""
    username = None
    token = None
    try:
        await websocket.accept()
        print("DEBUG: WebSocket connection accepted.")

        # Initial login or signup
        login_data = await websocket.receive_text()
        print(f"DEBUG: Received initial data: {login_data}")
        user_data = json.loads(login_data)
        action = user_data.get("action")  # 'login' or 'signup'
        username = user_data.get("username")
        password = user_data.get("password")

        if action == "signup":
            # Attempt user registration
            print(f"DEBUG: Registering new user: {username}")
            async with httpx.AsyncClient() as client:
                response = await client.post(f"{SERVICES_URL}/register", json={"username": username, "password": password})
                if response.status_code == 200:
                    print(f"DEBUG: Registration successful for {username}")
                    await websocket.send_text(json.dumps({"status": "success", "message": "Registration successful"}))
                    
                    # Create session token upon registration
                    token_response = await client.post(f"{SERVICES_URL}/login", json={"username": username, "password": password})
                    if token_response.status_code == 200:
                        token = token_response.json().get("token")
                        await websocket.send_text(json.dumps({"status": "success", "token": token}))
                        # Register the client after successful signup and token retrieval
                        await register_client(websocket, username, token)
                        return
                    else:
                        await websocket.send_text(json.dumps({"status": "error", "message": "Token retrieval failed"}))
                else:
                    error_message = response.json().get("detail", "Registration failed")
                    print(f"DEBUG: Registration failed for {username}: {error_message}")
                    await websocket.send_text(json.dumps({"status": "error", "message": error_message}))
                    return

        elif action == "login":
            # Proceed with login if not a signup
            print(f"DEBUG: Authenticating user: {username}")
            async with httpx.AsyncClient() as client:
                response = await client.post(f"{SERVICES_URL}/login", json={"username": username, "password": password})
                if response.status_code == 200:
                    print(f"DEBUG: Authentication successful for {username}")
                    token = response.json().get("token")
                    await websocket.send_text(json.dumps({"status": "success", "message": "Authenticated", "token": token}))
                    await register_client(websocket, username, token)
                else:
                    error_message = response.json().get("detail", "Invalid credentials")
                    print(f"DEBUG: Authentication failed for {username}: {error_message}")
                    await websocket.send_text(json.dumps({"status": "error", "message": error_message}))
                    return
        else:
            await websocket.send_text(json.dumps({"status": "error", "message": "Invalid action. Use 'login' or 'signup'."}))
            return

        # Relay messages
        while True:
            try:
                message = await websocket.receive_text()
                print(f"DEBUG: Received message: {message}")
                msg_data = json.loads(message)
                recipient = msg_data.get("recipient")
                msg_content = msg_data.get("message")
                timestamp = datetime.now().isoformat()

                # Check session token for message validation
                if recipient in session_tokens:
                    received_token = msg_data.get("token")
                    if received_token != session_tokens[recipient]:
                        await websocket.send_text(json.dumps({"status": "error", "message": "Invalid session token"}))
                        continue

                # Create message object
                message_obj = {
                    "from": username,
                    "message": msg_content,
                    "timestamp": timestamp
                }

                # Validate recipient and send/deliver message
                print(f"DEBUG: Sending message to recipient: {recipient}")
                if recipient in active_connections:
                    recipient_socket = active_connections[recipient]
                    await recipient_socket.send_text(json.dumps(message_obj))
                    print(f"DEBUG: Message sent to {recipient}: {message_obj}")
                else:
                    # Store undelivered message
                    if recipient not in message_store:
                        message_store[recipient] = []
                    message_store[recipient].append(message_obj)
                    print(f"DEBUG: Message stored for {recipient}: {message_obj}")
                    await websocket.send_text(json.dumps({"status": "success", "message": "Message stored for delivery"}))
            except Exception as e:
                # Handle errors gracefully
                await websocket.send_text(json.dumps({"status": "error", "message": "Message processing error"}))
                print(f"DEBUG: Error processing message: {e}")

    except WebSocketDisconnect:
        print(f"DEBUG: Connection with {username} closed.")
    finally:
        if username:
            await unregister_client(username)