Factor Studios commited on
Commit
d1c659d
·
verified ·
1 Parent(s): 16d64f1

Update websocket_storage.py

Browse files
Files changed (1) hide show
  1. websocket_storage.py +434 -434
websocket_storage.py CHANGED
@@ -1,434 +1,434 @@
1
- import websockets
2
- import json
3
- import numpy as np
4
- from typing import Dict, Any, Optional, Union
5
- import threading
6
- from queue import Queue
7
- import time
8
-
9
- class WebSocketGPUStorage:
10
- def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"): # Default to local WebSocket server
11
- self.url = url
12
- self.websocket = None
13
- self.connected = False
14
- self.message_queue = Queue()
15
- self.response_queues: Dict[str, Queue] = {}
16
- self.lock = threading.Lock()
17
- self._closing = False
18
- self._loop = None
19
- self.error_count = 0
20
- self.last_error_time = 0
21
- self.max_retries = 5
22
- self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
23
- self.resource_monitor = {'vram_used': 0, 'active_tensors': 0}
24
- self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
25
- self.resource_monitor = {
26
- 'vram_used': 0,
27
- 'active_tensors': 0,
28
- 'loaded_models': set()
29
- }
30
-
31
- # Start WebSocket connection in a separate thread
32
- self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
33
- self.ws_thread.start()
34
-
35
- def _run_websocket_loop(self):
36
- self._loop = asyncio.new_event_loop()
37
- asyncio.set_event_loop(self._loop)
38
- self._loop.run_until_complete(self._websocket_handler())
39
-
40
- async def _websocket_handler(self):
41
- while not self._closing:
42
- try:
43
- async with websockets.connect(self.url) as websocket:
44
- self.websocket = websocket
45
- self.connected = True
46
- self.error_count = 0 # Reset error count on successful connection
47
- print("Connected to GPU storage server")
48
-
49
- while True:
50
- # Handle outgoing messages
51
- try:
52
- while not self.message_queue.empty():
53
- msg_id, operation = self.message_queue.get()
54
- await websocket.send(json.dumps(operation))
55
-
56
- # Wait for response with timeout
57
- try:
58
- response = await asyncio.wait_for(websocket.recv(), timeout=30)
59
- response_data = json.loads(response)
60
-
61
- # Put response in corresponding queue
62
- if msg_id in self.response_queues:
63
- self.response_queues[msg_id].put(response_data)
64
- except asyncio.TimeoutError:
65
- if msg_id in self.response_queues:
66
- self.response_queues[msg_id].put({
67
- "status": "error",
68
- "message": "Operation timed out"
69
- })
70
- except Exception as e:
71
- if msg_id in self.response_queues:
72
- self.response_queues[msg_id].put({
73
- "status": "error",
74
- "message": f"Error processing response: {str(e)}"
75
- })
76
-
77
- except Exception as e:
78
- print(f"Error processing message: {str(e)}")
79
-
80
- # Keep connection alive with heartbeat
81
- try:
82
- await websocket.ping()
83
- except:
84
- break # Break inner loop on ping failure
85
-
86
- await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
87
-
88
- except Exception as e:
89
- print(f"WebSocket connection error: {e}")
90
- self.connected = False
91
- await asyncio.sleep(1) # Wait before reconnecting
92
-
93
- def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
94
- if self._closing:
95
- return {"status": "error", "message": "WebSocket is closing"}
96
-
97
- if not self.wait_for_connection(timeout=10):
98
- return {"status": "error", "message": "Not connected to GPU storage server"}
99
-
100
- msg_id = str(time.time())
101
- response_queue = Queue()
102
-
103
- with self.lock:
104
- self.response_queues[msg_id] = response_queue
105
- self.message_queue.put((msg_id, operation))
106
-
107
- try:
108
- # Wait for response with configurable timeout
109
- response = response_queue.get(timeout=30) # Extended timeout for large models
110
- if response.get("status") == "error" and "model_size" in operation:
111
- # Retry once for model loading operations
112
- self.message_queue.put((msg_id, operation))
113
- response = response_queue.get(timeout=30)
114
- except Exception as e:
115
- response = {"status": "error", "message": f"Operation failed: {str(e)}"}
116
- finally:
117
- with self.lock:
118
- if msg_id in self.response_queues:
119
- del self.response_queues[msg_id]
120
-
121
- return response
122
-
123
- def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
124
- try:
125
- if data is None:
126
- raise ValueError("Cannot store None tensor")
127
-
128
- # Calculate tensor metadata
129
- tensor_shape = data.shape
130
- tensor_dtype = str(data.dtype)
131
- tensor_size = data.nbytes
132
-
133
- operation = {
134
- 'operation': 'vram',
135
- 'type': 'write',
136
- 'block_id': tensor_id,
137
- 'data': data.tolist(),
138
- 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
139
- 'metadata': {
140
- 'shape': tensor_shape,
141
- 'dtype': tensor_dtype,
142
- 'size': tensor_size,
143
- 'timestamp': time.time()
144
- }
145
- }
146
-
147
- response = self._send_operation(operation)
148
- if response.get('status') == 'success':
149
- # Update tensor registry
150
- with self.lock:
151
- self.tensor_registry[tensor_id] = {
152
- 'shape': tensor_shape,
153
- 'dtype': tensor_dtype,
154
- 'size': tensor_size,
155
- 'timestamp': time.time()
156
- }
157
- self.resource_monitor['vram_used'] += tensor_size
158
- self.resource_monitor['active_tensors'] += 1
159
- return True
160
- else:
161
- print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
162
- return False
163
- except Exception as e:
164
- print(f"Error storing tensor {tensor_id}: {str(e)}")
165
- return False
166
-
167
- def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
168
- try:
169
- # Check tensor registry first
170
- if tensor_id not in self.tensor_registry:
171
- print(f"Tensor {tensor_id} not registered in VRAM")
172
- return None
173
-
174
- operation = {
175
- 'operation': 'vram',
176
- 'type': 'read',
177
- 'block_id': tensor_id,
178
- 'expected_metadata': self.tensor_registry.get(tensor_id, {})
179
- }
180
-
181
- response = self._send_operation(operation)
182
- if response.get('status') == 'success':
183
- data = response.get('data')
184
- if data is None:
185
- print(f"No data found for tensor {tensor_id}")
186
- return None
187
-
188
- # Verify tensor metadata
189
- metadata = response.get('metadata', {})
190
- expected_metadata = self.tensor_registry.get(tensor_id, {})
191
- if metadata.get('shape') != expected_metadata.get('shape'):
192
- print(f"Warning: Tensor {tensor_id} shape mismatch")
193
-
194
- try:
195
- # Convert to numpy array with correct dtype
196
- arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
197
- if arr.shape != expected_metadata.get('shape'):
198
- arr = arr.reshape(expected_metadata.get('shape'))
199
- return arr
200
- except Exception as e:
201
- print(f"Error converting tensor data: {str(e)}")
202
- return None
203
- else:
204
- print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
205
- return None
206
- except Exception as e:
207
- print(f"Error loading tensor {tensor_id}: {str(e)}")
208
- return None
209
-
210
- def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
211
- try:
212
- operation = {
213
- 'operation': 'state',
214
- 'type': 'save',
215
- 'component': component,
216
- 'state_id': state_id,
217
- 'data': state_data,
218
- 'timestamp': time.time()
219
- }
220
-
221
- response = self._send_operation(operation)
222
- if response.get('status') != 'success':
223
- print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
224
- return False
225
- return True
226
- except Exception as e:
227
- print(f"Error storing state for {component}/{state_id}: {str(e)}")
228
- return False
229
-
230
- def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
231
- try:
232
- operation = {
233
- 'operation': 'state',
234
- 'type': 'load',
235
- 'component': component,
236
- 'state_id': state_id
237
- }
238
-
239
- response = self._send_operation(operation)
240
- if response.get('status') == 'success':
241
- data = response.get('data')
242
- if data is None:
243
- print(f"No state found for {component}/{state_id}")
244
- return None
245
- return data
246
- else:
247
- print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
248
- return None
249
- except Exception as e:
250
- print(f"Error loading state for {component}/{state_id}: {str(e)}")
251
- return None
252
-
253
- def is_model_loaded(self, model_name: str) -> bool:
254
- """Check if a model is already loaded in VRAM"""
255
- return model_name in self.resource_monitor['loaded_models']
256
-
257
- def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
258
- """Load a model into VRAM if not already loaded"""
259
- try:
260
- # Check if model is already loaded
261
- if self.is_model_loaded(model_name):
262
- print(f"Model {model_name} already loaded in VRAM")
263
- return True
264
-
265
- # Calculate model hash if path provided
266
- model_hash = None
267
- if model_path:
268
- model_hash = self._calculate_model_hash(model_path)
269
-
270
- operation = {
271
- 'operation': 'model',
272
- 'type': 'load',
273
- 'model_name': model_name,
274
- 'model_hash': model_hash,
275
- 'model_data': model_data
276
- }
277
-
278
- response = self._send_operation(operation)
279
- if response.get('status') == 'success':
280
- with self.lock:
281
- self.model_registry[model_name] = {
282
- 'hash': model_hash,
283
- 'timestamp': time.time(),
284
- 'tensors': response.get('tensor_ids', [])
285
- }
286
- self.resource_monitor['loaded_models'].add(model_name)
287
- print(f"Successfully loaded model {model_name}")
288
- return True
289
- else:
290
- print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
291
- return False
292
- except Exception as e:
293
- print(f"Error loading model {model_name}: {str(e)}")
294
- return False
295
-
296
- def _calculate_model_hash(self, model_path: str) -> str:
297
- """Calculate SHA256 hash of model file"""
298
- try:
299
- sha256_hash = hashlib.sha256()
300
- with open(model_path, "rb") as f:
301
- for byte_block in iter(lambda: f.read(4096), b""):
302
- sha256_hash.update(byte_block)
303
- return sha256_hash.hexdigest()
304
- except Exception as e:
305
- print(f"Error calculating model hash: {str(e)}")
306
- return ""
307
-
308
- def cache_data(self, key: str, data: Any) -> bool:
309
- operation = {
310
- 'operation': 'cache',
311
- 'type': 'set',
312
- 'key': key,
313
- 'data': data
314
- }
315
-
316
- response = self._send_operation(operation)
317
- return response.get('status') == 'success'
318
-
319
- def get_cached_data(self, key: str) -> Optional[Any]:
320
- operation = {
321
- 'operation': 'cache',
322
- 'type': 'get',
323
- 'key': key
324
- }
325
-
326
- response = self._send_operation(operation)
327
- if response.get('status') == 'success':
328
- return response['data']
329
- return None
330
-
331
- def wait_for_connection(self, timeout: float = 30.0) -> bool:
332
- """Wait for WebSocket connection to be established"""
333
- start_time = time.time()
334
- while not self._closing and not self.connected:
335
- if time.time() - start_time > timeout:
336
- print("Connection timeout exceeded")
337
- return False
338
- time.sleep(0.1)
339
- return self.connected
340
-
341
- def is_connected(self) -> bool:
342
- """Check if WebSocket connection is active"""
343
- return self.connected and not self._closing
344
-
345
- def get_connection_status(self) -> Dict[str, Any]:
346
- """Get detailed connection status"""
347
- return {
348
- "connected": self.connected,
349
- "closing": self._closing,
350
- "error_count": self.error_count,
351
- "url": self.url,
352
- "last_error_time": self.last_error_time,
353
- "loaded_models": list(self.resource_monitor['loaded_models'])
354
- }
355
-
356
- def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
357
- """Start inference with a loaded model"""
358
- try:
359
- if not self.is_model_loaded(model_name):
360
- print(f"Model {model_name} not loaded. Please load the model first.")
361
- return None
362
-
363
- operation = {
364
- 'operation': 'inference',
365
- 'type': 'run',
366
- 'model_name': model_name,
367
- 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
368
- }
369
-
370
- response = self._send_operation(operation)
371
- if response.get('status') == 'success':
372
- return {
373
- 'output': np.array(response['output']) if 'output' in response else None,
374
- 'metrics': response.get('metrics', {}),
375
- 'model_info': self.model_registry.get(model_name, {})
376
- }
377
- else:
378
- print(f"Inference failed: {response.get('message', 'Unknown error')}")
379
- return None
380
- except Exception as e:
381
- print(f"Error during inference: {str(e)}")
382
- return None
383
-
384
- def close(self):
385
- """Close WebSocket connection and cleanup resources."""
386
- if not self._closing:
387
- self._closing = True
388
- if self.websocket and self._loop:
389
- async def cleanup():
390
- try:
391
- # Clean up registries
392
- with self.lock:
393
- self.tensor_registry.clear()
394
- self.model_registry.clear()
395
- self.resource_monitor['vram_used'] = 0
396
- self.resource_monitor['active_tensors'] = 0
397
- self.resource_monitor['loaded_models'].clear()
398
-
399
- # Notify server about cleanup
400
- if self.connected:
401
- try:
402
- await self.websocket.send(json.dumps({
403
- 'operation': 'cleanup',
404
- 'type': 'full'
405
- }))
406
- except:
407
- pass
408
-
409
- await self.websocket.close()
410
- except Exception as e:
411
- print(f"Error during cleanup: {str(e)}")
412
- finally:
413
- self.connected = False
414
-
415
- if self._loop.is_running():
416
- self._loop.create_task(cleanup())
417
- else:
418
- asyncio.run(cleanup())
419
-
420
- async def aclose(self):
421
- """Asynchronously close WebSocket connection."""
422
- if not self._closing:
423
- self._closing = True
424
- if self.websocket:
425
- try:
426
- await self.websocket.close()
427
- except:
428
- pass
429
- finally:
430
- self.connected = False
431
-
432
- def __del__(self):
433
- """Ensure cleanup on deletion."""
434
- self.close()
 
