Factor Studios commited on
Commit
43ff7f0
·
verified ·
1 Parent(s): c7bb495

Upload 3 files

Browse files
Files changed (2) hide show
  1. ai.py +1 -1
  2. websocket_storage.py +455 -497
ai.py CHANGED
@@ -174,7 +174,7 @@ class AIAccelerator:
174
  if isinstance(test_input, list):
175
  test_input = np.array(test_input, dtype=np.float32)
176
 
177
- test_result = self.tensor_core_array.matmul(test_input.tolist(), test_input.tolist())
178
  if test_result is None or not isinstance(test_result, (np.ndarray, list)) or len(test_result) == 0:
179
  raise RuntimeError("Tensor core test computation failed")
180
 
 
174
  if isinstance(test_input, list):
175
  test_input = np.array(test_input, dtype=np.float32)
176
 
177
+ test_result = self.tensor_core_array.matmul(test_input, test_input)
178
  if test_result is None or not isinstance(test_result, (np.ndarray, list)) or len(test_result) == 0:
179
  raise RuntimeError("Tensor core test computation failed")
180
 
websocket_storage.py CHANGED
@@ -1,497 +1,455 @@
1
- import websockets
2
- import json
3
- import numpy as np
4
- from typing import Dict, Any, Optional, Union
5
- import threading
6
- from queue import Queue
7
- import time
8
- import asyncio
9
- import hashlib
10
- import dataclasses
11
-
12
- def custom_json_serializer(obj):
13
- if hasattr(obj, '__dict__'):
14
- return obj.__dict__
15
- elif hasattr(obj, '_asdict'): # For namedtuples
16
- return obj._asdict()
17
- elif dataclasses.is_dataclass(obj):
18
- return dataclasses.asdict(obj)
19
- elif isinstance(obj, (np.ndarray, np.number)):
20
- return obj.tolist()
21
- elif isinstance(obj, set):
22
- return list(obj)
23
- raise TypeError(f'Object of type {type(obj)} is not JSON serializable')
24
-
25
- class WebSocketGPUStorage:
26
- # Singleton instance
27
- _instance = None
28
- _lock = threading.Lock()
29
-
30
- def __new__(cls, url: str = "wss://factorst-wbs1.hf.space/ws"):
31
- with cls._lock:
32
- if cls._instance is None:
33
- cls._instance = super().__new__(cls)
34
- cls._instance._init_singleton(url)
35
- return cls._instance
36
-
37
- def _init_singleton(self, url: str):
38
- """Initialize the singleton instance"""
39
- if hasattr(self, 'initialized'):
40
- return
41
-
42
- self.url = url
43
- self.websocket = None
44
- self.connected = False
45
- self.message_queue = Queue()
46
- self.response_queues: Dict[str, Queue] = {}
47
- self.lock = threading.Lock()
48
- self._closing = False
49
- self._loop = None
50
- self.error_count = 0
51
- self.last_error_time = 0
52
- self.max_retries = 5
53
- self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
54
- self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
55
- self.resource_monitor = {
56
- 'vram_used': 0,
57
- 'active_tensors': 0,
58
- 'loaded_models': set()
59
- }
60
-
61
- # Start WebSocket connection in a separate thread
62
- self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
63
- self.ws_thread.start()
64
- self.initialized = True
65
-
66
- def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"):
67
- """This will actually just return the singleton instance"""
68
- pass
69
-
70
- def _run_websocket_loop(self):
71
- self._loop = asyncio.new_event_loop()
72
- asyncio.set_event_loop(self._loop)
73
- self._loop.run_until_complete(self._websocket_handler())
74
-
75
- async def _websocket_handler(self):
76
- while not self._closing:
77
- try:
78
- async with websockets.connect(self.url) as websocket:
79
- self.websocket = websocket
80
- self.connected = True
81
- self.error_count = 0 # Reset error count on successful connection
82
- print("Connected to GPU storage server")
83
-
84
- while True:
85
- # Handle outgoing messages
86
- try:
87
- while not self.message_queue.empty():
88
- msg_id, operation = self.message_queue.get()
89
- await websocket.send(json.dumps(operation, default=custom_json_serializer))
90
-
91
- # Wait for response with timeout
92
- try:
93
- response = await asyncio.wait_for(websocket.recv(), timeout=30)
94
- response_data = json.loads(response)
95
-
96
- # Put response in corresponding queue
97
- if msg_id in self.response_queues:
98
- self.response_queues[msg_id].put(response_data)
99
- except asyncio.TimeoutError:
100
- if msg_id in self.response_queues:
101
- self.response_queues[msg_id].put({
102
- "status": "error",
103
- "message": "Operation timed out"
104
- })
105
- except Exception as e:
106
- if msg_id in self.response_queues:
107
- self.response_queues[msg_id].put({
108
- "status": "error",
109
- "message": f"Error processing response: {str(e)}"
110
- })
111
-
112
- except Exception as e:
113
- print(f"Error processing message: {str(e)}")
114
-
115
- # Keep connection alive with heartbeat
116
- try:
117
- await websocket.ping()
118
- except:
119
- break # Break inner loop on ping failure
120
-
121
- await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
122
-
123
- except Exception as e:
124
- print(f"WebSocket connection error: {e}")
125
- self.connected = False
126
- await asyncio.sleep(1) # Wait before reconnecting
127
-
128
- def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
129
- if self._closing:
130
- return {"status": "error", "message": "WebSocket is closing"}
131
-
132
- if not self.wait_for_connection(timeout=10):
133
- return {"status": "error", "message": "Not connected to GPU storage server"}
134
-
135
- msg_id = str(time.time())
136
- response_queue = Queue()
137
-
138
- with self.lock:
139
- self.response_queues[msg_id] = response_queue
140
- self.message_queue.put((msg_id, operation))
141
-
142
- try:
143
- # Wait for response with configurable timeout
144
- response = response_queue.get(timeout=30) # Extended timeout for large models
145
- if response.get("status") == "error" and "model_size" in operation:
146
- # Retry once for model loading operations
147
- self.message_queue.put((msg_id, operation))
148
- response = response_queue.get(timeout=30)
149
- except Exception as e:
150
- response = {"status": "error", "message": f"Operation failed: {str(e)}"}
151
- finally:
152
- with self.lock:
153
- if msg_id in self.response_queues:
154
- del self.response_queues[msg_id]
155
-
156
- return response
157
-
158
- def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
159
- try:
160
- if data is None:
161
- raise ValueError("Cannot store None tensor")
162
-
163
- # Calculate tensor metadata
164
- tensor_shape = data.shape
165
- tensor_dtype = str(data.dtype)
166
- tensor_size = data.nbytes
167
-
168
- operation = {
169
- 'operation': 'vram',
170
- 'type': 'write',
171
- 'block_id': tensor_id,
172
- 'data': data.tolist(),
173
- 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
174
- 'metadata': {
175
- 'shape': tensor_shape,
176
- 'dtype': tensor_dtype,
177
- 'size': tensor_size,
178
- 'timestamp': time.time()
179
- }
180
- }
181
-
182
- response = self._send_operation(operation)
183
- if response.get('status') == 'success':
184
- # Update tensor registry
185
- with self.lock:
186
- self.tensor_registry[tensor_id] = {
187
- 'shape': tensor_shape,
188
- 'dtype': tensor_dtype,
189
- 'size': tensor_size,
190
- 'timestamp': time.time()
191
- }
192
- self.resource_monitor['vram_used'] += tensor_size
193
- self.resource_monitor['active_tensors'] += 1
194
- return True
195
- else:
196
- print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
197
- return False
198
- except Exception as e:
199
- print(f"Error storing tensor {tensor_id}: {str(e)}")
200
- return False
201
-
202
- def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
203
- try:
204
- # Check tensor registry first
205
- if tensor_id not in self.tensor_registry:
206
- print(f"Tensor {tensor_id} not registered in VRAM")
207
- return None
208
-
209
- operation = {
210
- 'operation': 'vram',
211
- 'type': 'read',
212
- 'block_id': tensor_id,
213
- 'expected_metadata': self.tensor_registry.get(tensor_id, {})
214
- }
215
-
216
- response = self._send_operation(operation)
217
- if response.get('status') == 'success':
218
- data = response.get('data')
219
- if data is None:
220
- print(f"No data found for tensor {tensor_id}")
221
- return None
222
-
223
- # Verify tensor metadata
224
- metadata = response.get('metadata', {})
225
- expected_metadata = self.tensor_registry.get(tensor_id, {})
226
- if metadata.get('shape') != expected_metadata.get('shape'):
227
- print(f"Warning: Tensor {tensor_id} shape mismatch")
228
-
229
- try:
230
- # Convert to numpy array with correct dtype
231
- arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
232
- if arr.shape != expected_metadata.get('shape'):
233
- arr = arr.reshape(expected_metadata.get('shape'))
234
- return arr
235
- except Exception as e:
236
- print(f"Error converting tensor data: {str(e)}")
237
- return None
238
- else:
239
- print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
240
- return None
241
- except Exception as e:
242
- print(f"Error loading tensor {tensor_id}: {str(e)}")
243
- return None
244
-
245
- def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
246
- try:
247
- # Use memory-based state storage instead of file-based
248
- state_key = f"{component}_{state_id}"
249
-
250
- # Store state in memory
251
- operation = {
252
- 'operation': 'state',
253
- 'type': 'save',
254
- 'component': component,
255
- 'state_id': state_id,
256
- 'data': state_data
257
- }
258
-
259
- response = self._send_operation(operation)
260
- if response.get('status') != 'success':
261
- error_msg = response.get('message', 'Unknown error')
262
- if 'Permission denied' in error_msg:
263
- # Try memory-only fallback
264
- operation['storage_type'] = 'memory_only'
265
- response = self._send_operation(operation)
266
- if response.get('status') == 'success':
267
- return True
268
- print(f"Failed to store state for {component}/{state_id}: {error_msg}")
269
- return False
270
- return True
271
- except Exception as e:
272
- print(f"Error storing state for {component}/{state_id}: {str(e)}")
273
- return False
274
-
275
- def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
276
- try:
277
- state_key = f"{component}_{state_id}"
278
-
279
- # Try loading from memory first
280
- operation = {
281
- 'operation': 'vram/state',
282
- 'type': 'read',
283
- 'key': state_key,
284
- 'metadata': {
285
- 'component': component,
286
- 'state_id': state_id,
287
- 'storage_type': 'memory'
288
- }
289
- }
290
-
291
- response = self._send_operation(operation)
292
- if response.get('status') == 'success':
293
- data = response.get('data')
294
- if data is None:
295
- print(f"No state found for {component}/{state_id}")
296
- return None
297
- return data
298
- else:
299
- error_msg = response.get('message', 'Unknown error')
300
- if 'Permission denied' in error_msg:
301
- # Try memory-only fallback
302
- operation['storage_type'] = 'memory_only'
303
- response = self._send_operation(operation)
304
- if response.get('status') == 'success':
305
- return response.get('data')
306
- print(f"Failed to load state for {component}/{state_id}: {error_msg}")
307
- return None
308
- except Exception as e:
309
- print(f"Error loading state for {component}/{state_id}: {str(e)}")
310
- return None
311
-
312
- def is_model_loaded(self, model_name: str) -> bool:
313
- """Check if a model is already loaded in VRAM"""
314
- return model_name in self.resource_monitor['loaded_models']
315
-
316
- def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
317
- """Load a model into VRAM if not already loaded"""
318
- try:
319
- # Check if model is already loaded
320
- if self.is_model_loaded(model_name):
321
- print(f"Model {model_name} already loaded in VRAM")
322
- return True
323
-
324
- # Calculate model hash if path provided
325
- model_hash = None
326
- if model_path:
327
- model_hash = self._calculate_model_hash(model_path)
328
-
329
- operation = {
330
- 'operation': 'vram',
331
- 'type': 'write',
332
- 'block_id': f"model_{model_name}",
333
- 'data': model_data,
334
- 'metadata': {
335
- 'hash': model_hash,
336
- 'model_name': model_name,
337
- 'type': 'model'
338
- }
339
- }
340
-
341
- response = self._send_operation(operation)
342
- if response.get('status') == 'success':
343
- with self.lock:
344
- self.model_registry[model_name] = {
345
- 'hash': model_hash,
346
- 'timestamp': time.time(),
347
- 'tensors': response.get('tensor_ids', [])
348
- }
349
- self.resource_monitor['loaded_models'].add(model_name)
350
- print(f"Successfully loaded model {model_name}")
351
- return True
352
- else:
353
- print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
354
- return False
355
- except Exception as e:
356
- print(f"Error loading model {model_name}: {str(e)}")
357
- return False
358
-
359
- def _calculate_model_hash(self, model_path: str) -> str:
360
- """Calculate SHA256 hash of model file"""
361
- try:
362
- sha256_hash = hashlib.sha256()
363
- with open(model_path, "rb") as f:
364
- for byte_block in iter(lambda: f.read(4096), b""):
365
- sha256_hash.update(byte_block)
366
- return sha256_hash.hexdigest()
367
- except Exception as e:
368
- print(f"Error calculating model hash: {str(e)}")
369
- return ""
370
-
371
- def cache_data(self, key: str, data: Any) -> bool:
372
- operation = {
373
- 'operation': 'cache',
374
- 'type': 'set',
375
- 'key': key,
376
- 'data': data
377
- }
378
-
379
- response = self._send_operation(operation)
380
- return response.get('status') == 'success'
381
-
382
- def get_cached_data(self, key: str) -> Optional[Any]:
383
- operation = {
384
- 'operation': 'cache',
385
- 'type': 'get',
386
- 'key': key
387
- }
388
-
389
- response = self._send_operation(operation)
390
- if response.get('status') == 'success':
391
- return response['data']
392
- return None
393
-
394
- def wait_for_connection(self, timeout: float = 30.0) -> bool:
395
- """Wait for WebSocket connection to be established"""
396
- start_time = time.time()
397
- while not self._closing and not self.connected:
398
- if time.time() - start_time > timeout:
399
- print("Connection timeout exceeded")
400
- return False
401
- time.sleep(0.1)
402
- return self.connected
403
-
404
- def is_connected(self) -> bool:
405
- """Check if WebSocket connection is active"""
406
- return self.connected and not self._closing
407
-
408
- def get_connection_status(self) -> Dict[str, Any]:
409
- """Get detailed connection status"""
410
- return {
411
- "connected": self.connected,
412
- "closing": self._closing,
413
- "error_count": self.error_count,
414
- "url": self.url,
415
- "last_error_time": self.last_error_time,
416
- "loaded_models": list(self.resource_monitor['loaded_models'])
417
- }
418
-
419
- def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
420
- """Start inference with a loaded model"""
421
- try:
422
- if not self.is_model_loaded(model_name):
423
- print(f"Model {model_name} not loaded. Please load the model first.")
424
- return None
425
-
426
- operation = {
427
- 'operation': 'inference',
428
- 'type': 'run',
429
- 'model_name': model_name,
430
- 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
431
- }
432
-
433
- response = self._send_operation(operation)
434
- if response.get('status') == 'success':
435
- return {
436
- 'output': np.array(response['output']) if 'output' in response else None,
437
- 'metrics': response.get('metrics', {}),
438
- 'model_info': self.model_registry.get(model_name, {})
439
- }
440
- else:
441
- print(f"Inference failed: {response.get('message', 'Unknown error')}")
442
- return None
443
- except Exception as e:
444
- print(f"Error during inference: {str(e)}")
445
- return None
446
-
447
- def close(self):
448
- """Close WebSocket connection and cleanup resources."""
449
- if not self._closing:
450
- self._closing = True
451
- if self.websocket and self._loop:
452
- async def cleanup():
453
- try:
454
- # Clean up registries
455
- with self.lock:
456
- self.tensor_registry.clear()
457
- self.model_registry.clear()
458
- self.resource_monitor['vram_used'] = 0
459
- self.resource_monitor['active_tensors'] = 0
460
- self.resource_monitor['loaded_models'].clear()
461
-
462
- # Notify server about cleanup
463
- if self.connected:
464
- try:
465
- await self.websocket.send(json.dumps({
466
- 'operation': 'cleanup',
467
- 'type': 'full'
468
- }))
469
- except:
470
- pass
471
-
472
- await self.websocket.close()
473
- except Exception as e:
474
- print(f"Error during cleanup: {str(e)}")
475
- finally:
476
- self.connected = False
477
-
478
- if self._loop.is_running():
479
- self._loop.create_task(cleanup())
480
- else:
481
- asyncio.run(cleanup())
482
-
483
- async def aclose(self):
484
- """Asynchronously close WebSocket connection."""
485
- if not self._closing:
486
- self._closing = True
487
- if self.websocket:
488
- try:
489
- await self.websocket.close()
490
- except:
491
- pass
492
- finally:
493
- self.connected = False
494
-
495
- def __del__(self):
496
- """Ensure cleanup on deletion."""
497
- self.close()
 
