import websockets import json import numpy as np from typing import Dict, Any, Optional, Union import threading from queue import Queue import time import asyncio import hashlib class WebSocketGPUStorage: # Singleton instance _instance = None _lock = threading.Lock() def __new__(cls, url: str = "wss://factorst-wbs1.hf.space/ws"): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._init_singleton(url) return cls._instance def _init_singleton(self, url: str): """Initialize the singleton instance""" if hasattr(self, 'initialized'): return self.url = url self.websocket = None self.connected = False self.message_queue = Queue() self.response_queues: Dict[str, Queue] = {} self.lock = threading.Lock() self._closing = False self._loop = None self.error_count = 0 self.last_error_time = 0 self.max_retries = 5 self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models self.resource_monitor = { 'vram_used': 0, 'active_tensors': 0, 'loaded_models': set() } # Start WebSocket connection in a separate thread self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True) self.ws_thread.start() self.initialized = True def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"): """This will actually just return the singleton instance""" pass def _run_websocket_loop(self): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) self._loop.run_until_complete(self._websocket_handler()) async def _websocket_handler(self): while not self._closing: try: async with websockets.connect(self.url) as websocket: self.websocket = websocket self.connected = True self.error_count = 0 # Reset error count on successful connection print("Connected to GPU storage server") while True: # Handle outgoing messages try: while not self.message_queue.empty(): msg_id, operation = self.message_queue.get() await websocket.send(json.dumps(operation)) # Wait for response with timeout try: response = await asyncio.wait_for(websocket.recv(), timeout=30) response_data = json.loads(response) # Put response in corresponding queue if msg_id in self.response_queues: self.response_queues[msg_id].put(response_data) except asyncio.TimeoutError: if msg_id in self.response_queues: self.response_queues[msg_id].put({ "status": "error", "message": "Operation timed out" }) except Exception as e: if msg_id in self.response_queues: self.response_queues[msg_id].put({ "status": "error", "message": f"Error processing response: {str(e)}" }) except Exception as e: print(f"Error processing message: {str(e)}") # Keep connection alive with heartbeat try: await websocket.ping() except: break # Break inner loop on ping failure await asyncio.sleep(0.001) # 1ms sleep for electron-speed response except Exception as e: print(f"WebSocket connection error: {e}") self.connected = False await asyncio.sleep(1) # Wait before reconnecting def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]: if self._closing: return {"status": "error", "message": "WebSocket is closing"} if not self.wait_for_connection(timeout=10): return {"status": "error", "message": "Not connected to GPU storage server"} msg_id = str(time.time()) response_queue = Queue() with self.lock: self.response_queues[msg_id] = response_queue self.message_queue.put((msg_id, operation)) try: # Wait for response with configurable timeout response = response_queue.get(timeout=30) # Extended timeout for large models if response.get("status") == "error" and "model_size" in operation: # Retry once for model loading operations self.message_queue.put((msg_id, operation)) response = response_queue.get(timeout=30) except Exception as e: response = {"status": "error", "message": f"Operation failed: {str(e)}"} finally: with self.lock: if msg_id in self.response_queues: del self.response_queues[msg_id] return response def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool: try: if data is None: raise ValueError("Cannot store None tensor") # Calculate tensor metadata tensor_shape = data.shape tensor_dtype = str(data.dtype) tensor_size = data.nbytes operation = { 'operation': 'vram', 'type': 'write', 'block_id': tensor_id, 'data': data.tolist(), 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited 'metadata': { 'shape': tensor_shape, 'dtype': tensor_dtype, 'size': tensor_size, 'timestamp': time.time() } } response = self._send_operation(operation) if response.get('status') == 'success': # Update tensor registry with self.lock: self.tensor_registry[tensor_id] = { 'shape': tensor_shape, 'dtype': tensor_dtype, 'size': tensor_size, 'timestamp': time.time() } self.resource_monitor['vram_used'] += tensor_size self.resource_monitor['active_tensors'] += 1 return True else: print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}") return False except Exception as e: print(f"Error storing tensor {tensor_id}: {str(e)}") return False def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]: try: # Check tensor registry first if tensor_id not in self.tensor_registry: print(f"Tensor {tensor_id} not registered in VRAM") return None operation = { 'operation': 'vram', 'type': 'read', 'block_id': tensor_id, 'expected_metadata': self.tensor_registry.get(tensor_id, {}) } response = self._send_operation(operation) if response.get('status') == 'success': data = response.get('data') if data is None: print(f"No data found for tensor {tensor_id}") return None # Verify tensor metadata metadata = response.get('metadata', {}) expected_metadata = self.tensor_registry.get(tensor_id, {}) if metadata.get('shape') != expected_metadata.get('shape'): print(f"Warning: Tensor {tensor_id} shape mismatch") try: # Convert to numpy array with correct dtype arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32'))) if arr.shape != expected_metadata.get('shape'): arr = arr.reshape(expected_metadata.get('shape')) return arr except Exception as e: print(f"Error converting tensor data: {str(e)}") return None else: print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}") return None except Exception as e: print(f"Error loading tensor {tensor_id}: {str(e)}") return None def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool: try: operation = { 'operation': 'state', 'type': 'save', 'component': component, 'state_id': state_id, 'data': state_data, 'timestamp': time.time() } response = self._send_operation(operation) if response.get('status') != 'success': print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}") return False return True except Exception as e: print(f"Error storing state for {component}/{state_id}: {str(e)}") return False def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]: try: operation = { 'operation': 'state', 'type': 'load', 'component': component, 'state_id': state_id } response = self._send_operation(operation) if response.get('status') == 'success': data = response.get('data') if data is None: print(f"No state found for {component}/{state_id}") return None return data else: print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}") return None except Exception as e: print(f"Error loading state for {component}/{state_id}: {str(e)}") return None def is_model_loaded(self, model_name: str) -> bool: """Check if a model is already loaded in VRAM""" return model_name in self.resource_monitor['loaded_models'] def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool: """Load a model into VRAM if not already loaded""" try: # Check if model is already loaded if self.is_model_loaded(model_name): print(f"Model {model_name} already loaded in VRAM") return True # Calculate model hash if path provided model_hash = None if model_path: model_hash = self._calculate_model_hash(model_path) operation = { 'operation': 'model', 'type': 'load', 'model_name': model_name, 'model_hash': model_hash, 'model_data': model_data } response = self._send_operation(operation) if response.get('status') == 'success': with self.lock: self.model_registry[model_name] = { 'hash': model_hash, 'timestamp': time.time(), 'tensors': response.get('tensor_ids', []) } self.resource_monitor['loaded_models'].add(model_name) print(f"Successfully loaded model {model_name}") return True else: print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}") return False except Exception as e: print(f"Error loading model {model_name}: {str(e)}") return False def _calculate_model_hash(self, model_path: str) -> str: """Calculate SHA256 hash of model file""" try: sha256_hash = hashlib.sha256() with open(model_path, "rb") as f: for byte_block in iter(lambda: f.read(4096), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() except Exception as e: print(f"Error calculating model hash: {str(e)}") return "" def cache_data(self, key: str, data: Any) -> bool: operation = { 'operation': 'cache', 'type': 'set', 'key': key, 'data': data } response = self._send_operation(operation) return response.get('status') == 'success' def get_cached_data(self, key: str) -> Optional[Any]: operation = { 'operation': 'cache', 'type': 'get', 'key': key } response = self._send_operation(operation) if response.get('status') == 'success': return response['data'] return None def wait_for_connection(self, timeout: float = 30.0) -> bool: """Wait for WebSocket connection to be established""" start_time = time.time() while not self._closing and not self.connected: if time.time() - start_time > timeout: print("Connection timeout exceeded") return False time.sleep(0.1) return self.connected def is_connected(self) -> bool: """Check if WebSocket connection is active""" return self.connected and not self._closing def get_connection_status(self) -> Dict[str, Any]: """Get detailed connection status""" return { "connected": self.connected, "closing": self._closing, "error_count": self.error_count, "url": self.url, "last_error_time": self.last_error_time, "loaded_models": list(self.resource_monitor['loaded_models']) } def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]: """Start inference with a loaded model""" try: if not self.is_model_loaded(model_name): print(f"Model {model_name} not loaded. Please load the model first.") return None operation = { 'operation': 'inference', 'type': 'run', 'model_name': model_name, 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data } response = self._send_operation(operation) if response.get('status') == 'success': return { 'output': np.array(response['output']) if 'output' in response else None, 'metrics': response.get('metrics', {}), 'model_info': self.model_registry.get(model_name, {}) } else: print(f"Inference failed: {response.get('message', 'Unknown error')}") return None except Exception as e: print(f"Error during inference: {str(e)}") return None def close(self): """Close WebSocket connection and cleanup resources.""" if not self._closing: self._closing = True if self.websocket and self._loop: async def cleanup(): try: # Clean up registries with self.lock: self.tensor_registry.clear() self.model_registry.clear() self.resource_monitor['vram_used'] = 0 self.resource_monitor['active_tensors'] = 0 self.resource_monitor['loaded_models'].clear() # Notify server about cleanup if self.connected: try: await self.websocket.send(json.dumps({ 'operation': 'cleanup', 'type': 'full' })) except: pass await self.websocket.close() except Exception as e: print(f"Error during cleanup: {str(e)}") finally: self.connected = False if self._loop.is_running(): self._loop.create_task(cleanup()) else: asyncio.run(cleanup()) async def aclose(self): """Asynchronously close WebSocket connection.""" if not self._closing: self._closing = True if self.websocket: try: await self.websocket.close() except: pass finally: self.connected = False def __del__(self): """Ensure cleanup on deletion.""" self.close()