1
+ import websockets
2
+ import json
3
+ import numpy as np
4
+ from typing import Dict, Any, Optional, Union
5
+ import threading
6
+ from queue import Queue
7
+ import time
8
+ import asyncio
9
+ class WebSocketGPUStorage:
10
+ def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"): # Default to local WebSocket server
11
+ self.url = url
12
+ self.websocket = None
13
+ self.connected = False
14
+ self.message_queue = Queue()
15
+ self.response_queues: Dict[str, Queue] = {}
16
+ self.lock = threading.Lock()
17
+ self._closing = False
18
+ self._loop = None
19
+ self.error_count = 0
20
+ self.last_error_time = 0
21
+ self.max_retries = 5
22
+ self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
23
+ self.resource_monitor = {'vram_used': 0, 'active_tensors': 0}
24
+ self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
25
+ self.resource_monitor = {
26
+ 'vram_used': 0,
27
+ 'active_tensors': 0,
28
+ 'loaded_models': set()
29
+ }
30
+
31
+ # Start WebSocket connection in a separate thread
32
+ self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
33
+ self.ws_thread.start()
34
+
35
+ def _run_websocket_loop(self):
36
+ self._loop = asyncio.new_event_loop()
37
+ asyncio.set_event_loop(self._loop)
38
+ self._loop.run_until_complete(self._websocket_handler())
39
+
40
+ async def _websocket_handler(self):
41
+ while not self._closing:
42
+ try:
43
+ async with websockets.connect(self.url) as websocket:
44
+ self.websocket = websocket
45
+ self.connected = True
46
+ self.error_count = 0 # Reset error count on successful connection
47
+ print("Connected to GPU storage server")
48
+
49
+ while True:
50
+ # Handle outgoing messages
51
+ try:
52
+ while not self.message_queue.empty():
53
+ msg_id, operation = self.message_queue.get()
54
+ await websocket.send(json.dumps(operation))
55
+
56
+ # Wait for response with timeout
57
+ try:
58
+ response = await asyncio.wait_for(websocket.recv(), timeout=30)
59
+ response_data = json.loads(response)
60
+
61
+ # Put response in corresponding queue
62
+ if msg_id in self.response_queues:
63
+ self.response_queues[msg_id].put(response_data)
64
+ except asyncio.TimeoutError:
65
+ if msg_id in self.response_queues:
66
+ self.response_queues[msg_id].put({
67
+ "status": "error",
68
+ "message": "Operation timed out"
69
+ })
70
+ except Exception as e:
71
+ if msg_id in self.response_queues:
72
+ self.response_queues[msg_id].put({
73
+ "status": "error",
74
+ "message": f"Error processing response: {str(e)}"
75
+ })
76
+
77
+ except Exception as e:
78
+ print(f"Error processing message: {str(e)}")
79
+
80
+ # Keep connection alive with heartbeat
81
+ try:
82
+ await websocket.ping()
83
+ except:
84
+ break # Break inner loop on ping failure
85
+
86
+ await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
87
+
88
+ except Exception as e:
89
+ print(f"WebSocket connection error: {e}")
90
+ self.connected = False
91
+ await asyncio.sleep(1) # Wait before reconnecting
92
+
93
+ def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
94
+ if self._closing:
95
+ return {"status": "error", "message": "WebSocket is closing"}
96
+
97
+ if not self.wait_for_connection(timeout=10):
98
+ return {"status": "error", "message": "Not connected to GPU storage server"}
99
+
100
+ msg_id = str(time.time())
101
+ response_queue = Queue()
102
+
103
+ with self.lock:
104
+ self.response_queues[msg_id] = response_queue
105
+ self.message_queue.put((msg_id, operation))
106
+
107
+ try:
108
+ # Wait for response with configurable timeout
109
+ response = response_queue.get(timeout=30) # Extended timeout for large models
110
+ if response.get("status") == "error" and "model_size" in operation:
111
+ # Retry once for model loading operations
112
+ self.message_queue.put((msg_id, operation))
113
+ response = response_queue.get(timeout=30)
114
+ except Exception as e:
115
+ response = {"status": "error", "message": f"Operation failed: {str(e)}"}
116
+ finally:
117
+ with self.lock:
118
+ if msg_id in self.response_queues:
119
+ del self.response_queues[msg_id]
120
+
121
+ return response
122
+
123
+ def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
124
+ try:
125
+ if data is None:
126
+ raise ValueError("Cannot store None tensor")
127
+
128
+ # Calculate tensor metadata
129
+ tensor_shape = data.shape
130
+ tensor_dtype = str(data.dtype)
131
+ tensor_size = data.nbytes
132
+
133
+ operation = {
134
+ 'operation': 'vram',
135
+ 'type': 'write',
136
+ 'block_id': tensor_id,
137
+ 'data': data.tolist(),
138
+ 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
139
+ 'metadata': {
140
+ 'shape': tensor_shape,
141
+ 'dtype': tensor_dtype,
142
+ 'size': tensor_size,
143
+ 'timestamp': time.time()
144
+ }
145
+ }
146
+
147
+ response = self._send_operation(operation)
148
+ if response.get('status') == 'success':
149
+ # Update tensor registry
150
+ with self.lock:
151
+ self.tensor_registry[tensor_id] = {
152
+ 'shape': tensor_shape,
153
+ 'dtype': tensor_dtype,
154
+ 'size': tensor_size,
155
+ 'timestamp': time.time()
156
+ }
157
+ self.resource_monitor['vram_used'] += tensor_size
158
+ self.resource_monitor['active_tensors'] += 1
159
+ return True
160
+ else:
161
+ print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
162
+ return False
163
+ except Exception as e:
164
+ print(f"Error storing tensor {tensor_id}: {str(e)}")
165
+ return False
166
+
167
+ def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
168
+ try:
169
+ # Check tensor registry first
170
+ if tensor_id not in self.tensor_registry:
171
+ print(f"Tensor {tensor_id} not registered in VRAM")
172
+ return None
173
+
174
+ operation = {
175
+ 'operation': 'vram',
176
+ 'type': 'read',
177
+ 'block_id': tensor_id,
178
+ 'expected_metadata': self.tensor_registry.get(tensor_id, {})
179
+ }
180
+
181
+ response = self._send_operation(operation)
182
+ if response.get('status') == 'success':
183
+ data = response.get('data')
184
+ if data is None:
185
+ print(f"No data found for tensor {tensor_id}")
186
+ return None
187
+
188
+ # Verify tensor metadata
189
+ metadata = response.get('metadata', {})
190
+ expected_metadata = self.tensor_registry.get(tensor_id, {})
191
+ if metadata.get('shape') != expected_metadata.get('shape'):
192
+ print(f"Warning: Tensor {tensor_id} shape mismatch")
193
+
194
+ try:
195
+ # Convert to numpy array with correct dtype
196
+ arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
197
+ if arr.shape != expected_metadata.get('shape'):
198
+ arr = arr.reshape(expected_metadata.get('shape'))
199
+ return arr
200
+ except Exception as e:
201
+ print(f"Error converting tensor data: {str(e)}")
202
+ return None
203
+ else:
204
+ print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
205
+ return None
206
+ except Exception as e:
207
+ print(f"Error loading tensor {tensor_id}: {str(e)}")
208
+ return None
209
+
210
+ def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
211
+ try:
212
+ operation = {
213
+ 'operation': 'state',
214
+ 'type': 'save',
215
+ 'component': component,
216
+ 'state_id': state_id,
217
+ 'data': state_data,
218
+ 'timestamp': time.time()
219
+ }
220
+
221
+ response = self._send_operation(operation)
222
+ if response.get('status') != 'success':
223
+ print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
224
+ return False
225
+ return True
226
+ except Exception as e:
227
+ print(f"Error storing state for {component}/{state_id}: {str(e)}")
228
+ return False
229
+
230
+ def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
231
+ try:
232
+ operation = {
233
+ 'operation': 'state',
234
+ 'type': 'load',
235
+ 'component': component,
236
+ 'state_id': state_id
237
+ }
238
+
239
+ response = self._send_operation(operation)
240
+ if response.get('status') == 'success':
241
+ data = response.get('data')
242
+ if data is None:
243
+ print(f"No state found for {component}/{state_id}")
244
+ return None
245
+ return data
246
+ else:
247
+ print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
248
+ return None
249
+ except Exception as e:
250
+ print(f"Error loading state for {component}/{state_id}: {str(e)}")
251
+ return None
252
+
253
+ def is_model_loaded(self, model_name: str) -> bool:
254
+ """Check if a model is already loaded in VRAM"""
255
+ return model_name in self.resource_monitor['loaded_models']
256
+
257
+ def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
258
+ """Load a model into VRAM if not already loaded"""
259
+ try:
260
+ # Check if model is already loaded
261
+ if self.is_model_loaded(model_name):
262
+ print(f"Model {model_name} already loaded in VRAM")
263
+ return True
264
+
265
+ # Calculate model hash if path provided
266
+ model_hash = None
267
+ if model_path:
268
+ model_hash = self._calculate_model_hash(model_path)
269
+
270
+ operation = {
271
+ 'operation': 'model',
272
+ 'type': 'load',
273
+ 'model_name': model_name,
274
+ 'model_hash': model_hash,
275
+ 'model_data': model_data
276
+ }
277
+
278
+ response = self._send_operation(operation)
279
+ if response.get('status') == 'success':
280
+ with self.lock:
281
+ self.model_registry[model_name] = {
282
+ 'hash': model_hash,
283
+ 'timestamp': time.time(),
284
+ 'tensors': response.get('tensor_ids', [])
285
+ }
286
+ self.resource_monitor['loaded_models'].add(model_name)
287
+ print(f"Successfully loaded model {model_name}")
288
+ return True
289
+ else:
290
+ print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
291
+ return False
292
+ except Exception as e:
293
+ print(f"Error loading model {model_name}: {str(e)}")
294
+ return False
295
+
296
+ def _calculate_model_hash(self, model_path: str) -> str:
297
+ """Calculate SHA256 hash of model file"""
298
+ try:
299
+ sha256_hash = hashlib.sha256()
300
+ with open(model_path, "rb") as f:
301
+ for byte_block in iter(lambda: f.read(4096), b""):
302
+ sha256_hash.update(byte_block)
303
+ return sha256_hash.hexdigest()
304
+ except Exception as e:
305
+ print(f"Error calculating model hash: {str(e)}")
306
+ return ""
307
+
308
+ def cache_data(self, key: str, data: Any) -> bool:
309
+ operation = {
310
+ 'operation': 'cache',
311
+ 'type': 'set',
312
+ 'key': key,
313
+ 'data': data
314
+ }
315
+
316
+ response = self._send_operation(operation)
317
+ return response.get('status') == 'success'
318
+
319
+ def get_cached_data(self, key: str) -> Optional[Any]:
320
+ operation = {
321
+ 'operation': 'cache',
322
+ 'type': 'get',
323
+ 'key': key
324
+ }
325
+
326
+ response = self._send_operation(operation)
327
+ if response.get('status') == 'success':
328
+ return response['data']
329
+ return None
330
+
331
+ def wait_for_connection(self, timeout: float = 30.0) -> bool:
332
+ """Wait for WebSocket connection to be established"""
333
+ start_time = time.time()
334
+ while not self._closing and not self.connected:
335
+ if time.time() - start_time > timeout:
336
+ print("Connection timeout exceeded")
337
+ return False
338
+ time.sleep(0.1)
339
+ return self.connected
340
+
341
+ def is_connected(self) -> bool:
342
+ """Check if WebSocket connection is active"""
343
+ return self.connected and not self._closing
344
+
345
+ def get_connection_status(self) -> Dict[str, Any]:
346
+ """Get detailed connection status"""
347
+ return {
348
+ "connected": self.connected,
349
+ "closing": self._closing,
350
+ "error_count": self.error_count,
351
+ "url": self.url,
352
+ "last_error_time": self.last_error_time,
353
+ "loaded_models": list(self.resource_monitor['loaded_models'])
354
+ }
355
+
356
+ def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
357
+ """Start inference with a loaded model"""
358
+ try:
359
+ if not self.is_model_loaded(model_name):
360
+ print(f"Model {model_name} not loaded. Please load the model first.")
361
+ return None
362
+
363
+ operation = {
364
+ 'operation': 'inference',
365
+ 'type': 'run',
366
+ 'model_name': model_name,
367
+ 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
368
+ }
369
+
370
+ response = self._send_operation(operation)
371
+ if response.get('status') == 'success':
372
+ return {
373
+ 'output': np.array(response['output']) if 'output' in response else None,
374
+ 'metrics': response.get('metrics', {}),
375
+ 'model_info': self.model_registry.get(model_name, {})
376
+ }
377
+ else:
378
+ print(f"Inference failed: {response.get('message', 'Unknown error')}")
379
+ return None
380
+ except Exception as e:
381
+ print(f"Error during inference: {str(e)}")
382
+ return None
383
+
384
+ def close(self):
385
+ """Close WebSocket connection and cleanup resources."""
386
+ if not self._closing:
387
+ self._closing = True
388
+ if self.websocket and self._loop:
389
+ async def cleanup():
390
+ try:
391
+ # Clean up registries
392
+ with self.lock:
393
+ self.tensor_registry.clear()
394
+ self.model_registry.clear()
395
+ self.resource_monitor['vram_used'] = 0
396
+ self.resource_monitor['active_tensors'] = 0
397
+ self.resource_monitor['loaded_models'].clear()
398
+
399
+ # Notify server about cleanup
400
+ if self.connected:
401
+ try:
402
+ await self.websocket.send(json.dumps({
403
+ 'operation': 'cleanup',
404
+ 'type': 'full'
405
+ }))
406
+ except:
407
+ pass
408
+
409
+ await self.websocket.close()
410
+ except Exception as e:
411
+ print(f"Error during cleanup: {str(e)}")
412
+ finally:
413
+ self.connected = False
414
+
415
+ if self._loop.is_running():
416
+ self._loop.create_task(cleanup())
417
+ else:
418
+ asyncio.run(cleanup())
419
+
420
+ async def aclose(self):
421
+ """Asynchronously close WebSocket connection."""
422
+ if not self._closing:
423
+ self._closing = True
424
+ if self.websocket:
425
+ try:
426
+ await self.websocket.close()
427
+ except:
428
+ pass
429
+ finally:
430
+ self.connected = False
431
+
432
+ def __del__(self):
433
+ """Ensure cleanup on deletion."""
434
+ self.close()