Factor Studios commited on
Commit
eea87c5
·
verified ·
1 Parent(s): e7fb62d

Upload 21 files

Browse files
Files changed (4) hide show
  1. ai.py +1 -1
  2. network_vram_server.py +45 -0
  3. test_ai_integration.py +17 -46
  4. websocket_storage.py +43 -14
ai.py CHANGED
@@ -174,7 +174,7 @@ class AIAccelerator:
174
  if isinstance(test_input, list):
175
  test_input = np.array(test_input, dtype=np.float32)
176
 
177
- test_result = self.tensor_core_array.matmul(test_input, test_input)
178
  if test_result is None or not isinstance(test_result, (np.ndarray, list)) or len(test_result) == 0:
179
  raise RuntimeError("Tensor core test computation failed")
180
 
 
174
  if isinstance(test_input, list):
175
  test_input = np.array(test_input, dtype=np.float32)
176
 
177
+ test_result = self.tensor_core_array.matmul(test_input.tolist(), test_input.tolist())
178
  if test_result is None or not isinstance(test_result, (np.ndarray, list)) or len(test_result) == 0:
179
  raise RuntimeError("Tensor core test computation failed")
180
 
network_vram_server.py CHANGED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import asyncio
3
+ import websockets
4
+ import json
5
+
6
+ class VRAMServer:
7
+ def __init__(self):
8
+ self.vram_state = {}
9
+
10
+ async def handler(self, websocket):
11
+ async for message in websocket:
12
+ try:
13
+ operation = json.loads(message)
14
+ op_type = operation.get("operation")
15
+
16
+ if op_type == "vram/state":
17
+ state_type = operation.get("type")
18
+ key = operation.get("key")
19
+
20
+ if state_type == "write":
21
+ data = operation.get("data")
22
+ self.vram_state[key] = data
23
+ await websocket.send(json.dumps({"status": "success", "message": "State stored"}))
24
+ elif state_type == "read":
25
+ data = self.vram_state.get(key)
26
+ if data is not None:
27
+ await websocket.send(json.dumps({"status": "success", "data": data}))
28
+ else:
29
+ await websocket.send(json.dumps({"status": "error", "message": "State not found"}))
30
+ else:
31
+ await websocket.send(json.dumps({"status": "error", "message": "Unknown state operation type"}))
32
+ else:
33
+ await websocket.send(json.dumps({"status": "error", "message": "Unknown operation"}))
34
+ except Exception as e:
35
+ await websocket.send(json.dumps({"status": "error", "message": str(e)}))
36
+
37
+ async def main():
38
+ server = VRAMServer()
39
+ async with websockets.serve(server.handler, "0.0.0.0", 8765):
40
+ await asyncio.Future()
41
+
42
+ if __name__ == "__main__":
43
+ asyncio.run(main())
44
+
45
+
test_ai_integration.py CHANGED
@@ -184,28 +184,15 @@ def test_ai_integration():
184
  model_size = sum(p.numel() * p.element_size() for p in model.parameters())
185
  print(f"Model size: {model_size / (1024**3):.2f} GB")
186
 
187
- # Upload model weights directly to WebSocket storage
188
- print("Uploading model weights to WebSocket storage...")
189
- for name, param in model.state_dict().items():
190
- # Convert tensor to numpy and upload
191
- weight_data = param.cpu().numpy()
192
- storage.store_tensor(f"model_weights/{model_id}/{name}", weight_data)
193
-
194
- # Store minimal model info without serializing the config
195
- storage.store_state(f"models/{model_id}", "info", {
196
- "name": model_id,
197
- "size_bytes": model_size,
198
- "num_parameters": sum(p.numel() for p in model.parameters()),
199
- "weight_keys": list(model.state_dict().keys())
200
- })
201
-
202
- # Set model reference without serializing the full model
203
- ai_accelerator_for_loading.model_refs[model_id] = {
204
- "weight_prefix": f"model_weights/{model_id}",
205
- "size": model_size
206
- }
207
 
208
- print(f"Model weights uploaded successfully to WebSocket storage")
209
  assert ai_accelerator_for_loading.has_model(model_id), "Model not found in WebSocket storage after loading."
210
 
211
  # Store model parameters in components dict
@@ -254,6 +241,7 @@ def test_ai_integration():
254
  if (components['storage'] and
255
  components['storage'].wait_for_connection(timeout=10.0)):
256
  shared_storage = components['storage']
 
