Factor Studios commited on
Commit
a07258a
·
verified ·
1 Parent(s): 97c652e

Update ai.py

Browse files
Files changed (1) hide show
  1. ai.py +81 -56
ai.py CHANGED
@@ -49,6 +49,75 @@ class AIAccelerator:
49
  )
50
  self.tensor_cores_initialized = False
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def initialize_tensor_cores(self):
53
  """Initialize tensor cores and verify they're ready for computation"""
54
  if self.tensor_cores_initialized:
@@ -499,41 +568,22 @@ class AIAccelerator:
499
  total_ops = total_params * batch_size * ops_per_param
500
  return (total_ops / inference_time) / 1e12 # Convert to TFLOPS
501
 
502
- def load_model(self, model_id: str, model: Any = None, processor: Any = None, model_config: Dict[str, Any] = None):
503
  """Loads a model directly into WebSocket storage without CPU intermediary."""
504
  try:
505
- # Use provided config or create default
506
- config = model_config or {}
507
-
508
- if model is not None:
509
- # Extract model metadata if model is provided
510
- model_info = {
511
- "architecture": model.__class__.__name__,
512
- "processor": processor.__class__.__name__ if processor else "None",
513
- "config": model.config.to_dict() if hasattr(model, "config") else {},
514
- "model_config": {
515
- k: str(v) if not isinstance(v, (bool, int, float, str, list, dict)) else v
516
- for k, v in config.items()
517
- } # Ensure config is JSON serializable
518
- }
519
- else:
520
- # Use provided config for zero-copy mode
521
- model_info = {
522
- "architecture": "ZeroCopy",
523
- "processor": "None",
524
- "config": {},
525
- "model_config": {
526
- k: str(v) if not isinstance(v, (bool, int, float, str, list, dict)) else v
527
- for k, v in config.items()
528
- } # Ensure config is JSON serializable
529
- }
530
 
531
  # Store model state in WebSocket storage
532
  self.storage.store_state(f"models/{model_id}", "info", model_info)
533
 
534
- # Map weight tensors directly to WebSocket storage if model is provided
535
- model_weights = {}
536
- if model is not None and hasattr(model, "state_dict"):
537
 
538
  for name, param in model.state_dict().items():
539
  tensor_id = f"{model_id}/weights/{name}"
@@ -554,37 +604,12 @@ class AIAccelerator:
554
  self.storage.store_state(f"models/{model_id}", "state", model)
555
  self.model_registry[model_id] = tensor_id
556
 
557
- # Store processor if provided
558
- if processor is not None:
559
- self.tokenizer_registry[model_id] = processor
560
-
561
  self.model_loaded = True
562
  print(f"Model '{model_id}' loaded into WebSocket storage")
563
-
564
- # Additional setup for zero-copy mode
565
- if model_config and model_config.get("zero_copy"):
566
- # Register empty tensors for zero-copy mode
567
- tensor_id = f"{model_id}/zero_copy"
568
- self.model_registry[model_id] = {
569
- "mode": "zero_copy",
570
- "config": model_config,
571
- "tensor_id": tensor_id
572
- }
573
-
574
- return True
575
-
576
  except Exception as e:
577
  print(f"Error loading model into WebSocket storage: {str(e)}")
578
- if "is not JSON serializable" in str(e):
579
- print("Attempting to serialize config differently...")
580
- try:
581
- # Try again with string conversion for non-serializable types
582
- if model_config:
583
- model_config = {k: str(v) for k, v in model_config.items()}
584
- return self.load_model(model_id, model, processor, model_config)
585
- except Exception as e2:
586
- print(f"Second attempt also failed: {e2}")
587
- return False
588
 
589
  def has_model(self, model_id: str) -> bool:
590
  """Checks if a model is loaded in the accelerator's registry."""
 
49
  )
50
  self.tensor_cores_initialized = False
51
 
