Spaces:
Sleeping
Sleeping
Factor Studios
commited on
Update ai.py
Browse files
ai.py
CHANGED
|
@@ -49,9 +49,10 @@ class AIAccelerator:
|
|
| 49 |
)
|
| 50 |
self.tensor_cores_initialized = False
|
| 51 |
|
| 52 |
-
# Initialize model and
|
| 53 |
self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
|
| 54 |
self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
|
|
|
|
| 55 |
self.resource_monitor = {
|
| 56 |
'vram_used': 0,
|
| 57 |
'active_tensors': 0,
|
|
@@ -571,25 +572,45 @@ class AIAccelerator:
|
|
| 571 |
def load_model(self, model_id: str, model: Any, processor: Any):
|
| 572 |
"""Loads a model directly into WebSocket storage without CPU intermediary."""
|
| 573 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
# Extract model metadata
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
# Store model state in WebSocket storage
|
| 582 |
-
self.storage.store_state(f"models/{model_id}", "info", model_info)
|
|
|
|
| 583 |
|
| 584 |
# Map weight tensors directly to WebSocket storage
|
| 585 |
-
if hasattr(model, "state_dict"):
|
| 586 |
model_weights = {}
|
| 587 |
|
| 588 |
for name, param in model.state_dict().items():
|
| 589 |
tensor_id = f"{model_id}/weights/{name}"
|
| 590 |
|
| 591 |
# Store tensor directly in WebSocket storage
|
| 592 |
-
self.storage.store_tensor(tensor_id, param.detach().numpy())
|
|
|
|
| 593 |
model_weights[name] = tensor_id
|
| 594 |
|
| 595 |
# Store only WebSocket references
|
|
@@ -601,9 +622,11 @@ class AIAccelerator:
|
|
| 601 |
else:
|
| 602 |
# Store the entire model state in WebSocket storage
|
| 603 |
tensor_id = f"{model_id}/model_state"
|
| 604 |
-
self.storage.store_state(f"models/{model_id}", "state", model)
|
|
|
|
| 605 |
self.model_registry[model_id] = tensor_id
|
| 606 |
|
|
|
|
| 607 |
self.tokenizer_registry[model_id] = processor
|
| 608 |
self.model_loaded = True
|
| 609 |
print(f"Model '{model_id}' loaded into WebSocket storage")
|
|
@@ -675,5 +698,3 @@ class AIAccelerator:
|
|
| 675 |
except Exception as e:
|
| 676 |
print(f"[ERROR] WebSocket-based inference failed for idx={idx}: {e}")
|
| 677 |
return None
|
| 678 |
-
|
| 679 |
-
|
|
|
|
| 49 |
)
|
| 50 |
self.tensor_cores_initialized = False
|
| 51 |
|
| 52 |
+
# Initialize model, tensor, and tokenizer tracking
|
| 53 |
self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
|
| 54 |
self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
|
| 55 |
+
self.tokenizer_registry: Dict[str, Any] = {} # Track tokenizers
|
| 56 |
self.resource_monitor = {
|
| 57 |
'vram_used': 0,
|
| 58 |
'active_tensors': 0,
|
|
|
|
| 572 |
def load_model(self, model_id: str, model: Any, processor: Any):
|
| 573 |
"""Loads a model directly into WebSocket storage without CPU intermediary."""
|
| 574 |
try:
|
| 575 |
+
if model is None and processor is None:
|
| 576 |
+
# Zero-copy mode
|
| 577 |
+
self.model_registry[model_id] = {
|
| 578 |
+
"zero_copy": True,
|
| 579 |
+
"websocket_mapped": True
|
| 580 |
+
}
|
| 581 |
+
self.tokenizer_registry[model_id] = None
|
| 582 |
+
self.model_loaded = True
|
| 583 |
+
return
|
| 584 |
+
|
| 585 |
# Extract model metadata
|
| 586 |
+
try:
|
| 587 |
+
model_info = {
|
| 588 |
+
"architecture": model.__class__.__name__ if model else "Unknown",
|
| 589 |
+
"processor": processor.__class__.__name__ if processor else "Unknown",
|
| 590 |
+
"config": self._serialize_model_config(model.config) if hasattr(model, "config") else {}
|
| 591 |
+
}
|
| 592 |
+
except Exception as e:
|
| 593 |
+
print(f"Warning: Error serializing model metadata: {e}")
|
| 594 |
+
model_info = {"error": str(e)}
|
| 595 |
+
|
| 596 |
+
# Verify WebSocket connection
|
| 597 |
+
if not self.storage or not self.storage.wait_for_connection():
|
| 598 |
+
raise RuntimeError("WebSocket connection not available")
|
| 599 |
+
|
| 600 |
# Store model state in WebSocket storage
|
| 601 |
+
if not self.storage.store_state(f"models/{model_id}/info", "info", model_info):
|
| 602 |
+
raise RuntimeError("Failed to store model info")
|
| 603 |
|
| 604 |
# Map weight tensors directly to WebSocket storage
|
| 605 |
+
if model is not None and hasattr(model, "state_dict"):
|
| 606 |
model_weights = {}
|
| 607 |
|
| 608 |
for name, param in model.state_dict().items():
|
| 609 |
tensor_id = f"{model_id}/weights/{name}"
|
| 610 |
|
| 611 |
# Store tensor directly in WebSocket storage
|
| 612 |
+
if not self.storage.store_tensor(tensor_id, param.detach().numpy()):
|
| 613 |
+
raise RuntimeError(f"Failed to store tensor {name}")
|
| 614 |
model_weights[name] = tensor_id
|
| 615 |
|
| 616 |
# Store only WebSocket references
|
|
|
|
| 622 |
else:
|
| 623 |
# Store the entire model state in WebSocket storage
|
| 624 |
tensor_id = f"{model_id}/model_state"
|
| 625 |
+
if not self.storage.store_state(f"models/{model_id}/state", "state", model):
|
| 626 |
+
raise RuntimeError("Failed to store model state")
|
| 627 |
self.model_registry[model_id] = tensor_id
|
| 628 |
|
| 629 |
+
# Store tokenizer/processor
|
| 630 |
self.tokenizer_registry[model_id] = processor
|
| 631 |
self.model_loaded = True
|
| 632 |
print(f"Model '{model_id}' loaded into WebSocket storage")
|
|
|
|
| 698 |
except Exception as e:
|
| 699 |
print(f"[ERROR] WebSocket-based inference failed for idx={idx}: {e}")
|
| 700 |
return None
|
|
|
|
|
|