1
+ import websockets
2
+ import json
3
+ import numpy as np
4
+ from typing import Dict, Any, Optional, Union
5
+ import threading
6
+ from queue import Queue
7
+ import time
8
+ import asyncio
9
+ import hashlib
10
+
11
+ class WebSocketGPUStorage:
12
+ # Singleton instance
13
+ _instance = None
14
+ _lock = threading.Lock()
15
+
16
+ def __new__(cls, url: str = "wss://factorst-wbs1.hf.space/ws"):
17
+ with cls._lock:
18
+ if cls._instance is None:
19
+ cls._instance = super().__new__(cls)
20
+ cls._instance._init_singleton(url)
21
+ return cls._instance
22
+
23
+ def _init_singleton(self, url: str):
24
+ """Initialize the singleton instance"""
25
+ if hasattr(self, 'initialized'):
26
+ return
27
+
28
+ self.url = url
29
+ self.websocket = None
30
+ self.connected = False
31
+ self.message_queue = Queue()
32
+ self.response_queues: Dict[str, Queue] = {}
33
+ self.lock = threading.Lock()
34
+ self._closing = False
35
+ self._loop = None
36
+ self.error_count = 0
37
+ self.last_error_time = 0
38
+ self.max_retries = 5
39
+ self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
40
+ self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
41
+ self.resource_monitor = {
42
+ 'vram_used': 0,
43
+ 'active_tensors': 0,
44
+ 'loaded_models': set()
45
+ }
46
+
47
+ # Start WebSocket connection in a separate thread
48
+ self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
49
+ self.ws_thread.start()
50
+ self.initialized = True
51
+
52
+ def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"):
53
+ """This will actually just return the singleton instance"""
54
+ pass
55
+
56
+ def _run_websocket_loop(self):
57
+ self._loop = asyncio.new_event_loop()
58
+ asyncio.set_event_loop(self._loop)
59
+ self._loop.run_until_complete(self._websocket_handler())
60
+
61
+ async def _websocket_handler(self):
62
+ while not self._closing:
63
+ try:
64
+ async with websockets.connect(self.url) as websocket:
65
+ self.websocket = websocket
66
+ self.connected = True
67
+ self.error_count = 0 # Reset error count on successful connection
68
+ print("Connected to GPU storage server")
69
+
70
+ while True:
71
+ # Handle outgoing messages
72
+ try:
73
+ while not self.message_queue.empty():
74
+ msg_id, operation = self.message_queue.get()
75
+ await websocket.send(json.dumps(operation))
76
+
77
+ # Wait for response with timeout
78
+ try:
79
+ response = await asyncio.wait_for(websocket.recv(), timeout=30)
80
+ response_data = json.loads(response)
81
+
82
+ # Put response in corresponding queue
83
+ if msg_id in self.response_queues:
84
+ self.response_queues[msg_id].put(response_data)
85
+ except asyncio.TimeoutError:
86
+ if msg_id in self.response_queues:
87
+ self.response_queues[msg_id].put({
88
+ "status": "error",
89
+ "message": "Operation timed out"
90
+ })
91
+ except Exception as e:
92
+ if msg_id in self.response_queues:
93
+ self.response_queues[msg_id].put({
94
+ "status": "error",
95
+ "message": f"Error processing response: {str(e)}"
96
+ })
97
+
98
+ except Exception as e:
99
+ print(f"Error processing message: {str(e)}")
100
+
101
+ # Keep connection alive with heartbeat
102
+ try:
103
+ await websocket.ping()
104
+ except:
105
+ break # Break inner loop on ping failure
106
+
107
+ await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
108
+
109
+ except Exception as e:
110
+ print(f"WebSocket connection error: {e}")
111
+ self.connected = False
112
+ await asyncio.sleep(1) # Wait before reconnecting
113
+
114
+ def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
115
+ if self._closing:
116
+ return {"status": "error", "message": "WebSocket is closing"}
117
+
118
+ if not self.wait_for_connection(timeout=10):
119
+ return {"status": "error", "message": "Not connected to GPU storage server"}
120
+
121
+ msg_id = str(time.time())
122
+ response_queue = Queue()
123
+
124
+ with self.lock:
125
+ self.response_queues[msg_id] = response_queue
126
+ self.message_queue.put((msg_id, operation))
127
+
128
+ try:
129
+ # Wait for response with configurable timeout
130
+ response = response_queue.get(timeout=30) # Extended timeout for large models
131
+ if response.get("status") == "error" and "model_size" in operation:
132
+ # Retry once for model loading operations
133
+ self.message_queue.put((msg_id, operation))
134
+ response = response_queue.get(timeout=30)
135
+ except Exception as e:
136
+ response = {"status": "error", "message": f"Operation failed: {str(e)}"}
137
+ finally:
138
+ with self.lock:
139
+ if msg_id in self.response_queues:
140
+ del self.response_queues[msg_id]
141
+
142
+ return response
143
+
144
+ def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
145
+ try:
146
+ if data is None:
147
+ raise ValueError("Cannot store None tensor")
148
+
149
+ # Calculate tensor metadata
150
+ tensor_shape = data.shape
151
+ tensor_dtype = str(data.dtype)
152
+ tensor_size = data.nbytes
153
+
154
+ operation = {
155
+ 'operation': 'vram',
156
+ 'type': 'write',
157
+ 'block_id': tensor_id,
158
+ 'data': data.tolist(),
159
+ 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
160
+ 'metadata': {
161
+ 'shape': tensor_shape,
162
+ 'dtype': tensor_dtype,
163
+ 'size': tensor_size,
164
+ 'timestamp': time.time()
165
+ }
166
+ }
167
+
168
+ response = self._send_operation(operation)
169
+ if response.get('status') == 'success':
170
+ # Update tensor registry
171
+ with self.lock:
172
+ self.tensor_registry[tensor_id] = {
173
+ 'shape': tensor_shape,
174
+ 'dtype': tensor_dtype,
175
+ 'size': tensor_size,
176
+ 'timestamp': time.time()
177
+ }
178
+ self.resource_monitor['vram_used'] += tensor_size
179
+ self.resource_monitor['active_tensors'] += 1
180
+ return True
181
+ else:
182
+ print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
183
+ return False
184
+ except Exception as e:
185
+ print(f"Error storing tensor {tensor_id}: {str(e)}")
186
+ return False
187
+
188
+ def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
189
+ try:
190
+ # Check tensor registry first
191
+ if tensor_id not in self.tensor_registry:
192
+ print(f"Tensor {tensor_id} not registered in VRAM")
193
+ return None
194
+
195
+ operation = {
196
+ 'operation': 'vram',
197
+ 'type': 'read',
198
+ 'block_id': tensor_id,
199
+ 'expected_metadata': self.tensor_registry.get(tensor_id, {})
200
+ }
201
+
202
+ response = self._send_operation(operation)
203
+ if response.get('status') == 'success':
204
+ data = response.get('data')
205
+ if data is None:
206
+ print(f"No data found for tensor {tensor_id}")
207
+ return None
208
+
209
+ # Verify tensor metadata
210
+ metadata = response.get('metadata', {})
211
+ expected_metadata = self.tensor_registry.get(tensor_id, {})
212
+ if metadata.get('shape') != expected_metadata.get('shape'):
213
+ print(f"Warning: Tensor {tensor_id} shape mismatch")
214
+
215
+ try:
216
+ # Convert to numpy array with correct dtype
217
+ arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
218
+ if arr.shape != expected_metadata.get('shape'):
219
+ arr = arr.reshape(expected_metadata.get('shape'))
220
+ return arr
221
+ except Exception as e:
222
+ print(f"Error converting tensor data: {str(e)}")
223
+ return None
224
+ else:
225
+ print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
226
+ return None
227
+ except Exception as e:
228
+ print(f"Error loading tensor {tensor_id}: {str(e)}")
229
+ return None
230
+
231
+ def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
232
+ try:
233
+ operation = {
234
+ 'operation': 'state',
235
+ 'type': 'save',
236
+ 'component': component,
237
+ 'state_id': state_id,
238
+ 'data': state_data,
239
+ 'timestamp': time.time()
240
+ }
241
+
242
+ response = self._send_operation(operation)
243
+ if response.get('status') != 'success':
244
+ print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
245
+ return False
246
+ return True
247
+ except Exception as e:
248
+ print(f"Error storing state for {component}/{state_id}: {str(e)}")
249
+ return False
250
+
251
+ def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
252
+ try:
253
+ operation = {
254
+ 'operation': 'state',
255
+ 'type': 'load',
256
+ 'component': component,
257
+ 'state_id': state_id
258
+ }
259
+
260
+ response = self._send_operation(operation)
261
+ if response.get('status') == 'success':
262
+ data = response.get('data')
263
+ if data is None:
264
+ print(f"No state found for {component}/{state_id}")
265
+ return None
266
+ return data
267
+ else:
268
+ print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
269
+ return None
270
+ except Exception as e:
271
+ print(f"Error loading state for {component}/{state_id}: {str(e)}")
272
+ return None
273
+
274
+ def is_model_loaded(self, model_name: str) -> bool:
275
+ """Check if a model is already loaded in VRAM"""
276
+ return model_name in self.resource_monitor['loaded_models']
277
+
278
+ def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
279
+ """Load a model into VRAM if not already loaded"""
280
+ try:
281
+ # Check if model is already loaded
282
+ if self.is_model_loaded(model_name):
283
+ print(f"Model {model_name} already loaded in VRAM")
284
+ return True
285
+
286
+ # Calculate model hash if path provided
287
+ model_hash = None
288
+ if model_path:
289
+ model_hash = self._calculate_model_hash(model_path)
290
+
291
+ operation = {
292
+ 'operation': 'model',
293
+ 'type': 'load',
294
+ 'model_name': model_name,
295
+ 'model_hash': model_hash,
296
+ 'model_data': model_data
297
+ }
298
+
299
+ response = self._send_operation(operation)
300
+ if response.get('status') == 'success':
301
+ with self.lock:
302
+ self.model_registry[model_name] = {
303
+ 'hash': model_hash,
304
+ 'timestamp': time.time(),
305
+ 'tensors': response.get('tensor_ids', [])
306
+ }
307
+ self.resource_monitor['loaded_models'].add(model_name)
308
+ print(f"Successfully loaded model {model_name}")
309
+ return True
310
+ else:
311
+ print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
312
+ return False
313
+ except Exception as e:
314
+ print(f"Error loading model {model_name}: {str(e)}")
315
+ return False
316
+
317
+ def _calculate_model_hash(self, model_path: str) -> str:
318
+ """Calculate SHA256 hash of model file"""
319
+ try:
320
+ sha256_hash = hashlib.sha256()
321
+ with open(model_path, "rb") as f:
322
+ for byte_block in iter(lambda: f.read(4096), b""):
323
+ sha256_hash.update(byte_block)
324
+ return sha256_hash.hexdigest()
325
+ except Exception as e:
326
+ print(f"Error calculating model hash: {str(e)}")
327
+ return ""
328
+
329
+ def cache_data(self, key: str, data: Any) -> bool:
330
+ operation = {
331
+ 'operation': 'cache',
332
+ 'type': 'set',
333
+ 'key': key,
334
+ 'data': data
335
+ }
336
+
337
+ response = self._send_operation(operation)
338
+ return response.get('status') == 'success'
339
+
340
+ def get_cached_data(self, key: str) -> Optional[Any]:
341
+ operation = {
342
+ 'operation': 'cache',
343
+ 'type': 'get',
344
+ 'key': key
345
+ }
346
+
347
+ response = self._send_operation(operation)
348
+ if response.get('status') == 'success':
349
+ return response['data']
350
+ return None
351
+
352
+ def wait_for_connection(self, timeout: float = 30.0) -> bool:
353
+ """Wait for WebSocket connection to be established"""
354
+ start_time = time.time()
355
+ while not self._closing and not self.connected:
356
+ if time.time() - start_time > timeout:
357
+ print("Connection timeout exceeded")
358
+ return False
359
+ time.sleep(0.1)
360
+ return self.connected
361
+
362
+ def is_connected(self) -> bool:
363
+ """Check if WebSocket connection is active"""
364
+ return self.connected and not self._closing
365
+
366
+ def get_connection_status(self) -> Dict[str, Any]:
367
+ """Get detailed connection status"""
368
+ return {
369
+ "connected": self.connected,
370
+ "closing": self._closing,
371
+ "error_count": self.error_count,
372
+ "url": self.url,
373
+ "last_error_time": self.last_error_time,
374
+ "loaded_models": list(self.resource_monitor['loaded_models'])
375
+ }
376
+
377
+ def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
378
+ """Start inference with a loaded model"""
379
+ try:
380
+ if not self.is_model_loaded(model_name):
381
+ print(f"Model {model_name} not loaded. Please load the model first.")
382
+ return None
383
+
384
+ operation = {
385
+ 'operation': 'inference',
386
+ 'type': 'run',
387
+ 'model_name': model_name,
388
+ 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
389
+ }
390
+
391
+ response = self._send_operation(operation)
392
+ if response.get('status') == 'success':
393
+ return {
394
+ 'output': np.array(response['output']) if 'output' in response else None,
395
+ 'metrics': response.get('metrics', {}),
396
+ 'model_info': self.model_registry.get(model_name, {})
397
+ }
398
+ else:
399
+ print(f"Inference failed: {response.get('message', 'Unknown error')}")
400
+ return None
401
+ except Exception as e:
402
+ print(f"Error during inference: {str(e)}")
403
+ return None
404
+
405
+ def close(self):
406
+ """Close WebSocket connection and cleanup resources."""
407
+ if not self._closing:
408
+ self._closing = True
409
+ if self.websocket and self._loop:
410
+ async def cleanup():
411
+ try:
412
+ # Clean up registries
413
+ with self.lock:
414
+ self.tensor_registry.clear()
415
+ self.model_registry.clear()
416
+ self.resource_monitor['vram_used'] = 0
417
+ self.resource_monitor['active_tensors'] = 0
418
+ self.resource_monitor['loaded_models'].clear()
419
+
420
+ # Notify server about cleanup
421
+ if self.connected:
422
+ try:
423
+ await self.websocket.send(json.dumps({
424
+ 'operation': 'cleanup',
425
+ 'type': 'full'
426
+ }))
427
+ except:
428
+ pass
429
+
430
+ await self.websocket.close()
431
+ except Exception as e:
432
+ print(f"Error during cleanup: {str(e)}")
433
+ finally:
434
+ self.connected = False
435
+
436
+ if self._loop.is_running():
437
+ self._loop.create_task(cleanup())
438
+ else:
439
+ asyncio.run(cleanup())
440
+
441
+ async def aclose(self):
442
+ """Asynchronously close WebSocket connection."""
443
+ if not self._closing:
444
+ self._closing = True
445
+ if self.websocket:
446
+ try:
447
+ await self.websocket.close()
448
+ except:
449
+ pass
450
+ finally:
451
+ self.connected = False
452
+
453
+ def __del__(self):
454
+ """Ensure cleanup on deletion."""
455
+ self.close()