Spaces:
Sleeping
Sleeping
| import asyncio | |
| import websockets | |
| import json | |
| import os | |
| from pathlib import Path | |
| import uuid | |
| import time | |
| from typing import Dict, Any, Optional | |
| import numpy as np | |
| from fastapi import FastAPI, WebSocket | |
| from fastapi.responses import HTMLResponse | |
| from datetime import datetime | |
| # Create FastAPI instance | |
| app = FastAPI() | |
| class VirtualGPUServer: | |
| def __init__(self): | |
| self.base_path = Path(__file__).parent / "storage" | |
| if not self.base_path.exists(): | |
| self.base_path.mkdir(parents=True) | |
| self.vram_path = self.base_path / "vram_blocks" | |
| self.state_path = self.base_path / "gpu_state" | |
| self.cache_path = self.base_path / "cache" | |
| # Ensure all storage directories exist | |
| self.vram_path.mkdir(parents=True, exist_ok=True) | |
| self.state_path.mkdir(parents=True, exist_ok=True) | |
| self.cache_path.mkdir(parents=True, exist_ok=True) | |
| # In-memory caches for faster access | |
| self.vram_cache: Dict[str, Any] = {} | |
| self.state_cache: Dict[str, Any] = {} | |
| self.memory_cache: Dict[str, Any] = {} | |
| # Active connections and sessions | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| self.active_sessions: Dict[str, Dict[str, Any]] = {} | |
| self.heartbeat_interval = 5 # seconds | |
| self.connection_timeout = 30 # seconds | |
| # Performance monitoring | |
| self.ops_counter = 0 | |
| self.start_time = time.time() | |
| def _make_json_serializable(self, obj): | |
| """Convert non-JSON-serializable objects to serializable format""" | |
| if isinstance(obj, dict): | |
| return {k: self._make_json_serializable(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [self._make_json_serializable(i) for i in obj] | |
| elif isinstance(obj, tuple): | |
| return list(obj) | |
| elif isinstance(obj, (np.ndarray, np.generic)): | |
| return obj.tolist() | |
| elif isinstance(obj, (Path, uuid.UUID)): | |
| return str(obj) | |
| elif hasattr(obj, '__dict__'): | |
| # Handle custom objects by converting their __dict__ to serializable format | |
| return self._make_json_serializable(obj.__dict__) | |
| elif isinstance(obj, (int, float, str, bool, type(None))): | |
| return obj | |
| else: | |
| # Convert any other types to string representation | |
| return str(obj) | |
| async def monitor_connection(self, websocket: WebSocket, session_id: str): | |
| """Monitor connection health and handle reconnection""" | |
| try: | |
| while session_id in self.active_connections: | |
| current_time = time.time() | |
| session = self.active_sessions.get(session_id) | |
| if not session: | |
| break | |
| # Check if connection is still alive | |
| last_ping = session.get('last_ping', 0) | |
| if current_time - last_ping > self.connection_timeout: | |
| if session.get('keep_alive', False): | |
| try: | |
| await websocket.send_json({"type": "ping"}) | |
| session['last_ping'] = current_time | |
| except Exception as e: | |
| print(f"Connection lost for session {session_id}: {e}") | |
| break | |
| else: | |
| # Connection timed out and keep-alive is disabled | |
| break | |
| await asyncio.sleep(self.heartbeat_interval) | |
| finally: | |
| # Only disconnect if connection is truly dead | |
| if session_id in self.active_connections: | |
| try: | |
| if not await self.check_connection_alive(websocket): | |
| await self.handle_disconnect(session_id) | |
| except: | |
| await self.handle_disconnect(session_id) | |
| async def check_connection_alive(self, websocket: WebSocket) -> bool: | |
| """Check if a WebSocket connection is still alive""" | |
| try: | |
| await websocket.send_json({"type": "ping"}) | |
| return True | |
| except Exception: | |
| return False | |
| async def handle_disconnect(self, session_id: str): | |
| """Clean up resources when a client disconnects""" | |
| if session_id in self.active_sessions: | |
| # Save any pending state before cleanup | |
| session_data = self.active_sessions[session_id] | |
| if session_data.get('pending_state'): | |
| await self.handle_state_operation({ | |
| 'type': 'save', | |
| 'component': 'session', | |
| 'state_id': session_id, | |
| 'data': session_data['pending_state'] | |
| }) | |
| del self.active_sessions[session_id] | |
| async def handle_vram_operation(self, operation: dict) -> dict: | |
| """Handle VRAM read/write operations""" | |
| try: | |
| op_type = operation.get('type') | |
| if not op_type: | |
| raise ValueError("Missing operation type") | |
| block_id = operation.get('block_id') | |
| if not block_id: | |
| raise ValueError("Missing block_id") | |
| data = operation.get('data') | |
| if data and isinstance(data, (dict, list)): | |
| data = self._make_json_serializable(data) | |
| if op_type == 'write': | |
| if data is None: | |
| raise ValueError("Missing data for write operation") | |
| file_path = self.vram_path / f"{block_id}.npy" | |
| np.save(file_path, np.array(data)) | |
| self.vram_cache[block_id] = np.array(data) | |
| return {'status': 'success', 'message': f'Block {block_id} written'} | |
| if op_type == 'read': | |
| if block_id in self.vram_cache: | |
| return { | |
| 'status': 'success', | |
| 'data': self.vram_cache[block_id] if isinstance(self.vram_cache[block_id], list) else self.vram_cache[block_id].tolist(), | |
| 'source': 'cache' | |
| } | |
| file_path = self.vram_path / f"{block_id}.npy" | |
| if file_path.exists(): | |
| data = np.load(file_path) | |
| self.vram_cache[block_id] = np.array(data) | |
| return { | |
| 'status': 'success', | |
| 'data': data.tolist(), | |
| 'source': 'disk' | |
| } | |
| return {'status': 'error', 'message': 'Block not found'} | |
| return {'status': 'error', 'message': f'Unknown operation type: {op_type}'} | |
| except ValueError as e: | |
| return {'status': 'error', 'message': str(e)} | |
| except Exception as e: | |
| return {'status': 'error', 'message': f'Operation failed: {str(e)}'} | |
| async def handle_state_operation(self, operation: dict) -> dict: | |
| """Handle GPU state operations""" | |
| op_type = operation.get('type') | |
| component = operation.get('component') | |
| state_id = operation.get('state_id') | |
| state_data = operation.get('data') | |
| file_path = self.state_path / component / f"{state_id}.json" | |
| if op_type == 'save': | |
| file_path.parent.mkdir(exist_ok=True) | |
| with open(file_path, 'w') as f: | |
| json.dump(state_data, f) | |
| self.state_cache[f"{component}:{state_id}"] = state_data | |
| return {'status': 'success', 'message': f'State {state_id} saved'} | |
| elif op_type == 'load': | |
| cache_key = f"{component}:{state_id}" | |
| if cache_key in self.state_cache: | |
| return { | |
| 'status': 'success', | |
| 'data': self.state_cache[cache_key], | |
| 'source': 'cache' | |
| } | |
| if file_path.exists(): | |
| with open(file_path) as f: | |
| state_data = json.load(f) | |
| self.state_cache[cache_key] = state_data | |
| return { | |
| 'status': 'success', | |
| 'data': state_data, | |
| 'source': 'disk' | |
| } | |
| return {'status': 'error', 'message': 'State not found'} | |
| async def handle_cache_operation(self, operation: dict) -> dict: | |
| """Handle cache operations""" | |
| op_type = operation.get('type') | |
| key = operation.get('key') | |
| data = operation.get('data') | |
| if op_type == 'set': | |
| self.memory_cache[key] = data | |
| # Also persist to disk for recovery | |
| file_path = self.cache_path / f"{key}.json" | |
| with open(file_path, 'w') as f: | |
| json.dump(data, f) | |
| return {'status': 'success', 'message': f'Cache key {key} set'} | |
| elif op_type == 'get': | |
| if key in self.memory_cache: | |
| return { | |
| 'status': 'success', | |
| 'data': self.memory_cache[key], | |
| 'source': 'memory' | |
| } | |
| file_path = self.cache_path / f"{key}.json" | |
| if file_path.exists(): | |
| with open(file_path) as f: | |
| data = json.load(f) | |
| self.memory_cache[key] = data | |
| return { | |
| 'status': 'success', | |
| 'data': data, | |
| 'source': 'disk' | |
| } | |
| return {'status': 'error', 'message': 'Cache key not found'} | |
| async def handle_connection(self, websocket: websockets.WebSocketServerProtocol): | |
| """Handle incoming WebSocket connections""" | |
| # Generate unique session ID | |
| session_id = str(uuid.uuid4()) | |
| self.active_connections[session_id] = websocket | |
| self.active_sessions[session_id] = { | |
| 'start_time': time.time(), | |
| 'ops_count': 0 | |
| } | |
| try: | |
| async for message in websocket: | |
| # Parse incoming message | |
| try: | |
| data = json.loads(message) | |
| except json.JSONDecodeError: | |
| await websocket.send(json.dumps({ | |
| 'status': 'error', | |
| 'message': 'Invalid JSON' | |
| })) | |
| continue | |
| # Route operation to appropriate handler | |
| operation_type = data.get('operation') | |
| if operation_type == 'vram': | |
| response = await self.handle_vram_operation(data) | |
| elif operation_type == 'state': | |
| response = await self.handle_state_operation(data) | |
| elif operation_type == 'cache': | |
| response = await self.handle_cache_operation(data) | |
| else: | |
| response = { | |
| 'status': 'error', | |
| 'message': 'Unknown operation type' | |
| } | |
| # Update statistics | |
| self.ops_counter += 1 | |
| self.active_sessions[session_id]['ops_count'] += 1 | |
| # Send response | |
| await websocket.send(json.dumps(response)) | |
| except websockets.exceptions.ConnectionClosed: | |
| pass | |
| finally: | |
| # Cleanup on disconnect | |
| del self.active_connections[session_id] | |
| del self.active_sessions[session_id] | |
| def get_stats(self) -> dict: | |
| """Get server statistics""" | |
| current_time = time.time() | |
| uptime = current_time - self.start_time | |
| ops_per_second = self.ops_counter / uptime if uptime > 0 else 0 | |
| return { | |
| 'uptime': uptime, | |
| 'total_operations': self.ops_counter, | |
| 'ops_per_second': ops_per_second, | |
| 'active_connections': len(self.active_connections), | |
| 'vram_cache_size': len(self.vram_cache), | |
| 'state_cache_size': len(self.state_cache), | |
| 'memory_cache_size': len(self.memory_cache) | |
| } | |
| server = VirtualGPUServer() | |
| async def handle_index(): | |
| """Handle HTTP index request""" | |
| stats = server.get_stats() | |
| html = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Virtual GPU Server</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 40px; }} | |
| table {{ border-collapse: collapse; width: 100%; margin-top: 20px; }} | |
| th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }} | |
| th {{ background-color: #f2f2f2; }} | |
| .stats {{ background-color: #f9f9f9; padding: 20px; border-radius: 5px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Virtual GPU Server Status</h1> | |
| <div class="stats"> | |
| <h2>Server Statistics</h2> | |
| <ul> | |
| <li>Uptime: {stats['uptime']:.2f} seconds</li> | |
| <li>Total Operations: {stats['total_operations']}</li> | |
| <li>Operations per Second: {stats['ops_per_second']:.2f}</li> | |
| <li>Active Connections: {stats['active_connections']}</li> | |
| <li>VRAM Cache Size: {stats['vram_cache_size']}</li> | |
| <li>State Cache Size: {stats['state_cache_size']}</li> | |
| <li>Memory Cache Size: {stats['memory_cache_size']}</li> | |
| </ul> | |
| </div> | |
| <h2>Server Files</h2> | |
| <iframe src="/files" style="width: 100%; height: 500px; border: none;"></iframe> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html) | |
| async def handle_files(): | |
| """Handle HTTP files listing request""" | |
| def format_size(size): | |
| for unit in ['B', 'KB', 'MB', 'GB']: | |
| if size < 1024: | |
| return f"{size:.2f} {unit}" | |
| size /= 1024 | |
| return f"{size:.2f} TB" | |
| html = ['<!DOCTYPE html><html><head>', | |
| '<style>', | |
| 'body { font-family: Arial, sans-serif; margin: 20px; }', | |
| 'table { border-collapse: collapse; width: 100%; }', | |
| 'th, td { padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }', | |
| 'th { background-color: #f2f2f2; }', | |
| '</style></head><body>', | |
| '<h2>Server Files</h2>', | |
| '<table><tr><th>Path</th><th>Size</th><th>Last Modified</th></tr>'] | |
| for root, _, files in os.walk(server.base_path): | |
| for file in files: | |
| full_path = Path(root) / file | |
| rel_path = full_path.relative_to(server.base_path) | |
| size = format_size(os.path.getsize(full_path)) | |
| mtime = datetime.fromtimestamp(os.path.getmtime(full_path)) | |
| html.append(f'<tr><td>{rel_path}</td><td>{size}</td><td>{mtime}</td></tr>') | |
| html.extend(['</table></body></html>']) | |
| return HTMLResponse(content='\n'.join(html)) | |
| # WebSocket endpoint | |
| async def websocket_endpoint(websocket: WebSocket): | |
| session_id = None | |
| try: | |
| await websocket.accept() | |
| session_id = str(uuid.uuid4()) | |
| print(f"INFO: WebSocket connection opened, session: {session_id}") | |
| # Initialize session with keep-alive enabled | |
| server.active_connections[session_id] = websocket | |
| server.active_sessions[session_id] = { | |
| 'start_time': time.time(), | |
| 'ops_count': 0, | |
| 'keep_alive': True, | |
| 'last_ping': time.time() | |
| } | |
| while True: | |
| try: | |
| # Use a shorter timeout for more responsive connection management | |
| message = await websocket.receive_json() | |
| # Update last activity timestamp | |
| server.active_sessions[session_id]['last_ping'] = time.time() | |
| # Handle ping messages | |
| if message.get('type') == 'ping': | |
| await websocket.send_json({"type": "pong"}) | |
| continue | |
| # Route operation to appropriate handler | |
| operation_type = message.get('operation') | |
| if operation_type == 'vram': | |
| response = await server.handle_vram_operation(message) | |
| elif operation_type == 'state': | |
| response = await server.handle_state_operation(message) | |
| elif operation_type == 'cache': | |
| response = await server.handle_cache_operation(message) | |
| else: | |
| response = { | |
| 'status': 'error', | |
| 'message': 'Unknown operation type' | |
| } | |
| # Update statistics | |
| server.ops_counter += 1 | |
| server.active_sessions[session_id]['ops_count'] += 1 | |
| # Send response | |
| await websocket.send_json(response) | |
| except asyncio.TimeoutError: | |
| # Send ping on timeout | |
| try: | |
| await websocket.send_json({"type": "ping"}) | |
| except: | |
| break # Connection lost | |
| except websockets.exceptions.ConnectionClosed: | |
| print(f"INFO: WebSocket connection closed normally, session: {session_id}") | |
| except Exception as e: | |
| print(f"ERROR: WebSocket error in session {session_id}: {str(e)}") | |
| finally: | |
| # Cleanup on disconnect, but only if we had a valid session | |
| if session_id: | |
| try: | |
| if session_id in server.active_connections: | |
| del server.active_connections[session_id] | |
| if session_id in server.active_sessions: | |
| del server.active_sessions[session_id] | |
| print(f"INFO: Cleaned up session: {session_id}") | |
| except Exception as cleanup_error: | |
| print(f"WARNING: Error during session cleanup: {cleanup_error}") | |
| # For running directly (development) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("server:app", host="0.0.0.0", port=7860, reload=True) | |