257
  logging.info("Successfully reused existing WebSocket connection")
258
  break
259
  else:
@@ -262,6 +250,7 @@ def test_ai_integration():
262
  if new_storage and new_storage.wait_for_connection(timeout=10.0):
263
  components['storage'] = new_storage
264
  shared_storage = new_storage
 
265
  logging.info("Successfully established new WebSocket connection")
266
  break
267
  except Exception as e:
@@ -384,31 +373,13 @@ def test_ai_integration():
384
  # Load image section from WebSocket storage
385
  tensor_id = f"input_image/{img_name}"
386
 
387
- # Load weights from WebSocket storage and run inference
388
- try:
389
- # Get model info
390
- model_info = accelerator.storage.load_state(f"models/{model_id}", "info")
391
- weight_prefix = f"model_weights/{model_id}"
392
-
393
- # Load input tensor
394
- input_tensor = accelerator.storage.load_tensor(tensor_id)
395
-
396
- # Run inference with direct weight access
397
- result = accelerator.inference_with_ws_weights(
398
- model_id=model_id,
399
- input_tensor=input_tensor,
400
- weight_prefix=weight_prefix
401
- )
402
-
403
- # Store result in WebSocket storage
404
- if result is not None:
405
- storage.store_tensor(f"results/chip_{i}/{img_name}", result)
406
- results.append(result)
407
- else:
408
- logging.error(f"Inference returned None for chip {i}")
409
- except Exception as e:
410
- logging.error(f"Inference failed on chip {i}: {str(e)}")
411
- raise
412
 
413
  elapsed = time.time() - start_time
414
 
 
184
  model_size = sum(p.numel() * p.element_size() for p in model.parameters())
185
  print(f"Model size: {model_size / (1024**3):.2f} GB")
186
 
