WBS1 / server.py
Factor Studios
Upload 9 files
ce4253e verified
raw
history blame
11.9 kB
import asyncio
import websockets
import json
import os
from pathlib import Path
import uuid
import time
from typing import Dict, Any, Optional
import numpy as np
from aiohttp import web
import aiohttp
from datetime import datetime
class VirtualGPUServer:
def __init__(self):
self.base_path = Path(__file__).parent / "storage"
self.vram_path = self.base_path / "vram_blocks"
self.state_path = self.base_path / "gpu_state"
self.cache_path = self.base_path / "cache"
# Ensure all storage directories exist
self.vram_path.mkdir(parents=True, exist_ok=True)
self.state_path.mkdir(parents=True, exist_ok=True)
self.cache_path.mkdir(parents=True, exist_ok=True)
# In-memory caches for faster access
self.vram_cache: Dict[str, Any] = {}
self.state_cache: Dict[str, Any] = {}
self.memory_cache: Dict[str, Any] = {}
# Active connections and sessions
self.active_connections: Dict[str, websockets.WebSocketServerProtocol] = {}
self.active_sessions: Dict[str, Dict[str, Any]] = {}
# Performance monitoring
self.ops_counter = 0
self.start_time = time.time()
# HTTP app
self.app = web.Application()
self.app.router.add_get('/', self.handle_index)
self.app.router.add_get('/files', self.handle_files)
async def handle_vram_operation(self, operation: dict) -> dict:
"""Handle VRAM read/write operations"""
op_type = operation.get('type')
block_id = operation.get('block_id')
data = operation.get('data')
if op_type == 'write':
# Generate unique file path for this block
file_path = self.vram_path / f"{block_id}.npy"
# Save data to file
np.save(file_path, np.array(data))
# Update cache
self.vram_cache[block_id] = np.array(data)
return {'status': 'success', 'message': f'Block {block_id} written'}
elif op_type == 'read':
# Try cache first
if block_id in self.vram_cache:
return {
'status': 'success',
'data': self.vram_cache[block_id] if isinstance(self.vram_cache[block_id], list) else self.vram_cache[block_id].tolist(),
'source': 'cache'
}
# Load from file
file_path = self.vram_path / f"{block_id}.npy"
if file_path.exists():
data = np.load(file_path)
self.vram_cache[block_id] = np.array(data)
return {
'status': 'success',
'data': data.tolist(),
'source': 'disk'
}
return {'status': 'error', 'message': 'Block not found'}
async def handle_state_operation(self, operation: dict) -> dict:
"""Handle GPU state operations"""
op_type = operation.get('type')
component = operation.get('component')
state_id = operation.get('state_id')
state_data = operation.get('data')
file_path = self.state_path / component / f"{state_id}.json"
if op_type == 'save':
file_path.parent.mkdir(exist_ok=True)
with open(file_path, 'w') as f:
json.dump(state_data, f)
self.state_cache[f"{component}:{state_id}"] = state_data
return {'status': 'success', 'message': f'State {state_id} saved'}
elif op_type == 'load':
cache_key = f"{component}:{state_id}"
if cache_key in self.state_cache:
return {
'status': 'success',
'data': self.state_cache[cache_key],
'source': 'cache'
}
if file_path.exists():
with open(file_path) as f:
state_data = json.load(f)
self.state_cache[cache_key] = state_data
return {
'status': 'success',
'data': state_data,
'source': 'disk'
}
return {'status': 'error', 'message': 'State not found'}
async def handle_cache_operation(self, operation: dict) -> dict:
"""Handle cache operations"""
op_type = operation.get('type')
key = operation.get('key')
data = operation.get('data')
if op_type == 'set':
self.memory_cache[key] = data
# Also persist to disk for recovery
file_path = self.cache_path / f"{key}.json"
with open(file_path, 'w') as f:
json.dump(data, f)
return {'status': 'success', 'message': f'Cache key {key} set'}
elif op_type == 'get':
if key in self.memory_cache:
return {
'status': 'success',
'data': self.memory_cache[key],
'source': 'memory'
}
file_path = self.cache_path / f"{key}.json"
if file_path.exists():
with open(file_path) as f:
data = json.load(f)
self.memory_cache[key] = data
return {
'status': 'success',
'data': data,
'source': 'disk'
}
return {'status': 'error', 'message': 'Cache key not found'}
async def handle_connection(self, websocket: websockets.WebSocketServerProtocol):
"""Handle incoming WebSocket connections"""
# Generate unique session ID
session_id = str(uuid.uuid4())
self.active_connections[session_id] = websocket
self.active_sessions[session_id] = {
'start_time': time.time(),
'ops_count': 0
}
try:
async for message in websocket:
# Parse incoming message
try:
data = json.loads(message)
except json.JSONDecodeError:
await websocket.send(json.dumps({
'status': 'error',
'message': 'Invalid JSON'
}))
continue
# Route operation to appropriate handler
operation_type = data.get('operation')
if operation_type == 'vram':
response = await self.handle_vram_operation(data)
elif operation_type == 'state':
response = await self.handle_state_operation(data)
elif operation_type == 'cache':
response = await self.handle_cache_operation(data)
else:
response = {
'status': 'error',
'message': 'Unknown operation type'
}
# Update statistics
self.ops_counter += 1
self.active_sessions[session_id]['ops_count'] += 1
# Send response
await websocket.send(json.dumps(response))
except websockets.exceptions.ConnectionClosed:
pass
finally:
# Cleanup on disconnect
del self.active_connections[session_id]
del self.active_sessions[session_id]
def get_stats(self) -> dict:
"""Get server statistics"""
current_time = time.time()
uptime = current_time - self.start_time
ops_per_second = self.ops_counter / uptime if uptime > 0 else 0
return {
'uptime': uptime,
'total_operations': self.ops_counter,
'ops_per_second': ops_per_second,
'active_connections': len(self.active_connections),
'vram_cache_size': len(self.vram_cache),
'state_cache_size': len(self.state_cache),
'memory_cache_size': len(self.memory_cache)
}
async def handle_index(self, request):
"""Handle HTTP index request"""
stats = self.get_stats()
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Virtual GPU Server</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
table {{ border-collapse: collapse; width: 100%; margin-top: 20px; }}
th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }}
th {{ background-color: #f2f2f2; }}
.stats {{ background-color: #f9f9f9; padding: 20px; border-radius: 5px; }}
</style>
</head>
<body>
<h1>Virtual GPU Server Status</h1>
<div class="stats">
<h2>Server Statistics</h2>
<ul>
<li>Uptime: {stats['uptime']:.2f} seconds</li>
<li>Total Operations: {stats['total_operations']}</li>
<li>Operations per Second: {stats['ops_per_second']:.2f}</li>
<li>Active Connections: {stats['active_connections']}</li>
<li>VRAM Cache Size: {stats['vram_cache_size']}</li>
<li>State Cache Size: {stats['state_cache_size']}</li>
<li>Memory Cache Size: {stats['memory_cache_size']}</li>
</ul>
</div>
<h2>Server Files</h2>
<iframe src="/files" style="width: 100%; height: 500px; border: none;"></iframe>
</body>
</html>
"""
return web.Response(text=html, content_type='text/html')
async def handle_files(self, request):
"""Handle HTTP files listing request"""
def format_size(size):
for unit in ['B', 'KB', 'MB', 'GB']:
if size < 1024:
return f"{size:.2f} {unit}"
size /= 1024
return f"{size:.2f} TB"
html = ['<!DOCTYPE html><html><head>',
'<style>',
'body { font-family: Arial, sans-serif; margin: 20px; }',
'table { border-collapse: collapse; width: 100%; }',
'th, td { padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }',
'th { background-color: #f2f2f2; }',
'</style></head><body>',
'<h2>Server Files</h2>',
'<table><tr><th>Path</th><th>Size</th><th>Last Modified</th></tr>']
for root, _, files in os.walk(self.base_path):
for file in files:
full_path = Path(root) / file
rel_path = full_path.relative_to(self.base_path)
size = format_size(os.path.getsize(full_path))
mtime = datetime.fromtimestamp(os.path.getmtime(full_path))
html.append(f'<tr><td>{rel_path}</td><td>{size}</td><td>{mtime}</td></tr>')
html.extend(['</table></body></html>'])
return web.Response(text='\n'.join(html), content_type='text/html')
async def main():
server = VirtualGPUServer()
# Start WebSocket server
websocket_server = await websockets.serve(server.handle_connection, "0.0.0.0", 8765)
# Start HTTP server
runner = web.AppRunner(server.app)
await runner.setup()
site = web.TCPSite(runner, '0.0.0.0', 8080)
await site.start()
print("Virtual GPU Server running:")
print("- WebSocket: ws://localhost:8765")
print("- HTTP Interface: http://localhost:8080")
# Run forever
await asyncio.Future()
if __name__ == "__main__":
asyncio.run(main())