from websocket_storage import WebSocketGPUStorage from virtual_vram import VirtualVRAM from streaming_multiprocessor import StreamingMultiprocessor from typing import Dict, Any, List, Optional import time class GPUChip: def __init__(self, chip_id: int, num_sms: int = 108, vram_gb: int = 24, storage=None): self.chip_id = chip_id self.storage = storage if self.storage is None: from websocket_storage import WebSocketGPUStorage self.storage = WebSocketGPUStorage() if not self.storage.wait_for_connection(): raise RuntimeError("Could not connect to GPU storage server") # Initialize components with shared storage self.vram = VirtualVRAM(vram_gb, storage=self.storage) self.sms = [StreamingMultiprocessor(i, storage=self.storage) for i in range(num_sms)] # Initialize chip state self.chip_state = { "chip_id": chip_id, "num_sms": num_sms, "vram_gb": vram_gb, "pcie_state": { "active_transfers": {}, "bandwidth_usage": 0 }, "power_state": { "total_watts": 0, "sm_power": [0] * num_sms, "vram_power": 0 }, "memory_controller": { "active_requests": {}, "bandwidth_usage": 0 } } self.store_chip_state() def store_chip_state(self): """Store chip state in WebSocket storage""" self.storage.store_state(f"chip_{self.chip_id}", "state", self.chip_state) def allocate_memory(self, size: int, virtual_addr: Optional[str] = None) -> str: """Allocate memory through VRAM""" block_id = self.vram.allocate_block(size) if virtual_addr: self.vram.map_address(virtual_addr, block_id) # Update memory controller state self.chip_state["memory_controller"]["active_requests"][block_id] = { "type": "allocation", "size": size, "timestamp": time.time_ns() } self.store_chip_state() return block_id def transfer_to_device(self, data: bytes, virtual_addr: Optional[str] = None) -> str: """Transfer data to device through PCIe""" # Simulate PCIe transfer transfer_id = f"transfer_{time.time_ns()}" self.chip_state["pcie_state"]["active_transfers"][transfer_id] = { "direction": "to_device", "size": len(data), "timestamp": time.time_ns() } self.store_chip_state() # Allocate and store in VRAM block_id = self.allocate_memory(len(data), virtual_addr) self.storage.store_tensor(block_id, data) # Update transfer state self.chip_state["pcie_state"]["active_transfers"][transfer_id]["completed"] = True self.store_chip_state() return block_id def schedule_compute(self, sm_index: int, warp_state: Dict[str, Any]) -> str: """Schedule computation on an SM""" if 0 <= sm_index < len(self.sms): warp_id = f"warp_{time.time_ns()}" self.sms[sm_index].schedule_warp(warp_id, warp_state) # Update power state self.chip_state["power_state"]["sm_power"][sm_index] += 10 # Simulate power increase self.chip_state["power_state"]["total_watts"] = sum(self.chip_state["power_state"]["sm_power"]) self.store_chip_state() return warp_id raise ValueError(f"Invalid SM index: {sm_index}") def get_stats(self) -> Dict[str, Any]: """Get comprehensive chip statistics""" stats = { "chip_id": self.chip_id, "vram": self.vram.get_stats(), "sms": [sm.get_stats() for sm in self.sms], "pcie": { "active_transfers": len(self.chip_state["pcie_state"]["active_transfers"]), "bandwidth_usage": self.chip_state["pcie_state"]["bandwidth_usage"] }, "power": { "total_watts": self.chip_state["power_state"]["total_watts"], "vram_watts": self.chip_state["power_state"]["vram_power"] }, "memory_controller": { "active_requests": len(self.chip_state["memory_controller"]["active_requests"]), "bandwidth_usage": self.chip_state["memory_controller"]["bandwidth_usage"] } } return stats