187
+ # Store model in WebSocket storage with size information
188
+ # Load model directly using AIAccelerator's load_model method
189
+ ai_accelerator_for_loading.load_model(
190
+ model_id=model_id,
191
+ model=model,
192
+ processor=processor
193
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ print(f"Model '{model_id}' loaded successfully to WebSocket storage.")
196
  assert ai_accelerator_for_loading.has_model(model_id), "Model not found in WebSocket storage after loading."
197
 
198
  # Store model parameters in components dict
 
241
  if (components['storage'] and
242
  components['storage'].wait_for_connection(timeout=10.0)):
243
  shared_storage = components['storage']
244
+ shared_storage.set_keep_alive(True) # Enable keep-alive
245
  logging.info("Successfully reused existing WebSocket connection")
246
  break
247
  else:
 
250
  if new_storage and new_storage.wait_for_connection(timeout=10.0):
251
  components['storage'] = new_storage
252
  shared_storage = new_storage
253
+ shared_storage.set_keep_alive(True) # Enable keep-alive
254
  logging.info("Successfully established new WebSocket connection")
255
  break
256
  except Exception as e:
 
373
  # Load image section from WebSocket storage
374
  tensor_id = f"input_image/{img_name}"
375
 
376
+ # Run inference using WebSocket-stored weights
377
+ result = accelerator.inference(model_id, tensor_id)
378
+
379
+ # Store result in WebSocket storage
380
+ if result is not None:
381
+ storage.store_tensor(f"results/chip_{i}/{img_name}", result)
382
+ results.append(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  elapsed = time.time() - start_time
385
 
websocket_storage.py CHANGED
@@ -13,7 +13,7 @@ class WebSocketGPUStorage:
13
  _instance = None
14
  _lock = threading.Lock()
15
 
16
- def __new__(cls, url: str = "wss://factorst-wbs1.hf.space/ws"):
17
  with cls._lock:
18
  if cls._instance is None:
19
  cls._instance = super().__new__(cls)
@@ -49,7 +49,7 @@ class WebSocketGPUStorage:
49
  self.ws_thread.start()
50
  self.initialized = True
51
 
52
- def __init__(self, url: str = "wss://factorst-wbs1.hf.space/ws"):
53
  """This will actually just return the singleton instance"""
54
  pass
55
 
@@ -230,18 +230,33 @@ class WebSocketGPUStorage:
230
 
231
  def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
232
  try:
 
 
 
 
233
  operation = {
234
- 'operation': 'state',
235
- 'type': 'save',
236
- 'component': component,
237
- 'state_id': state_id,
238
  'data': state_data,
239
- 'timestamp': time.time()
 
 
 
 
 
240
  }
241
 
242
  response = self._send_operation(operation)
243
  if response.get('status') != 'success':
244
- print(f"Failed to store state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
 
 
 
 
 
 
 
245
  return False
246
  return True
247
  except Exception as e:
@@ -250,11 +265,18 @@ class WebSocketGPUStorage:
250
 
251
  def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
252
  try:
 
 
 
253
  operation = {
254
- 'operation': 'state',
255
- 'type': 'load',
256
- 'component': component,
257
- 'state_id': state_id
 
 
 
 
258
  }
259
 
260
  response = self._send_operation(operation)
@@ -265,7 +287,14 @@ class WebSocketGPUStorage:
265
  return None
266
  return data
267
  else:
268
- print(f"Failed to load state for {component}/{state_id}: {response.get('message', 'Unknown error')}")
 
 
 
 
 
 
 
269
  return None
270
  except Exception as e:
271
  print(f"Error loading state for {component}/{state_id}: {str(e)}")
@@ -290,7 +319,7 @@ class WebSocketGPUStorage:
290
 
291
  operation = {
292
  'operation': 'model',
293
- 'type': 'load',
294
  'model_name': model_name,
295
  'model_hash': model_hash,
296
  'model_data': model_data
 
13
  _instance = None
14
  _lock = threading.Lock()
15
 
16
+ def __new__(cls, url: str = "wss://8765-ie635qf2d79t3i1wada8c-fc2963e7.manusvm.computer/ws"):
17
  with cls._lock:
18
  if cls._instance is None:
19
  cls._instance = super().__new__(cls)
 
49
  self.ws_thread.start()
50
  self.initialized = True
51
 
52
+ def __init__(self, url: str = "wss://8765-ie635qf2d79t3i1wada8c-fc2963e7.manusvm.computer/ws"):
53
  """This will actually just return the singleton instance"""
54
  pass
55
 
 
230
 
231
  def store_state(self, component: str, state_id: str, state_data: Dict[str, Any]) -> bool:
232
  try:
233
+ # Use memory-based state storage instead of file-based
234
+ state_key = f"{component}_{state_id}"
235
+
236
+ # Store state in memory
237
  operation = {
238
+ 'operation': 'vram/state',
239
+ 'type': 'write',
240
+ 'key': state_key,
 
241
  'data': state_data,
242
+ 'timestamp': time.time(),
243
+ 'metadata': {
244
+ 'component': component,
245
+ 'state_id': state_id,
246
+ 'storage_type': 'memory'
247
+ }
248
  }
249
 
250
  response = self._send_operation(operation)
251
  if response.get('status') != 'success':
252
+ error_msg = response.get('message', 'Unknown error')
253
+ if 'Permission denied' in error_msg:
254
+ # Try memory-only fallback
255
+ operation['storage_type'] = 'memory_only'
256
+ response = self._send_operation(operation)
257
+ if response.get('status') == 'success':
258
+ return True
259
+ print(f"Failed to store state for {component}/{state_id}: {error_msg}")
260
  return False
261
  return True
262
  except Exception as e:
 
265
 
266
  def load_state(self, component: str, state_id: str) -> Optional[Dict[str, Any]]:
267
  try:
268
+ state_key = f"{component}_{state_id}"
269
+
270
+ # Try loading from memory first
271
  operation = {
272
+ 'operation': 'vram/state',
273
+ 'type': 'read',
274
+ 'key': state_key,
275
+ 'metadata': {
276
+ 'component': component,
277
+ 'state_id': state_id,
278
+ 'storage_type': 'memory'
279
+ }
280
  }
281
 
282
  response = self._send_operation(operation)
 
287
  return None
288
  return data
289
  else:
290
+ error_msg = response.get('message', 'Unknown error')
291
+ if 'Permission denied' in error_msg:
292
+ # Try memory-only fallback
293
+ operation['storage_type'] = 'memory_only'
294
+ response = self._send_operation(operation)
295
+ if response.get('status') == 'success':
296
+ return response.get('data')
297
+ print(f"Failed to load state for {component}/{state_id}: {error_msg}")
298
  return None
299
  except Exception as e:
300
  print(f"Error loading state for {component}/{state_id}: {str(e)}")
 
319
 
320
  operation = {
321
  'operation': 'model',
322
+ 'type': 'read',
323
  'model_name': model_name,
324
  'model_hash': model_hash,
325
  'model_data': model_data