Factor Studios commited on
Commit
16d64f1
·
verified ·
1 Parent(s): d200bfd

Upload 21 files

Browse files
gpu_chip.py CHANGED
@@ -5,15 +5,18 @@ from typing import Dict, Any, List, Optional
5
  import time
6
 
7
  class GPUChip:
8
- def __init__(self, chip_id: int, num_sms: int = 108, vram_gb: int = 24):
9
  self.chip_id = chip_id
10
- self.storage = WebSocketGPUStorage()
11
- if not self.storage.wait_for_connection():
12
- raise RuntimeError("Could not connect to GPU storage server")
 
 
 
13
 
14
- # Initialize components
15
- self.vram = VirtualVRAM(vram_gb)
16
- self.sms = [StreamingMultiprocessor(i) for i in range(num_sms)]
17
 
18
  # Initialize chip state
19
  self.chip_state = {
 
5
  import time
6
 
7
  class GPUChip:
8
+ def __init__(self, chip_id: int, num_sms: int = 108, vram_gb: int = 24, storage=None):
9
  self.chip_id = chip_id
10
+ self.storage = storage
11
+ if self.storage is None:
12
+ from websocket_storage import WebSocketGPUStorage
13
+ self.storage = WebSocketGPUStorage()
14
+ if not self.storage.wait_for_connection():
15
+ raise RuntimeError("Could not connect to GPU storage server")
16
 
17
+ # Initialize components with shared storage
18
+ self.vram = VirtualVRAM(vram_gb, storage=self.storage)
19
+ self.sms = [StreamingMultiprocessor(i, storage=self.storage) for i in range(num_sms)]
20
 
21
  # Initialize chip state
22
  self.chip_state = {
multi_gpu_system.py CHANGED
@@ -5,13 +5,16 @@ import time
5
  import numpy as np
6
 
7
  class MultiGPUSystem:
8
- def __init__(self, num_gpus: int = 8):
9
- self.storage = WebSocketGPUStorage()
10
- if not self.storage.wait_for_connection():
11
- raise RuntimeError("Could not connect to GPU storage server")
 
 
 
12
 
13
- # Initialize GPUs
14
- self.gpus = [GPUChip(i) for i in range(num_gpus)]
15
 
16
  # Initialize system state
17
  self.system_state = {
 
5
  import numpy as np
6
 
7
  class MultiGPUSystem:
8
+ def __init__(self, num_gpus: int = 8, storage=None):
9
+ self.storage = storage
10
+ if self.storage is None:
11
+ from websocket_storage import WebSocketGPUStorage
12
+ self.storage = WebSocketGPUStorage()
13
+ if not self.storage.wait_for_connection():
14
+ raise RuntimeError("Could not connect to GPU storage server")
15
 
16
+ # Initialize GPUs with shared storage
17
+ self.gpus = [GPUChip(i, storage=self.storage) for i in range(num_gpus)]
18
 
19
  # Initialize system state
20
  self.system_state = {
streaming_multiprocessor.py CHANGED
@@ -4,12 +4,15 @@ from typing import Dict, Any, Optional, List
4
  import time
5
 
6
  class StreamingMultiprocessor:
7
- def __init__(self, sm_id: int, num_cores: int = 128):
8
  self.sm_id = sm_id
9
  self.num_cores = num_cores
10
- self.storage = WebSocketGPUStorage()
11
- if not self.storage.wait_for_connection():
12
- raise RuntimeError("Could not connect to GPU storage server")
 
 
 
13
 
14
  # Initialize SM state
15
  self.sm_state = {
 
4
  import time
5
 
6
  class StreamingMultiprocessor:
7
+ def __init__(self, sm_id: int, num_cores: int = 128, storage=None):
8
  self.sm_id = sm_id
9
  self.num_cores = num_cores
10
+ self.storage = storage
11
+ if self.storage is None:
12
+ from websocket_storage import WebSocketGPUStorage
13
+ self.storage = WebSocketGPUStorage()
14
+ if not self.storage.wait_for_connection():
15
+ raise RuntimeError("Could not connect to GPU storage server")
16
 
17
  # Initialize SM state
18
  self.sm_state = {
tensor_core.py CHANGED
@@ -23,14 +23,17 @@ class TensorCore:
23
  Pure virtual tensor core for matrix operations with zero CPU involvement.
24
  All operations happen in virtual space at electron speed with WebSocket-based storage.
25
  """
26
- def __init__(self, bits=2, memory_size=800*1024*1024*1024, bandwidth_tbps=10000, sm=None):
27
  from electron_speed import drift_velocity, TARGET_SWITCHES_PER_SEC
28
 
29
  self.bits = bits
30
  # WebSocket-based storage
31
- self.storage = WebSocketGPUStorage()
32
- if not self.storage.wait_for_connection():
33
- raise RuntimeError("Could not connect to GPU storage server")
 
 
 
34
 
35
  # Virtual memory space (WebSocket-backed)
36
  self.virtual_memory_map: Dict[str, str] = {} # Maps virtual addresses to tensor IDs
 
23
  Pure virtual tensor core for matrix operations with zero CPU involvement.
24
  All operations happen in virtual space at electron speed with WebSocket-based storage.
25
  """
26
+ def __init__(self, bits=2, memory_size=800*1024*1024*1024, bandwidth_tbps=10000, sm=None, storage=None):
27
  from electron_speed import drift_velocity, TARGET_SWITCHES_PER_SEC
28
 
29
  self.bits = bits
30
  # WebSocket-based storage
31
+ self.storage = storage
32
+ if self.storage is None:
33
+ from websocket_storage import WebSocketGPUStorage
34
+ self.storage = WebSocketGPUStorage()
35
+ if not self.storage.wait_for_connection():
36
+ raise RuntimeError("Could not connect to GPU storage server")
37
 
38
  # Virtual memory space (WebSocket-backed)
39
  self.virtual_memory_map: Dict[str, str] = {} # Maps virtual addresses to tensor IDs
test_ai_integration.py CHANGED
@@ -115,9 +115,8 @@ def test_ai_integration():
115
  chip_for_loading = Chip(chip_id=0, vram_size_gb=None, storage=storage) # Pass shared storage
116
  components['chips'].append(chip_for_loading)
117
 
118
- # Initialize VRAM with WebSocket storage
119
- vram = VirtualVRAM()
120
- vram.storage = storage # Share WebSocket connection
121
  components['vram'] = vram
122
 
123
  # Set up AI accelerator - note it already has the shared storage
 
115
  chip_for_loading = Chip(chip_id=0, vram_size_gb=None, storage=storage) # Pass shared storage
116
  components['chips'].append(chip_for_loading)
117
 
118
+ # Initialize VRAM with shared WebSocket storage
119
+ vram = VirtualVRAM(storage=storage) # Pass shared storage instance
 
120
  components['vram'] = vram
121
 
122
  # Set up AI accelerator - note it already has the shared storage
virtual_vram.py CHANGED
@@ -4,11 +4,14 @@ from typing import Dict, Any, Optional
4
  import time
5
 
6
  class VirtualVRAM:
7
- def __init__(self, size_gb: int = None):
8
  """Initialize virtual VRAM with unlimited storage capability"""
9
- self.storage = WebSocketGPUStorage()
10
- if not self.storage.wait_for_connection():
11
- raise RuntimeError("Could not connect to GPU storage server")
 
 
 
12
 
13
  # Initialize VRAM state with unlimited capacity
14
  self.vram_state = {
 
4
  import time
5
 
6
  class VirtualVRAM:
7
+ def __init__(self, size_gb: int = None, storage=None):
8
  """Initialize virtual VRAM with unlimited storage capability"""
9
+ self.storage = storage
10
+ if self.storage is None:
11
+ from websocket_storage import WebSocketGPUStorage
12
+ self.storage = WebSocketGPUStorage()
13
+ if not self.storage.wait_for_connection():
14
+ raise RuntimeError("Could not connect to GPU storage server")
15
 
16
  # Initialize VRAM state with unlimited capacity
17
  self.vram_state = {
websocket_storage.py CHANGED
@@ -1,435 +1,434 @@
1
- import websockets
2
- import json
3
- import numpy as np
4
- from typing import Dict, Any, Optional, Union
5
- import threading
6
- from queue import Queue
7
- import time
8
- import asyncio
9
-
10
- class WebSocketGPUStorage:
11
- def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"): # Default to local WebSocket server
12
- self.url = url
13
- self.websocket = None
14
- self.connected = False
15
- self.message_queue = Queue()
16
- self.response_queues: Dict[str, Queue] = {}
17
- self.lock = threading.Lock()
18
- self._closing = False
19
- self._loop = None
20
- self.error_count = 0
21
- self.last_error_time = 0
22
- self.max_retries = 5
23
- self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
24
- self.resource_monitor = {'vram_used': 0, 'active_tensors': 0}
25
- self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
26
- self.resource_monitor = {
27
- 'vram_used': 0,
28
- 'active_tensors': 0,
29
- 'loaded_models': set()
30
- }
31
-
32
- # Start WebSocket connection in a separate thread
33
- self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
34
- self.ws_thread.start()
35
-
36
- def _run_websocket_loop(self):
37
- self._loop = asyncio.new_event_loop()
38
- asyncio.set_event_loop(self._loop)
39
- self._loop.run_until_complete(self._websocket_handler())
40
-
41
- async def _websocket_handler(self):
42
- while not self._closing:
43
- try:
44
- async with websockets.connect(self.url) as websocket:
45
- self.websocket = websocket
46
- self.connected = True
47
- self.error_count = 0 # Reset error count on successful connection
48
- print("Connected to GPU storage server")
49
-
50
- while True:
51
- # Handle outgoing messages
52
- try:
53
- while not self.message_queue.empty():
54
- msg_id, operation = self.message_queue.get()
55
- await websocket.send(json.dumps(operation))
56
-
57
- # Wait for response with timeout
58
- try:
59
- response = await asyncio.wait_for(websocket.recv(), timeout=30)
60
- response_data = json.loads(response)
61
-
62
- # Put response in corresponding queue
63
- if msg_id in self.response_queues:
64
- self.response_queues[msg_id].put(response_data)
65
- except asyncio.TimeoutError:
66
- if msg_id in self.response_queues:
67
- self.response_queues[msg_id].put({
68
- "status": "error",
69
- "message": "Operation timed out"
70
- })
71
- except Exception as e:
72
- if msg_id in self.response_queues:
73
- self.response_queues[msg_id].put({
74
- "status": "error",
75
- "message": f"Error processing response: {str(e)}"
76
- })
77
-
78
- except Exception as e:
79
- print(f"Error processing message: {str(e)}")
80
-
81
- # Keep connection alive with heartbeat
82
- try:
83
- await websocket.ping()
84
- except:
85
- break # Break inner loop on ping failure
86
-
87
- await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
88
-
89
- except Exception as e:
90
- print(f"WebSocket connection error: {e}")
91
- self.connected = False
92
- await asyncio.sleep(1) # Wait before reconnecting
93
-
94
- def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
95
- if self._closing:
96
- return {"status": "error", "message": "WebSocket is closing"}
97
-
98
- if not self.wait_for_connection(timeout=10):
99
- return {"status": "error", "message": "Not connected to GPU storage server"}
100
-
101
- msg_id = str(time.time())
102
- response_queue = Queue()
103
-
104
- with self.lock:
105
- self.response_queues[msg_id] = response_queue
106
- self.message_queue.put((msg_id, operation))
107
-
108
- try:
109
- # Wait for response with configurable timeout
110
- response = response_queue.get(timeout=30) # Extended timeout for large models
111
- if response.get("status") == "error" and "model_size" in operation:
112
- # Retry once for model loading operations
113
- self.message_queue.put((msg_id, operation))
114
- response = response_queue.get(timeout=30)
115
- except Exception as e:
116
- response = {"status": "error", "message": f"Operation failed: {str(e)}"}
117
- finally:
118
- with self.lock:
119
- if msg_id in self.response_queues:
120
- del self.response_queues[msg_id]
121
-
122
- return response
123
-
124
- def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
125
- try:
126
- if data is None:
127
- raise ValueError("Cannot store None tensor")
128
-
129
- # Calculate tensor metadata
130
- tensor_shape = data.shape
131
- tensor_dtype = str(data.dtype)
132
- tensor_size = data.nbytes
133
-
134
- operation = {
135
- 'operation': 'vram',
136
- 'type': 'write',
137
- 'block_id': tensor_id,
138
- 'data': data.tolist(),
139
- 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
140
- 'metadata': {
141
- 'shape': tensor_shape,
142
- 'dtype': tensor_dtype,
143
- 'size': tensor_size,
144
- 'timestamp': time.time()
145
- }
146
- }
147
-
148
- response = self._send_operation(operation)
149
- if response.get('status') == 'success':
150
- # Update tensor registry
151
- with self.lock:
152
- self.tensor_registry[tensor_id] = {
153
- 'shape': tensor_shape,
154
- 'dtype': tensor_dtype,
155
- 'size': tensor_size,
156
- 'timestamp': time.time()
157
- }
158
- self.resource_monitor['vram_used'] += tensor_size
159
- self.resource_monitor['active_tensors'] += 1
160
- return True
161
- else:
162
- print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
163
- return False
164
- except Exception as e:
165
- print(f"Error storing tensor {tensor_id}: {str(e)}")
166
- return False
167
-
168
- def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
169
- try:
170
- # Check tensor registry first
171
- if tensor_id not in self.tensor_registry:
172
- print(f"Tensor {tensor_id} not registered in VRAM")
173
- return None
174
-
175
- operation = {
176
- 'operation': 'vram',
177
- 'type': 'read',
178
- 'block_id': tensor_id,
179
- 'expected_metadata': self.tensor_registry.get(tensor_id, {})
180
- }
181
-
182
- response = self._send_operation(operation)
183
- if response.get('status') == 'success':
184
- data = response.get('data')
185
- if data is None:
186
- print(f"No data found for tensor {tensor_id}")
187
- return None
188
-
189
- # Verify tensor metadata
190
- metadata = response.get('metadata', {})
191
- expected_metadata = self.tensor_registry.get(tensor_id, {})
192
- if metadata.get('shape') != expected_metadata.get('shape'):
193
- print(f"Warning: Tensor {tensor_id} shape mismatch")
194
-
195
- try:
196
- # Convert to numpy array with correct dtype
197
- arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
198
- if arr.shape != expected_metadata.get('shape'):
199
- arr = arr.reshape(expected_metadata.get('shape'))
200
- return arr
201
- except Exception as e:
202
- print(f"Error converting tensor data: {str(e)}")
203
- return None
204
- else:
205
- print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
206
- return None
207
- except Exception as e:
208
- print(f"Error loading tensor {tensor_id}: {str(e)}")
209
- return None
210
-
211
- def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
212
- try:
213
- operation = {
214
- 'operation': 'state',
215
- 'type': 'save',
216
- 'component': component,
217
- 'state_id': state_id,
218
- 'data': state_data,
219
- 'timestamp': time.time()
220
- }
221
-
222
- response = self._send_operation(operation)
223
- if response.get('status') != 'success':
224
- print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
225
- return False
226
- return True
227
- except Exception as e:
228
- print(f"Error storing state for {component}/{state_id}: {str(e)}")
229
- return False
230
-
231
- def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
232
- try:
233
- operation = {
234
- 'operation': 'state',
235
- 'type': 'load',
236
- 'component': component,
237
- 'state_id': state_id
238
- }
239
-
240
- response = self._send_operation(operation)
241
- if response.get('status') == 'success':
242
- data = response.get('data')
243
- if data is None:
244
- print(f"No state found for {component}/{state_id}")
245
- return None
246
- return data
247
- else:
248
- print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
249
- return None
250
- except Exception as e:
251
- print(f"Error loading state for {component}/{state_id}: {str(e)}")
252
- return None
253
-
254
- def is_model_loaded(self, model_name: str) -> bool:
255
- """Check if a model is already loaded in VRAM"""
256
- return model_name in self.resource_monitor['loaded_models']
257
-
258
- def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
259
- """Load a model into VRAM if not already loaded"""
260
- try:
261
- # Check if model is already loaded
262
- if self.is_model_loaded(model_name):
263
- print(f"Model {model_name} already loaded in VRAM")
264
- return True
265
-
266
- # Calculate model hash if path provided
267
- model_hash = None
268
- if model_path:
269
- model_hash = self._calculate_model_hash(model_path)
270
-
271
- operation = {
272
- 'operation': 'model',
273
- 'type': 'load',
274
- 'model_name': model_name,
275
- 'model_hash': model_hash,
276
- 'model_data': model_data
277
- }
278
-
279
- response = self._send_operation(operation)
280
- if response.get('status') == 'success':
281
- with self.lock:
282
- self.model_registry[model_name] = {
283
- 'hash': model_hash,
284
- 'timestamp': time.time(),
285
- 'tensors': response.get('tensor_ids', [])
286
- }
287
- self.resource_monitor['loaded_models'].add(model_name)
288
- print(f"Successfully loaded model {model_name}")
289
- return True
290
- else:
291
- print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
292
- return False
293
- except Exception as e:
294
- print(f"Error loading model {model_name}: {str(e)}")
295
- return False
296
-
297
- def _calculate_model_hash(self, model_path: str) -> str:
298
- """Calculate SHA256 hash of model file"""
299
- try:
300
- sha256_hash = hashlib.sha256()
301
- with open(model_path, "rb") as f:
302
- for byte_block in iter(lambda: f.read(4096), b""):
303
- sha256_hash.update(byte_block)
304
- return sha256_hash.hexdigest()
305
- except Exception as e:
306
- print(f"Error calculating model hash: {str(e)}")
307
- return ""
308
-
309
- def cache_data(self, key: str, data: Any) -> bool:
310
- operation = {
311
- 'operation': 'cache',
312
- 'type': 'set',
313
- 'key': key,
314
- 'data': data
315
- }
316
-
317
- response = self._send_operation(operation)
318
- return response.get('status') == 'success'
319
-
320
- def get_cached_data(self, key: str) -> Optional[Any]:
321
- operation = {
322
- 'operation': 'cache',
323
- 'type': 'get',
324
- 'key': key
325
- }
326
-
327
- response = self._send_operation(operation)
328
- if response.get('status') == 'success':
329
- return response['data']
330
- return None
331
-
332
- def wait_for_connection(self, timeout: float = 30.0) -> bool:
333
- """Wait for WebSocket connection to be established"""
334
- start_time = time.time()
335
- while not self._closing and not self.connected:
336
- if time.time() - start_time > timeout:
337
- print("Connection timeout exceeded")
338
- return False
339
- time.sleep(0.1)
340
- return self.connected
341
-
342
- def is_connected(self) -> bool:
343
- """Check if WebSocket connection is active"""
344
- return self.connected and not self._closing
345
-
346
- def get_connection_status(self) -> Dict[str, Any]:
347
- """Get detailed connection status"""
348
- return {
349
- "connected": self.connected,
350
- "closing": self._closing,
351
- "error_count": self.error_count,
352
- "url": self.url,
353
- "last_error_time": self.last_error_time,
354
- "loaded_models": list(self.resource_monitor['loaded_models'])
355
- }
356
-
357
- def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
358
- """Start inference with a loaded model"""
359
- try:
360
- if not self.is_model_loaded(model_name):
361
- print(f"Model {model_name} not loaded. Please load the model first.")
362
- return None
363
-
364
- operation = {
365
- 'operation': 'inference',
366
- 'type': 'run',
367
- 'model_name': model_name,
368
- 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
369
- }
370
-
371
- response = self._send_operation(operation)
372
- if response.get('status') == 'success':
373
- return {
374
- 'output': np.array(response['output']) if 'output' in response else None,
375
- 'metrics': response.get('metrics', {}),
376
- 'model_info': self.model_registry.get(model_name, {})
377
- }
378
- else:
379
- print(f"Inference failed: {response.get('message', 'Unknown error')}")
380
- return None
381
- except Exception as e:
382
- print(f"Error during inference: {str(e)}")
383
- return None
384
-
385
- def close(self):
386
- """Close WebSocket connection and cleanup resources."""
387
- if not self._closing:
388
- self._closing = True
389
- if self.websocket and self._loop:
390
- async def cleanup():
391
- try:
392
- # Clean up registries
393
- with self.lock:
394
- self.tensor_registry.clear()
395
- self.model_registry.clear()
396
- self.resource_monitor['vram_used'] = 0
397
- self.resource_monitor['active_tensors'] = 0
398
- self.resource_monitor['loaded_models'].clear()
399
-
400
- # Notify server about cleanup
401
- if self.connected:
402
- try:
403
- await self.websocket.send(json.dumps({
404
- 'operation': 'cleanup',
405
- 'type': 'full'
406
- }))
407
- except:
408
- pass
409
-
410
- await self.websocket.close()
411
- except Exception as e:
412
- print(f"Error during cleanup: {str(e)}")
413
- finally:
414
- self.connected = False
415
-
416
- if self._loop.is_running():
417
- self._loop.create_task(cleanup())
418
- else:
419
- asyncio.run(cleanup())
420
-
421
- async def aclose(self):
422
- """Asynchronously close WebSocket connection."""
423
- if not self._closing:
424
- self._closing = True
425
- if self.websocket:
426
- try:
427
- await self.websocket.close()
428
- except:
429
- pass
430
- finally:
431
- self.connected = False
432
-
433
- def __del__(self):
434
- """Ensure cleanup on deletion."""
435
- self.close()
 
1
+ import websockets
2
+ import json
3
+ import numpy as np
4
+ from typing import Dict, Any, Optional, Union
5
+ import threading
6
+ from queue import Queue
7
+ import time
8
+
9
+ class WebSocketGPUStorage:
10
+ def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"): # Default to local WebSocket server
11
+ self.url = url
12
+ self.websocket = None
13
+ self.connected = False
14
+ self.message_queue = Queue()
15
+ self.response_queues: Dict[str, Queue] = {}
16
+ self.lock = threading.Lock()
17
+ self._closing = False
18
+ self._loop = None
19
+ self.error_count = 0
20
+ self.last_error_time = 0
21
+ self.max_retries = 5
22
+ self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
23
+ self.resource_monitor = {'vram_used': 0, 'active_tensors': 0}
24
+ self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
25
+ self.resource_monitor = {
26
+ 'vram_used': 0,
27
+ 'active_tensors': 0,
28
+ 'loaded_models': set()
29
+ }
30
+
31
+ # Start WebSocket connection in a separate thread
32
+ self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
33
+ self.ws_thread.start()
34
+
35
+ def _run_websocket_loop(self):
36
+ self._loop = asyncio.new_event_loop()
37
+ asyncio.set_event_loop(self._loop)
38
+ self._loop.run_until_complete(self._websocket_handler())
39
+
40
+ async def _websocket_handler(self):
41
+ while not self._closing:
42
+ try:
43
+ async with websockets.connect(self.url) as websocket:
44
+ self.websocket = websocket
45
+ self.connected = True
46
+ self.error_count = 0 # Reset error count on successful connection
47
+ print("Connected to GPU storage server")
48
+
49
+ while True:
50
+ # Handle outgoing messages
51
+ try:
52
+ while not self.message_queue.empty():
53
+ msg_id, operation = self.message_queue.get()
54
+ await websocket.send(json.dumps(operation))
55
+
56
+ # Wait for response with timeout
57
+ try:
58
+ response = await asyncio.wait_for(websocket.recv(), timeout=30)
59
+ response_data = json.loads(response)
60
+
61
+ # Put response in corresponding queue
62
+ if msg_id in self.response_queues:
63
+ self.response_queues[msg_id].put(response_data)
64
+ except asyncio.TimeoutError:
65
+ if msg_id in self.response_queues:
66
+ self.response_queues[msg_id].put({
67
+ "status": "error",
68
+ "message": "Operation timed out"
69
+ })
70
+ except Exception as e:
71
+ if msg_id in self.response_queues:
72
+ self.response_queues[msg_id].put({
73
+ "status": "error",
74
+ "message": f"Error processing response: {str(e)}"
75
+ })
76
+
77
+ except Exception as e:
78
+ print(f"Error processing message: {str(e)}")
79
+
80
+ # Keep connection alive with heartbeat
81
+ try:
82
+ await websocket.ping()
83
+ except:
84
+ break # Break inner loop on ping failure
85
+
86
+ await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
87
+
88
+ except Exception as e:
89
+ print(f"WebSocket connection error: {e}")
90
+ self.connected = False
91
+ await asyncio.sleep(1) # Wait before reconnecting
92
+
93
+ def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
94
+ if self._closing:
95
+ return {"status": "error", "message": "WebSocket is closing"}
96
+
97
+ if not self.wait_for_connection(timeout=10):
98
+ return {"status": "error", "message": "Not connected to GPU storage server"}
99
+
100
+ msg_id = str(time.time())
101
+ response_queue = Queue()
102
+
103
+ with self.lock:
104
+ self.response_queues[msg_id] = response_queue
105
+ self.message_queue.put((msg_id, operation))
106
+
107
+ try:
108
+ # Wait for response with configurable timeout
109
+ response = response_queue.get(timeout=30) # Extended timeout for large models
110
+ if response.get("status") == "error" and "model_size" in operation:
111
+ # Retry once for model loading operations
112
+ self.message_queue.put((msg_id, operation))
113
+ response = response_queue.get(timeout=30)
114
+ except Exception as e:
115
+ response = {"status": "error", "message": f"Operation failed: {str(e)}"}
116
+ finally:
117
+ with self.lock:
118
+ if msg_id in self.response_queues:
119
+ del self.response_queues[msg_id]
120
+
121
+ return response
122
+
123
+ def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
124
+ try:
125
+ if data is None:
126
+ raise ValueError("Cannot store None tensor")
127
+
128
+ # Calculate tensor metadata
129
+ tensor_shape = data.shape
130
+ tensor_dtype = str(data.dtype)
131
+ tensor_size = data.nbytes
132
+
133
+ operation = {
134
+ 'operation': 'vram',
135
+ 'type': 'write',
136
+ 'block_id': tensor_id,
137
+ 'data': data.tolist(),
138
+ 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
139
+ 'metadata': {
140
+ 'shape': tensor_shape,
141
+ 'dtype': tensor_dtype,
142
+ 'size': tensor_size,
143
+ 'timestamp': time.time()
144
+ }
145
+ }
146
+
147
+ response = self._send_operation(operation)
148
+ if response.get('status') == 'success':
149
+ # Update tensor registry
150
+ with self.lock:
151
+ self.tensor_registry[tensor_id] = {
152
+ 'shape': tensor_shape,
153
+ 'dtype': tensor_dtype,
154
+ 'size': tensor_size,
155
+ 'timestamp': time.time()
156
+ }
157
+ self.resource_monitor['vram_used'] += tensor_size
158
+ self.resource_monitor['active_tensors'] += 1
159
+ return True
160
+ else:
161
+ print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
162
+ return False
163
+ except Exception as e:
164
+ print(f"Error storing tensor {tensor_id}: {str(e)}")
165
+ return False
166
+
167
+ def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
168
+ try:
169
+ # Check tensor registry first
170
+ if tensor_id not in self.tensor_registry:
171
+ print(f"Tensor {tensor_id} not registered in VRAM")
172
+ return None
173
+
174
+ operation = {
175
+ 'operation': 'vram',
176
+ 'type': 'read',
177
+ 'block_id': tensor_id,
178
+ 'expected_metadata': self.tensor_registry.get(tensor_id, {})
179
+ }
180
+
181
+ response = self._send_operation(operation)
182
+ if response.get('status') == 'success':
183
+ data = response.get('data')
184
+ if data is None:
185
+ print(f"No data found for tensor {tensor_id}")
186
+ return None
187
+
188
+ # Verify tensor metadata
189
+ metadata = response.get('metadata', {})
190
+ expected_metadata = self.tensor_registry.get(tensor_id, {})
191
+ if metadata.get('shape') != expected_metadata.get('shape'):
192
+ print(f"Warning: Tensor {tensor_id} shape mismatch")
193
+
194
+ try:
195
+ # Convert to numpy array with correct dtype
196
+ arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
197
+ if arr.shape != expected_metadata.get('shape'):
198
+ arr = arr.reshape(expected_metadata.get('shape'))
199
+ return arr
200
+ except Exception as e:
201
+ print(f"Error converting tensor data: {str(e)}")
202
+ return None
203
+ else:
204
+ print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
205
+ return None
206
+ except Exception as e:
207
+ print(f"Error loading tensor {tensor_id}: {str(e)}")
208
+ return None
209
+
210
+ def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
211
+ try:
212
+ operation = {
213
+ 'operation': 'state',
214
+ 'type': 'save',
215
+ 'component': component,
216
+ 'state_id': state_id,
217
+ 'data': state_data,
218
+ 'timestamp': time.time()
219
+ }
220
+
221
+ response = self._send_operation(operation)
222
+ if response.get('status') != 'success':
223
+ print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
224
+ return False
225
+ return True
226
+ except Exception as e:
227
+ print(f"Error storing state for {component}/{state_id}: {str(e)}")
228
+ return False
229
+
230
+ def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
231
+ try:
232
+ operation = {
233
+ 'operation': 'state',
234
+ 'type': 'load',
235
+ 'component': component,
236
+ 'state_id': state_id
237
+ }
238
+
239
+ response = self._send_operation(operation)
240
+ if response.get('status') == 'success':
241
+ data = response.get('data')
242
+ if data is None:
243
+ print(f"No state found for {component}/{state_id}")
244
+ return None
245
+ return data
246
+ else:
247
+ print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
248
+ return None
249
+ except Exception as e:
250
+ print(f"Error loading state for {component}/{state_id}: {str(e)}")
251
+ return None
252
+
253
+ def is_model_loaded(self, model_name: str) -> bool:
254
+ """Check if a model is already loaded in VRAM"""
255
+ return model_name in self.resource_monitor['loaded_models']
256
+
257
+ def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
258
+ """Load a model into VRAM if not already loaded"""
259
+ try:
260
+ # Check if model is already loaded
261
+ if self.is_model_loaded(model_name):
262
+ print(f"Model {model_name} already loaded in VRAM")
263
+ return True
264
+
265
+ # Calculate model hash if path provided
266
+ model_hash = None
267
+ if model_path:
268
+ model_hash = self._calculate_model_hash(model_path)
269
+
270
+ operation = {
271
+ 'operation': 'model',
272
+ 'type': 'load',
273
+ 'model_name': model_name,
274
+ 'model_hash': model_hash,
275
+ 'model_data': model_data
276
+ }
277
+
278
+ response = self._send_operation(operation)
279
+ if response.get('status') == 'success':
280
+ with self.lock:
281
+ self.model_registry[model_name] = {
282
+ 'hash': model_hash,
283
+ 'timestamp': time.time(),
284
+ 'tensors': response.get('tensor_ids', [])
285
+ }
286
+ self.resource_monitor['loaded_models'].add(model_name)
287
+ print(f"Successfully loaded model {model_name}")
288
+ return True
289
+ else:
290
+ print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
291
+ return False
292
+ except Exception as e:
293
+ print(f"Error loading model {model_name}: {str(e)}")
294
+ return False
295
+
296
+ def _calculate_model_hash(self, model_path: str) -> str:
297
+ """Calculate SHA256 hash of model file"""
298
+ try:
299
+ sha256_hash = hashlib.sha256()
300
+ with open(model_path, "rb") as f:
301
+ for byte_block in iter(lambda: f.read(4096), b""):
302
+ sha256_hash.update(byte_block)
303
+ return sha256_hash.hexdigest()
304
+ except Exception as e:
305
+ print(f"Error calculating model hash: {str(e)}")
306
+ return ""
307
+
308
+ def cache_data(self, key: str, data: Any) -> bool:
309
+ operation = {
310
+ 'operation': 'cache',
311
+ 'type': 'set',
312
+ 'key': key,
313
+ 'data': data
314
+ }
315
+
316
+ response = self._send_operation(operation)
317
+ return response.get('status') == 'success'
318
+
319
+ def get_cached_data(self, key: str) -> Optional[Any]:
320
+ operation = {
321
+ 'operation': 'cache',
322
+ 'type': 'get',
323
+ 'key': key
324
+ }
325
+
326
+ response = self._send_operation(operation)
327
+ if response.get('status') == 'success':
328
+ return response['data']
329
+ return None
330
+
331
+ def wait_for_connection(self, timeout: float = 30.0) -> bool:
332
+ """Wait for WebSocket connection to be established"""
333
+ start_time = time.time()
334
+ while not self._closing and not self.connected:
335
+ if time.time() - start_time > timeout:
336
+ print("Connection timeout exceeded")
337
+ return False
338
+ time.sleep(0.1)
339
+ return self.connected
340
+
341
+ def is_connected(self) -> bool:
342
+ """Check if WebSocket connection is active"""
343
+ return self.connected and not self._closing
344
+
345
+ def get_connection_status(self) -> Dict[str, Any]:
346
+ """Get detailed connection status"""
347
+ return {
348
+ "connected": self.connected,
349
+ "closing": self._closing,
350
+ "error_count": self.error_count,
351
+ "url": self.url,
352
+ "last_error_time": self.last_error_time,
353
+ "loaded_models": list(self.resource_monitor['loaded_models'])
354
+ }
355
+
356
+ def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
357
+ """Start inference with a loaded model"""
358
+ try:
359
+ if not self.is_model_loaded(model_name):
360
+ print(f"Model {model_name} not loaded. Please load the model first.")
361
+ return None
362
+
363
+ operation = {
364
+ 'operation': 'inference',
365
+ 'type': 'run',
366
+ 'model_name': model_name,
367
+ 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
368
+ }
369
+
370
+ response = self._send_operation(operation)
371
+ if response.get('status') == 'success':
372
+ return {
373
+ 'output': np.array(response['output']) if 'output' in response else None,
374
+ 'metrics': response.get('metrics', {}),
375
+ 'model_info': self.model_registry.get(model_name, {})
376
+ }
377
+ else:
378
+ print(f"Inference failed: {response.get('message', 'Unknown error')}")
379
+ return None
380
+ except Exception as e:
381
+ print(f"Error during inference: {str(e)}")
382
+ return None
383
+
384
+ def close(self):
385
+ """Close WebSocket connection and cleanup resources."""
386
+ if not self._closing:
387
+ self._closing = True
388
+ if self.websocket and self._loop:
389
+ async def cleanup():
390
+ try:
391
+ # Clean up registries
392
+ with self.lock:
393
+ self.tensor_registry.clear()
394
+ self.model_registry.clear()
395
+ self.resource_monitor['vram_used'] = 0
396
+ self.resource_monitor['active_tensors'] = 0
397
+ self.resource_monitor['loaded_models'].clear()
398
+
399
+ # Notify server about cleanup
400
+ if self.connected:
401
+ try:
402
+ await self.websocket.send(json.dumps({
403
+ 'operation': 'cleanup',
404
+ 'type': 'full'
405
+ }))
406
+ except:
407
+ pass
408
+
409
+ await self.websocket.close()
410
+ except Exception as e:
411
+ print(f"Error during cleanup: {str(e)}")
412
+ finally:
413
+ self.connected = False
414
+
415
+ if self._loop.is_running():
416
+ self._loop.create_task(cleanup())
417
+ else:
418
+ asyncio.run(cleanup())
419
+
420
+ async def aclose(self):
421
+ """Asynchronously close WebSocket connection."""
422
+ if not self._closing:
423
+ self._closing = True
424
+ if self.websocket:
425
+ try:
426
+ await self.websocket.close()
427
+ except:
428
+ pass
429
+ finally:
430
+ self.connected = False
431
+
432
+ def __del__(self):
433
+ """Ensure cleanup on deletion."""
434
+ self.close()