Factor Studios commited on
Commit
43464e3
·
verified ·
1 Parent(s): f55c75f

Upload 36 files

Browse files
ai.py CHANGED
@@ -1,8 +1,13 @@
 
1
  import numpy as np
2
  import time
3
  from typing import Dict, Any, Optional, Tuple, Union, List
4
  from enum import Enum
5
- from tensor_core import TensorCoreArray
 
 
 
 
6
 
7
  class VectorOperation(Enum):
8
  """Enumeration of supported vector operations."""
@@ -17,33 +22,23 @@ class VectorOperation(Enum):
17
 
18
 
19
  class AIAccelerator:
20
- """
21
- AI Accelerator that simulates GPU-based AI computations.
22
-
23
- This class leverages NumPy's optimized operations to simulate the parallel
24
- processing capabilities of the vGPU for AI workloads.
25
- """
26
-
27
  def __init__(self, vram=None, num_sms: int = 800, cores_per_sm: int = 222, storage=None):
28
- """Initialize AI Accelerator with electron-speed awareness and shared WebSocket storage."""
29
- from electron_speed import TARGET_SWITCHES_PER_SEC, TRANSISTORS_ON_CHIP, drift_velocity
30
-
31
- self.storage = storage # Use the shared storage instance
32
- if self.storage is None:
33
- from websocket_storage import WebSocketGPUStorage
34
- self.storage = WebSocketGPUStorage() # Only create new if not provided
35
- if not self.storage.wait_for_connection():
36
- raise RuntimeError("Could not connect to GPU storage server")
37
-
38
- self.vram = vram
39
  self.num_sms = num_sms
40
  self.cores_per_sm = cores_per_sm
41
  self.total_cores = num_sms * cores_per_sm
42
-
43
- # Configure for maximum parallel processing at electron speed
44
- total_tensor_cores = num_sms * cores_per_sm # Use ALL cores for tensor operations
 
 
45
  self.tensor_core_array = TensorCoreArray(
46
- num_tensor_cores=total_tensor_cores,
47
  bits=32,
48
  bandwidth_tbps=drift_velocity / 1e-12 # Bandwidth scaled to electron drift speed
49
  )
@@ -116,7 +111,7 @@ class AIAccelerator:
116
  except Exception as e:
117
  return f"<Unserializable object of type {type(config).__name__}: {str(e)}>"
118
 
119
- def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
120
  """Store model state in WebSocket storage with proper serialization."""
121
  try:
122
  # Convert any non-serializable parts of model_info
@@ -126,25 +121,14 @@ class AIAccelerator:
126
  self.model_registry[model_name] = serializable_info
127
 
128
  # Save to storage
129
- if self.storage:
130
- # Store model info
131
- info_success = self.storage.store_state(
132
- "models",
133
- f"{model_name}/info",
134
- serializable_info
135
- )
136
 
137
- # Store model state
138
- state_success = self.storage.store_state(
139
- "models",
140
- f"{model_name}/state",
141
- {"loaded": True, "timestamp": time.time()}
142
- )
143
-
144
- if info_success and state_success:
145
- self.resource_monitor['loaded_models'].add(model_name)
146
  return True
147
-
148
  return False
149
  except Exception as e:
150
  print(f"Error storing model state: {str(e)}")
@@ -209,14 +193,11 @@ class AIAccelerator:
209
  self.min_batch_size = 4
210
  self.dynamic_batching = True # Enable automatic batch size adjustment
211
 
212
- def set_vram(self, vram):
213
- """Set the VRAM reference."""
214
- self.vram = vram
215
-
216
  def allocate_matrix(self, shape: Tuple[int, ...], dtype=np.float32,
217
  name: Optional[str] = None) -> str:
218
  """Allocate a matrix in VRAM and return its ID."""
219
- if not self.vram:
220
  raise RuntimeError("VRAM not available")
221
 
222
  if name is None:
@@ -227,14 +208,14 @@ class AIAccelerator:
227
  matrix_data = np.zeros(shape, dtype=dtype)
228
 
229
  # Store in VRAM as a texture (reusing texture storage mechanism)
230
- matrix_id = self.vram.load_texture(matrix_data, name)
231
  self.matrix_registry[name] = matrix_id
232
 
233
  return name
234
 
235
  def load_matrix(self, matrix_data: np.ndarray, name: Optional[str] = None) -> str:
236
  """Load matrix data into VRAM and return its ID."""
237
- if not self.vram:
238
  raise RuntimeError("VRAM not available")
239
 
240
  if name is None:
@@ -242,18 +223,18 @@ class AIAccelerator:
242
  self.matrix_counter += 1
243
 
244
  # Store in VRAM
245
- matrix_id = self.vram.load_texture(matrix_data, name)
246
  self.matrix_registry[name] = matrix_id
247
 
248
  return name
249
 
250
  def get_matrix(self, matrix_id: str) -> Optional[np.ndarray]:
251
  """Retrieve matrix data from VRAM."""
252
- if not self.vram or matrix_id not in self.matrix_registry:
253
  return None
254
 
255
  vram_id = self.matrix_registry[matrix_id]
256
- return self.vram.get_texture(vram_id)
257
 
258
  def matrix_multiply(self, matrix_a_id: str, matrix_b_id: str,
259
  result_id: Optional[str] = None) -> Optional[str]:
@@ -801,3 +782,7 @@ class AIAccelerator:
801
  return None
802
 
803
 
 
 
 
 
 
1
+ import json
2
  import numpy as np
3
  import time
4
  from typing import Dict, Any, Optional, Tuple, Union, List
5
  from enum import Enum
6
+ from electron_speed import TARGET_SWITCHES_PER_SEC, TRANSISTORS_ON_CHIP, drift_velocity
7
+
8
+ from network_tensor_core import TensorCoreArray
9
+ from websocket_storage import WebSocketGPUStorage
10
+ from websocket_model_storage import WebSocketModelStorage
11
 
12
  class VectorOperation(Enum):
13
  """Enumeration of supported vector operations."""
 
22
 
23
 
24
  class AIAccelerator:
25
+ """AI Accelerator that leverages electron-speed physics for optimized AI inference and virtual GPU operations."""
26
+
 
 
 
 
 
27
  def __init__(self, vram=None, num_sms: int = 800, cores_per_sm: int = 222, storage=None):
28
+ self.gpu_storage = WebSocketGPUStorage("ws://localhost:7860/ws") # For tensor operations and general GPU state
29
+ self.model_storage = WebSocketModelStorage("ws://localhost:7860/ws/model") # For model upload/download
30
+
31
+ self.vram = self.gpu_storage # VRAM operations will go through gpu_storage
 
 
 
 
 
 
 
