Factor Studios commited on
Commit
0bc19ec
·
verified ·
1 Parent(s): 4831394

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +103 -51
server.py CHANGED
@@ -17,6 +17,9 @@ app = FastAPI()
17
  class VirtualGPUServer:
18
  def __init__(self):
19
  self.base_path = Path(__file__).parent / "storage"
 
 
 
20
  self.vram_path = self.base_path / "vram_blocks"
21
  self.state_path = self.base_path / "gpu_state"
22
  self.cache_path = self.base_path / "cache"
@@ -66,30 +69,48 @@ class VirtualGPUServer:
66
  """Monitor connection health and handle reconnection"""
67
  try:
68
  while session_id in self.active_connections:
69
- try:
70
- await asyncio.wait_for(websocket.receive_text(), timeout=self.heartbeat_interval)
71
- except asyncio.TimeoutError:
72
- try:
73
- await websocket.send_json({"type": "ping"})
74
- except:
75
- print(f"Connection lost for session {session_id}")
76
- break
77
- except Exception:
78
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  await asyncio.sleep(self.heartbeat_interval)
80
  finally:
81
- await self.handle_disconnect(session_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  async def handle_disconnect(self, session_id: str):
84
  """Clean up resources when a client disconnects"""
85
- if session_id in self.active_connections:
86
- try:
87
- await self.active_connections[session_id].close()
88
- except:
89
- pass
90
- del self.active_connections[session_id]
91
  if session_id in self.active_sessions:
92
- # Save any pending state before removing session
93
  session_data = self.active_sessions[session_id]
94
  if session_data.get('pending_state'):
95
  await self.handle_state_operation({
@@ -361,45 +382,76 @@ async def handle_files():
361
  # WebSocket endpoint
362
  @app.websocket("/ws")
363
  async def websocket_endpoint(websocket: WebSocket):
364
- await websocket.accept()
365
- session_id = str(uuid.uuid4())
366
- server.active_connections[session_id] = websocket
367
- server.active_sessions[session_id] = {
368
- 'start_time': time.time(),
369
- 'ops_count': 0
370
- }
371
-
372
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  while True:
374
- message = await websocket.receive_json()
375
-
376
- # Route operation to appropriate handler
377
- operation_type = message.get('operation')
378
- if operation_type == 'vram':
379
- response = await server.handle_vram_operation(message)
380
- elif operation_type == 'state':
381
- response = await server.handle_state_operation(message)
382
- elif operation_type == 'cache':
383
- response = await server.handle_cache_operation(message)
384
- else:
385
- response = {
386
- 'status': 'error',
387
- 'message': 'Unknown operation type'
388
- }
389
-
390
- # Update statistics
391
- server.ops_counter += 1
392
- server.active_sessions[session_id]['ops_count'] += 1
393
-
394
- # Send response
395
- await websocket.send_json(response)
 
 
 
 
 
 
396
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  except Exception as e:
398
- print(f"WebSocket error: {e}")
399
  finally:
400
- # Cleanup on disconnect
401
- del server.active_connections[session_id]
402
- del server.active_sessions[session_id]
 
 
 
 
 
 
 
403
 
404
  # For running directly (development)
405
  if __name__ == "__main__":
 
17
  class VirtualGPUServer:
18
  def __init__(self):
19
  self.base_path = Path(__file__).parent / "storage"
20
+ if not self.base_path.exists():
21
+ self.base_path.mkdir(parents=True)
22
+
23
  self.vram_path = self.base_path / "vram_blocks"
24
  self.state_path = self.base_path / "gpu_state"
25
  self.cache_path = self.base_path / "cache"
 
69
  """Monitor connection health and handle reconnection"""
70
  try:
71
  while session_id in self.active_connections:
72
+ current_time = time.time()
73
+ session = self.active_sessions.get(session_id)
74
+
75
+ if not session:
 
 
 
 
 
76
  break
77
+
78
+ # Check if connection is still alive
79
+ last_ping = session.get('last_ping', 0)
80
+ if current_time - last_ping > self.connection_timeout:
81
+ if session.get('keep_alive', False):
82
+ try:
83
+ await websocket.send_json({"type": "ping"})
84
+ session['last_ping'] = current_time
85
+ except Exception as e:
86
+ print(f"Connection lost for session {session_id}: {e}")
87
+ break
88
+ else:
89
+ # Connection timed out and keep-alive is disabled
90
+ break
91
+
92
  await asyncio.sleep(self.heartbeat_interval)
93
  finally:
94
+ # Only disconnect if connection is truly dead
95
+ if session_id in self.active_connections:
96
+ try:
97
+ if not await self.check_connection_alive(websocket):
98
+ await self.handle_disconnect(session_id)
99
+ except:
100
+ await self.handle_disconnect(session_id)
101
+
102
+ async def check_connection_alive(self, websocket: WebSocket) -> bool:
103
+ """Check if a WebSocket connection is still alive"""
104
+ try:
105
+ await websocket.send_json({"type": "ping"})
106
+ return True
107
+ except Exception:
108
+ return False
109
 
110
  async def handle_disconnect(self, session_id: str):
111
  """Clean up resources when a client disconnects"""
 
 
 
 
 
 
112
  if session_id in self.active_sessions:
113
+ # Save any pending state before cleanup
114
  session_data = self.active_sessions[session_id]
115
  if session_data.get('pending_state'):
116
  await self.handle_state_operation({
 
382
  # WebSocket endpoint
383
  @app.websocket("/ws")
384
  async def websocket_endpoint(websocket: WebSocket):
385
+ session_id = None
 
 
 
 
 
 
 
386
  try:
387
+ await websocket.accept()
388
+ session_id = str(uuid.uuid4())
389
+ print(f"INFO: WebSocket connection opened, session: {session_id}")
390
+
391
+ # Initialize session with keep-alive enabled
392
+ server.active_connections[session_id] = websocket
393
+ server.active_sessions[session_id] = {
394
+ 'start_time': time.time(),
395
+ 'ops_count': 0,
396
+ 'keep_alive': True,
397
+ 'last_ping': time.time()
398
+ }
399
+
400
  while True:
401
+ try:
402
+ # Use a shorter timeout for more responsive connection management
403
+ message = await websocket.receive_json()
404
+ # Update last activity timestamp
405
+ server.active_sessions[session_id]['last_ping'] = time.time()
406
+
407
+ # Handle ping messages
408
+ if message.get('type') == 'ping':
409
+ await websocket.send_json({"type": "pong"})
410
+ continue
411
+
412
+ # Route operation to appropriate handler
413
+ operation_type = message.get('operation')
414
+ if operation_type == 'vram':
415
+ response = await server.handle_vram_operation(message)
416
+ elif operation_type == 'state':
417
+ response = await server.handle_state_operation(message)
418
+ elif operation_type == 'cache':
419
+ response = await server.handle_cache_operation(message)
420
+ else:
421
+ response = {
422
+ 'status': 'error',
423
+ 'message': 'Unknown operation type'
424
+ }
425
+
426
+ # Update statistics
427
+ server.ops_counter += 1
428
+ server.active_sessions[session_id]['ops_count'] += 1
429
 
430
+ # Send response
431
+ await websocket.send_json(response)
432
+
433
+ except asyncio.TimeoutError:
434
+ # Send ping on timeout
435
+ try:
436
+ await websocket.send_json({"type": "ping"})
437
+ except:
438
+ break # Connection lost
439
+
440
+ except websockets.exceptions.ConnectionClosed:
441
+ print(f"INFO: WebSocket connection closed normally, session: {session_id}")
442
  except Exception as e:
443
+ print(f"ERROR: WebSocket error in session {session_id}: {str(e)}")
444
  finally:
445
+ # Cleanup on disconnect, but only if we had a valid session
446
+ if session_id:
447
+ try:
448
+ if session_id in server.active_connections:
449
+ del server.active_connections[session_id]
450
+ if session_id in server.active_sessions:
451
+ del server.active_sessions[session_id]
452
+ print(f"INFO: Cleaned up session: {session_id}")
453
+ except Exception as cleanup_error:
454
+ print(f"WARNING: Error during session cleanup: {cleanup_error}")
455
 
456
  # For running directly (development)
457
  if __name__ == "__main__":