WBS1 / server.py
Factor Studios
Update server.py
ba864c5 verified
raw
history blame
12.6 kB
import asyncio
import websockets
import json
import os
from pathlib import Path
import uuid
import time
from typing import Dict, Any, Optional
import numpy as np
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
from datetime import datetime
# Create FastAPI instance
app = FastAPI()
class VirtualGPUServer:
def __init__(self):
self.base_path = Path(__file__).parent / "storage"
self.vram_path = self.base_path / "vram_blocks"
self.state_path = self.base_path / "gpu_state"
self.cache_path = self.base_path / "cache"
# Ensure all storage directories exist
self.vram_path.mkdir(parents=True, exist_ok=True)
self.state_path.mkdir(parents=True, exist_ok=True)
self.cache_path.mkdir(parents=True, exist_ok=True)
# In-memory caches for faster access
self.vram_cache: Dict[str, Any] = {}
self.state_cache: Dict[str, Any] = {}
self.memory_cache: Dict[str, Any] = {}
# Active connections and sessions
self.active_connections: Dict[str, WebSocket] = {}
self.active_sessions: Dict[str, Dict[str, Any]] = {}
# Performance monitoring
self.ops_counter = 0
self.start_time = time.time()
async def handle_vram_operation(self, operation: dict) -> dict:
"""Handle VRAM read/write operations"""
op_type = operation.get('type')
block_id = operation.get('block_id')
data = operation.get('data')
if op_type == 'write':
# Generate unique file path for this block
file_path = self.vram_path / f"{block_id}.npy"
# Save data to file
np.save(file_path, np.array(data))
# Update cache
self.vram_cache[block_id] = np.array(data)
return {'status': 'success', 'message': f'Block {block_id} written'}
elif op_type == 'read':
# Try cache first
if block_id in self.vram_cache:
return {
'status': 'success',
'data': self.vram_cache[block_id] if isinstance(self.vram_cache[block_id], list) else self.vram_cache[block_id].tolist(),
'source': 'cache'
}
# Load from file
file_path = self.vram_path / f"{block_id}.npy"
if file_path.exists():
data = np.load(file_path)
self.vram_cache[block_id] = np.array(data)
return {
'status': 'success',
'data': data.tolist(),
'source': 'disk'
}
return {'status': 'error', 'message': 'Block not found'}
async def handle_state_operation(self, operation: dict) -> dict:
"""Handle GPU state operations"""
op_type = operation.get('type')
component = operation.get('component')
state_id = operation.get('state_id')
state_data = operation.get('data')
file_path = self.state_path / component / f"{state_id}.json"
if op_type == 'save':
file_path.parent.mkdir(exist_ok=True)
with open(file_path, 'w') as f:
json.dump(state_data, f)
self.state_cache[f"{component}:{state_id}"] = state_data
return {'status': 'success', 'message': f'State {state_id} saved'}
elif op_type == 'load':
cache_key = f"{component}:{state_id}"
if cache_key in self.state_cache:
return {
'status': 'success',
'data': self.state_cache[cache_key],
'source': 'cache'
}
if file_path.exists():
with open(file_path) as f:
state_data = json.load(f)
self.state_cache[cache_key] = state_data
return {
'status': 'success',
'data': state_data,
'source': 'disk'
}
return {'status': 'error', 'message': 'State not found'}
async def handle_cache_operation(self, operation: dict) -> dict:
"""Handle cache operations"""
op_type = operation.get('type')
key = operation.get('key')
data = operation.get('data')
if op_type == 'set':
self.memory_cache[key] = data
# Also persist to disk for recovery
file_path = self.cache_path / f"{key}.json"
with open(file_path, 'w') as f:
json.dump(data, f)
return {'status': 'success', 'message': f'Cache key {key} set'}
elif op_type == 'get':
if key in self.memory_cache:
return {
'status': 'success',
'data': self.memory_cache[key],
'source': 'memory'
}
file_path = self.cache_path / f"{key}.json"
if file_path.exists():
with open(file_path) as f:
data = json.load(f)
self.memory_cache[key] = data
return {
'status': 'success',
'data': data,
'source': 'disk'
}
return {'status': 'error', 'message': 'Cache key not found'}
async def handle_connection(self, websocket: websockets.WebSocketServerProtocol):
"""Handle incoming WebSocket connections"""
# Generate unique session ID
session_id = str(uuid.uuid4())
self.active_connections[session_id] = websocket
self.active_sessions[session_id] = {
'start_time': time.time(),
'ops_count': 0
}
try:
async for message in websocket:
# Parse incoming message
try:
data = json.loads(message)
except json.JSONDecodeError:
await websocket.send(json.dumps({
'status': 'error',
'message': 'Invalid JSON'
}))
continue
# Route operation to appropriate handler
operation_type = data.get('operation')
if operation_type == 'vram':
response = await self.handle_vram_operation(data)
elif operation_type == 'state':
response = await self.handle_state_operation(data)
elif operation_type == 'cache':
response = await self.handle_cache_operation(data)
else:
response = {
'status': 'error',
'message': 'Unknown operation type'
}
# Update statistics
self.ops_counter += 1
self.active_sessions[session_id]['ops_count'] += 1
# Send response
await websocket.send(json.dumps(response))
except websockets.exceptions.ConnectionClosed:
pass
finally:
# Cleanup on disconnect
del self.active_connections[session_id]
del self.active_sessions[session_id]
def get_stats(self) -> dict:
"""Get server statistics"""
current_time = time.time()
uptime = current_time - self.start_time
ops_per_second = self.ops_counter / uptime if uptime > 0 else 0
return {
'uptime': uptime,
'total_operations': self.ops_counter,
'ops_per_second': ops_per_second,
'active_connections': len(self.active_connections),
'vram_cache_size': len(self.vram_cache),
'state_cache_size': len(self.state_cache),
'memory_cache_size': len(self.memory_cache)
}
server = VirtualGPUServer()
@app.get("/", response_class=HTMLResponse)
async def handle_index():
"""Handle HTTP index request"""
stats = server.get_stats()
html = f"""
<!DOCTYPE html>
<html>
<head>
<title>Virtual GPU Server</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
table {{ border-collapse: collapse; width: 100%; margin-top: 20px; }}
th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }}
th {{ background-color: #f2f2f2; }}
.stats {{ background-color: #f9f9f9; padding: 20px; border-radius: 5px; }}
</style>
</head>
<body>
<h1>Virtual GPU Server Status</h1>
<div class="stats">
<h2>Server Statistics</h2>
<ul>
<li>Uptime: {stats['uptime']:.2f} seconds</li>
<li>Total Operations: {stats['total_operations']}</li>
<li>Operations per Second: {stats['ops_per_second']:.2f}</li>
<li>Active Connections: {stats['active_connections']}</li>
<li>VRAM Cache Size: {stats['vram_cache_size']}</li>
<li>State Cache Size: {stats['state_cache_size']}</li>
<li>Memory Cache Size: {stats['memory_cache_size']}</li>
</ul>
</div>
<h2>Server Files</h2>
<iframe src="/files" style="width: 100%; height: 500px; border: none;"></iframe>
</body>
</html>
"""
return HTMLResponse(content=html)
@app.get("/files", response_class=HTMLResponse)
async def handle_files():
"""Handle HTTP files listing request"""
def format_size(size):
for unit in ['B', 'KB', 'MB', 'GB']:
if size < 1024:
return f"{size:.2f} {unit}"
size /= 1024
return f"{size:.2f} TB"
html = ['<!DOCTYPE html><html><head>',
'<style>',
'body { font-family: Arial, sans-serif; margin: 20px; }',
'table { border-collapse: collapse; width: 100%; }',
'th, td { padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }',
'th { background-color: #f2f2f2; }',
'</style></head><body>',
'<h2>Server Files</h2>',
'<table><tr><th>Path</th><th>Size</th><th>Last Modified</th></tr>']
for root, _, files in os.walk(server.base_path):
for file in files:
full_path = Path(root) / file
rel_path = full_path.relative_to(server.base_path)
size = format_size(os.path.getsize(full_path))
mtime = datetime.fromtimestamp(os.path.getmtime(full_path))
html.append(f'<tr><td>{rel_path}</td><td>{size}</td><td>{mtime}</td></tr>')
html.extend(['</table></body></html>'])
return HTMLResponse(content='\n'.join(html))
# WebSocket endpoint
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
session_id = str(uuid.uuid4())
server.active_connections[session_id] = websocket
server.active_sessions[session_id] = {
'start_time': time.time(),
'ops_count': 0
}
try:
while True:
message = await websocket.receive_json()
# Route operation to appropriate handler
operation_type = message.get('operation')
if operation_type == 'vram':
response = await server.handle_vram_operation(message)
elif operation_type == 'state':
response = await server.handle_state_operation(message)
elif operation_type == 'cache':
response = await server.handle_cache_operation(message)
else:
response = {
'status': 'error',
'message': 'Unknown operation type'
}
# Update statistics
server.ops_counter += 1
server.active_sessions[session_id]['ops_count'] += 1
# Send response
await websocket.send_json(response)
except Exception as e:
print(f"WebSocket error: {e}")
finally:
# Cleanup on disconnect
del server.active_connections[session_id]
del server.active_sessions[session_id]
# For running directly (development)
if __name__ == "__main__":
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=7860, reload=True)