32
  self.num_sms = num_sms
33
  self.cores_per_sm = cores_per_sm
34
  self.total_cores = num_sms * cores_per_sm
35
+
36
+ async def connect_to_storage(self):
37
+ if not self.gpu_storage.wait_for_connection():
38
+ raise RuntimeError("Could not connect to GPU storage server")
39
+ await self.model_storage.connect()
40
  self.tensor_core_array = TensorCoreArray(
41
+ num_tensor_cores=self.total_cores,
42
  bits=32,
43
  bandwidth_tbps=drift_velocity / 1e-12 # Bandwidth scaled to electron drift speed
44
  )
 
111
  except Exception as e:
112
  return f"<Unserializable object of type {type(config).__name__}: {str(e)}>"
113
 
114
+ async def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
115
  """Store model state in WebSocket storage with proper serialization."""
116
  try:
117
  # Convert any non-serializable parts of model_info
 
121
  self.model_registry[model_name] = serializable_info
122
 
123
  # Save to storage
124
+ if self.model_storage:
125
+ # Convert model_info to JSON string for upload
126
+ model_data_str = json.dumps(serializable_info)
127
+ upload_success = await self.model_storage.upload_model(model_name, model_data_str)
 
 
 
128
 
129
+ if upload_success:
130
+ self.resource_monitor["loaded_models"].add(model_name)
 
 
 
 
 
 
 
131
  return True
 
132
  return False
133
  except Exception as e:
134
  print(f"Error storing model state: {str(e)}")
 
193
  self.min_batch_size = 4
194
  self.dynamic_batching = True # Enable automatic batch size adjustment
195
 
196
+
 
 
 
197
  def allocate_matrix(self, shape: Tuple[int, ...], dtype=np.float32,
198
  name: Optional[str] = None) -> str:
199
  """Allocate a matrix in VRAM and return its ID."""
200
+ if not self.gpu_storage:
201
  raise RuntimeError("VRAM not available")
202
 
203
  if name is None:
 
208
  matrix_data = np.zeros(shape, dtype=dtype)
209
 
210
  # Store in VRAM as a texture (reusing texture storage mechanism)
211
+ matrix_id = self.gpu_storage.load_texture(matrix_data, name)
212
  self.matrix_registry[name] = matrix_id
213
 
214
  return name
215
 
216
  def load_matrix(self, matrix_data: np.ndarray, name: Optional[str] = None) -> str:
217
  """Load matrix data into VRAM and return its ID."""
218
+ if not self.gpu_storage:
219
  raise RuntimeError("VRAM not available")
220
 
221
  if name is None:
 
223
  self.matrix_counter += 1
224
 
225
  # Store in VRAM
226
+ matrix_id = self.gpu_storage.load_texture(matrix_data, name)
227
  self.matrix_registry[name] = matrix_id
228
 
229
  return name
230
 
231
  def get_matrix(self, matrix_id: str) -> Optional[np.ndarray]:
232
  """Retrieve matrix data from VRAM."""
233
+ if not self.gpu_storage or matrix_id not in self.matrix_registry:
234
  return None
235
 
236
  vram_id = self.matrix_registry[matrix_id]
237
+ return self.gpu_storage.get_texture(vram_id)
238
 
239
  def matrix_multiply(self, matrix_a_id: str, matrix_b_id: str,
240
  result_id: Optional[str] = None) -> Optional[str]:
 
782
  return None
783
 
784
 
