Factor Studios commited on
Commit
e8c0748
·
verified ·
1 Parent(s): bd3e2ef

Update gpu_arch.py

Browse files
Files changed (1) hide show
  1. gpu_arch.py +5 -4
gpu_arch.py CHANGED
@@ -295,7 +295,7 @@ class GPUMemoryHierarchy:
295
 
296
 
297
  class Chip:
298
- def __init__(self, chip_id, num_sms=1500, vram_size_gb=16, db_path="gpu_state.db"):
299
  self.chip_id = chip_id
300
  self.db = GPUStateDB(db_path)
301
  # Handle unlimited VRAM case (when vram_size_gb is None)
@@ -303,9 +303,10 @@ class Chip:
303
  self.gpu_mem = GPUMemoryHierarchy(num_sms=num_sms, global_mem_size_bytes=global_mem_size_bytes, chip_id=chip_id, db=self.db)
304
  self.sm_ids = list(range(num_sms))
305
  self.connected_chips = []
306
- self.ai_accelerator = AIAccelerator() # Instantiate AIAccelerator
307
- self.custom_vram = CustomVRAM(self.gpu_mem.global_mem) # Create CustomVRAM instance
308
- self.ai_accelerator.set_vram(self.custom_vram) # Set VRAM for AIAccelerator
 
309
 
310
  def get_sm(self, sm_id):
311
  return StreamingMultiprocessor(sm_id, self.chip_id, self.db)
 
295
 
296
 
297
  class Chip:
298
+ def __init__(self, chip_id, num_sms=1500, vram_size_gb=16, db_path="gpu_state.db", storage=None):
299
  self.chip_id = chip_id
300
  self.db = GPUStateDB(db_path)
301
  # Handle unlimited VRAM case (when vram_size_gb is None)
 
303
  self.gpu_mem = GPUMemoryHierarchy(num_sms=num_sms, global_mem_size_bytes=global_mem_size_bytes, chip_id=chip_id, db=self.db)
304
  self.sm_ids = list(range(num_sms))
305
  self.connected_chips = []
306
+ self.storage = storage # Store shared WebSocket storage
307
+ self.ai_accelerator = AIAccelerator(storage=storage) # Pass shared storage to accelerator
308
+ self.custom_vram = CustomVRAM(self.gpu_mem.global_mem) # Create CustomVRAM instance
309
+ self.ai_accelerator.set_vram(self.custom_vram) # Set VRAM for AIAccelerator
310
 
311
  def get_sm(self, sm_id):
312
  return StreamingMultiprocessor(sm_id, self.chip_id, self.db)