Spaces:
Runtime error
Runtime error
| import asyncio | |
| import websockets | |
| import json | |
| import os | |
| from pathlib import Path | |
| import uuid | |
| import time | |
| import jwt | |
| from typing import Dict, Any, Optional, List | |
| import numpy as np | |
| from fastapi import FastAPI, WebSocket, HTTPException, Depends, Request, Response | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from datetime import datetime, timedelta | |
| import hashlib | |
| import gzip | |
| import base64 | |
| import logging | |
| from pydantic import BaseModel | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| # Create FastAPI instance with enhanced configuration | |
| app = FastAPI( | |
| title="Virtual GPU Server", | |
| description="HTTP and WebSocket API for Virtual GPU v2", | |
| version="2.0.0" | |
| ) | |
| # Add CORS middleware for cross-origin requests | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins for development | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # JWT configuration | |
| JWT_SECRET = "virtual_gpu_secret_key_2025" # In production, use environment variable | |
| JWT_ALGORITHM = "HS256" | |
| JWT_EXPIRATION_HOURS = 24 | |
| # HTTP Bearer security scheme | |
| security = HTTPBearer() | |
| # Pydantic models for request/response validation | |
| class SessionCreateRequest(BaseModel): | |
| client_id: Optional[str] = None | |
| resource_limits: Optional[Dict[str, Any]] = None | |
| class SessionResponse(BaseModel): | |
| session_token: str | |
| session_id: str | |
| expires_at: datetime | |
| class VRAMWriteRequest(BaseModel): | |
| data: List[Any] | |
| metadata: Optional[Dict[str, Any]] = None | |
| model_size: Optional[int] = None | |
| class VRAMResponse(BaseModel): | |
| status: str | |
| message: Optional[str] = None | |
| data: Optional[List[Any]] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| source: Optional[str] = None | |
| class StateRequest(BaseModel): | |
| data: Dict[str, Any] | |
| timestamp: Optional[float] = None | |
| class StateResponse(BaseModel): | |
| status: str | |
| message: Optional[str] = None | |
| data: Optional[Dict[str, Any]] = None | |
| source: Optional[str] = None | |
| class CacheRequest(BaseModel): | |
| data: Any | |
| ttl: Optional[int] = None | |
| class CacheResponse(BaseModel): | |
| status: str | |
| message: Optional[str] = None | |
| data: Optional[Any] = None | |
| source: Optional[str] = None | |
| class ModelLoadRequest(BaseModel): | |
| model_data: Optional[Dict[str, Any]] = None | |
| model_path: Optional[str] = None | |
| model_hash: Optional[str] = None | |
| class ModelInferenceRequest(BaseModel): | |
| input_data: List[Any] | |
| batch_size: Optional[int] = None | |
| class ErrorResponse(BaseModel): | |
| status: str | |
| error_code: str | |
| message: str | |
| details: Optional[Dict[str, Any]] = None | |
| retry_after: Optional[int] = None | |
| request_id: str | |
| class VirtualGPUServer: | |
| def __init__(self): | |
| self.base_path = Path(__file__).parent / "storage" | |
| self.vram_path = self.base_path / "vram_blocks" | |
| self.state_path = self.base_path / "gpu_state" | |
| self.cache_path = self.base_path / "cache" | |
| self.models_path = self.base_path / "models" | |
| # Ensure all storage directories exist | |
| self.vram_path.mkdir(parents=True, exist_ok=True) | |
| self.state_path.mkdir(parents=True, exist_ok=True) | |
| self.cache_path.mkdir(parents=True, exist_ok=True) | |
| self.models_path.mkdir(parents=True, exist_ok=True) | |
| # In-memory caches for faster access | |
| self.vram_cache: Dict[str, Any] = {} | |
| self.state_cache: Dict[str, Any] = {} | |
| self.memory_cache: Dict[str, Any] = {} | |
| self.model_cache: Dict[str, Any] = {} | |
| # Session management for HTTP API | |
| self.http_sessions: Dict[str, Dict[str, Any]] = {} | |
| # Active WebSocket connections and sessions (for backward compatibility) | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| self.active_sessions: Dict[str, Dict[str, Any]] = {} | |
| self.heartbeat_interval = 5 # seconds | |
| self.connection_timeout = 30 # seconds | |
| # Performance monitoring | |
| self.ops_counter = 0 | |
| self.start_time = time.time() | |
| self.request_counter = 0 | |
| def _make_json_serializable(self, obj): | |
| """Convert non-JSON-serializable objects to serializable format""" | |
| if isinstance(obj, dict): | |
| return {k: self._make_json_serializable(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [self._make_json_serializable(i) for i in obj] | |
| elif isinstance(obj, tuple): | |
| return list(obj) | |
| elif isinstance(obj, (np.ndarray, np.generic)): | |
| return obj.tolist() | |
| elif isinstance(obj, (Path, uuid.UUID)): | |
| return str(obj) | |
| elif hasattr(obj, '__dict__'): | |
| # Handle custom objects by converting their __dict__ to serializable format | |
| return self._make_json_serializable(obj.__dict__) | |
| elif isinstance(obj, (int, float, str, bool, type(None))): | |
| return obj | |
| else: | |
| # Convert any other types to string representation | |
| return str(obj) | |
| def create_session_token(self, session_id: str, client_id: str = None, resource_limits: Dict[str, Any] = None) -> str: | |
| """Create a JWT session token""" | |
| payload = { | |
| "session_id": session_id, | |
| "client_id": client_id or "anonymous", | |
| "resource_limits": resource_limits or {}, | |
| "created_at": time.time(), | |
| "expires_at": time.time() + (JWT_EXPIRATION_HOURS * 3600) | |
| } | |
| return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) | |
| def verify_session_token(self, token: str) -> Dict[str, Any]: | |
| """Verify and decode a JWT session token""" | |
| try: | |
| payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) | |
| if payload["expires_at"] < time.time(): | |
| raise HTTPException(status_code=401, detail="Session token expired") | |
| return payload | |
| except jwt.InvalidTokenError: | |
| raise HTTPException(status_code=401, detail="Invalid session token") | |
| def generate_request_id(self) -> str: | |
| """Generate a unique request ID""" | |
| self.request_counter += 1 | |
| return f"req_{int(time.time())}_{self.request_counter}" | |
| def compress_data(self, data: bytes) -> bytes: | |
| """Compress data using gzip""" | |
| return gzip.compress(data) | |
| def decompress_data(self, data: bytes) -> bytes: | |
| """Decompress gzip data""" | |
| return gzip.decompress(data) | |
| async def handle_vram_operation(self, operation: dict) -> dict: | |
| """Handle VRAM read/write operations (preserved from WebSocket implementation)""" | |
| try: | |
| op_type = operation.get('type') | |
| if not op_type: | |
| raise ValueError("Missing operation type") | |
| block_id = operation.get('block_id') | |
| if not block_id: | |
| raise ValueError("Missing block_id") | |
| data = operation.get('data') | |
| if data and isinstance(data, (dict, list)): | |
| data = self._make_json_serializable(data) | |
| if op_type == 'write': | |
| if data is None: | |
| raise ValueError("Missing data for write operation") | |
| file_path = self.vram_path / f"{block_id}.npy" | |
| np.save(file_path, np.array(data)) | |
| self.vram_cache[block_id] = np.array(data) | |
| # Store metadata | |
| metadata = operation.get('metadata', {}) | |
| metadata_path = self.vram_path / f"{block_id}_metadata.json" | |
| with open(metadata_path, 'w') as f: | |
| json.dump(metadata, f) | |
| return {'status': 'success', 'message': f'Block {block_id} written'} | |
| if op_type == 'read': | |
| if block_id in self.vram_cache: | |
| # Load metadata | |
| metadata_path = self.vram_path / f"{block_id}_metadata.json" | |
| metadata = {} | |
| if metadata_path.exists(): | |
| with open(metadata_path, 'r') as f: | |
| metadata = json.load(f) | |
| return { | |
| 'status': 'success', | |
| 'data': self.vram_cache[block_id] if isinstance(self.vram_cache[block_id], list) else self.vram_cache[block_id].tolist(), | |
| 'metadata': metadata, | |
| 'source': 'cache' | |
| } | |
| file_path = self.vram_path / f"{block_id}.npy" | |
| if file_path.exists(): | |
| data = np.load(file_path) | |
| self.vram_cache[block_id] = np.array(data) | |
| # Load metadata | |
| metadata_path = self.vram_path / f"{block_id}_metadata.json" | |
| metadata = {} | |
| if metadata_path.exists(): | |
| with open(metadata_path, 'r') as f: | |
| metadata = json.load(f) | |
| return { | |
| 'status': 'success', | |
| 'data': data.tolist(), | |
| 'metadata': metadata, | |
| 'source': 'disk' | |
| } | |
| return {'status': 'error', 'message': 'Block not found'} | |
| return {'status': 'error', 'message': f'Unknown operation type: {op_type}'} | |
| except ValueError as e: | |
| return {'status': 'error', 'message': str(e)} | |
| except Exception as e: | |
| return {'status': 'error', 'message': f'Operation failed: {str(e)}'} | |
| async def handle_state_operation(self, operation: dict) -> dict: | |
| """Handle GPU state operations (preserved from WebSocket implementation)""" | |
| op_type = operation.get('type') | |
| component = operation.get('component') | |
| state_id = operation.get('state_id') | |
| state_data = operation.get('data') | |
| file_path = self.state_path / component / f"{state_id}.json" | |
| if op_type == 'save': | |
| file_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(file_path, 'w') as f: | |
| json.dump(state_data, f) | |
| self.state_cache[f"{component}:{state_id}"] = state_data | |
| return {'status': 'success', 'message': f'State {state_id} saved'} | |
| elif op_type == 'load': | |
| cache_key = f"{component}:{state_id}" | |
| if cache_key in self.state_cache: | |
| return { | |
| 'status': 'success', | |
| 'data': self.state_cache[cache_key], | |
| 'source': 'cache' | |
| } | |
| if file_path.exists(): | |
| with open(file_path) as f: | |
| state_data = json.load(f) | |
| self.state_cache[cache_key] = state_data | |
| return { | |
| 'status': 'success', | |
| 'data': state_data, | |
| 'source': 'disk' | |
| } | |
| return {'status': 'error', 'message': 'State not found'} | |
| async def handle_cache_operation(self, operation: dict) -> dict: | |
| """Handle cache operations (preserved from WebSocket implementation)""" | |
| op_type = operation.get('type') | |
| key = operation.get('key') | |
| data = operation.get('data') | |
| if op_type == 'set': | |
| self.memory_cache[key] = data | |
| # Also persist to disk for recovery | |
| file_path = self.cache_path / f"{key}.json" | |
| with open(file_path, 'w') as f: | |
| json.dump(data, f) | |
| return {'status': 'success', 'message': f'Cache key {key} set'} | |
| elif op_type == 'get': | |
| if key in self.memory_cache: | |
| return { | |
| 'status': 'success', | |
| 'data': self.memory_cache[key], | |
| 'source': 'memory' | |
| } | |
| file_path = self.cache_path / f"{key}.json" | |
| if file_path.exists(): | |
| with open(file_path) as f: | |
| data = json.load(f) | |
| self.memory_cache[key] = data | |
| return { | |
| 'status': 'success', | |
| 'data': data, | |
| 'source': 'disk' | |
| } | |
| return {'status': 'error', 'message': 'Cache key not found'} | |
| def get_stats(self) -> dict: | |
| """Get server statistics""" | |
| current_time = time.time() | |
| uptime = current_time - self.start_time | |
| ops_per_second = self.ops_counter / uptime if uptime > 0 else 0 | |
| return { | |
| 'uptime': uptime, | |
| 'total_operations': self.ops_counter, | |
| 'ops_per_second': ops_per_second, | |
| 'active_connections': len(self.active_connections), | |
| 'active_http_sessions': len(self.http_sessions), | |
| 'vram_cache_size': len(self.vram_cache), | |
| 'state_cache_size': len(self.state_cache), | |
| 'memory_cache_size': len(self.memory_cache), | |
| 'model_cache_size': len(self.model_cache) | |
| } | |
| # Create server instance | |
| server = VirtualGPUServer() | |
| # Dependency to get current session from JWT token | |
| def get_current_session(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict[str, Any]: | |
| return server.verify_session_token(credentials.credentials) | |
| # HTTP API Endpoints | |
| async def create_session(request: SessionCreateRequest): | |
| """Create a new HTTP session""" | |
| session_id = str(uuid.uuid4()) | |
| client_id = request.client_id or "anonymous" | |
| # Create session token | |
| token = server.create_session_token(session_id, client_id, request.resource_limits) | |
| # Store session info | |
| server.http_sessions[session_id] = { | |
| 'session_id': session_id, | |
| 'client_id': client_id, | |
| 'created_at': time.time(), | |
| 'resource_limits': request.resource_limits or {}, | |
| 'ops_count': 0 | |
| } | |
| expires_at = datetime.fromtimestamp(time.time() + (JWT_EXPIRATION_HOURS * 3600)) | |
| return SessionResponse( | |
| session_token=token, | |
| session_id=session_id, | |
| expires_at=expires_at | |
| ) | |
| async def write_vram_block( | |
| block_id: str, | |
| request: VRAMWriteRequest, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Write tensor data to VRAM block""" | |
| try: | |
| operation = { | |
| 'operation': 'vram', | |
| 'type': 'write', | |
| 'block_id': block_id, | |
| 'data': request.data, | |
| 'metadata': request.metadata or {}, | |
| 'model_size': request.model_size | |
| } | |
| result = await server.handle_vram_operation(operation) | |
| server.ops_counter += 1 | |
| if result['status'] == 'success': | |
| return VRAMResponse( | |
| status=result['status'], | |
| message=result['message'] | |
| ) | |
| else: | |
| raise HTTPException(status_code=400, detail=result['message']) | |
| except Exception as e: | |
| request_id = server.generate_request_id() | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"VRAM write operation failed: {str(e)}" | |
| ) | |
| async def read_vram_block( | |
| block_id: str, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Read tensor data from VRAM block""" | |
| try: | |
| operation = { | |
| 'operation': 'vram', | |
| 'type': 'read', | |
| 'block_id': block_id | |
| } | |
| result = await server.handle_vram_operation(operation) | |
| server.ops_counter += 1 | |
| if result['status'] == 'success': | |
| return VRAMResponse( | |
| status=result['status'], | |
| data=result.get('data'), | |
| metadata=result.get('metadata'), | |
| source=result.get('source') | |
| ) | |
| else: | |
| raise HTTPException(status_code=404, detail=result['message']) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| request_id = server.generate_request_id() | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"VRAM read operation failed: {str(e)}" | |
| ) | |
| async def delete_vram_block( | |
| block_id: str, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Delete tensor data from VRAM block""" | |
| try: | |
| # Remove from cache | |
| if block_id in server.vram_cache: | |
| del server.vram_cache[block_id] | |
| # Remove files | |
| file_path = server.vram_path / f"{block_id}.npy" | |
| metadata_path = server.vram_path / f"{block_id}_metadata.json" | |
| if file_path.exists(): | |
| file_path.unlink() | |
| if metadata_path.exists(): | |
| metadata_path.unlink() | |
| server.ops_counter += 1 | |
| return {"status": "success", "message": f"Block {block_id} deleted"} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"VRAM delete operation failed: {str(e)}" | |
| ) | |
| async def save_state( | |
| component: str, | |
| state_id: str, | |
| request: StateRequest, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Save component state""" | |
| try: | |
| operation = { | |
| 'operation': 'state', | |
| 'type': 'save', | |
| 'component': component, | |
| 'state_id': state_id, | |
| 'data': request.data | |
| } | |
| result = await server.handle_state_operation(operation) | |
| server.ops_counter += 1 | |
| if result['status'] == 'success': | |
| return StateResponse( | |
| status=result['status'], | |
| message=result['message'] | |
| ) | |
| else: | |
| raise HTTPException(status_code=400, detail=result['message']) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"State save operation failed: {str(e)}" | |
| ) | |
| async def load_state( | |
| component: str, | |
| state_id: str, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Load component state""" | |
| try: | |
| operation = { | |
| 'operation': 'state', | |
| 'type': 'load', | |
| 'component': component, | |
| 'state_id': state_id | |
| } | |
| result = await server.handle_state_operation(operation) | |
| server.ops_counter += 1 | |
| if result['status'] == 'success': | |
| return StateResponse( | |
| status=result['status'], | |
| data=result.get('data'), | |
| source=result.get('source') | |
| ) | |
| else: | |
| raise HTTPException(status_code=404, detail=result['message']) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"State load operation failed: {str(e)}" | |
| ) | |
| async def set_cache( | |
| key: str, | |
| request: CacheRequest, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Set cache value""" | |
| try: | |
| operation = { | |
| 'operation': 'cache', | |
| 'type': 'set', | |
| 'key': key, | |
| 'data': request.data | |
| } | |
| result = await server.handle_cache_operation(operation) | |
| server.ops_counter += 1 | |
| if result['status'] == 'success': | |
| return CacheResponse( | |
| status=result['status'], | |
| message=result['message'] | |
| ) | |
| else: | |
| raise HTTPException(status_code=400, detail=result['message']) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Cache set operation failed: {str(e)}" | |
| ) | |
| async def get_cache( | |
| key: str, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Get cache value""" | |
| try: | |
| operation = { | |
| 'operation': 'cache', | |
| 'type': 'get', | |
| 'key': key | |
| } | |
| result = await server.handle_cache_operation(operation) | |
| server.ops_counter += 1 | |
| if result['status'] == 'success': | |
| return CacheResponse( | |
| status=result['status'], | |
| data=result.get('data'), | |
| source=result.get('source') | |
| ) | |
| else: | |
| raise HTTPException(status_code=404, detail=result['message']) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Cache get operation failed: {str(e)}" | |
| ) | |
| def sanitize_filename(name: str) -> str: | |
| """ | |
| Sanitize a string for safe file system usage. | |
| Replaces slashes with double underscores. | |
| """ | |
| return name.replace('/', '__') | |
| async def load_model( | |
| model_name: str, | |
| request: ModelLoadRequest, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Load AI model""" | |
| try: | |
| # Log the received model name for debugging | |
| logging.info(f"Received model load request for: {model_name}") | |
| # Get safe filename for storage | |
| safe_name = sanitize_filename(model_name) | |
| if not request.model_data: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="model_data is required and must include architecture configuration" | |
| ) | |
| # Validate required model configuration | |
| required_fields = ['num_sms', 'tensor_cores_per_sm', 'cuda_cores_per_sm'] | |
| missing_fields = [field for field in required_fields if field not in request.model_data] | |
| if missing_fields: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Missing required model configuration fields: {missing_fields}" | |
| ) | |
| # Store model information with full configuration | |
| model_info = { | |
| 'model_name': model_name, | |
| 'model_data': request.model_data, | |
| 'model_path': request.model_path, | |
| 'model_hash': request.model_hash, | |
| 'loaded_at': time.time(), | |
| 'session_id': session['session_id'], | |
| 'architecture': { | |
| 'num_sms': request.model_data['num_sms'], | |
| 'tensor_cores_per_sm': request.model_data['tensor_cores_per_sm'], | |
| 'cuda_cores_per_sm': request.model_data['cuda_cores_per_sm'], | |
| 'vram_allocation': request.model_data.get('vram_allocation', 'dynamic'), | |
| 'compute_capability': request.model_data.get('compute_capability', '8.0') | |
| } | |
| } | |
| server.model_cache[model_name] = model_info | |
| # Store in persistent storage | |
| model_file = server.models_path / f"{safe_name}.json" | |
| model_data_file = server.models_path / f"{safe_name}.data" | |
| logging.info(f"Storing model info at: {model_file}") | |
| # Store metadata and configuration | |
| with open(model_file, 'w') as f: | |
| json.dump(model_info, f) | |
| # Store actual model data separately | |
| if request.model_data.get('weights') or request.model_data.get('parameters'): | |
| logging.info(f"Storing model data at: {model_data_file}") | |
| with open(model_data_file, 'w') as f: | |
| json.dump(request.model_data, f) | |
| server.ops_counter += 1 | |
| return { | |
| "status": "success", | |
| "message": f"Model {model_name} loaded successfully", | |
| "model_info": { | |
| "name": model_name, | |
| "architecture": model_info['architecture'], | |
| "loaded_at": model_info['loaded_at'] | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Model load operation failed: {str(e)}" | |
| ) | |
| async def run_inference( | |
| model_name: str, | |
| request: ModelInferenceRequest, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Run model inference""" | |
| try: | |
| logging.info(f"Running inference - Raw model name: {model_name}") | |
| safe_name = sanitize_model_name(model_name) | |
| logging.info(f"Running inference - Safe model name: {safe_name}") | |
| # Check if model is loaded (try both original and safe names) | |
| if model_name not in server.model_cache: | |
| # Try loading from file system using safe name | |
| model_file = server.models_path / f"{safe_name}.json" | |
| if not model_file.exists(): | |
| logging.error(f"Model {model_name} not found in cache or filesystem") | |
| raise HTTPException(status_code=404, detail=f"Model {model_name} not loaded") | |
| logging.info(f"Loading model info from file: {model_file}") | |
| with open(model_file) as f: | |
| model_info = json.load(f) | |
| server.model_cache[model_name] = model_info | |
| # Simulate inference processing | |
| # In a real implementation, this would invoke the actual model | |
| result = { | |
| "status": "success", | |
| "output": request.input_data, # Echo input for now | |
| "metrics": { | |
| "inference_time": 0.1, | |
| "tokens_processed": len(request.input_data) | |
| }, | |
| "model_info": server.model_cache[model_name] | |
| } | |
| server.ops_counter += 1 | |
| logging.info(f"Inference completed successfully for model: {model_name}") | |
| return result | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logging.error(f"Inference operation failed for {model_name}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Inference operation failed: {str(e)}" | |
| ) | |
| async def get_model_status( | |
| model_name: str, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Get model status""" | |
| try: | |
| logging.info(f"Checking model status for: {model_name}") | |
| # Check cache first | |
| if model_name in server.model_cache: | |
| logging.info(f"Model {model_name} found in cache") | |
| return { | |
| "status": "loaded", | |
| "model_info": server.model_cache[model_name] | |
| } | |
| # Check file system using safe name | |
| safe_name = sanitize_filename(model_name) | |
| model_file = server.models_path / f"{safe_name}.json" | |
| if model_file.exists(): | |
| logging.info(f"Model file found: {model_file}") | |
| with open(model_file) as f: | |
| model_info = json.load(f) | |
| # Update cache | |
| server.model_cache[model_name] = model_info | |
| return { | |
| "status": "loaded", | |
| "model_info": model_info | |
| } | |
| logging.info(f"Model {model_name} not found in cache or filesystem") | |
| return { | |
| "status": "not_loaded", | |
| "message": f"Model {model_name} is not loaded" | |
| } | |
| except Exception as e: | |
| logging.error(f"Model status check failed for {model_name}: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Model status check failed: {str(e)}" | |
| ) | |
| # Multi-chip coordination endpoints | |
| async def transfer_between_chips( | |
| src_chip_id: int, | |
| dst_chip_id: int, | |
| request: dict, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Transfer data between GPU chips""" | |
| try: | |
| data_id = request.get('data_id') | |
| if not data_id: | |
| raise HTTPException(status_code=400, detail="Missing data_id") | |
| # Load the source data | |
| source_operation = { | |
| 'operation': 'vram', | |
| 'type': 'read', | |
| 'block_id': data_id | |
| } | |
| source_result = await server.handle_vram_operation(source_operation) | |
| if source_result.get('status') != 'success': | |
| raise HTTPException(status_code=404, detail=f"Source data {data_id} not found") | |
| # Create new data ID for destination | |
| new_data_id = f"{data_id}_chip_{dst_chip_id}" | |
| # Store the data with the new ID | |
| dest_operation = { | |
| 'operation': 'vram', | |
| 'type': 'write', | |
| 'block_id': new_data_id, | |
| 'data': source_result.get('data'), | |
| 'metadata': source_result.get('metadata', {}) | |
| } | |
| dest_result = await server.handle_vram_operation(dest_operation) | |
| if dest_result.get('status') != 'success': | |
| raise HTTPException(status_code=500, detail="Failed to store transferred data") | |
| # Simulate cross-chip transfer | |
| transfer_id = f"transfer_{time.time_ns()}" | |
| result = { | |
| "status": "success", | |
| "transfer_id": transfer_id, | |
| "src_chip": src_chip_id, | |
| "dst_chip": dst_chip_id, | |
| "data_id": data_id, | |
| "new_data_id": new_data_id | |
| } | |
| server.ops_counter += 1 | |
| return result | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Chip transfer failed: {str(e)}" | |
| ) | |
| async def create_sync_barrier( | |
| barrier_id: str, | |
| request: dict, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Create synchronization barrier""" | |
| try: | |
| num_participants = request.get('num_participants', 1) | |
| # Store barrier info | |
| barrier_info = { | |
| 'barrier_id': barrier_id, | |
| 'num_participants': num_participants, | |
| 'arrived_participants': 0, | |
| 'created_at': time.time() | |
| } | |
| server.memory_cache[f"barrier_{barrier_id}"] = barrier_info | |
| return { | |
| "status": "success", | |
| "barrier_id": barrier_id, | |
| "num_participants": num_participants | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Barrier creation failed: {str(e)}" | |
| ) | |
| async def wait_sync_barrier( | |
| barrier_id: str, | |
| session: Dict[str, Any] = Depends(get_current_session) | |
| ): | |
| """Wait at synchronization barrier""" | |
| try: | |
| barrier_key = f"barrier_{barrier_id}" | |
| if barrier_key not in server.memory_cache: | |
| raise HTTPException(status_code=404, detail="Barrier not found") | |
| barrier_info = server.memory_cache[barrier_key] | |
| barrier_info['arrived_participants'] += 1 | |
| # Check if all participants have arrived | |
| if barrier_info['arrived_participants'] >= barrier_info['num_participants']: | |
| # All participants arrived, release barrier | |
| del server.memory_cache[barrier_key] | |
| return { | |
| "status": "released", | |
| "message": "All participants arrived, barrier released" | |
| } | |
| else: | |
| return { | |
| "status": "waiting", | |
| "arrived": barrier_info['arrived_participants'], | |
| "total": barrier_info['num_participants'] | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Barrier wait failed: {str(e)}" | |
| ) | |
| # Preserved WebSocket endpoints for backward compatibility | |
| async def handle_index(): | |
| """Handle HTTP index request""" | |
| stats = server.get_stats() | |
| html = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Virtual GPU Server v2.0</title> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 40px; }} | |
| table {{ border-collapse: collapse; width: 100%; margin-top: 20px; }} | |
| th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }} | |
| th {{ background-color: #f2f2f2; }} | |
| .stats {{ background-color: #f9f9f9; padding: 20px; border-radius: 5px; }} | |
| .api-info {{ background-color: #e8f4fd; padding: 20px; border-radius: 5px; margin-top: 20px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Virtual GPU Server v2.0 Status</h1> | |
| <div class="api-info"> | |
| <h2>API Information</h2> | |
| <p><strong>HTTP REST API:</strong> Available at /api/v1/</p> | |
| <p><strong>WebSocket API:</strong> Available at /ws (backward compatibility)</p> | |
| <p><strong>API Documentation:</strong> <a href="/docs">/docs</a></p> | |
| </div> | |
| <div class="stats"> | |
| <h2>Server Statistics</h2> | |
| <ul> | |
| <li>Uptime: {stats['uptime']:.2f} seconds</li> | |
| <li>Total Operations: {stats['total_operations']}</li> | |
| <li>Operations per Second: {stats['ops_per_second']:.2f}</li> | |
| <li>Active WebSocket Connections: {stats['active_connections']}</li> | |
| <li>Active HTTP Sessions: {stats['active_http_sessions']}</li> | |
| <li>VRAM Cache Size: {stats['vram_cache_size']}</li> | |
| <li>State Cache Size: {stats['state_cache_size']}</li> | |
| <li>Memory Cache Size: {stats['memory_cache_size']}</li> | |
| <li>Model Cache Size: {stats['model_cache_size']}</li> | |
| </ul> | |
| </div> | |
| <h2>Server Files</h2> | |
| <iframe src="/files" style="width: 100%; height: 500px; border: none;"></iframe> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html) | |
| async def handle_files(): | |
| """Handle HTTP files listing request""" | |
| def format_size(size): | |
| for unit in ['B', 'KB', 'MB', 'GB']: | |
| if size < 1024: | |
| return f"{size:.2f} {unit}" | |
| size /= 1024 | |
| return f"{size:.2f} TB" | |
| html = ['<!DOCTYPE html><html><head>', | |
| '<style>', | |
| 'body { font-family: Arial, sans-serif; margin: 20px; }', | |
| 'table { border-collapse: collapse; width: 100%; }', | |
| 'th, td { padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }', | |
| 'th { background-color: #f2f2f2; }', | |
| '</style></head><body>', | |
| '<h2>Server Files</h2>', | |
| '<table><tr><th>Path</th><th>Size</th><th>Last Modified</th></tr>'] | |
| for root, _, files in os.walk(server.base_path): | |
| for file in files: | |
| full_path = Path(root) / file | |
| rel_path = full_path.relative_to(server.base_path) | |
| size = format_size(os.path.getsize(full_path)) | |
| mtime = datetime.fromtimestamp(os.path.getmtime(full_path)) | |
| html.append(f'<tr><td>{rel_path}</td><td>{size}</td><td>{mtime}</td></tr>') | |
| html.extend(['</table></body></html>']) | |
| return HTMLResponse(content='\n'.join(html)) | |
| # WebSocket endpoint (preserved for backward compatibility) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| session_id = str(uuid.uuid4()) | |
| server.active_connections[session_id] = websocket | |
| server.active_sessions[session_id] = { | |
| 'start_time': time.time(), | |
| 'ops_count': 0 | |
| } | |
| try: | |
| while True: | |
| message = await websocket.receive_json() | |
| # Route operation to appropriate handler | |
| operation_type = message.get('operation') | |
| if operation_type == 'vram': | |
| response = await server.handle_vram_operation(message) | |
| elif operation_type == 'state': | |
| response = await server.handle_state_operation(message) | |
| elif operation_type == 'cache': | |
| response = await server.handle_cache_operation(message) | |
| else: | |
| response = { | |
| 'status': 'error', | |
| 'message': 'Unknown operation type' | |
| } | |
| # Update statistics | |
| server.ops_counter += 1 | |
| server.active_sessions[session_id]['ops_count'] += 1 | |
| # Send response | |
| await websocket.send_json(response) | |
| except Exception as e: | |
| print(f"WebSocket error: {e}") | |
| finally: | |
| # Cleanup on disconnect | |
| if session_id in server.active_connections: | |
| del server.active_connections[session_id] | |
| if session_id in server.active_sessions: | |
| del server.active_sessions[session_id] | |
| # For running directly (development) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("virtual_gpu_server_http:app", host="0.0.0.0", port=7860, reload=True) | |
| async def get_status(): | |
| """Get server status""" | |
| return {"status": "ok", "message": "Virtual GPU Server is running"} | |