Factor Studios commited on
Commit
5e36ec1
·
verified ·
1 Parent(s): 34c283d

Update websocket_storage.py

Browse files
Files changed (1) hide show
  1. websocket_storage.py +435 -434
websocket_storage.py CHANGED
@@ -1,434 +1,435 @@
1
- import websockets
2
- import json
3
- import numpy as np
4
- from typing import Dict, Any, Optional, Union
5
- import threading
6
- from queue import Queue
7
- import time
8
-
9
- class WebSocketGPUStorage:
10
- def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"): # Default to local WebSocket server
11
- self.url = url
12
- self.websocket = None
13
- self.connected = False
14
- self.message_queue = Queue()
15
- self.response_queues: Dict[str, Queue] = {}
16
- self.lock = threading.Lock()
17
- self._closing = False
18
- self._loop = None
19
- self.error_count = 0
20
- self.last_error_time = 0
21
- self.max_retries = 5
22
- self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
23
- self.resource_monitor = {'vram_used': 0, 'active_tensors': 0}
24
- self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
25
- self.resource_monitor = {
26
- 'vram_used': 0,
27
- 'active_tensors': 0,
28
- 'loaded_models': set()
29
- }
30
-
31
- # Start WebSocket connection in a separate thread
32
- self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
33
- self.ws_thread.start()
34
-
35
- def _run_websocket_loop(self):
36
- self._loop = asyncio.new_event_loop()
37
- asyncio.set_event_loop(self._loop)
38
- self._loop.run_until_complete(self._websocket_handler())
39
-
40
- async def _websocket_handler(self):
41
- while not self._closing:
42
- try:
43
- async with websockets.connect(self.url) as websocket:
44
- self.websocket = websocket
45
- self.connected = True
46
- self.error_count = 0 # Reset error count on successful connection
47
- print("Connected to GPU storage server")
48
-
49
- while True:
50
- # Handle outgoing messages
51
- try:
52
- while not self.message_queue.empty():
53
- msg_id, operation = self.message_queue.get()
54
- await websocket.send(json.dumps(operation))
55
-
56
- # Wait for response with timeout
57
- try:
58
- response = await asyncio.wait_for(websocket.recv(), timeout=30)
59
- response_data = json.loads(response)
60
-
61
- # Put response in corresponding queue
62
- if msg_id in self.response_queues:
63
- self.response_queues[msg_id].put(response_data)
64
- except asyncio.TimeoutError:
65
- if msg_id in self.response_queues:
66
- self.response_queues[msg_id].put({
67
- "status": "error",
68
- "message": "Operation timed out"
69
- })
70
- except Exception as e:
71
- if msg_id in self.response_queues:
72
- self.response_queues[msg_id].put({
73
- "status": "error",
74
- "message": f"Error processing response: {str(e)}"
75
- })
76
-
77
- except Exception as e:
78
- print(f"Error processing message: {str(e)}")
79
-
80
- # Keep connection alive with heartbeat
81
- try:
82
- await websocket.ping()
83
- except:
84
- break # Break inner loop on ping failure
85
-
86
- await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
87
-
88
- except Exception as e:
89
- print(f"WebSocket connection error: {e}")
90
- self.connected = False
91
- await asyncio.sleep(1) # Wait before reconnecting
92
-
93
- def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
94
- if self._closing:
95
- return {"status": "error", "message": "WebSocket is closing"}
96
-
97
- if not self.wait_for_connection(timeout=10):
98
- return {"status": "error", "message": "Not connected to GPU storage server"}
99
-
100
- msg_id = str(time.time())
101
- response_queue = Queue()
102
-
103
- with self.lock:
104
- self.response_queues[msg_id] = response_queue
105
- self.message_queue.put((msg_id, operation))
106
-
107
- try:
108
- # Wait for response with configurable timeout
109
- response = response_queue.get(timeout=30) # Extended timeout for large models
110
- if response.get("status") == "error" and "model_size" in operation:
111
- # Retry once for model loading operations
112
- self.message_queue.put((msg_id, operation))
113
- response = response_queue.get(timeout=30)
114
- except Exception as e:
115
- response = {"status": "error", "message": f"Operation failed: {str(e)}"}
116
- finally:
117
- with self.lock:
118
- if msg_id in self.response_queues:
119
- del self.response_queues[msg_id]
120
-
121
- return response
122
-
123
- def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
124
- try:
125
- if data is None:
126
- raise ValueError("Cannot store None tensor")
127
-
128
- # Calculate tensor metadata
129
- tensor_shape = data.shape
130
- tensor_dtype = str(data.dtype)
131
- tensor_size = data.nbytes
132
-
133
- operation = {
134
- 'operation': 'vram',
135
- 'type': 'write',
136
- 'block_id': tensor_id,
137
- 'data': data.tolist(),
138
- 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
139
- 'metadata': {
140
- 'shape': tensor_shape,
141
- 'dtype': tensor_dtype,
142
- 'size': tensor_size,
143
- 'timestamp': time.time()
144
- }
145
- }
146
-
147
- response = self._send_operation(operation)
148
- if response.get('status') == 'success':
149
- # Update tensor registry
150
- with self.lock:
151
- self.tensor_registry[tensor_id] = {
152
- 'shape': tensor_shape,
153
- 'dtype': tensor_dtype,
154
- 'size': tensor_size,
155
- 'timestamp': time.time()
156
- }
157
- self.resource_monitor['vram_used'] += tensor_size
158
- self.resource_monitor['active_tensors'] += 1
159
- return True
160
- else:
161
- print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
162
- return False
163
- except Exception as e:
164
- print(f"Error storing tensor {tensor_id}: {str(e)}")
165
- return False
166
-
167
- def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
168
- try:
169
- # Check tensor registry first
170
- if tensor_id not in self.tensor_registry:
171
- print(f"Tensor {tensor_id} not registered in VRAM")
172
- return None
173
-
174
- operation = {
175
- 'operation': 'vram',
176
- 'type': 'read',
177
- 'block_id': tensor_id,
178
- 'expected_metadata': self.tensor_registry.get(tensor_id, {})
179
- }
180
-
181
- response = self._send_operation(operation)
182
- if response.get('status') == 'success':
183
- data = response.get('data')
184
- if data is None:
185
- print(f"No data found for tensor {tensor_id}")
186
- return None
187
-
188
- # Verify tensor metadata
189
- metadata = response.get('metadata', {})
190
- expected_metadata = self.tensor_registry.get(tensor_id, {})
191
- if metadata.get('shape') != expected_metadata.get('shape'):
192
- print(f"Warning: Tensor {tensor_id} shape mismatch")
193
-
194
- try:
195
- # Convert to numpy array with correct dtype
196
- arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
197
- if arr.shape != expected_metadata.get('shape'):
198
- arr = arr.reshape(expected_metadata.get('shape'))
199
- return arr
200
- except Exception as e:
201
- print(f"Error converting tensor data: {str(e)}")
202
- return None
203
- else:
204
- print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
205
- return None
206
- except Exception as e:
207
- print(f"Error loading tensor {tensor_id}: {str(e)}")
208
- return None
209
-
210
- def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
211
- try:
212
- operation = {
213
- 'operation': 'state',
214
- 'type': 'save',
215
- 'component': component,
216
- 'state_id': state_id,
217
- 'data': state_data,
218
- 'timestamp': time.time()
219
- }
220
-
221
- response = self._send_operation(operation)
222
- if response.get('status') != 'success':
223
- print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
224
- return False
225
- return True
226
- except Exception as e:
227
- print(f"Error storing state for {component}/{state_id}: {str(e)}")
228
- return False
229
-
230
- def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
231
- try:
232
- operation = {
233
- 'operation': 'state',
234
- 'type': 'load',
235
- 'component': component,
236
- 'state_id': state_id
237
- }
238
-
239
- response = self._send_operation(operation)
240
- if response.get('status') == 'success':
241
- data = response.get('data')
242
- if data is None:
243
- print(f"No state found for {component}/{state_id}")
244
- return None
245
- return data
246
- else:
247
- print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
248
- return None
249
- except Exception as e:
250
- print(f"Error loading state for {component}/{state_id}: {str(e)}")
251
- return None
252
-
253
- def is_model_loaded(self, model_name: str) -> bool:
254
- """Check if a model is already loaded in VRAM"""
255
- return model_name in self.resource_monitor['loaded_models']
256
-
257
- def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
258
- """Load a model into VRAM if not already loaded"""
259
- try:
260
- # Check if model is already loaded
261
- if self.is_model_loaded(model_name):
262
- print(f"Model {model_name} already loaded in VRAM")
263
- return True
264
-
265
- # Calculate model hash if path provided
266
- model_hash = None
267
- if model_path:
268
- model_hash = self._calculate_model_hash(model_path)
269
-
270
- operation = {
271
- 'operation': 'model',
272
- 'type': 'load',
273
- 'model_name': model_name,
274
- 'model_hash': model_hash,
275
- 'model_data': model_data
276
- }
277
-
278
- response = self._send_operation(operation)
279
- if response.get('status') == 'success':
280
- with self.lock:
281
- self.model_registry[model_name] = {
282
- 'hash': model_hash,
283
- 'timestamp': time.time(),
284
- 'tensors': response.get('tensor_ids', [])
285
- }
286
- self.resource_monitor['loaded_models'].add(model_name)
287
- print(f"Successfully loaded model {model_name}")
288
- return True
289
- else:
290
- print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
291
- return False
292
- except Exception as e:
293
- print(f"Error loading model {model_name}: {str(e)}")
294
- return False
295
-
296
- def _calculate_model_hash(self, model_path: str) -> str:
297
- """Calculate SHA256 hash of model file"""
298
- try:
299
- sha256_hash = hashlib.sha256()
300
- with open(model_path, "rb") as f:
301
- for byte_block in iter(lambda: f.read(4096), b""):
302
- sha256_hash.update(byte_block)
303
- return sha256_hash.hexdigest()
304
- except Exception as e:
305
- print(f"Error calculating model hash: {str(e)}")
306
- return ""
307
-
308
- def cache_data(self, key: str, data: Any) -> bool:
309
- operation = {
310
- 'operation': 'cache',
311
- 'type': 'set',
312
- 'key': key,
313
- 'data': data
314
- }
315
-
316
- response = self._send_operation(operation)
317
- return response.get('status') == 'success'
318
-
319
- def get_cached_data(self, key: str) -> Optional[Any]:
320
- operation = {
321
- 'operation': 'cache',
322
- 'type': 'get',
323
- 'key': key
324
- }
325
-
326
- response = self._send_operation(operation)
327
- if response.get('status') == 'success':
328
- return response['data']
329
- return None
330
-
331
- def wait_for_connection(self, timeout: float = 30.0) -> bool:
332
- """Wait for WebSocket connection to be established"""
333
- start_time = time.time()
334
- while not self._closing and not self.connected:
335
- if time.time() - start_time > timeout:
336
- print("Connection timeout exceeded")
337
- return False
338
- time.sleep(0.1)
339
- return self.connected
340
-
341
- def is_connected(self) -> bool:
342
- """Check if WebSocket connection is active"""
343
- return self.connected and not self._closing
344
-
345
- def get_connection_status(self) -> Dict[str, Any]:
346
- """Get detailed connection status"""
347
- return {
348
- "connected": self.connected,
349
- "closing": self._closing,
350
- "error_count": self.error_count,
351
- "url": self.url,
352
- "last_error_time": self.last_error_time,
353
- "loaded_models": list(self.resource_monitor['loaded_models'])
354
- }
355
-
356
- def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
357
- """Start inference with a loaded model"""
358
- try:
359
- if not self.is_model_loaded(model_name):
360
- print(f"Model {model_name} not loaded. Please load the model first.")
361
- return None
362
-
363
- operation = {
364
- 'operation': 'inference',
365
- 'type': 'run',
366
- 'model_name': model_name,
367
- 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
368
- }
369
-
370
- response = self._send_operation(operation)
371
- if response.get('status') == 'success':
372
- return {
373
- 'output': np.array(response['output']) if 'output' in response else None,
374
- 'metrics': response.get('metrics', {}),
375
- 'model_info': self.model_registry.get(model_name, {})
376
- }
377
- else:
378
- print(f"Inference failed: {response.get('message', 'Unknown error')}")
379
- return None
380
- except Exception as e:
381
- print(f"Error during inference: {str(e)}")
382
- return None
383
-
384
- def close(self):
385
- """Close WebSocket connection and cleanup resources."""
386
- if not self._closing:
387
- self._closing = True
388
- if self.websocket and self._loop:
389
- async def cleanup():
390
- try:
391
- # Clean up registries
392
- with self.lock:
393
- self.tensor_registry.clear()
394
- self.model_registry.clear()
395
- self.resource_monitor['vram_used'] = 0
396
- self.resource_monitor['active_tensors'] = 0
397
- self.resource_monitor['loaded_models'].clear()
398
-
399
- # Notify server about cleanup
400
- if self.connected:
401
- try:
402
- await self.websocket.send(json.dumps({
403
- 'operation': 'cleanup',
404
- 'type': 'full'
405
- }))
406
- except:
407
- pass
408
-
409
- await self.websocket.close()
410
- except Exception as e:
411
- print(f"Error during cleanup: {str(e)}")
412
- finally:
413
- self.connected = False
414
-
415
- if self._loop.is_running():
416
- self._loop.create_task(cleanup())
417
- else:
418
- asyncio.run(cleanup())
419
-
420
- async def aclose(self):
421
- """Asynchronously close WebSocket connection."""
422
- if not self._closing:
423
- self._closing = True
424
- if self.websocket:
425
- try:
426
- await self.websocket.close()
427
- except:
428
- pass
429
- finally:
430
- self.connected = False
431
-
432
- def __del__(self):
433
- """Ensure cleanup on deletion."""
434
- self.close()
 
 
1
+ import websockets
2
+ import json
3
+ import numpy as np
4
+ from typing import Dict, Any, Optional, Union
5
+ import threading
6
+ from queue import Queue
7
+ import time
8
+ import asyncio
9
+
10
+ class WebSocketGPUStorage:
11
+ def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"): # Default to local WebSocket server
12
+ self.url = url
13
+ self.websocket = None
14
+ self.connected = False
15
+ self.message_queue = Queue()
16
+ self.response_queues: Dict[str, Queue] = {}
17
+ self.lock = threading.Lock()
18
+ self._closing = False
19
+ self._loop = None
20
+ self.error_count = 0
21
+ self.last_error_time = 0
22
+ self.max_retries = 5
23
+ self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
24
+ self.resource_monitor = {'vram_used': 0, 'active_tensors': 0}
25
+ self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
26
+ self.resource_monitor = {
27
+ 'vram_used': 0,
28
+ 'active_tensors': 0,
29
+ 'loaded_models': set()
30
+ }
31
+
32
+ # Start WebSocket connection in a separate thread
33
+ self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
34
+ self.ws_thread.start()
35
+
36
+ def _run_websocket_loop(self):
37
+ self._loop = asyncio.new_event_loop()
38
+ asyncio.set_event_loop(self._loop)
39
+ self._loop.run_until_complete(self._websocket_handler())
40
+
41
+ async def _websocket_handler(self):
42
+ while not self._closing:
43
+ try:
44
+ async with websockets.connect(self.url) as websocket:
45
+ self.websocket = websocket
46
+ self.connected = True
47
+ self.error_count = 0 # Reset error count on successful connection
48
+ print("Connected to GPU storage server")
49
+
50
+ while True:
51
+ # Handle outgoing messages
52
+ try:
53
+ while not self.message_queue.empty():
54
+ msg_id, operation = self.message_queue.get()
55
+ await websocket.send(json.dumps(operation))
56
+
57
+ # Wait for response with timeout
58
+ try:
59
+ response = await asyncio.wait_for(websocket.recv(), timeout=30)
60
+ response_data = json.loads(response)
61
+
62
+ # Put response in corresponding queue
63
+ if msg_id in self.response_queues:
64
+ self.response_queues[msg_id].put(response_data)
65
+ except asyncio.TimeoutError:
66
+ if msg_id in self.response_queues:
67
+ self.response_queues[msg_id].put({
68
+ "status": "error",
69
+ "message": "Operation timed out"
70
+ })
71
+ except Exception as e:
72
+ if msg_id in self.response_queues:
73
+ self.response_queues[msg_id].put({
74
+ "status": "error",
75
+ "message": f"Error processing response: {str(e)}"
76
+ })
77
+
78
+ except Exception as e:
79
+ print(f"Error processing message: {str(e)}")
80
+
81
+ # Keep connection alive with heartbeat
82
+ try:
83
+ await websocket.ping()
84
+ except:
85
+ break # Break inner loop on ping failure
86
+
87
+ await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
88
+
89
+ except Exception as e:
90
+ print(f"WebSocket connection error: {e}")
91
+ self.connected = False
92
+ await asyncio.sleep(1) # Wait before reconnecting
93
+
94
+ def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
95
+ if self._closing:
96
+ return {"status": "error", "message": "WebSocket is closing"}
97
+
98
+ if not self.wait_for_connection(timeout=10):
99
+ return {"status": "error", "message": "Not connected to GPU storage server"}
100
+
101
+ msg_id = str(time.time())
102
+ response_queue = Queue()
103
+
104
+ with self.lock:
105
+ self.response_queues[msg_id] = response_queue
106
+ self.message_queue.put((msg_id, operation))
107
+
108
+ try:
109
+ # Wait for response with configurable timeout
110
+ response = response_queue.get(timeout=30) # Extended timeout for large models
111
+ if response.get("status") == "error" and "model_size" in operation:
112
+ # Retry once for model loading operations
113
+ self.message_queue.put((msg_id, operation))
114
+ response = response_queue.get(timeout=30)
115
+ except Exception as e:
116
+ response = {"status": "error", "message": f"Operation failed: {str(e)}"}
117
+ finally:
118
+ with self.lock:
119
+ if msg_id in self.response_queues:
120
+ del self.response_queues[msg_id]
121
+
122
+ return response
123
+
124
+ def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
125
+ try:
126
+ if data is None:
127
+ raise ValueError("Cannot store None tensor")
128
+
129
+ # Calculate tensor metadata
130
+ tensor_shape = data.shape
131
+ tensor_dtype = str(data.dtype)
132
+ tensor_size = data.nbytes
133
+
134
+ operation = {
135
+ 'operation': 'vram',
136
+ 'type': 'write',
137
+ 'block_id': tensor_id,
138
+ 'data': data.tolist(),
139
+ 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
140
+ 'metadata': {
141
+ 'shape': tensor_shape,
142
+ 'dtype': tensor_dtype,
143
+ 'size': tensor_size,
144
+ 'timestamp': time.time()
145
+ }
146
+ }
147
+
148
+ response = self._send_operation(operation)
149
+ if response.get('status') == 'success':
150
+ # Update tensor registry
151
+ with self.lock:
152
+ self.tensor_registry[tensor_id] = {
153
+ 'shape': tensor_shape,
154
+ 'dtype': tensor_dtype,
155
+ 'size': tensor_size,
156
+ 'timestamp': time.time()
157
+ }
158
+ self.resource_monitor['vram_used'] += tensor_size
159
+ self.resource_monitor['active_tensors'] += 1
160
+ return True
161
+ else:
162
+ print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
163
+ return False
164
+ except Exception as e:
165
+ print(f"Error storing tensor {tensor_id}: {str(e)}")
166
+ return False
167
+
168
+ def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
169
+ try:
170
+ # Check tensor registry first
171
+ if tensor_id not in self.tensor_registry:
172
+ print(f"Tensor {tensor_id} not registered in VRAM")
173
+ return None
174
+
175
+ operation = {
176
+ 'operation': 'vram',
177
+ 'type': 'read',
178
+ 'block_id': tensor_id,
179
+ 'expected_metadata': self.tensor_registry.get(tensor_id, {})
180
+ }
181
+
182
+ response = self._send_operation(operation)
183
+ if response.get('status') == 'success':
184
+ data = response.get('data')
185
+ if data is None:
186
+ print(f"No data found for tensor {tensor_id}")
187
+ return None
188
+
189
+ # Verify tensor metadata
190
+ metadata = response.get('metadata', {})
191
+ expected_metadata = self.tensor_registry.get(tensor_id, {})
192
+ if metadata.get('shape') != expected_metadata.get('shape'):
193
+ print(f"Warning: Tensor {tensor_id} shape mismatch")
194
+
195
+ try:
196
+ # Convert to numpy array with correct dtype
197
+ arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
198
+ if arr.shape != expected_metadata.get('shape'):
199
+ arr = arr.reshape(expected_metadata.get('shape'))
200
+ return arr
201
+ except Exception as e:
202
+ print(f"Error converting tensor data: {str(e)}")
203
+ return None
204
+ else:
205
+ print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
206
+ return None
207
+ except Exception as e:
208
+ print(f"Error loading tensor {tensor_id}: {str(e)}")
209
+ return None
210
+
211
+ def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
212
+ try:
213
+ operation = {
214
+ 'operation': 'state',
215
+ 'type': 'save',
216
+ 'component': component,
217
+ 'state_id': state_id,
218
+ 'data': state_data,
219
+ 'timestamp': time.time()
220
+ }
221
+
222
+ response = self._send_operation(operation)
223
+ if response.get('status') != 'success':
224
+ print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
225
+ return False
226
+ return True
227
+ except Exception as e:
228
+ print(f"Error storing state for {component}/{state_id}: {str(e)}")
229
+ return False
230
+
231
+ def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
232
+ try:
233
+ operation = {
234
+ 'operation': 'state',
235
+ 'type': 'load',
236
+ 'component': component,
237
+ 'state_id': state_id
238
+ }
239
+
240
+ response = self._send_operation(operation)
241
+ if response.get('status') == 'success':
242
+ data = response.get('data')
243
+ if data is None:
244
+ print(f"No state found for {component}/{state_id}")
245
+ return None
246
+ return data
247
+ else:
248
+ print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
249
+ return None
250
+ except Exception as e:
251
+ print(f"Error loading state for {component}/{state_id}: {str(e)}")
252
+ return None
253
+
254
+ def is_model_loaded(self, model_name: str) -> bool:
255
+ """Check if a model is already loaded in VRAM"""
256
+ return model_name in self.resource_monitor['loaded_models']
257
+
258
+ def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
259
+ """Load a model into VRAM if not already loaded"""
260
+ try:
261
+ # Check if model is already loaded
262
+ if self.is_model_loaded(model_name):
263
+ print(f"Model {model_name} already loaded in VRAM")
264
+ return True
265
+
266
+ # Calculate model hash if path provided
267
+ model_hash = None
268
+ if model_path:
269
+ model_hash = self._calculate_model_hash(model_path)
270
+
271
+ operation = {
272
+ 'operation': 'model',
273
+ 'type': 'load',
274
+ 'model_name': model_name,
275
+ 'model_hash': model_hash,
276
+ 'model_data': model_data
277
+ }
278
+
279
+ response = self._send_operation(operation)
280
+ if response.get('status') == 'success':
281
+ with self.lock:
282
+ self.model_registry[model_name] = {
283
+ 'hash': model_hash,
284
+ 'timestamp': time.time(),
285
+ 'tensors': response.get('tensor_ids', [])
286
+ }
287
+ self.resource_monitor['loaded_models'].add(model_name)
288
+ print(f"Successfully loaded model {model_name}")
289
+ return True
290
+ else:
291
+ print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
292
+ return False
293
+ except Exception as e:
294
+ print(f"Error loading model {model_name}: {str(e)}")
295
+ return False
296
+
297
+ def _calculate_model_hash(self, model_path: str) -> str:
298
+ """Calculate SHA256 hash of model file"""
299
+ try:
300
+ sha256_hash = hashlib.sha256()
301
+ with open(model_path, "rb") as f:
302
+ for byte_block in iter(lambda: f.read(4096), b""):
303
+ sha256_hash.update(byte_block)
304
+ return sha256_hash.hexdigest()
305
+ except Exception as e:
306
+ print(f"Error calculating model hash: {str(e)}")
307
+ return ""
308
+
309
+ def cache_data(self, key: str, data: Any) -> bool:
310
+ operation = {
311
+ 'operation': 'cache',
312
+ 'type': 'set',
313
+ 'key': key,
314
+ 'data': data
315
+ }
316
+
317
+ response = self._send_operation(operation)
318
+ return response.get('status') == 'success'
319
+
320
+ def get_cached_data(self, key: str) -> Optional[Any]:
321
+ operation = {
322
+ 'operation': 'cache',
323
+ 'type': 'get',
324
+ 'key': key
325
+ }
326
+
327
+ response = self._send_operation(operation)
328
+ if response.get('status') == 'success':
329
+ return response['data']
330
+ return None
331
+
332
+ def wait_for_connection(self, timeout: float = 30.0) -> bool:
333
+ """Wait for WebSocket connection to be established"""
334
+ start_time = time.time()
335
+ while not self._closing and not self.connected:
336
+ if time.time() - start_time > timeout:
337
+ print("Connection timeout exceeded")
338
+ return False
339
+ time.sleep(0.1)
340
+ return self.connected
341
+
342
+ def is_connected(self) -> bool:
343
+ """Check if WebSocket connection is active"""
344
+ return self.connected and not self._closing
345
+
346
+ def get_connection_status(self) -> Dict[str, Any]:
347
+ """Get detailed connection status"""
348
+ return {
349
+ "connected": self.connected,
350
+ "closing": self._closing,
351
+ "error_count": self.error_count,
352
+ "url": self.url,
353
+ "last_error_time": self.last_error_time,
354
+ "loaded_models": list(self.resource_monitor['loaded_models'])
355
+ }
356
+
357
+ def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
358
+ """Start inference with a loaded model"""
359
+ try:
360
+ if not self.is_model_loaded(model_name):
361
+ print(f"Model {model_name} not loaded. Please load the model first.")
362
+ return None
363
+
364
+ operation = {
365
+ 'operation': 'inference',
366
+ 'type': 'run',
367
+ 'model_name': model_name,
368
+ 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
369
+ }
370
+
371
+ response = self._send_operation(operation)
372
+ if response.get('status') == 'success':
373
+ return {
374
+ 'output': np.array(response['output']) if 'output' in response else None,
375
+ 'metrics': response.get('metrics', {}),
376
+ 'model_info': self.model_registry.get(model_name, {})
377
+ }
378
+ else:
379
+ print(f"Inference failed: {response.get('message', 'Unknown error')}")
380
+ return None
381
+ except Exception as e:
382
+ print(f"Error during inference: {str(e)}")
383
+ return None
384
+
385
+ def close(self):
386
+ """Close WebSocket connection and cleanup resources."""
387
+ if not self._closing:
388
+ self._closing = True
389
+ if self.websocket and self._loop:
390
+ async def cleanup():
391
+ try:
392
+ # Clean up registries
393
+ with self.lock:
394
+ self.tensor_registry.clear()
395
+ self.model_registry.clear()
396
+ self.resource_monitor['vram_used'] = 0
397
+ self.resource_monitor['active_tensors'] = 0
398
+ self.resource_monitor['loaded_models'].clear()
399
+
400
+ # Notify server about cleanup
401
+ if self.connected:
402
+ try:
403
+ await self.websocket.send(json.dumps({
404
+ 'operation': 'cleanup',
405
+ 'type': 'full'
406
+ }))
407
+ except:
408
+ pass
409
+
410
+ await self.websocket.close()
411
+ except Exception as e:
412
+ print(f"Error during cleanup: {str(e)}")
413
+ finally:
414
+ self.connected = False
415
+
416
+ if self._loop.is_running():
417
+ self._loop.create_task(cleanup())
418
+ else:
419
+ asyncio.run(cleanup())
420
+
421
+ async def aclose(self):
422
+ """Asynchronously close WebSocket connection."""
423
+ if not self._closing:
424
+ self._closing = True
425
+ if self.websocket:
426
+ try:
427
+ await self.websocket.close()
428
+ except:
429
+ pass
430
+ finally:
431
+ self.connected = False
432
+
433
+ def __del__(self):
434
+ """Ensure cleanup on deletion."""
435
+ self.close()