Spaces:
Sleeping
Sleeping
Factor Studios
commited on
Update ai.py
Browse files
ai.py
CHANGED
|
@@ -49,6 +49,75 @@ class AIAccelerator:
|
|
| 49 |
)
|
| 50 |
self.tensor_cores_initialized = False
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def initialize_tensor_cores(self):
|
| 53 |
"""Initialize tensor cores and verify they're ready for computation"""
|
| 54 |
if self.tensor_cores_initialized:
|
|
@@ -499,41 +568,22 @@ class AIAccelerator:
|
|
| 499 |
total_ops = total_params * batch_size * ops_per_param
|
| 500 |
return (total_ops / inference_time) / 1e12 # Convert to TFLOPS
|
| 501 |
|
| 502 |
-
def load_model(self, model_id: str, model: Any
|
| 503 |
"""Loads a model directly into WebSocket storage without CPU intermediary."""
|
| 504 |
try:
|
| 505 |
-
#
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
"architecture": model.__class__.__name__,
|
| 512 |
-
"processor": processor.__class__.__name__ if processor else "None",
|
| 513 |
-
"config": model.config.to_dict() if hasattr(model, "config") else {},
|
| 514 |
-
"model_config": {
|
| 515 |
-
k: str(v) if not isinstance(v, (bool, int, float, str, list, dict)) else v
|
| 516 |
-
for k, v in config.items()
|
| 517 |
-
} # Ensure config is JSON serializable
|
| 518 |
-
}
|
| 519 |
-
else:
|
| 520 |
-
# Use provided config for zero-copy mode
|
| 521 |
-
model_info = {
|
| 522 |
-
"architecture": "ZeroCopy",
|
| 523 |
-
"processor": "None",
|
| 524 |
-
"config": {},
|
| 525 |
-
"model_config": {
|
| 526 |
-
k: str(v) if not isinstance(v, (bool, int, float, str, list, dict)) else v
|
| 527 |
-
for k, v in config.items()
|
| 528 |
-
} # Ensure config is JSON serializable
|
| 529 |
-
}
|
| 530 |
|
| 531 |
# Store model state in WebSocket storage
|
| 532 |
self.storage.store_state(f"models/{model_id}", "info", model_info)
|
| 533 |
|
| 534 |
-
# Map weight tensors directly to WebSocket storage
|
| 535 |
-
|
| 536 |
-
|
| 537 |
|
| 538 |
for name, param in model.state_dict().items():
|
| 539 |
tensor_id = f"{model_id}/weights/{name}"
|
|
@@ -554,37 +604,12 @@ class AIAccelerator:
|
|
| 554 |
self.storage.store_state(f"models/{model_id}", "state", model)
|
| 555 |
self.model_registry[model_id] = tensor_id
|
| 556 |
|
| 557 |
-
|
| 558 |
-
if processor is not None:
|
| 559 |
-
self.tokenizer_registry[model_id] = processor
|
| 560 |
-
|
| 561 |
self.model_loaded = True
|
| 562 |
print(f"Model '{model_id}' loaded into WebSocket storage")
|
| 563 |
-
|
| 564 |
-
# Additional setup for zero-copy mode
|
| 565 |
-
if model_config and model_config.get("zero_copy"):
|
| 566 |
-
# Register empty tensors for zero-copy mode
|
| 567 |
-
tensor_id = f"{model_id}/zero_copy"
|
| 568 |
-
self.model_registry[model_id] = {
|
| 569 |
-
"mode": "zero_copy",
|
| 570 |
-
"config": model_config,
|
| 571 |
-
"tensor_id": tensor_id
|
| 572 |
-
}
|
| 573 |
-
|
| 574 |
-
return True
|
| 575 |
-
|
| 576 |
except Exception as e:
|
| 577 |
print(f"Error loading model into WebSocket storage: {str(e)}")
|
| 578 |
-
|
| 579 |
-
print("Attempting to serialize config differently...")
|
| 580 |
-
try:
|
| 581 |
-
# Try again with string conversion for non-serializable types
|
| 582 |
-
if model_config:
|
| 583 |
-
model_config = {k: str(v) for k, v in model_config.items()}
|
| 584 |
-
return self.load_model(model_id, model, processor, model_config)
|
| 585 |
-
except Exception as e2:
|
| 586 |
-
print(f"Second attempt also failed: {e2}")
|
| 587 |
-
return False
|
| 588 |
|
| 589 |
def has_model(self, model_id: str) -> bool:
|
| 590 |
"""Checks if a model is loaded in the accelerator's registry."""
|
|
|
|
| 49 |
)
|
| 50 |
self.tensor_cores_initialized = False
|
| 51 |
|
| 52 |
+
# Initialize model and tensor tracking
|
| 53 |
+
self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
|
| 54 |
+
self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
|
| 55 |
+
self.resource_monitor = {
|
| 56 |
+
'vram_used': 0,
|
| 57 |
+
'active_tensors': 0,
|
| 58 |
+
'loaded_models': set()
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def _serialize_model_config(self, config: Any) -> dict:
|
| 62 |
+
"""Convert model config to a serializable format."""
|
| 63 |
+
if hasattr(config, '__dict__'):
|
| 64 |
+
# Convert object attributes to dict
|
| 65 |
+
config_dict = {}
|
| 66 |
+
for key, value in config.__dict__.items():
|
| 67 |
+
if isinstance(value, (int, float, str, bool, type(None))):
|
| 68 |
+
config_dict[key] = value
|
| 69 |
+
elif isinstance(value, (list, tuple)):
|
| 70 |
+
config_dict[key] = [self._serialize_model_config(item) for item in value]
|
| 71 |
+
elif isinstance(value, dict):
|
| 72 |
+
config_dict[key] = {k: self._serialize_model_config(v) for k, v in value.items()}
|
| 73 |
+
elif hasattr(value, '__dict__'):
|
| 74 |
+
config_dict[key] = self._serialize_model_config(value)
|
| 75 |
+
else:
|
| 76 |
+
config_dict[key] = str(value) # Fallback to string representation
|
| 77 |
+
return config_dict
|
| 78 |
+
elif isinstance(config, (list, tuple)):
|
| 79 |
+
return [self._serialize_model_config(item) for item in config]
|
| 80 |
+
elif isinstance(config, dict):
|
| 81 |
+
return {k: self._serialize_model_config(v) for k, v in config.items()}
|
| 82 |
+
elif isinstance(config, (int, float, str, bool, type(None))):
|
| 83 |
+
return config
|
| 84 |
+
else:
|
| 85 |
+
return str(config) # Fallback to string representation
|
| 86 |
+
|
| 87 |
+
def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
|
| 88 |
+
"""Store model state in WebSocket storage with proper serialization."""
|
| 89 |
+
try:
|
| 90 |
+
# Convert any non-serializable parts of model_info
|
| 91 |
+
serializable_info = self._serialize_model_config(model_info)
|
| 92 |
+
|
| 93 |
+
# Store in model registry
|
| 94 |
+
self.model_registry[model_name] = serializable_info
|
| 95 |
+
|
| 96 |
+
# Save to storage
|
| 97 |
+
if self.storage:
|
| 98 |
+
# Store model info
|
| 99 |
+
info_success = self.storage.store_state(
|
| 100 |
+
"models",
|
| 101 |
+
f"{model_name}/info",
|
| 102 |
+
serializable_info
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Store model state
|
| 106 |
+
state_success = self.storage.store_state(
|
| 107 |
+
"models",
|
| 108 |
+
f"{model_name}/state",
|
| 109 |
+
{"loaded": True, "timestamp": time.time()}
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if info_success and state_success:
|
| 113 |
+
self.resource_monitor['loaded_models'].add(model_name)
|
| 114 |
+
return True
|
| 115 |
+
|
| 116 |
+
return False
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"Error storing model state: {str(e)}")
|
| 119 |
+
return False
|
| 120 |
+
|
| 121 |
def initialize_tensor_cores(self):
|
| 122 |
"""Initialize tensor cores and verify they're ready for computation"""
|
| 123 |
if self.tensor_cores_initialized:
|
|
|
|
| 568 |
total_ops = total_params * batch_size * ops_per_param
|
| 569 |
return (total_ops / inference_time) / 1e12 # Convert to TFLOPS
|
| 570 |
|
| 571 |
+
def load_model(self, model_id: str, model: Any, processor: Any):
|
| 572 |
"""Loads a model directly into WebSocket storage without CPU intermediary."""
|
| 573 |
try:
|
| 574 |
+
# Extract model metadata
|
| 575 |
+
model_info = {
|
| 576 |
+
"architecture": model.__class__.__name__,
|
| 577 |
+
"processor": processor.__class__.__name__,
|
| 578 |
+
"config": model.config.to_dict() if hasattr(model, "config") else {}
|
| 579 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
|
| 581 |
# Store model state in WebSocket storage
|
| 582 |
self.storage.store_state(f"models/{model_id}", "info", model_info)
|
| 583 |
|
| 584 |
+
# Map weight tensors directly to WebSocket storage
|
| 585 |
+
if hasattr(model, "state_dict"):
|
| 586 |
+
model_weights = {}
|
| 587 |
|
| 588 |
for name, param in model.state_dict().items():
|
| 589 |
tensor_id = f"{model_id}/weights/{name}"
|
|
|
|
| 604 |
self.storage.store_state(f"models/{model_id}", "state", model)
|
| 605 |
self.model_registry[model_id] = tensor_id
|
| 606 |
|
| 607 |
+
self.tokenizer_registry[model_id] = processor
|
|
|
|
|
|
|
|
|
|
| 608 |
self.model_loaded = True
|
| 609 |
print(f"Model '{model_id}' loaded into WebSocket storage")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
except Exception as e:
|
| 611 |
print(f"Error loading model into WebSocket storage: {str(e)}")
|
| 612 |
+
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
|
| 614 |
def has_model(self, model_id: str) -> bool:
|
| 615 |
"""Checks if a model is loaded in the accelerator's registry."""
|