785
+
786
+
787
+
788
+
network_tensor_core.py CHANGED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import websockets
3
+ import json
4
+ import numpy as np
5
+ from typing import List, Any, Optional, Dict
6
+
7
+ class TensorCoreArray:
8
+ def __init__(self, num_tensor_cores: int, bits: int, bandwidth_tbps: float):
9
+ self.num_tensor_cores = num_tensor_cores
10
+ self.bits = bits
11
+ self.bandwidth_tbps = bandwidth_tbps
12
+ self.initialized = False
13
+
14
+ def initialize(self):
15
+ print(f"Initializing {self.num_tensor_cores} tensor cores with {self.bits}-bit precision...")
16
+ self.initialized = True
17
+
18
+ def matmul(self, matrix_a: List[List[float]], matrix_b: List[List[float]]) -> List[List[float]]:
19
+ if not self.initialized:
20
+ raise RuntimeError("Tensor cores not initialized. Call initialize() first.")
21
+
22
+ np_a = np.array(matrix_a)
23
+ np_b = np.array(matrix_b)
24
+
25
+ if np_a.shape[1] != np_b.shape[0]:
26
+ raise ValueError("Matrix dimensions incompatible for multiplication")
27
+
28
+ result = np.matmul(np_a, np_b)
29
+ return result.tolist()
30
+
31
+ async def send_tensor_data(self, uri: str, tensor_id: str, data: np.ndarray):
32
+ async with websockets.connect(uri) as websocket:
33
+ payload = {
34
+ "operation": "tensor_data",
35
+ "type": "send",
36
+ "tensor_id": tensor_id,
37
+ "data": data.tolist()
38
+ }
39
+ await websocket.send(json.dumps(payload))
40
+ response = await websocket.recv()
41
+ return json.loads(response)
42
+
43
+ async def receive_tensor_data(self, uri: str, tensor_id: str) -> Optional[np.ndarray]:
44
+ async with websockets.connect(uri) as websocket:
45
+ payload = {
46
+ "operation": "tensor_data",
47
+ "type": "receive",
48
+ "tensor_id": tensor_id
49
+ }
50
+ await websocket.send(json.dumps(payload))
51
+ response = await websocket.recv()
52
+ response_data = json.loads(response)
53
+ if response_data.get("status") == "success":
54
+ return np.array(response_data["data"])
55
+ return None
56
+
57
+ def get_status(self) -> Dict[str, Any]:
58
+ return {
59
+ "num_tensor_cores": self.num_tensor_cores,
60
+ "bits": self.bits,
61
+ "bandwidth_tbps": self.bandwidth_tbps,
62
+ "initialized": self.initialized
63
+ }
64
+
65
+ if __name__ == "__main__":
66
+ async def test_tensor_core_array():
67
+ tca = TensorCoreArray(num_tensor_cores=10, bits=32, bandwidth_tbps=1.0)
68
+ tca.initialize()
69
+
70
+ matrix_a = [[1, 2], [3, 4]]
71
+ matrix_b = [[5, 6], [7, 8]]
72
+
73
+ result = tca.matmul(matrix_a, matrix_b)
74
+ print(f"Matrix multiplication result: {result}")
75
+
76
+ # Example of sending/receiving tensor data (requires a running WebSocket server)
77
+ # uri = "ws://localhost:7860/ws"
78
+ # tensor_id = "test_tensor"
79
+ # dummy_data = np.array([[10, 20], [30, 40]])
80
+ #
81
+ # print(f"Sending tensor data: {dummy_data.tolist()}")
82
+ # send_response = await tca.send_tensor_data(uri, tensor_id, dummy_data)
83
+ # print(f"Send response: {send_response}")
84
+ #
85
+ # received_data = await tca.receive_tensor_data(uri, tensor_id)
86
+ # print(f"Received tensor data: {received_data.tolist() if received_data is not None else None}")
87
+
88
+ asyncio.run(test_tensor_core_array())
89
+
90
+
network_vram_server.py CHANGED
@@ -1,45 +0,0 @@
1
-
2
- import asyncio
3
- import websockets
4
- import json
5
-
6
- class VRAMServer:
7
- def __init__(self):
8
- self.vram_state = {}
9
-
10
- async def handler(self, websocket):
11
- async for message in websocket:
12
- try:
13
- operation = json.loads(message)
14
- op_type = operation.get("operation")
15
-
16
- if op_type == "vram/state":
17
- state_type = operation.get("type")
18
- key = operation.get("key")
19
-
20
- if state_type == "write":
21
- data = operation.get("data")
22
- self.vram_state[key] = data
23
- await websocket.send(json.dumps({"status": "success", "message": "State stored"}))
24
- elif state_type == "read":
25
- data = self.vram_state.get(key)
26
- if data is not None:
27
- await websocket.send(json.dumps({"status": "success", "data": data}))
28
- else:
29
- await websocket.send(json.dumps({"status": "error", "message": "State not found"}))
30
- else:
31
- await websocket.send(json.dumps({"status": "error", "message": "Unknown state operation type"}))
32
- else:
33
- await websocket.send(json.dumps({"status": "error", "message": "Unknown operation"}))
34
- except Exception as e:
35
- await websocket.send(json.dumps({"status": "error", "message": str(e)}))
36
-
37
- async def main():
38
- server = VRAMServer()
39
- async with websockets.serve(server.handler, "0.0.0.0", 8765):
40
- await asyncio.Future()
41
-
42
- if __name__ == "__main__":
43
- asyncio.run(main())
44
-
45
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_ai.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import numpy as np
3
+ from ai import AIAccelerator
4
+
5
+ async def main():
6
+ print("\n--- Testing AIAccelerator with WebSocket Storage ---")
7
+ try:
8
+ accelerator = AIAccelerator()
9
+ await accelerator.connect_to_storage()
10
+ print("AIAccelerator initialized and connected successfully.")
11
+
12
+ # Test model upload
13
+ dummy_model_info = {"layers": 5, "neurons": 100, "type": "CNN"}
14
+ model_name = "test_cnn_model"
15
+ print(f"Attempting to store model: {model_name}")
16
+ if await accelerator.store_model_state(model_name, dummy_model_info):
17
+ print(f"Model \'{model_name}\' stored successfully.")
18
+ else:
19
+ print(f"Failed to store model \'{model_name}\'")
20
+
21
+ # Test tensor core initialization (requires VRAM connection)
22
+ print("Attempting to initialize tensor cores...")
23
+ if accelerator.initialize_tensor_cores():
24
+ print("Tensor cores initialized successfully.")
25
+ else:
26
+ print("Failed to initialize tensor cores.")
27
+
28
+ except Exception as e:
29
+ print(f"An error occurred during AIAccelerator testing: {e}")
30
+
31
+ if __name__ == "__main__":
32
+ asyncio.run(main())
33
+
34
+
websocket_model_storage.py CHANGED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import websockets
3
+ import json
4
+ import numpy as np
5
+
6
+ class WebSocketModelStorage:
7
+ def __init__(self, uri):
8
+ self.uri = uri
9
+ self.websocket = None
10
+
11
+ async def connect(self):
12
+ self.websocket = await websockets.connect(self.uri, max_size=None)
13
+
14
+ async def disconnect(self):
15
+ if self.websocket:
16
+ await self.websocket.close()
17
+
18
+ async def upload_model_chunk(self, model_id, chunk_id, chunk_data):
19
+ payload = {
20
+ "operation": "vram",
21
+ "type": "write",
22
+ "block_id": f"{model_id}_{chunk_id}",
23
+ "data": chunk_data.tolist() if isinstance(chunk_data, np.ndarray) else chunk_data
24
+ }
25
+ await self.websocket.send(json.dumps(payload))
26
+ response = await self.websocket.recv()
27
+ return json.loads(response)
28
+
29
+ async def download_model_chunk(self, model_id, chunk_id):
30
+ payload = {
31
+ "operation": "vram",
32
+ "type": "read",
33
+ "block_id": f"{model_id}_{chunk_id}"
34
+ }
35
+ await self.websocket.send(json.dumps(payload))
36
+ response = await self.websocket.recv()
37
+ return json.loads(response)
38
+
39
+ async def upload_model(self, model_id, model_data, chunk_size=1024*1024): # 1MB chunk size
40
+ if isinstance(model_data, np.ndarray):
41
+ model_data_bytes = model_data.tobytes()
42
+ else:
43
+ model_data_bytes = model_data.encode("utf-8") # Assuming string data for now
44
+
45
+ total_size = len(model_data_bytes)
46
+ num_chunks = (total_size + chunk_size - 1) // chunk_size
47
+
48
+ print(f"Uploading model {model_id} in {num_chunks} chunks...")
49
+
50
+ for i in range(num_chunks):
51
+ start = i * chunk_size
52
+ end = min((i + 1) * chunk_size, total_size)
53
+ chunk = model_data_bytes[start:end]
54
+
55
+ # Convert chunk to a list of integers for JSON serialization
56
+ chunk_list = list(chunk)
57
+
58
+ response = await self.upload_model_chunk(model_id, i, chunk_list)
59
+ if response.get("status") != "success":
60
+ print(f"Error uploading chunk {i}: {response.get('message')}")
61
+ return False
62
+ print(f"Uploaded chunk {i+1}/{num_chunks}")
63
+ return True
64
+
65
+ async def download_model(self, model_id, num_chunks):
66
+ print(f"Downloading model {model_id} with {num_chunks} chunks...")
67
+ downloaded_chunks = []
68
+ for i in range(num_chunks):
69
+ response = await self.download_model_chunk(model_id, i)
70
+ if response.get("status") == "success":
71
+ downloaded_chunks.append(np.array(response["data"], dtype=np.uint8).tobytes())
72
+ print(f"Downloaded chunk {i+1}/{num_chunks}")
73
+ else:
74
+ print("Error downloading chunk " + str(i) + ": " + str(response.get("message")))
75
+ return None
76
+
77
+ # Reconstruct the model from downloaded chunks
78
+ full_model_bytes = b"".join(downloaded_chunks)
79
+ return np.frombuffer(full_model_bytes, dtype=np.float32) # Assuming original data type was float32
80
+
81
+ async def main():
82
+ uri = "ws://localhost:7860/ws"
83
+ storage = WebSocketModelStorage(uri)
84
+ await storage.connect()
85
+
86
+ # Example usage: Upload a dummy model
87
+ dummy_model_data = np.random.rand(1024 * 1024 * 5).astype(np.float32) # 5MB dummy model
88
+ model_id = "test_model_123"
89
+ chunk_size = 1024*1024 # Must match the chunk_size in upload_model
90
+ total_size = len(dummy_model_data.tobytes())
91
+ num_chunks = (total_size + chunk_size - 1) // chunk_size
92
+ success = await storage.upload_model(model_id, dummy_model_data)
93
+
94
+ if success:
95
+ print(f"Model {model_id} uploaded successfully.")
96
+ # Test download
97
+ downloaded_model = await storage.download_model(model_id, num_chunks)
98
+ if downloaded_model is not None:
99
+ print(f"Model {model_id} downloaded successfully. Shape: {downloaded_model.shape}")
100
+ # Verify integrity (optional, for testing purposes)
101
+ if np.array_equal(dummy_model_data, downloaded_model):
102
+ print("Downloaded model matches original.")
103
+ else:
104
+ print("Downloaded model DOES NOT match original.")
105
+ else:
106
+ print(f"Model {model_id} download failed.")
107
+ else:
108
+ print(f"Model {model_id} upload failed.")
109
+
110
+ await storage.disconnect()
111
+
112
+ if __name__ == "__main__":
113
+ asyncio.run(main())
114
+
115
+
websocket_storage.py CHANGED
@@ -1,455 +1,455 @@
1
- import websockets
2
- import json
3
- import numpy as np
4
- from typing import Dict, Any, Optional, Union
5
- import threading
6
- from queue import Queue
7
- import time
8
- import asyncio
9
- import hashlib
10
-
11
- class WebSocketGPUStorage:
12
- # Singleton instance
13
- _instance = None
14
- _lock = threading.Lock()
15
-
16
- def __new__(cls, url: str = "wss://factorst-wbs1.hf.space/ws"):
17
- with cls._lock:
18
- if cls._instance is None:
19
- cls._instance = super().__new__(cls)
20
- cls._instance._init_singleton(url)
21
- return cls._instance
22
-
23
- def _init_singleton(self, url: str):
24
- """Initialize the singleton instance"""
25
- if hasattr(self, 'initialized'):
26
- return
27
-
28
- self.url = url
29
- self.websocket = None
30
- self.connected = False
31
- self.message_queue = Queue()
32
- self.response_queues: Dict[str, Queue] = {}
33
- self.lock = threading.Lock()
34
- self._closing = False
35
- self._loop = None
36
- self.error_count = 0
37
- self.last_error_time = 0
38
- self.max_retries = 5
39
- self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
40
- self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
41
- self.resource_monitor = {
42
- 'vram_used': 0,
43
- 'active_tensors': 0,
44
- 'loaded_models': set()
45
- }
46
-
47
- # Start WebSocket connection in a separate thread
48
- self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
49
- self.ws_thread.start()
50
- self.initialized = True
51
-
52
- def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"):
53
- """This will actually just return the singleton instance"""
54
- pass
55
-
56
- def _run_websocket_loop(self):
57
- self._loop = asyncio.new_event_loop()
58
- asyncio.set_event_loop(self._loop)
59
- self._loop.run_until_complete(self._websocket_handler())
60
-
61
- async def _websocket_handler(self):
62
- while not self._closing:
63
- try:
64
- async with websockets.connect(self.url) as websocket:
65
- self.websocket = websocket
66
- self.connected = True
67
- self.error_count = 0 # Reset error count on successful connection
68
- print("Connected to GPU storage server")
69
-
70
- while True:
71
- # Handle outgoing messages
72
- try:
73
- while not self.message_queue.empty():
74
- msg_id, operation = self.message_queue.get()
75
- await websocket.send(json.dumps(operation))
76
-
77
- # Wait for response with timeout
78
- try:
79
- response = await asyncio.wait_for(websocket.recv(), timeout=30)
80
- response_data = json.loads(response)
81
-
82
- # Put response in corresponding queue
83
- if msg_id in self.response_queues:
84
- self.response_queues[msg_id].put(response_data)
85
- except asyncio.TimeoutError:
86
- if msg_id in self.response_queues:
87
- self.response_queues[msg_id].put({
88
- "status": "error",
89
- "message": "Operation timed out"
90
- })
91
- except Exception as e:
92
- if msg_id in self.response_queues:
93
- self.response_queues[msg_id].put({
94
- "status": "error",
95
- "message": f"Error processing response: {str(e)}"
96
- })
97
-
98
- except Exception as e:
99
- print(f"Error processing message: {str(e)}")
100
-
101
- # Keep connection alive with heartbeat
102
- try:
103
- await websocket.ping()
104
- except:
105
- break # Break inner loop on ping failure
106
-
107
- await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
108
-
109
- except Exception as e:
110
- print(f"WebSocket connection error: {e}")
111
- self.connected = False
112
- await asyncio.sleep(1) # Wait before reconnecting
113
-
114
- def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
115
- if self._closing:
116
- return {"status": "error", "message": "WebSocket is closing"}
117
-
118
- if not self.wait_for_connection(timeout=10):
119
- return {"status": "error", "message": "Not connected to GPU storage server"}
120
-
121
- msg_id = str(time.time())
122
- response_queue = Queue()
123
-
124
- with self.lock:
125
- self.response_queues[msg_id] = response_queue
126
- self.message_queue.put((msg_id, operation))
127
-
128
- try:
129
- # Wait for response with configurable timeout
130
- response = response_queue.get(timeout=30) # Extended timeout for large models
131
- if response.get("status") == "error" and "model_size" in operation:
132
- # Retry once for model loading operations
133
- self.message_queue.put((msg_id, operation))
134
- response = response_queue.get(timeout=30)
135
- except Exception as e:
136
- response = {"status": "error", "message": f"Operation failed: {str(e)}"}
137
- finally:
138
- with self.lock:
139
- if msg_id in self.response_queues:
140
- del self.response_queues[msg_id]
141
-
142
- return response
143
-
144
- def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
145
- try:
146
- if data is None:
147
- raise ValueError("Cannot store None tensor")
148
-
149
- # Calculate tensor metadata
150
- tensor_shape = data.shape
151
- tensor_dtype = str(data.dtype)
152
- tensor_size = data.nbytes
153
-
154
- operation = {
155
- 'operation': 'vram',
156
- 'type': 'write',
157
- 'block_id': tensor_id,
158
- 'data': data.tolist(),
159
- 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
160
- 'metadata': {
161
- 'shape': tensor_shape,
162
- 'dtype': tensor_dtype,
163
- 'size': tensor_size,
164
- 'timestamp': time.time()
165
- }
166
- }
167
-
168
- response = self._send_operation(operation)
169
- if response.get('status') == 'success':
170
- # Update tensor registry
171
- with self.lock:
172
- self.tensor_registry[tensor_id] = {
173
- 'shape': tensor_shape,
174
- 'dtype': tensor_dtype,
175
- 'size': tensor_size,
176
- 'timestamp': time.time()
177
- }
178
- self.resource_monitor['vram_used'] += tensor_size
179
- self.resource_monitor['active_tensors'] += 1
180
- return True
181
- else:
182
- print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
183
- return False
184
- except Exception as e:
185
- print(f"Error storing tensor {tensor_id}: {str(e)}")
186
- return False
187
-
188
- def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
189
- try:
190
- # Check tensor registry first
191
- if tensor_id not in self.tensor_registry:
192
- print(f"Tensor {tensor_id} not registered in VRAM")
193
- return None
194
-
195
- operation = {
196
- 'operation': 'vram',
197
- 'type': 'read',
198
- 'block_id': tensor_id,
199
- 'expected_metadata': self.tensor_registry.get(tensor_id, {})
200
- }
201
-
202
- response = self._send_operation(operation)
203
- if response.get('status') == 'success':
204
- data = response.get('data')
205
- if data is None:
206
- print(f"No data found for tensor {tensor_id}")
207
- return None
208
-
209
- # Verify tensor metadata
210
- metadata = response.get('metadata', {})
211
- expected_metadata = self.tensor_registry.get(tensor_id, {})
212
- if metadata.get('shape') != expected_metadata.get('shape'):
213
- print(f"Warning: Tensor {tensor_id} shape mismatch")
214
-
215
- try:
216
- # Convert to numpy array with correct dtype
217
- arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
218
- if arr.shape != expected_metadata.get('shape'):
219
- arr = arr.reshape(expected_metadata.get('shape'))
220
- return arr
221
- except Exception as e:
222
- print(f"Error converting tensor data: {str(e)}")
223
- return None
224
- else:
225
- print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
226
- return None
227
- except Exception as e:
228
- print(f"Error loading tensor {tensor_id}: {str(e)}")
229
- return None
230
-
231
- def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
232
- try:
233
- operation = {
234
- 'operation': 'state',
235
- 'type': 'save',
236
- 'component': component,
237
- 'state_id': state_id,
238
- 'data': state_data,
239
- 'timestamp': time.time()
240
- }
241
-
242
- response = self._send_operation(operation)
243
- if response.get('status') != 'success':
244
- print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
245
- return False
246
- return True
247
- except Exception as e:
248
- print(f"Error storing state for {component}/{state_id}: {str(e)}")
249
- return False
250
-
251
- def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
252
- try:
253
- operation = {
254
- 'operation': 'state',
255
- 'type': 'load',
256
- 'component': component,
257
- 'state_id': state_id
258
- }
259
-
260
- response = self._send_operation(operation)
261
- if response.get('status') == 'success':
262
- data = response.get('data')
263
- if data is None:
264
- print(f"No state found for {component}/{state_id}")
265
- return None
266
- return data
267
- else:
268
- print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
269
- return None
270
- except Exception as e:
271
- print(f"Error loading state for {component}/{state_id}: {str(e)}")
272
- return None
273
-
274
- def is_model_loaded(self, model_name: str) -> bool:
275
- """Check if a model is already loaded in VRAM"""
276
- return model_name in self.resource_monitor['loaded_models']
277
-
278
- def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
279
- """Load a model into VRAM if not already loaded"""
280
- try:
281
- # Check if model is already loaded
282
- if self.is_model_loaded(model_name):
283
- print(f"Model {model_name} already loaded in VRAM")
284
- return True
285
-
286
- # Calculate model hash if path provided
287
- model_hash = None
288
- if model_path:
289
- model_hash = self._calculate_model_hash(model_path)
290
-
291
- operation = {
292
- 'operation': 'model',
293
- 'type': 'load',
294
- 'model_name': model_name,
295
- 'model_hash': model_hash,
296
- 'model_data': model_data
297
- }
298
-
299
- response = self._send_operation(operation)
300
- if response.get('status') == 'success':
301
- with self.lock:
302
- self.model_registry[model_name] = {
303
- 'hash': model_hash,
304
- 'timestamp': time.time(),
305
- 'tensors': response.get('tensor_ids', [])
306
- }
307
- self.resource_monitor['loaded_models'].add(model_name)
308
- print(f"Successfully loaded model {model_name}")
309
- return True
310
- else:
311
- print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
312
- return False
313
- except Exception as e:
314
- print(f"Error loading model {model_name}: {str(e)}")
315
- return False
316
-
317
- def _calculate_model_hash(self, model_path: str) -> str:
318
- """Calculate SHA256 hash of model file"""
319
- try:
320
- sha256_hash = hashlib.sha256()
321
- with open(model_path, "rb") as f:
322
- for byte_block in iter(lambda: f.read(4096), b""):
323
- sha256_hash.update(byte_block)
324
- return sha256_hash.hexdigest()
325
- except Exception as e:
326
- print(f"Error calculating model hash: {str(e)}")
327
- return ""
328
-
329
- def cache_data(self, key: str, data: Any) -> bool:
330
- operation = {
331
- 'operation': 'cache',
332
- 'type': 'set',
333
- 'key': key,
334
- 'data': data
335
- }
336
-
337
- response = self._send_operation(operation)
338
- return response.get('status') == 'success'
339
-
340
- def get_cached_data(self, key: str) -> Optional[Any]:
341
- operation = {
342
- 'operation': 'cache',
343
- 'type': 'get',
344
- 'key': key
345
- }
346
-
347
- response = self._send_operation(operation)
348
- if response.get('status') == 'success':
349
- return response['data']
350
- return None
351
-
352
- def wait_for_connection(self, timeout: float = 30.0) -> bool:
353
- """Wait for WebSocket connection to be established"""
354
- start_time = time.time()
355
- while not self._closing and not self.connected:
356
- if time.time() - start_time > timeout:
357
- print("Connection timeout exceeded")
358
- return False
359
- time.sleep(0.1)
360
- return self.connected
361
-
362
- def is_connected(self) -> bool:
363
- """Check if WebSocket connection is active"""
364
- return self.connected and not self._closing
365
-
366
- def get_connection_status(self) -> Dict[str, Any]:
367
- """Get detailed connection status"""
368
- return {
369
- "connected": self.connected,
370
- "closing": self._closing,
371
- "error_count": self.error_count,
372
- "url": self.url,
373
- "last_error_time": self.last_error_time,
374
- "loaded_models": list(self.resource_monitor['loaded_models'])
375
- }
376
-
377
- def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
378
- """Start inference with a loaded model"""
379
- try:
380
- if not self.is_model_loaded(model_name):
381
- print(f"Model {model_name} not loaded. Please load the model first.")
382
- return None
383
-
384
- operation = {
385
- 'operation': 'inference',
386
- 'type': 'run',
387
- 'model_name': model_name,
388
- 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
389
- }
390
-
391
- response = self._send_operation(operation)
392
- if response.get('status') == 'success':
393
- return {
394
- 'output': np.array(response['output']) if 'output' in response else None,
395
- 'metrics': response.get('metrics', {}),
396
- 'model_info': self.model_registry.get(model_name, {})
397
- }
398
- else:
399
- print(f"Inference failed: {response.get('message', 'Unknown error')}")
400
- return None
401
- except Exception as e:
402
- print(f"Error during inference: {str(e)}")
403
- return None
404
-
405
- def close(self):
406
- """Close WebSocket connection and cleanup resources."""
407
- if not self._closing:
408
- self._closing = True
409
- if self.websocket and self._loop:
410
- async def cleanup():
411
- try:
412
- # Clean up registries
413
- with self.lock:
414
- self.tensor_registry.clear()
415
- self.model_registry.clear()
416
- self.resource_monitor['vram_used'] = 0
417
- self.resource_monitor['active_tensors'] = 0
418
- self.resource_monitor['loaded_models'].clear()
419
-
420
- # Notify server about cleanup
421
- if self.connected:
422
- try:
423
- await self.websocket.send(json.dumps({
424
- 'operation': 'cleanup',
425
- 'type': 'full'
426
- }))
427
- except:
428
- pass
429
-
430
- await self.websocket.close()
431
- except Exception as e:
432
- print(f"Error during cleanup: {str(e)}")
433
- finally:
434
- self.connected = False
435
-
436
- if self._loop.is_running():
437
- self._loop.create_task(cleanup())
438
- else:
439
- asyncio.run(cleanup())
440
-
441
- async def aclose(self):
442
- """Asynchronously close WebSocket connection."""
443
- if not self._closing:
444
- self._closing = True
445
- if self.websocket:
446
- try:
447
- await self.websocket.close()
448
- except:
449
- pass
450
- finally:
451
- self.connected = False
452
-
453
- def __del__(self):
454
- """Ensure cleanup on deletion."""
455
- self.close()
 
