Spaces:
Sleeping
Sleeping
Factor Studios
commited on
Upload 2 files
Browse files- ai.py +125 -28
- test_ai_integration.py +87 -29
ai.py
CHANGED
|
@@ -61,29 +61,60 @@ class AIAccelerator:
|
|
| 61 |
|
| 62 |
def _serialize_model_config(self, config: Any) -> dict:
|
| 63 |
"""Convert model config to a serializable format."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if hasattr(config, '__dict__'):
|
| 65 |
-
# Convert object attributes to dict
|
| 66 |
config_dict = {}
|
| 67 |
for key, value in config.__dict__.items():
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
elif isinstance(value, dict):
|
| 73 |
-
config_dict[key] = {k: self._serialize_model_config(v) for k, v in value.items()}
|
| 74 |
-
elif hasattr(value, '__dict__'):
|
| 75 |
config_dict[key] = self._serialize_model_config(value)
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 78 |
return config_dict
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
return
|
| 83 |
-
|
| 84 |
-
return config
|
| 85 |
-
else:
|
| 86 |
-
return str(config) # Fallback to string representation
|
| 87 |
|
| 88 |
def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
|
| 89 |
"""Store model state in WebSocket storage with proper serialization."""
|
|
@@ -573,6 +604,20 @@ class AIAccelerator:
|
|
| 573 |
total_ops = total_params * batch_size * ops_per_param
|
| 574 |
return (total_ops / inference_time) / 1e12 # Convert to TFLOPS
|
| 575 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
def load_model(self, model_id: str, model: Any, processor: Any):
|
| 577 |
"""Loads a model directly into WebSocket storage without CPU intermediary."""
|
| 578 |
try:
|
|
@@ -586,24 +631,76 @@ class AIAccelerator:
|
|
| 586 |
self.model_loaded = True
|
| 587 |
return
|
| 588 |
|
| 589 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
try:
|
|
|
|
|
|
|
| 591 |
model_info = {
|
| 592 |
"architecture": model.__class__.__name__ if model else "Unknown",
|
| 593 |
"processor": processor.__class__.__name__ if processor else "Unknown",
|
| 594 |
-
"config":
|
| 595 |
}
|
| 596 |
except Exception as e:
|
| 597 |
-
print(f"Warning: Error serializing model
|
| 598 |
-
model_info = {
|
|
|
|
|
|
|
|
|
|
| 599 |
|
| 600 |
-
#
|
| 601 |
-
|
| 602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
-
# Store model
|
| 605 |
-
if
|
| 606 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
|
| 608 |
# Map weight tensors directly to WebSocket storage
|
| 609 |
if model is not None and hasattr(model, "state_dict"):
|
|
|
|
| 61 |
|
| 62 |
def _serialize_model_config(self, config: Any) -> dict:
|
| 63 |
"""Convert model config to a serializable format."""
|
| 64 |
+
# Handle None case first
|
| 65 |
+
if config is None:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
# Handle Florence2LanguageConfig specifically
|
| 69 |
+
if config.__class__.__name__ == "Florence2LanguageConfig":
|
| 70 |
+
try:
|
| 71 |
+
return {
|
| 72 |
+
"type": "Florence2LanguageConfig",
|
| 73 |
+
"model_type": getattr(config, "model_type", ""),
|
| 74 |
+
"architectures": getattr(config, "architectures", []),
|
| 75 |
+
"hidden_size": getattr(config, "hidden_size", 0),
|
| 76 |
+
"num_attention_heads": getattr(config, "num_attention_heads", 0),
|
| 77 |
+
"num_hidden_layers": getattr(config, "num_hidden_layers", 0),
|
| 78 |
+
"intermediate_size": getattr(config, "intermediate_size", 0),
|
| 79 |
+
"max_position_embeddings": getattr(config, "max_position_embeddings", 0),
|
| 80 |
+
"layer_norm_eps": getattr(config, "layer_norm_eps", 1e-12),
|
| 81 |
+
"vocab_size": getattr(config, "vocab_size", 0)
|
| 82 |
+
}
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Warning: Error serializing Florence2LanguageConfig: {e}")
|
| 85 |
+
return {"type": "Florence2LanguageConfig", "error": str(e)}
|
| 86 |
+
|
| 87 |
+
# Handle standard types
|
| 88 |
+
if isinstance(config, (int, float, str, bool)):
|
| 89 |
+
return config
|
| 90 |
+
|
| 91 |
+
# Handle lists and tuples
|
| 92 |
+
if isinstance(config, (list, tuple)):
|
| 93 |
+
return [self._serialize_model_config(item) for item in config]
|
| 94 |
+
|
| 95 |
+
# Handle dictionaries
|
| 96 |
+
if isinstance(config, dict):
|
| 97 |
+
return {k: self._serialize_model_config(v) for k, v in config.items()}
|
| 98 |
+
|
| 99 |
+
# Handle objects with __dict__
|
| 100 |
if hasattr(config, '__dict__'):
|
|
|
|
| 101 |
config_dict = {}
|
| 102 |
for key, value in config.__dict__.items():
|
| 103 |
+
try:
|
| 104 |
+
# Skip private attributes
|
| 105 |
+
if key.startswith('_'):
|
| 106 |
+
continue
|
|
|
|
|
|
|
|
|
|
| 107 |
config_dict[key] = self._serialize_model_config(value)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"Warning: Error serializing attribute {key}: {e}")
|
| 110 |
+
config_dict[key] = str(value)
|
| 111 |
return config_dict
|
| 112 |
+
|
| 113 |
+
# Fallback: convert to string representation
|
| 114 |
+
try:
|
| 115 |
+
return str(config)
|
| 116 |
+
except Exception as e:
|
| 117 |
+
return f"<Unserializable object of type {type(config).__name__}: {str(e)}>"
|
|
|
|
|
|
|
| 118 |
|
| 119 |
def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
|
| 120 |
"""Store model state in WebSocket storage with proper serialization."""
|
|
|
|
| 604 |
total_ops = total_params * batch_size * ops_per_param
|
| 605 |
return (total_ops / inference_time) / 1e12 # Convert to TFLOPS
|
| 606 |
|
| 607 |
+
def _serialize_tensor(self, tensor: Any) -> np.ndarray:
|
| 608 |
+
"""Convert a PyTorch tensor to numpy array safely."""
|
| 609 |
+
try:
|
| 610 |
+
if hasattr(tensor, 'detach'):
|
| 611 |
+
tensor = tensor.detach()
|
| 612 |
+
if hasattr(tensor, 'cpu'):
|
| 613 |
+
tensor = tensor.cpu()
|
| 614 |
+
if hasattr(tensor, 'numpy'):
|
| 615 |
+
return tensor.numpy()
|
| 616 |
+
return np.array(tensor)
|
| 617 |
+
except Exception as e:
|
| 618 |
+
print(f"Warning: Error converting tensor to numpy: {e}")
|
| 619 |
+
return None
|
| 620 |
+
|
| 621 |
def load_model(self, model_id: str, model: Any, processor: Any):
|
| 622 |
"""Loads a model directly into WebSocket storage without CPU intermediary."""
|
| 623 |
try:
|
|
|
|
| 631 |
self.model_loaded = True
|
| 632 |
return
|
| 633 |
|
| 634 |
+
# Verify WebSocket connection first
|
| 635 |
+
if not self.storage or not self.storage.wait_for_connection():
|
| 636 |
+
raise RuntimeError("WebSocket connection not available")
|
| 637 |
+
|
| 638 |
+
# 1. Store model configuration
|
| 639 |
try:
|
| 640 |
+
config_dict = (self._serialize_model_config(model.config)
|
| 641 |
+
if hasattr(model, "config") else {})
|
| 642 |
model_info = {
|
| 643 |
"architecture": model.__class__.__name__ if model else "Unknown",
|
| 644 |
"processor": processor.__class__.__name__ if processor else "Unknown",
|
| 645 |
+
"config": config_dict
|
| 646 |
}
|
| 647 |
except Exception as e:
|
| 648 |
+
print(f"Warning: Error serializing model config: {e}")
|
| 649 |
+
model_info = {
|
| 650 |
+
"architecture": str(type(model).__name__),
|
| 651 |
+
"error": str(e)
|
| 652 |
+
}
|
| 653 |
|
| 654 |
+
# Store model info with retry
|
| 655 |
+
for attempt in range(3):
|
| 656 |
+
try:
|
| 657 |
+
if self.storage.store_state(f"models/{model_id}/info", "info", model_info):
|
| 658 |
+
break
|
| 659 |
+
print(f"Retrying model info storage, attempt {attempt + 1}")
|
| 660 |
+
time.sleep(1)
|
| 661 |
+
except Exception as e:
|
| 662 |
+
if attempt == 2:
|
| 663 |
+
raise RuntimeError(f"Failed to store model info: {e}")
|
| 664 |
|
| 665 |
+
# 2. Store model weights
|
| 666 |
+
if hasattr(model, "state_dict"):
|
| 667 |
+
weight_registry = {}
|
| 668 |
+
for name, param in model.state_dict().items():
|
| 669 |
+
# Convert tensor to numpy and store in chunks if needed
|
| 670 |
+
tensor_data = self._serialize_tensor(param)
|
| 671 |
+
if tensor_data is not None:
|
| 672 |
+
tensor_id = f"{model_id}/weights/{name}"
|
| 673 |
+
if tensor_data.nbytes > 1024*1024*1024: # If larger than 1GB
|
| 674 |
+
# Store large tensors in chunks
|
| 675 |
+
chunks = np.array_split(tensor_data,
|
| 676 |
+
max(1, tensor_data.nbytes // (512*1024*1024)))
|
| 677 |
+
chunk_ids = []
|
| 678 |
+
for i, chunk in enumerate(chunks):
|
| 679 |
+
chunk_id = f"{tensor_id}/chunk_{i}"
|
| 680 |
+
if self.storage.store_tensor(chunk_id, chunk):
|
| 681 |
+
chunk_ids.append(chunk_id)
|
| 682 |
+
weight_registry[name] = {
|
| 683 |
+
"type": "chunked",
|
| 684 |
+
"chunks": chunk_ids,
|
| 685 |
+
"shape": tensor_data.shape,
|
| 686 |
+
"dtype": str(tensor_data.dtype)
|
| 687 |
+
}
|
| 688 |
+
else:
|
| 689 |
+
# Store small tensors directly
|
| 690 |
+
if self.storage.store_tensor(tensor_id, tensor_data):
|
| 691 |
+
weight_registry[name] = {
|
| 692 |
+
"type": "direct",
|
| 693 |
+
"tensor_id": tensor_id,
|
| 694 |
+
"shape": tensor_data.shape,
|
| 695 |
+
"dtype": str(tensor_data.dtype)
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
# Store weight registry
|
| 699 |
+
self.storage.store_state(f"models/{model_id}/weights", "registry", weight_registry)
|
| 700 |
+
self.model_registry[model_id] = {
|
| 701 |
+
"weight_registry": weight_registry,
|
| 702 |
+
"websocket_mapped": True
|
| 703 |
+
}
|
| 704 |
|
| 705 |
# Map weight tensors directly to WebSocket storage
|
| 706 |
if model is not None and hasattr(model, "state_dict"):
|
test_ai_integration.py
CHANGED
|
@@ -33,32 +33,59 @@ def increase_file_limit():
|
|
| 33 |
|
| 34 |
# WebSocket connection manager with retry
|
| 35 |
@contextlib.contextmanager
|
| 36 |
-
def websocket_manager(max_retries=
|
| 37 |
storage = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
for attempt in range(max_retries):
|
| 39 |
try:
|
| 40 |
-
|
| 41 |
-
if storage.wait_for_connection(timeout=10.0):
|
| 42 |
logging.info("Successfully connected to GPU storage server")
|
| 43 |
break
|
| 44 |
else:
|
| 45 |
-
logging.warning(f"Connection attempt {attempt + 1} failed, retrying...")
|
| 46 |
-
if storage:
|
| 47 |
-
storage.close()
|
| 48 |
time.sleep(retry_delay)
|
| 49 |
except Exception as e:
|
|
|
|
| 50 |
logging.error(f"Connection attempt {attempt + 1} failed with error: {e}")
|
| 51 |
-
if storage:
|
| 52 |
-
storage.close()
|
| 53 |
-
if attempt == max_retries - 1:
|
| 54 |
-
raise RuntimeError(f"Could not connect to GPU storage server after {max_retries} attempts")
|
| 55 |
time.sleep(retry_delay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
try:
|
|
|
|
|
|
|
| 58 |
yield storage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
finally:
|
| 60 |
if storage:
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Cleanup handler
|
| 64 |
def cleanup_resources():
|
|
@@ -207,15 +234,33 @@ def test_ai_integration():
|
|
| 207 |
ai_accelerators = []
|
| 208 |
|
| 209 |
try:
|
| 210 |
-
#
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
with websocket_manager() as shared_storage:
|
| 214 |
-
if not shared_storage or not shared_storage.wait_for_connection():
|
| 215 |
-
raise RuntimeError("Could not establish WebSocket connection")
|
| 216 |
-
components['storage'] = shared_storage
|
| 217 |
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
# Initialize high-performance chip array with WebSocket storage
|
| 221 |
total_sms = 0
|
|
@@ -246,16 +291,29 @@ def test_ai_integration():
|
|
| 246 |
ai_accelerator.storage = shared_storage # Ensure storage is set
|
| 247 |
ai_accelerators.append(ai_accelerator)
|
| 248 |
|
| 249 |
-
# Verify
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
# Track total processing units
|
| 261 |
total_sms += chip.num_sms
|
|
|
|
| 33 |
|
| 34 |
# WebSocket connection manager with retry
|
| 35 |
@contextlib.contextmanager
|
| 36 |
+
def websocket_manager(max_retries=5, retry_delay=2, timeout=30.0):
|
| 37 |
storage = None
|
| 38 |
+
last_error = None
|
| 39 |
+
|
| 40 |
+
def try_connect():
|
| 41 |
+
nonlocal storage
|
| 42 |
+
if storage:
|
| 43 |
+
try:
|
| 44 |
+
storage.close()
|
| 45 |
+
except:
|
| 46 |
+
pass
|
| 47 |
+
storage = WebSocketGPUStorage()
|
| 48 |
+
return storage.wait_for_connection(timeout=timeout)
|
| 49 |
+
|
| 50 |
+
# Initial connection attempts
|
| 51 |
for attempt in range(max_retries):
|
| 52 |
try:
|
| 53 |
+
if try_connect():
|
|
|
|
| 54 |
logging.info("Successfully connected to GPU storage server")
|
| 55 |
break
|
| 56 |
else:
|
| 57 |
+
logging.warning(f"Connection attempt {attempt + 1} failed, retrying in {retry_delay}s...")
|
|
|
|
|
|
|
| 58 |
time.sleep(retry_delay)
|
| 59 |
except Exception as e:
|
| 60 |
+
last_error = str(e)
|
| 61 |
logging.error(f"Connection attempt {attempt + 1} failed with error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
time.sleep(retry_delay)
|
| 63 |
+
|
| 64 |
+
if attempt == max_retries - 1:
|
| 65 |
+
error_msg = f"Could not connect to GPU storage server after {max_retries} attempts"
|
| 66 |
+
if last_error:
|
| 67 |
+
error_msg += f". Last error: {last_error}"
|
| 68 |
+
raise RuntimeError(error_msg)
|
| 69 |
|
| 70 |
try:
|
| 71 |
+
# Set up keep-alive mechanism
|
| 72 |
+
storage.set_keep_alive(True)
|
| 73 |
yield storage
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logging.error(f"WebSocket operation failed: {e}")
|
| 76 |
+
# Try to reconnect once if operation fails
|
| 77 |
+
if try_connect():
|
| 78 |
+
logging.info("Successfully reconnected to GPU storage server")
|
| 79 |
+
yield storage
|
| 80 |
+
else:
|
| 81 |
+
raise
|
| 82 |
finally:
|
| 83 |
if storage:
|
| 84 |
+
try:
|
| 85 |
+
storage.set_keep_alive(False)
|
| 86 |
+
storage.close()
|
| 87 |
+
except:
|
| 88 |
+
pass
|
| 89 |
|
| 90 |
# Cleanup handler
|
| 91 |
def cleanup_resources():
|
|
|
|
| 234 |
ai_accelerators = []
|
| 235 |
|
| 236 |
try:
|
| 237 |
+
# Try to reuse existing connection with verification
|
| 238 |
+
shared_storage = None
|
| 239 |
+
max_connection_attempts = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
+
for attempt in range(max_connection_attempts):
|
| 242 |
+
try:
|
| 243 |
+
if (components['storage'] and
|
| 244 |
+
components['storage'].wait_for_connection(timeout=10.0)):
|
| 245 |
+
shared_storage = components['storage']
|
| 246 |
+
shared_storage.set_keep_alive(True) # Enable keep-alive
|
| 247 |
+
logging.info("Successfully reused existing WebSocket connection")
|
| 248 |
+
break
|
| 249 |
+
else:
|
| 250 |
+
logging.warning("Existing connection unavailable, creating new connection...")
|
| 251 |
+
with websocket_manager(timeout=30.0) as new_storage:
|
| 252 |
+
if new_storage and new_storage.wait_for_connection(timeout=10.0):
|
| 253 |
+
components['storage'] = new_storage
|
| 254 |
+
shared_storage = new_storage
|
| 255 |
+
shared_storage.set_keep_alive(True) # Enable keep-alive
|
| 256 |
+
logging.info("Successfully established new WebSocket connection")
|
| 257 |
+
break
|
| 258 |
+
except Exception as e:
|
| 259 |
+
logging.error(f"Connection attempt {attempt + 1} failed: {e}")
|
| 260 |
+
if attempt < max_connection_attempts - 1:
|
| 261 |
+
time.sleep(2)
|
| 262 |
+
continue
|
| 263 |
+
raise RuntimeError(f"Failed to establish WebSocket connection after {max_connection_attempts} attempts")
|
| 264 |
|
| 265 |
# Initialize high-performance chip array with WebSocket storage
|
| 266 |
total_sms = 0
|
|
|
|
| 291 |
ai_accelerator.storage = shared_storage # Ensure storage is set
|
| 292 |
ai_accelerators.append(ai_accelerator)
|
| 293 |
|
| 294 |
+
# Verify and potentially repair WebSocket connection
|
| 295 |
+
max_retry = 3
|
| 296 |
+
for retry in range(max_retry):
|
| 297 |
+
try:
|
| 298 |
+
if not shared_storage.wait_for_connection(timeout=5.0):
|
| 299 |
+
logging.warning(f"Connection check failed for chip {i}, attempt {retry + 1}")
|
| 300 |
+
shared_storage.reconnect() # Attempt to reconnect
|
| 301 |
+
time.sleep(1)
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
# Load model weights from WebSocket storage (no CPU transfer)
|
| 305 |
+
ai_accelerator.load_model(model_id, None, None) # Model already in WebSocket storage
|
| 306 |
+
logging.info(f"Successfully initialized chip {i} with model")
|
| 307 |
+
break
|
| 308 |
+
|
| 309 |
+
except Exception as e:
|
| 310 |
+
if retry < max_retry - 1:
|
| 311 |
+
logging.warning(f"Error initializing chip {i}, attempt {retry + 1}: {e}")
|
| 312 |
+
time.sleep(1)
|
| 313 |
+
continue
|
| 314 |
+
else:
|
| 315 |
+
logging.error(f"Failed to initialize chip {i} after {max_retry} attempts: {e}")
|
| 316 |
+
raise
|
| 317 |
|
| 318 |
# Track total processing units
|
| 319 |
total_sms += chip.num_sms
|