Spaces:
Sleeping
Sleeping
| import websockets | |
| import json | |
| import numpy as np | |
| from typing import Dict, Any, Optional, Union | |
| import threading | |
| from queue import Queue | |
| import time | |
| import asyncio | |
| import hashlib | |
| class WebSocketGPUStorage: | |
| # Singleton instance | |
| _instance = None | |
| _lock = threading.Lock() | |
| def __new__(cls, url: str = "wss://factorst-intiv.hf.space:443/ws"): | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| cls._instance._init_singleton(url) | |
| return cls._instance | |
| def _init_singleton(self, url: str): | |
| """Initialize the singleton instance""" | |
| if hasattr(self, 'initialized'): | |
| return | |
| self.url = url | |
| self.websocket = None | |
| self.connected = False | |
| self.message_queue = Queue() | |
| self.response_queues: Dict[str, Queue] = {} | |
| self.lock = threading.Lock() | |
| self._closing = False | |
| self._loop = None | |
| self.error_count = 0 | |
| self.last_error_time = 0 | |
| self.max_retries = 5 | |
| self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata | |
| self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models | |
| self.resource_monitor = { | |
| 'vram_used': 0, | |
| 'active_tensors': 0, | |
| 'loaded_models': set() | |
| } | |
| # Start WebSocket connection in a separate thread | |
| self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True) | |
| self.ws_thread.start() | |
| self.initialized = True | |
| def __init__(self, url: str = "wss://factorst-intiv.hf.space:443/ws"): | |
| """This will actually just return the singleton instance""" | |
| pass | |
| def _run_websocket_loop(self): | |
| self._loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(self._loop) | |
| self._loop.run_until_complete(self._websocket_handler()) | |
| async def _websocket_handler(self): | |
| while not self._closing: | |
| try: | |
| async with websockets.connect(self.url) as websocket: | |
| self.websocket = websocket | |
| self.connected = True | |
| self.error_count = 0 # Reset error count on successful connection | |
| print("Connected to GPU storage server") | |
| while True: | |
| # Handle outgoing messages | |
| try: | |
| while not self.message_queue.empty(): | |
| msg_id, operation = self.message_queue.get() | |
| await websocket.send(json.dumps(operation)) | |
| # Wait for response with timeout | |
| try: | |
| response = await asyncio.wait_for(websocket.recv(), timeout=30) | |
| response_data = json.loads(response) | |
| # Put response in corresponding queue | |
| if msg_id in self.response_queues: | |
| self.response_queues[msg_id].put(response_data) | |
| except asyncio.TimeoutError: | |
| if msg_id in self.response_queues: | |
| self.response_queues[msg_id].put({ | |
| "status": "error", | |
| "message": "Operation timed out" | |
| }) | |
| except Exception as e: | |
| if msg_id in self.response_queues: | |
| self.response_queues[msg_id].put({ | |
| "status": "error", | |
| "message": f"Error processing response: {str(e)}" | |
| }) | |
| except Exception as e: | |
| print(f"Error processing message: {str(e)}") | |
| # Keep connection alive with heartbeat | |
| try: | |
| await websocket.ping() | |
| except: | |
| break # Break inner loop on ping failure | |
| await asyncio.sleep(0.001) # 1ms sleep for electron-speed response | |
| except Exception as e: | |
| print(f"WebSocket connection error: {e}") | |
| self.connected = False | |
| await asyncio.sleep(1) # Wait before reconnecting | |
| def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]: | |
| if self._closing: | |
| return {"status": "error", "message": "WebSocket is closing"} | |
| if not self.wait_for_connection(timeout=10): | |
| return {"status": "error", "message": "Not connected to GPU storage server"} | |
| msg_id = str(time.time()) | |
| response_queue = Queue() | |
| with self.lock: | |
| self.response_queues[msg_id] = response_queue | |
| self.message_queue.put((msg_id, operation)) | |
| try: | |
| # Wait for response with configurable timeout | |
| response = response_queue.get(timeout=30) # Extended timeout for large models | |
| if response.get("status") == "error" and "model_size" in operation: | |
| # Retry once for model loading operations | |
| self.message_queue.put((msg_id, operation)) | |
| response = response_queue.get(timeout=30) | |
| except Exception as e: | |
| response = {"status": "error", "message": f"Operation failed: {str(e)}"} | |
| finally: | |
| with self.lock: | |
| if msg_id in self.response_queues: | |
| del self.response_queues[msg_id] | |
| return response | |
| def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool: | |
| try: | |
| if data is None: | |
| raise ValueError("Cannot store None tensor") | |
| # Calculate tensor metadata | |
| tensor_shape = data.shape | |
| tensor_dtype = str(data.dtype) | |
| tensor_size = data.nbytes | |
| operation = { | |
| 'operation': 'vram', | |
| 'type': 'write', | |
| 'block_id': tensor_id, | |
| 'data': data.tolist(), | |
| 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited | |
| 'metadata': { | |
| 'shape': tensor_shape, | |
| 'dtype': tensor_dtype, | |
| 'size': tensor_size, | |
| 'timestamp': time.time() | |
| } | |
| } | |
| response = self._send_operation(operation) | |
| if response.get('status') == 'success': | |
| # Update tensor registry | |
| with self.lock: | |
| self.tensor_registry[tensor_id] = { | |
| 'shape': tensor_shape, | |
| 'dtype': tensor_dtype, | |
| 'size': tensor_size, | |
| 'timestamp': time.time() | |
| } | |
| self.resource_monitor['vram_used'] += tensor_size | |
| self.resource_monitor['active_tensors'] += 1 | |
| return True | |
| else: | |
| print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}") | |
| return False | |
| except Exception as e: | |
| print(f"Error storing tensor {tensor_id}: {str(e)}") | |
| return False | |
| def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]: | |
| try: | |
| # Check tensor registry first | |
| if tensor_id not in self.tensor_registry: | |
| print(f"Tensor {tensor_id} not registered in VRAM") | |
| return None | |
| operation = { | |
| 'operation': 'vram', | |
| 'type': 'read', | |
| 'block_id': tensor_id, | |
| 'expected_metadata': self.tensor_registry.get(tensor_id, {}) | |
| } | |
| response = self._send_operation(operation) | |
| if response.get('status') == 'success': | |
| data = response.get('data') | |
| if data is None: | |
| print(f"No data found for tensor {tensor_id}") | |
| return None | |
| # Verify tensor metadata | |
| metadata = response.get('metadata', {}) | |
| expected_metadata = self.tensor_registry.get(tensor_id, {}) | |
| if metadata.get('shape') != expected_metadata.get('shape'): | |
| print(f"Warning: Tensor {tensor_id} shape mismatch") | |
| try: | |
| # Convert to numpy array with correct dtype | |
| arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32'))) | |
| if arr.shape != expected_metadata.get('shape'): | |
| arr = arr.reshape(expected_metadata.get('shape')) | |
| return arr | |
| except Exception as e: | |
| print(f"Error converting tensor data: {str(e)}") | |
| return None | |
| else: | |
| print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}") | |
| return None | |
| except Exception as e: | |
| print(f"Error loading tensor {tensor_id}: {str(e)}") | |
| return None | |
| def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool: | |
| try: | |
| operation = { | |
| 'operation': 'state', | |
| 'type': 'save', | |
| 'component': component, | |
| 'state_id': state_id, | |
| 'data': state_data, | |
| 'timestamp': time.time() | |
| } | |
| response = self._send_operation(operation) | |
| if response.get('status') != 'success': | |
| print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}") | |
| return False | |
| return True | |
| except Exception as e: | |
| print(f"Error storing state for {component}/{state_id}: {str(e)}") | |
| return False | |
| def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]: | |
| try: | |
| operation = { | |
| 'operation': 'state', | |
| 'type': 'load', | |
| 'component': component, | |
| 'state_id': state_id | |
| } | |
| response = self._send_operation(operation) | |
| if response.get('status') == 'success': | |
| data = response.get('data') | |
| if data is None: | |
| print(f"No state found for {component}/{state_id}") | |
| return None | |
| return data | |
| else: | |
| print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}") | |
| return None | |
| except Exception as e: | |
| print(f"Error loading state for {component}/{state_id}: {str(e)}") | |
| return None | |
| def is_model_loaded(self, model_name: str) -> bool: | |
| """Check if a model is already loaded in VRAM""" | |
| return model_name in self.resource_monitor['loaded_models'] | |
| def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool: | |
| """Load a model into VRAM if not already loaded""" | |
| try: | |
| # Check if model is already loaded | |
| if self.is_model_loaded(model_name): | |
| print(f"Model {model_name} already loaded in VRAM") | |
| return True | |
| # Calculate model hash if path provided | |
| model_hash = None | |
| if model_path: | |
| model_hash = self._calculate_model_hash(model_path) | |
| operation = { | |
| 'operation': 'model', | |
| 'type': 'load', | |
| 'model_name': model_name, | |
| 'model_hash': model_hash, | |
| 'model_data': model_data | |
| } | |
| response = self._send_operation(operation) | |
| if response.get('status') == 'success': | |
| with self.lock: | |
| self.model_registry[model_name] = { | |
| 'hash': model_hash, | |
| 'timestamp': time.time(), | |
| 'tensors': response.get('tensor_ids', []) | |
| } | |
| self.resource_monitor['loaded_models'].add(model_name) | |
| print(f"Successfully loaded model {model_name}") | |
| return True | |
| else: | |
| print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}") | |
| return False | |
| except Exception as e: | |
| print(f"Error loading model {model_name}: {str(e)}") | |
| return False | |
| def _calculate_model_hash(self, model_path: str) -> str: | |
| """Calculate SHA256 hash of model file""" | |
| try: | |
| sha256_hash = hashlib.sha256() | |
| with open(model_path, "rb") as f: | |
| for byte_block in iter(lambda: f.read(4096), b""): | |
| sha256_hash.update(byte_block) | |
| return sha256_hash.hexdigest() | |
| except Exception as e: | |
| print(f"Error calculating model hash: {str(e)}") | |
| return "" | |
| def cache_data(self, key: str, data: Any) -> bool: | |
| operation = { | |
| 'operation': 'cache', | |
| 'type': 'set', | |
| 'key': key, | |
| 'data': data | |
| } | |
| response = self._send_operation(operation) | |
| return response.get('status') == 'success' | |
| def get_cached_data(self, key: str) -> Optional[Any]: | |
| operation = { | |
| 'operation': 'cache', | |
| 'type': 'get', | |
| 'key': key | |
| } | |
| response = self._send_operation(operation) | |
| if response.get('status') == 'success': | |
| return response['data'] | |
| return None | |
| def wait_for_connection(self, timeout: float = 30.0) -> bool: | |
| """Wait for WebSocket connection to be established""" | |
| start_time = time.time() | |
| while not self._closing and not self.connected: | |
| if time.time() - start_time > timeout: | |
| print("Connection timeout exceeded") | |
| return False | |
| time.sleep(0.1) | |
| return self.connected | |
| def is_connected(self) -> bool: | |
| """Check if WebSocket connection is active""" | |
| return self.connected and not self._closing | |
| def get_connection_status(self) -> Dict[str, Any]: | |
| """Get detailed connection status""" | |
| return { | |
| "connected": self.connected, | |
| "closing": self._closing, | |
| "error_count": self.error_count, | |
| "url": self.url, | |
| "last_error_time": self.last_error_time, | |
| "loaded_models": list(self.resource_monitor['loaded_models']) | |
| } | |
| def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]: | |
| """Start inference with a loaded model""" | |
| try: | |
| if not self.is_model_loaded(model_name): | |
| print(f"Model {model_name} not loaded. Please load the model first.") | |
| return None | |
| operation = { | |
| 'operation': 'inference', | |
| 'type': 'run', | |
| 'model_name': model_name, | |
| 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data | |
| } | |
| response = self._send_operation(operation) | |
| if response.get('status') == 'success': | |
| return { | |
| 'output': np.array(response['output']) if 'output' in response else None, | |
| 'metrics': response.get('metrics', {}), | |
| 'model_info': self.model_registry.get(model_name, {}) | |
| } | |
| else: | |
| print(f"Inference failed: {response.get('message', 'Unknown error')}") | |
| return None | |
| except Exception as e: | |
| print(f"Error during inference: {str(e)}") | |
| return None | |
| def close(self): | |
| """Close WebSocket connection and cleanup resources.""" | |
| if not self._closing: | |
| self._closing = True | |
| if self.websocket and self._loop: | |
| async def cleanup(): | |
| try: | |
| # Clean up registries | |
| with self.lock: | |
| self.tensor_registry.clear() | |
| self.model_registry.clear() | |
| self.resource_monitor['vram_used'] = 0 | |
| self.resource_monitor['active_tensors'] = 0 | |
| self.resource_monitor['loaded_models'].clear() | |
| # Notify server about cleanup | |
| if self.connected: | |
| try: | |
| await self.websocket.send(json.dumps({ | |
| 'operation': 'cleanup', | |
| 'type': 'full' | |
| })) | |
| except: | |
| pass | |
| await self.websocket.close() | |
| except Exception as e: | |
| print(f"Error during cleanup: {str(e)}") | |
| finally: | |
| self.connected = False | |
| if self._loop.is_running(): | |
| self._loop.create_task(cleanup()) | |
| else: | |
| asyncio.run(cleanup()) | |
| async def aclose(self): | |
| """Asynchronously close WebSocket connection.""" | |
| if not self._closing: | |
| self._closing = True | |
| if self.websocket: | |
| try: | |
| await self.websocket.close() | |
| except: | |
| pass | |
| finally: | |
| self.connected = False | |
| def __del__(self): | |
| """Ensure cleanup on deletion.""" | |
| self.close() | |