1
+ import websockets
2
+ import json
3
+ import numpy as np
4
+ from typing import Dict, Any, Optional, Union
5
+ import threading
6
+ from queue import Queue
7
+ import time
8
+ import asyncio
9
+ import hashlib
10
+
11
+ class WebSocketGPUStorage:
12
+ # Singleton instance
13
+ _instance = None
14
+ _lock = threading.Lock()
15
+
16
+ def __new__(cls, url: str = "ws://localhost:7860/ws"):
17
+ with cls._lock:
18
+ if cls._instance is None:
19
+ cls._instance = super().__new__(cls)
20
+ cls._instance._init_singleton(url)
21
+ return cls._instance
22
+
23
+ def _init_singleton(self, url: str):
24
+ """Initialize the singleton instance"""
25
+ if hasattr(self, 'initialized'):
26
+ return
27
+
28
+ self.url = url
29
+ self.websocket = None
30
+ self.connected = False
31
+ self.message_queue = Queue()
32
+ self.response_queues: Dict[str, Queue] = {}
33
+ self.lock = threading.Lock()
34
+ self._closing = False
35
+ self._loop = None
36
+ self.error_count = 0
37
+ self.last_error_time = 0
38
+ self.max_retries = 5
39
+ self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
40
+ self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
41
+ self.resource_monitor = {
42
+ 'vram_used': 0,
43
+ 'active_tensors': 0,
44
+ 'loaded_models': set()
45
+ }
46
+
47
+ # Start WebSocket connection in a separate thread
48
+ self.ws_thread = threading.Thread(target=self._run_websocket_loop, daemon=True)
49
+ self.ws_thread.start()
50
+ self.initialized = True
51
+
52
+ def __init__(self, url: str = "ws://localhost:7860/ws"):
53
+ """This will actually just return the singleton instance"""
54
+ pass
55
+
56
+ def _run_websocket_loop(self):
57
+ self._loop = asyncio.new_event_loop()
58
+ asyncio.set_event_loop(self._loop)
59
+ self._loop.run_until_complete(self._websocket_handler())
60
+
61
+ async def _websocket_handler(self):
62
+ while not self._closing:
63
+ try:
64
+ async with websockets.connect(self.url) as websocket:
65
+ self.websocket = websocket
66
+ self.connected = True
67
+ self.error_count = 0 # Reset error count on successful connection
68
+ print("Connected to GPU storage server")
69
+
70
+ while True:
71
+ # Handle outgoing messages
72
+ try:
73
+ while not self.message_queue.empty():
74
+ msg_id, operation = self.message_queue.get()
75
+ await websocket.send(json.dumps(operation))
76
+
77
+ # Wait for response with timeout
78
+ try:
79
+ response = await asyncio.wait_for(websocket.recv(), timeout=30)
80
+ response_data = json.loads(response)
81
+
82
+ # Put response in corresponding queue
83
+ if msg_id in self.response_queues:
84
+ self.response_queues[msg_id].put(response_data)
85
+ except asyncio.TimeoutError:
86
+ if msg_id in self.response_queues:
87
+ self.response_queues[msg_id].put({
88
+ "status": "error",
89
+ "message": "Operation timed out"
90
+ })
91
+ except Exception as e:
92
+ if msg_id in self.response_queues:
93
+ self.response_queues[msg_id].put({
94
+ "status": "error",
95
+ "message": f"Error processing response: {str(e)}"
96
+ })
97
+
98
+ except Exception as e:
99
+ print(f"Error processing message: {str(e)}")
100
+
101
+ # Keep connection alive with heartbeat
102
+ try:
103
+ await websocket.ping()
104
+ except:
105
+ break # Break inner loop on ping failure
106
+
107
+ await asyncio.sleep(0.001) # 1ms sleep for electron-speed response
108
+
109
+ except Exception as e:
110
+ print(f"WebSocket connection error: {e}")
111
+ self.connected = False
112
+ await asyncio.sleep(1) # Wait before reconnecting
113
+
114
+ def _send_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]:
115
+ if self._closing:
116
+ return {"status": "error", "message": "WebSocket is closing"}
117
+
118
+ if not self.wait_for_connection(timeout=10):
119
+ return {"status": "error", "message": "Not connected to GPU storage server"}
120
+
121
+ msg_id = str(time.time())
122
+ response_queue = Queue()
123
+
124
+ with self.lock:
125
+ self.response_queues[msg_id] = response_queue
126
+ self.message_queue.put((msg_id, operation))
127
+
128
+ try:
129
+ # Wait for response with configurable timeout
130
+ response = response_queue.get(timeout=30) # Extended timeout for large models
131
+ if response.get("status") == "error" and "model_size" in operation:
132
+ # Retry once for model loading operations
133
+ self.message_queue.put((msg_id, operation))
134
+ response = response_queue.get(timeout=30)
135
+ except Exception as e:
136
+ response = {"status": "error", "message": f"Operation failed: {str(e)}"}
137
+ finally:
138
+ with self.lock:
139
+ if msg_id in self.response_queues:
140
+ del self.response_queues[msg_id]
141
+
142
+ return response
143
+
144
+ def store_tensor(self, tensor_id: str, data: np.ndarray, model_size: Optional[int] = None) -> bool:
145
+ try:
146
+ if data is None:
147
+ raise ValueError("Cannot store None tensor")
148
+
149
+ # Calculate tensor metadata
150
+ tensor_shape = data.shape
151
+ tensor_dtype = str(data.dtype)
152
+ tensor_size = data.nbytes
153
+
154
+ operation = {
155
+ 'operation': 'vram',
156
+ 'type': 'write',
157
+ 'block_id': tensor_id,
158
+ 'data': data.tolist(),
159
+ 'model_size': model_size if model_size is not None else -1, # -1 indicates unlimited
160
+ 'metadata': {
161
+ 'shape': tensor_shape,
162
+ 'dtype': tensor_dtype,
163
+ 'size': tensor_size,
164
+ 'timestamp': time.time()
165
+ }
166
+ }
167
+
168
+ response = self._send_operation(operation)
169
+ if response.get('status') == 'success':
170
+ # Update tensor registry
171
+ with self.lock:
172
+ self.tensor_registry[tensor_id] = {
173
+ 'shape': tensor_shape,
174
+ 'dtype': tensor_dtype,
175
+ 'size': tensor_size,
176
+ 'timestamp': time.time()
177
+ }
178
+ self.resource_monitor['vram_used'] += tensor_size
179
+ self.resource_monitor['active_tensors'] += 1
180
+ return True
181
+ else:
182
+ print(f"Failed to store tensor {tensor_id}: {response.get('message', 'Unknown error')}")
183
+ return False
184
+ except Exception as e:
185
+ print(f"Error storing tensor {tensor_id}: {str(e)}")
186
+ return False
187
+
188
+ def load_tensor(self, tensor_id: str) -> Optional[np.ndarray]:
189
+ try:
190
+ # Check tensor registry first
191
+ if tensor_id not in self.tensor_registry:
192
+ print(f"Tensor {tensor_id} not registered in VRAM")
193
+ return None
194
+
195
+ operation = {
196
+ 'operation': 'vram',
197
+ 'type': 'read',
198
+ 'block_id': tensor_id,
199
+ 'expected_metadata': self.tensor_registry.get(tensor_id, {})
200
+ }
201
+
202
+ response = self._send_operation(operation)
203
+ if response.get('status') == 'success':
204
+ data = response.get('data')
205
+ if data is None:
206
+ print(f"No data found for tensor {tensor_id}")
207
+ return None
208
+
209
+ # Verify tensor metadata
210
+ metadata = response.get('metadata', {})
211
+ expected_metadata = self.tensor_registry.get(tensor_id, {})
212
+ if metadata.get('shape') != expected_metadata.get('shape'):
213
+ print(f"Warning: Tensor {tensor_id} shape mismatch")
214
+
215
+ try:
216
+ # Convert to numpy array with correct dtype
217
+ arr = np.array(data, dtype=np.dtype(expected_metadata.get('dtype', 'float32')))
218
+ if arr.shape != expected_metadata.get('shape'):
219
+ arr = arr.reshape(expected_metadata.get('shape'))
220
+ return arr
221
+ except Exception as e:
222
+ print(f"Error converting tensor data: {str(e)}")
223
+ return None
224
+ else:
225
+ print(f"Failed to load tensor {tensor_id}: {response.get('message', 'Unknown error')}")
226
+ return None
227
+ except Exception as e:
228
+ print(f"Error loading tensor {tensor_id}: {str(e)}")
229
+ return None
230
+
231
+ def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
232
+ try:
233
+ operation = {
234
+ 'operation': 'state',
235
+ 'type': 'save',
236
+ 'component': component,
237
+ 'state_id': state_id,
238
+ 'data': state_data,
239
+ 'timestamp': time.time()
240
+ }
241
+
242
+ response = self._send_operation(operation)
243
+ if response.get('status') != 'success':
244
+ print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
245
+ return False
246
+ return True
247
+ except Exception as e:
248
+ print(f"Error storing state for {component}/{state_id}: {str(e)}")
249
+ return False
250
+
251
+ def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
252
+ try:
253
+ operation = {
254
+ 'operation': 'state',
255
+ 'type': 'load',
256
+ 'component': component,
257
+ 'state_id': state_id
258
+ }
259
+
260
+ response = self._send_operation(operation)
261
+ if response.get('status') == 'success':
262
+ data = response.get('data')
263
+ if data is None:
264
+ print(f"No state found for {component}/{state_id}")
265
+ return None
266
+ return data
267
+ else:
268
+ print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
269
+ return None
270
+ except Exception as e:
271
+ print(f"Error loading state for {component}/{state_id}: {str(e)}")
272
+ return None
273
+
274
+ def is_model_loaded(self, model_name: str) -> bool:
275
+ """Check if a model is already loaded in VRAM"""
276
+ return model_name in self.resource_monitor['loaded_models']
277
+
278
+ def load_model(self, model_name: str, model_path: Optional[str] = None, model_data: Optional[Dict] = None) -> bool:
279
+ """Load a model into VRAM if not already loaded"""
280
+ try:
281
+ # Check if model is already loaded
282
+ if self.is_model_loaded(model_name):
283
+ print(f"Model {model_name} already loaded in VRAM")
284
+ return True
285
+
286
+ # Calculate model hash if path provided
287
+ model_hash = None
288
+ if model_path:
289
+ model_hash = self._calculate_model_hash(model_path)
290
+
291
+ operation = {
292
+ 'operation': 'model',
293
+ 'type': 'load',
294
+ 'model_name': model_name,
295
+ 'model_hash': model_hash,
296
+ 'model_data': model_data
297
+ }
298
+
299
+ response = self._send_operation(operation)
300
+ if response.get('status') == 'success':
301
+ with self.lock:
302
+ self.model_registry[model_name] = {
303
+ 'hash': model_hash,
304
+ 'timestamp': time.time(),
305
+ 'tensors': response.get('tensor_ids', [])
306
+ }
307
+ self.resource_monitor['loaded_models'].add(model_name)
308
+ print(f"Successfully loaded model {model_name}")
309
+ return True
310
+ else:
311
+ print(f"Failed to load model {model_name}: {response.get('message', 'Unknown error')}")
312
+ return False
313
+ except Exception as e:
314
+ print(f"Error loading model {model_name}: {str(e)}")
315
+ return False
316
+
317
+ def _calculate_model_hash(self, model_path: str) -> str:
318
+ """Calculate SHA256 hash of model file"""
319
+ try:
320
+ sha256_hash = hashlib.sha256()
321
+ with open(model_path, "rb") as f:
322
+ for byte_block in iter(lambda: f.read(4096), b""):
323
+ sha256_hash.update(byte_block)
324
+ return sha256_hash.hexdigest()
325
+ except Exception as e:
326
+ print(f"Error calculating model hash: {str(e)}")
327
+ return ""
328
+
329
+ def cache_data(self, key: str, data: Any) -> bool:
330
+ operation = {
331
+ 'operation': 'cache',
332
+ 'type': 'set',
333
+ 'key': key,
334
+ 'data': data
335
+ }
336
+
337
+ response = self._send_operation(operation)
338
+ return response.get('status') == 'success'
339
+
340
+ def get_cached_data(self, key: str) -> Optional[Any]:
341
+ operation = {
342
+ 'operation': 'cache',
343
+ 'type': 'get',
344
+ 'key': key
345
+ }
346
+
347
+ response = self._send_operation(operation)
348
+ if response.get('status') == 'success':
349
+ return response['data']
350
+ return None
351
+
352
+ def wait_for_connection(self, timeout: float = 30.0) -> bool:
353
+ """Wait for WebSocket connection to be established"""
354
+ start_time = time.time()
355
+ while not self._closing and not self.connected:
356
+ if time.time() - start_time > timeout:
357
+ print("Connection timeout exceeded")
358
+ return False
359
+ time.sleep(0.1)
360
+ return self.connected
361
+
362
+ def is_connected(self) -> bool:
363
+ """Check if WebSocket connection is active"""
364
+ return self.connected and not self._closing
365
+
366
+ def get_connection_status(self) -> Dict[str, Any]:
367
+ """Get detailed connection status"""
368
+ return {
369
+ "connected": self.connected,
370
+ "closing": self._closing,
371
+ "error_count": self.error_count,
372
+ "url": self.url,
373
+ "last_error_time": self.last_error_time,
374
+ "loaded_models": list(self.resource_monitor['loaded_models'])
375
+ }
376
+
377
+ def start_inference(self, model_name: str, input_data: np.ndarray) -> Optional[Dict[str, Any]]:
378
+ """Start inference with a loaded model"""
379
+ try:
380
+ if not self.is_model_loaded(model_name):
381
+ print(f"Model {model_name} not loaded. Please load the model first.")
382
+ return None
383
+
384
+ operation = {
385
+ 'operation': 'inference',
386
+ 'type': 'run',
387
+ 'model_name': model_name,
388
+ 'input_data': input_data.tolist() if isinstance(input_data, np.ndarray) else input_data
389
+ }
390
+
391
+ response = self._send_operation(operation)
392
+ if response.get('status') == 'success':
393
+ return {
394
+ 'output': np.array(response['output']) if 'output' in response else None,
395
+ 'metrics': response.get('metrics', {}),
396
+ 'model_info': self.model_registry.get(model_name, {})
397
+ }
398
+ else:
399
+ print(f"Inference failed: {response.get('message', 'Unknown error')}")
400
+ return None
401
+ except Exception as e:
402
+ print(f"Error during inference: {str(e)}")
403
+ return None
404
+
405
+ def close(self):
406
+ """Close WebSocket connection and cleanup resources."""
407
+ if not self._closing:
408
+ self._closing = True
409
+ if self.websocket and self._loop:
410
+ async def cleanup():
411
+ try:
412
+ # Clean up registries
413
+ with self.lock:
414
+ self.tensor_registry.clear()
415
+ self.model_registry.clear()
416
+ self.resource_monitor['vram_used'] = 0
417
+ self.resource_monitor['active_tensors'] = 0
418
+ self.resource_monitor['loaded_models'].clear()
419
+
420
+ # Notify server about cleanup
421
+ if self.connected:
422
+ try:
423
+ await self.websocket.send(json.dumps({
424
+ 'operation': 'cleanup',
425
+ 'type': 'full'
426
+ }))
427
+ except:
428
+ pass
429
+
430
+ await self.websocket.close()
431
+ except Exception as e:
432
+ print(f"Error during cleanup: {str(e)}")
433
+ finally:
434
+ self.connected = False
435
+
436
+ if self._loop.is_running():
437
+ self._loop.create_task(cleanup())
438
+ else:
439
+ asyncio.run(cleanup())
440
+
441
+ async def aclose(self):
442
+ """Asynchronously close WebSocket connection."""
443
+ if not self._closing:
444
+ self._closing = True
445
+ if self.websocket:
446
+ try:
447
+ await self.websocket.close()
448
+ except:
449
+ pass
450
+ finally:
451
+ self.connected = False
452
+
453
+ def __del__(self):
454
+ """Ensure cleanup on deletion."""
455
+ self.close()