52
+ # Initialize model and tensor tracking
53
+ self.model_registry: Dict[str, Dict[str, Any]] = {} # Track loaded models
54
+ self.tensor_registry: Dict[str, Dict[str, Any]] = {} # Track tensor metadata
55
+ self.resource_monitor = {
56
+ 'vram_used': 0,
57
+ 'active_tensors': 0,
58
+ 'loaded_models': set()
59
+ }
60
+
61
+ def _serialize_model_config(self, config: Any) -> dict:
62
+ """Convert model config to a serializable format."""
63
+ if hasattr(config, '__dict__'):
64
+ # Convert object attributes to dict
65
+ config_dict = {}
66
+ for key, value in config.__dict__.items():
67
+ if isinstance(value, (int, float, str, bool, type(None))):
68
+ config_dict[key] = value
69
+ elif isinstance(value, (list, tuple)):
70
+ config_dict[key] = [self._serialize_model_config(item) for item in value]
71
+ elif isinstance(value, dict):
72
+ config_dict[key] = {k: self._serialize_model_config(v) for k, v in value.items()}
73
+ elif hasattr(value, '__dict__'):
74
+ config_dict[key] = self._serialize_model_config(value)
75
+ else:
76
+ config_dict[key] = str(value) # Fallback to string representation
77
+ return config_dict
78
+ elif isinstance(config, (list, tuple)):
79
+ return [self._serialize_model_config(item) for item in config]
80
+ elif isinstance(config, dict):
81
+ return {k: self._serialize_model_config(v) for k, v in config.items()}
82
+ elif isinstance(config, (int, float, str, bool, type(None))):
83
+ return config
84
+ else:
85
+ return str(config) # Fallback to string representation
86
+
87
+ def store_model_state(self, model_name: str, model_info: Dict[str, Any]) -> bool:
88
+ """Store model state in WebSocket storage with proper serialization."""
89
+ try:
90
+ # Convert any non-serializable parts of model_info
91
+ serializable_info = self._serialize_model_config(model_info)
92
+
93
+ # Store in model registry
94
+ self.model_registry[model_name] = serializable_info
95
+
96
+ # Save to storage
97
+ if self.storage:
98
+ # Store model info
99
+ info_success = self.storage.store_state(
100
+ "models",
101
+ f"{model_name}/info",
102
+ serializable_info
103
+ )
104
+
105
+ # Store model state
106
+ state_success = self.storage.store_state(
107
+ "models",
108
+ f"{model_name}/state",
109
+ {"loaded": True, "timestamp": time.time()}
110
+ )
111
+
112
+ if info_success and state_success:
113
+ self.resource_monitor['loaded_models'].add(model_name)
114
+ return True
115
+
116
+ return False
117
+ except Exception as e:
118
+ print(f"Error storing model state: {str(e)}")
119
+ return False
120
+
121
  def initialize_tensor_cores(self):
122
  """Initialize tensor cores and verify they're ready for computation"""
123
  if self.tensor_cores_initialized:
 
568
  total_ops = total_params * batch_size * ops_per_param
569
  return (total_ops / inference_time) / 1e12 # Convert to TFLOPS
570
 
571
+ def load_model(self, model_id: str, model: Any, processor: Any):
572
  """Loads a model directly into WebSocket storage without CPU intermediary."""
573
  try:
574
+ # Extract model metadata
575
+ model_info = {
576
+ "architecture": model.__class__.__name__,
577
+ "processor": processor.__class__.__name__,
578
+ "config": model.config.to_dict() if hasattr(model, "config") else {}
579
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
  # Store model state in WebSocket storage
582
  self.storage.store_state(f"models/{model_id}", "info", model_info)
583
 
584
+ # Map weight tensors directly to WebSocket storage
585
+ if hasattr(model, "state_dict"):
586
+ model_weights = {}
587
 
588
  for name, param in model.state_dict().items():
589
  tensor_id = f"{model_id}/weights/{name}"
 
604
  self.storage.store_state(f"models/{model_id}", "state", model)
605
  self.model_registry[model_id] = tensor_id
606
 
607
+ self.tokenizer_registry[model_id] = processor
 
 
 
608
  self.model_loaded = True
609
  print(f"Model '{model_id}' loaded into WebSocket storage")
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  except Exception as e:
611
  print(f"Error loading model into WebSocket storage: {str(e)}")
612
+ raise
 
 
 
 
 
 
 
 
 
613
 
614
  def has_model(self, model_id: str) -> bool:
615
  """Checks if a model is loaded in the accelerator's registry."""