Factor Studios commited on
Commit
ab3a38e
·
verified ·
1 Parent(s): 3f4680e

Update ai.py

Browse files
Files changed (1) hide show
  1. ai.py +34 -13
ai.py CHANGED
@@ -49,9 +49,10 @@ class AIAccelerator:
49
  )
50
  self.tensor_cores_initialized = False
51
 
52
- # Initialize model and tensor tracking
53
  self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
54
  self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
 
55
  self.resource_monitor = {
56
  'vram_used': 0,
57
  'active_tensors': 0,
@@ -571,25 +572,45 @@ class AIAccelerator:
571
  def load_model(self, model_id: str, model: Any, processor: Any):
572
  """Loads a model directly into WebSocket storage without CPU intermediary."""
573
  try:
 
 
 
 
 
 
 
 
 
 
574
  # Extract model metadata
575
- model_info = {
576
- "architecture": model.__class__.__name__,
577
- "processor": processor.__class__.__name__,
578
- "config": model.config.to_dict() if hasattr(model, "config") else {}
579
- }
580
-
 
 
 
 
 
 
 
 
581
  # Store model state in WebSocket storage
582
- self.storage.store_state(f"models/{model_id}", "info", model_info)
 
583
 
584
  # Map weight tensors directly to WebSocket storage
585
- if hasattr(model, "state_dict"):
586
  model_weights = {}
587
 
588
  for name, param in model.state_dict().items():
589
  tensor_id = f"{model_id}/weights/{name}"
590
 
591
  # Store tensor directly in WebSocket storage
592
- self.storage.store_tensor(tensor_id, param.detach().numpy())
 
593
  model_weights[name] = tensor_id
594
 
595
  # Store only WebSocket references
@@ -601,9 +622,11 @@ class AIAccelerator:
601
  else:
602
  # Store the entire model state in WebSocket storage
603
  tensor_id = f"{model_id}/model_state"
604
- self.storage.store_state(f"models/{model_id}", "state", model)
 
605
  self.model_registry[model_id] = tensor_id
606
 
 
607
  self.tokenizer_registry[model_id] = processor
608
  self.model_loaded = True
609
  print(f"Model '{model_id}' loaded into WebSocket storage")
@@ -675,5 +698,3 @@ class AIAccelerator:
675
  except Exception as e:
676
  print(f"[ERROR] WebSocket-based inference failed for idx={idx}: {e}")
677
  return None
678
-
679
-
 
49
  )
50
  self.tensor_cores_initialized = False
51
 
52
+ # Initialize model, tensor, and tokenizer tracking
53
  self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
54
  self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
55
+ self.tokenizer_registry: Dict[str, Any] = {} # Track tokenizers
56
  self.resource_monitor = {
57
  'vram_used': 0,
58
  'active_tensors': 0,
 
572
  def load_model(self, model_id: str, model: Any, processor: Any):
573
  """Loads a model directly into WebSocket storage without CPU intermediary."""
574
  try:
575
+ if model is None and processor is None:
576
+ # Zero-copy mode
577
+ self.model_registry[model_id] = {
578
+ "zero_copy": True,
579
+ "websocket_mapped": True
580
+ }
581
+ self.tokenizer_registry[model_id] = None
582
+ self.model_loaded = True
583
+ return
584
+
585
  # Extract model metadata
586
+ try:
587
+ model_info = {
588
+ "architecture": model.__class__.__name__ if model else "Unknown",
589
+ "processor": processor.__class__.__name__ if processor else "Unknown",
590
+ "config": self._serialize_model_config(model.config) if hasattr(model, "config") else {}
591
+ }
592
+ except Exception as e:
593
+ print(f"Warning: Error serializing model metadata: {e}")
594
+ model_info = {"error": str(e)}
595
+
596
+ # Verify WebSocket connection
597
+ if not self.storage or not self.storage.wait_for_connection():
598
+ raise RuntimeError("WebSocket connection not available")
599
+
600
  # Store model state in WebSocket storage
601
+ if not self.storage.store_state(f"models/{model_id}/info", "info", model_info):
602
+ raise RuntimeError("Failed to store model info")
603
 
604
  # Map weight tensors directly to WebSocket storage
605
+ if model is not None and hasattr(model, "state_dict"):
606
  model_weights = {}
607
 
608
  for name, param in model.state_dict().items():
609
  tensor_id = f"{model_id}/weights/{name}"
610
 
611
  # Store tensor directly in WebSocket storage
612
+ if not self.storage.store_tensor(tensor_id, param.detach().numpy()):
613
+ raise RuntimeError(f"Failed to store tensor {name}")
614
  model_weights[name] = tensor_id
615
 
616
  # Store only WebSocket references
 
622
  else:
623
  # Store the entire model state in WebSocket storage
624
  tensor_id = f"{model_id}/model_state"
625
+ if not self.storage.store_state(f"models/{model_id}/state", "state", model):
626
+ raise RuntimeError("Failed to store model state")
627
  self.model_registry[model_id] = tensor_id
628
 
629
+ # Store tokenizer/processor
630
  self.tokenizer_registry[model_id] = processor
631
  self.model_loaded = True
632
  print(f"Model '{model_id}' loaded into WebSocket storage")
 
698
  except Exception as e:
699
  print(f"[ERROR] WebSocket-based inference failed for idx={idx}: {e}")
700
  return None