Junyi42 commited on
Commit
d79eae5
·
1 Parent(s): fed9f1a

improve socket stability

Browse files
Files changed (1) hide show
  1. viser_proxy_manager.py +69 -58
viser_proxy_manager.py CHANGED
@@ -27,7 +27,7 @@ class ViserProxyManager:
27
  app: FastAPI,
28
  min_local_port: int = 8000,
29
  max_local_port: int = 9000,
30
- max_message_size: int = 100 * 1024 * 1024, # 100MB default
31
  ) -> None:
32
  self._min_port = min_local_port
33
  self._max_port = max_local_port
@@ -83,68 +83,79 @@ class ViserProxyManager:
83
  @app.websocket("/viser/{server_id}")
84
  async def websocket_proxy(websocket: WebSocket, server_id: str):
85
  """Proxy WebSocket connections to the appropriate Viser server."""
86
- await websocket.accept()
87
-
88
- server = self._server_from_session_hash.get(server_id)
89
- if server is None:
90
- await websocket.close(code=1008, reason="Not Found")
91
- return
92
-
93
- # Determine target WebSocket URL
94
- target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
95
-
96
- if not target_ws_url:
97
- await websocket.close(code=1008, reason="Not Found")
98
- return
99
-
100
  try:
101
- # Connect to the target WebSocket with increased message size
102
- async with websockets.connect(
103
- target_ws_url,
104
- max_size=self._max_message_size, # Set max message size for receiving
105
- ) as ws_target:
106
- # Create tasks for bidirectional communication
107
- async def forward_to_target():
108
- """Forward messages from the client to the target WebSocket."""
109
- try:
110
- while True:
111
- data = await websocket.receive_bytes()
112
- await ws_target.send(data, text=False)
113
- except WebSocketDisconnect:
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  try:
115
- await ws_target.close()
116
- except RuntimeError:
117
- pass
118
-
119
- async def forward_from_target():
120
- """Forward messages from the target WebSocket to the client."""
121
- try:
122
- while True:
123
- data = await ws_target.recv(decode=False)
124
- await websocket.send_bytes(data)
125
- except websockets.exceptions.ConnectionClosed:
126
  try:
127
- await websocket.close()
128
- except RuntimeError:
129
- pass
130
-
131
- # Run both forwarding tasks concurrently
132
- forward_task = asyncio.create_task(forward_to_target())
133
- backward_task = asyncio.create_task(forward_from_target())
134
-
135
- # Wait for either task to complete (which means a connection was closed)
136
- done, pending = await asyncio.wait(
137
- [forward_task, backward_task],
138
- return_when=asyncio.FIRST_COMPLETED,
139
- )
140
-
141
- # Cancel the remaining task
142
- for task in pending:
143
- task.cancel()
144
-
 
 
 
 
 
 
 
 
 
145
  except Exception as e:
146
  print(f"WebSocket proxy error: {e}")
147
- await websocket.close(code=1011, reason=str(e))
 
 
 
148
 
149
  def start_server(self, server_id: str) -> viser.ViserServer:
150
  """Start a new Viser server and associate it with the given server ID.
 
27
  app: FastAPI,
28
  min_local_port: int = 8000,
29
  max_local_port: int = 9000,
30
+ max_message_size: int = 300 * 1024 * 1024, # 300MB default
31
  ) -> None:
32
  self._min_port = min_local_port
33
  self._max_port = max_local_port
 
83
  @app.websocket("/viser/{server_id}")
84
  async def websocket_proxy(websocket: WebSocket, server_id: str):
85
  """Proxy WebSocket connections to the appropriate Viser server."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  try:
87
+ await websocket.accept()
88
+
89
+ server = self._server_from_session_hash.get(server_id)
90
+ if server is None:
91
+ await websocket.close(code=1008, reason="Not Found")
92
+ return
93
+
94
+ # Determine target WebSocket URL
95
+ target_ws_url = f"ws://127.0.0.1:{server.get_port()}"
96
+
97
+ if not target_ws_url:
98
+ await websocket.close(code=1008, reason="Not Found")
99
+ return
100
+
101
+ try:
102
+ # Connect to the target WebSocket with increased message size and timeout
103
+ async with websockets.connect(
104
+ target_ws_url,
105
+ max_size=self._max_message_size,
106
+ ping_interval=30, # Send ping every 30 seconds
107
+ ping_timeout=10, # Wait 10 seconds for pong response
108
+ close_timeout=5, # Wait 5 seconds for close handshake
109
+ ) as ws_target:
110
+ # Create tasks for bidirectional communication
111
+ async def forward_to_target():
112
+ """Forward messages from the client to the target WebSocket."""
113
  try:
114
+ while True:
115
+ data = await websocket.receive_bytes()
116
+ await ws_target.send(data, text=False)
117
+ except WebSocketDisconnect:
118
+ try:
119
+ await ws_target.close()
120
+ except RuntimeError:
121
+ pass
122
+
123
+ async def forward_from_target():
124
+ """Forward messages from the target WebSocket to the client."""
125
  try:
126
+ while True:
127
+ data = await ws_target.recv(decode=False)
128
+ await websocket.send_bytes(data)
129
+ except websockets.exceptions.ConnectionClosed:
130
+ try:
131
+ await websocket.close()
132
+ except RuntimeError:
133
+ pass
134
+
135
+ # Run both forwarding tasks concurrently
136
+ forward_task = asyncio.create_task(forward_to_target())
137
+ backward_task = asyncio.create_task(forward_from_target())
138
+
139
+ # Wait for either task to complete (which means a connection was closed)
140
+ done, pending = await asyncio.wait(
141
+ [forward_task, backward_task],
142
+ return_when=asyncio.FIRST_COMPLETED,
143
+ )
144
+
145
+ # Cancel the remaining task
146
+ for task in pending:
147
+ task.cancel()
148
+
149
+ except websockets.exceptions.ConnectionClosedError as e:
150
+ print(f"WebSocket connection closed with error: {e}")
151
+ await websocket.close(code=1011, reason="Connection to target closed")
152
+
153
  except Exception as e:
154
  print(f"WebSocket proxy error: {e}")
155
+ try:
156
+ await websocket.close(code=1011, reason=str(e)[:120]) # Limit reason length
157
+ except:
158
+ pass # Already closed
159
 
160
  def start_server(self, server_id: str) -> viser.ViserServer:
161
  """Start a new Viser server and associate it with the given server ID.