Factor Studios commited on
Commit
65881ed
·
verified ·
1 Parent(s): cc47889

Upload 2 files

Browse files
Files changed (2) hide show
  1. ai.py +125 -28
  2. test_ai_integration.py +87 -29
ai.py CHANGED
@@ -61,29 +61,60 @@ class AIAccelerator:
61
 
62
  def _serialize_model_config(self, config: Any) -> dict:
63
  """Convert model config to a serializable format."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  if hasattr(config, '__dict__'):
65
- # Convert object attributes to dict
66
  config_dict = {}
67
  for key, value in config.__dict__.items():
68
- if isinstance(value, (int, float, str, bool, type(None))):
69
- config_dict[key] = value
70
- elif isinstance(value, (list, tuple)):
71
- config_dict[key] = [self._serialize_model_config(item) for item in value]
72
- elif isinstance(value, dict):
73
- config_dict[key] = {k: self._serialize_model_config(v) for k, v in value.items()}
74
- elif hasattr(value, '__dict__'):
75
  config_dict[key] = self._serialize_model_config(value)
76
- else:
77
- config_dict[key] = str(value) # Fallback to string representation
 
78
  return config_dict
79
- elif isinstance(config, (list, tuple)):
80
- return [self._serialize_model_config(item) for item in config]
81
- elif isinstance(config, dict):
82
- return {k: self._serialize_model_config(v) for k, v in config.items()}
83
- elif isinstance(config, (int, float, str, bool, type(None))):
84
- return config
85
- else:
86
- return str(config) # Fallback to string representation
87
 
88
  def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
89
  """Store model state in WebSocket storage with proper serialization."""
@@ -573,6 +604,20 @@ class AIAccelerator:
573
  total_ops = total_params * batch_size * ops_per_param
574
  return (total_ops / inference_time) / 1e12 # Convert to TFLOPS
575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  def load_model(self, model_id: str, model: Any, processor: Any):
577
  """Loads a model directly into WebSocket storage without CPU intermediary."""
578
  try:
@@ -586,24 +631,76 @@ class AIAccelerator:
586
  self.model_loaded = True
587
  return
588
 
589
- # Extract model metadata
 
 
 
 
590
  try:
 
 
591
  model_info = {
592
  "architecture": model.__class__.__name__ if model else "Unknown",
593
  "processor": processor.__class__.__name__ if processor else "Unknown",
594
- "config": self._serialize_model_config(model.config) if hasattr(model, "config") else {}
595
  }
596
  except Exception as e:
597
- print(f"Warning: Error serializing model metadata: {e}")
598
- model_info = {"error": str(e)}
 
 
 
599
 
600
- # Verify WebSocket connection
601
- if not self.storage or not self.storage.wait_for_connection():
602
- raise RuntimeError("WebSocket connection not available")
 
 
 
 
 
 
 
603
 
604
- # Store model state in WebSocket storage
605
- if not self.storage.store_state(f"models/{model_id}/info", "info", model_info):
606
- raise RuntimeError("Failed to store model info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
 
608
  # Map weight tensors directly to WebSocket storage
609
  if model is not None and hasattr(model, "state_dict"):
 
61
 
62
  def _serialize_model_config(self, config: Any) -> dict:
63
  """Convert model config to a serializable format."""
64
+ # Handle None case first
65
+ if config is None:
66
+ return None
67
+
68
+ # Handle Florence2LanguageConfig specifically
69
+ if config.__class__.__name__ == "Florence2LanguageConfig":
70
+ try:
71
+ return {
72
+ "type": "Florence2LanguageConfig",
73
+ "model_type": getattr(config, "model_type", ""),
74
+ "architectures": getattr(config, "architectures", []),
75
+ "hidden_size": getattr(config, "hidden_size", 0),
76
+ "num_attention_heads": getattr(config, "num_attention_heads", 0),
77
+ "num_hidden_layers": getattr(config, "num_hidden_layers", 0),
78
+ "intermediate_size": getattr(config, "intermediate_size", 0),
79
+ "max_position_embeddings": getattr(config, "max_position_embeddings", 0),
80
+ "layer_norm_eps": getattr(config, "layer_norm_eps", 1e-12),
81
+ "vocab_size": getattr(config, "vocab_size", 0)
82
+ }
83
+ except Exception as e:
84
+ print(f"Warning: Error serializing Florence2LanguageConfig: {e}")
85
+ return {"type": "Florence2LanguageConfig", "error": str(e)}
86
+
87
+ # Handle standard types
88
+ if isinstance(config, (int, float, str, bool)):
89
+ return config
90
+
91
+ # Handle lists and tuples
92
+ if isinstance(config, (list, tuple)):
93
+ return [self._serialize_model_config(item) for item in config]
94
+
95
+ # Handle dictionaries
96
+ if isinstance(config, dict):
97
+ return {k: self._serialize_model_config(v) for k, v in config.items()}
98
+
99
+ # Handle objects with __dict__
100
  if hasattr(config, '__dict__'):
 
101
  config_dict = {}
102
  for key, value in config.__dict__.items():
103
+ try:
104
+ # Skip private attributes
105
+ if key.startswith('_'):
106
+ continue
 
 
 
107
  config_dict[key] = self._serialize_model_config(value)
108
+ except Exception as e:
109
+ print(f"Warning: Error serializing attribute {key}: {e}")
110
+ config_dict[key] = str(value)
111
  return config_dict
112
+
113
+ # Fallback: convert to string representation
114
+ try:
115
+ return str(config)
116
+ except Exception as e:
117
+ return f"<Unserializable object of type {type(config).__name__}: {str(e)}>"
 
 
118
 
119
  def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
120
  """Store model state in WebSocket storage with proper serialization."""
 
604
  total_ops = total_params * batch_size * ops_per_param
605
  return (total_ops / inference_time) / 1e12 # Convert to TFLOPS
606
 
607
+ def _serialize_tensor(self, tensor: Any) -> np.ndarray:
608
+ """Convert a PyTorch tensor to numpy array safely."""
609
+ try:
610
+ if hasattr(tensor, 'detach'):
611
+ tensor = tensor.detach()
612
+ if hasattr(tensor, 'cpu'):
613
+ tensor = tensor.cpu()
614
+ if hasattr(tensor, 'numpy'):
615
+ return tensor.numpy()
616
+ return np.array(tensor)
617
+ except Exception as e:
618
+ print(f"Warning: Error converting tensor to numpy: {e}")
619
+ return None
620
+
621
  def load_model(self, model_id: str, model: Any, processor: Any):
622
  """Loads a model directly into WebSocket storage without CPU intermediary."""
623
  try:
 
631
  self.model_loaded = True
632
  return
633
 
634
+ # Verify WebSocket connection first
635
+ if not self.storage or not self.storage.wait_for_connection():
636
+ raise RuntimeError("WebSocket connection not available")
637
+
638
+ # 1. Store model configuration
639
  try:
640
+ config_dict = (self._serialize_model_config(model.config)
641
+ if hasattr(model, "config") else {})
642
  model_info = {
643
  "architecture": model.__class__.__name__ if model else "Unknown",
644
  "processor": processor.__class__.__name__ if processor else "Unknown",
645
+ "config": config_dict
646
  }
647
  except Exception as e:
648
+ print(f"Warning: Error serializing model config: {e}")
649
+ model_info = {
650
+ "architecture": str(type(model).__name__),
651
+ "error": str(e)
652
+ }
653
 
654
+ # Store model info with retry
655
+ for attempt in range(3):
656
+ try:
657
+ if self.storage.store_state(f"models/{model_id}/info", "info", model_info):
658
+ break
659
+ print(f"Retrying model info storage, attempt {attempt + 1}")
660
+ time.sleep(1)
661
+ except Exception as e:
662
+ if attempt == 2:
663
+ raise RuntimeError(f"Failed to store model info: {e}")
664
 
665
+ # 2. Store model weights
666
+ if hasattr(model, "state_dict"):
667
+ weight_registry = {}
668
+ for name, param in model.state_dict().items():
669
+ # Convert tensor to numpy and store in chunks if needed
670
+ tensor_data = self._serialize_tensor(param)
671
+ if tensor_data is not None:
672
+ tensor_id = f"{model_id}/weights/{name}"
673
+ if tensor_data.nbytes > 1024*1024*1024: # If larger than 1GB
674
+ # Store large tensors in chunks
675
+ chunks = np.array_split(tensor_data,
676
+ max(1, tensor_data.nbytes // (512*1024*1024)))
677
+ chunk_ids = []
678
+ for i, chunk in enumerate(chunks):
679
+ chunk_id = f"{tensor_id}/chunk_{i}"
680
+ if self.storage.store_tensor(chunk_id, chunk):
681
+ chunk_ids.append(chunk_id)
682
+ weight_registry[name] = {
683
+ "type": "chunked",
684
+ "chunks": chunk_ids,
685
+ "shape": tensor_data.shape,
686
+ "dtype": str(tensor_data.dtype)
687
+ }
688
+ else:
689
+ # Store small tensors directly
690
+ if self.storage.store_tensor(tensor_id, tensor_data):
691
+ weight_registry[name] = {
692
+ "type": "direct",
693
+ "tensor_id": tensor_id,
694
+ "shape": tensor_data.shape,
695
+ "dtype": str(tensor_data.dtype)
696
+ }
697
+
698
+ # Store weight registry
699
+ self.storage.store_state(f"models/{model_id}/weights", "registry", weight_registry)
700
+ self.model_registry[model_id] = {
701
+ "weight_registry": weight_registry,
702
+ "websocket_mapped": True
703
+ }
704
 
705
  # Map weight tensors directly to WebSocket storage
706
  if model is not None and hasattr(model, "state_dict"):
test_ai_integration.py CHANGED
@@ -33,32 +33,59 @@ def increase_file_limit():
33
 
34
  # WebSocket connection manager with retry
35
  @contextlib.contextmanager
36
- def websocket_manager(max_retries=3, retry_delay=2):
37
  storage = None
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  for attempt in range(max_retries):
39
  try:
40
- storage = WebSocketGPUStorage()
41
- if storage.wait_for_connection(timeout=10.0):
42
  logging.info("Successfully connected to GPU storage server")
43
  break
44
  else:
45
- logging.warning(f"Connection attempt {attempt + 1} failed, retrying...")
46
- if storage:
47
- storage.close()
48
  time.sleep(retry_delay)
49
  except Exception as e:
 
50
  logging.error(f"Connection attempt {attempt + 1} failed with error: {e}")
51
- if storage:
52
- storage.close()
53
- if attempt == max_retries - 1:
54
- raise RuntimeError(f"Could not connect to GPU storage server after {max_retries} attempts")
55
  time.sleep(retry_delay)
 
 
 
 
 
 
56
 
57
  try:
 
 
58
  yield storage
 
 
 
 
 
 
 
 
59
  finally:
60
  if storage:
61
- storage.close() # Ensure connection is closed
 
 
 
 
62
 
63
  # Cleanup handler
64
  def cleanup_resources():
@@ -207,15 +234,33 @@ def test_ai_integration():
207
  ai_accelerators = []
208
 
209
  try:
210
- # Reuse the existing storage connection from the previous test
211
- if not components['storage'] or not components['storage'].wait_for_connection():
212
- # If connection lost, try to reconnect
213
- with websocket_manager() as shared_storage:
214
- if not shared_storage or not shared_storage.wait_for_connection():
215
- raise RuntimeError("Could not establish WebSocket connection")
216
- components['storage'] = shared_storage
217
 
218
- shared_storage = components['storage']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  # Initialize high-performance chip array with WebSocket storage
221
  total_sms = 0
@@ -246,16 +291,29 @@ def test_ai_integration():
246
  ai_accelerator.storage = shared_storage # Ensure storage is set
247
  ai_accelerators.append(ai_accelerator)
248
 
249
- # Verify WebSocket connection before loading model
250
- if not shared_storage.wait_for_connection():
251
- raise RuntimeError(f"Lost WebSocket connection during chip {i} initialization")
252
-
253
- # Load model weights from WebSocket storage (no CPU transfer)
254
- try:
255
- ai_accelerator.load_model(model_id, None, None) # Model already in WebSocket storage
256
- except Exception as e:
257
- print(f"Warning: Failed to load model on chip {i}: {e}")
258
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  # Track total processing units
261
  total_sms += chip.num_sms
 
33
 
34
  # WebSocket connection manager with retry
35
  @contextlib.contextmanager
36
+ def websocket_manager(max_retries=5, retry_delay=2, timeout=30.0):
37
  storage = None
38
+ last_error = None
39
+
40
+ def try_connect():
41
+ nonlocal storage
42
+ if storage:
43
+ try:
44
+ storage.close()
45
+ except:
46
+ pass
47
+ storage = WebSocketGPUStorage()
48
+ return storage.wait_for_connection(timeout=timeout)
49
+
50
+ # Initial connection attempts
51
  for attempt in range(max_retries):
52
  try:
53
+ if try_connect():
 
54
  logging.info("Successfully connected to GPU storage server")
55
  break
56
  else:
57
+ logging.warning(f"Connection attempt {attempt + 1} failed, retrying in {retry_delay}s...")
 
 
58
  time.sleep(retry_delay)
59
  except Exception as e:
60
+ last_error = str(e)
61
  logging.error(f"Connection attempt {attempt + 1} failed with error: {e}")
 
 
 
 
62
  time.sleep(retry_delay)
63
+
64
+ if attempt == max_retries - 1:
65
+ error_msg = f"Could not connect to GPU storage server after {max_retries} attempts"
66
+ if last_error:
67
+ error_msg += f". Last error: {last_error}"
68
+ raise RuntimeError(error_msg)
69
 
70
  try:
71
+ # Set up keep-alive mechanism
72
+ storage.set_keep_alive(True)
73
  yield storage
74
+ except Exception as e:
75
+ logging.error(f"WebSocket operation failed: {e}")
76
+ # Try to reconnect once if operation fails
77
+ if try_connect():
78
+ logging.info("Successfully reconnected to GPU storage server")
79
+ yield storage
80
+ else:
81
+ raise
82
  finally:
83
  if storage:
84
+ try:
85
+ storage.set_keep_alive(False)
86
+ storage.close()
87
+ except:
88
+ pass
89
 
90
  # Cleanup handler
91
  def cleanup_resources():
 
234
  ai_accelerators = []
235
 
236
  try:
237
+ # Try to reuse existing connection with verification
238
+ shared_storage = None
239
+ max_connection_attempts = 3
 
 
 
 
240
 
241
+ for attempt in range(max_connection_attempts):
242
+ try:
243
+ if (components['storage'] and
244
+ components['storage'].wait_for_connection(timeout=10.0)):
245
+ shared_storage = components['storage']
246
+ shared_storage.set_keep_alive(True) # Enable keep-alive
247
+ logging.info("Successfully reused existing WebSocket connection")
248
+ break
249
+ else:
250
+ logging.warning("Existing connection unavailable, creating new connection...")
251
+ with websocket_manager(timeout=30.0) as new_storage:
252
+ if new_storage and new_storage.wait_for_connection(timeout=10.0):
253
+ components['storage'] = new_storage
254
+ shared_storage = new_storage
255
+ shared_storage.set_keep_alive(True) # Enable keep-alive
256
+ logging.info("Successfully established new WebSocket connection")
257
+ break
258
+ except Exception as e:
259
+ logging.error(f"Connection attempt {attempt + 1} failed: {e}")
260
+ if attempt < max_connection_attempts - 1:
261
+ time.sleep(2)
262
+ continue
263
+ raise RuntimeError(f"Failed to establish WebSocket connection after {max_connection_attempts} attempts")
264
 
265
  # Initialize high-performance chip array with WebSocket storage
266
  total_sms = 0
 
291
  ai_accelerator.storage = shared_storage # Ensure storage is set
292
  ai_accelerators.append(ai_accelerator)
293
 
294
+ # Verify and potentially repair WebSocket connection
295
+ max_retry = 3
296
+ for retry in range(max_retry):
297
+ try:
298
+ if not shared_storage.wait_for_connection(timeout=5.0):
299
+ logging.warning(f"Connection check failed for chip {i}, attempt {retry + 1}")
300
+ shared_storage.reconnect() # Attempt to reconnect
301
+ time.sleep(1)
302
+ continue
303
+
304
+ # Load model weights from WebSocket storage (no CPU transfer)
305
+ ai_accelerator.load_model(model_id, None, None) # Model already in WebSocket storage
306
+ logging.info(f"Successfully initialized chip {i} with model")
307
+ break
308
+
309
+ except Exception as e:
310
+ if retry < max_retry - 1:
311
+ logging.warning(f"Error initializing chip {i}, attempt {retry + 1}: {e}")
312
+ time.sleep(1)
313
+ continue
314
+ else:
315
+ logging.error(f"Failed to initialize chip {i} after {max_retry} attempts: {e}")
316
+ raise
317
 
318
  # Track total processing units
319
  total_sms += chip.num_sms