Spaces:
Runtime error
Runtime error
| from websocket_storage import WebSocketGPUStorage | |
| from virtual_vram import VirtualVRAM | |
| from streaming_multiprocessor import StreamingMultiprocessor | |
| from typing import Dict, Any, List, Optional | |
| import time | |
| class GPUChip: | |
| def __init__(self, chip_id: int, num_sms: int = 108, vram_gb: int = 24, storage=None): | |
| self.chip_id = chip_id | |
| self.storage = storage | |
| if self.storage is None: | |
| from websocket_storage import WebSocketGPUStorage | |
| self.storage = WebSocketGPUStorage() | |
| if not self.storage.wait_for_connection(): | |
| raise RuntimeError("Could not connect to GPU storage server") | |
| # Initialize components with shared storage | |
| self.vram = VirtualVRAM(vram_gb, storage=self.storage) | |
| self.sms = [StreamingMultiprocessor(i, storage=self.storage) for i in range(num_sms)] | |
| # Initialize chip state | |
| self.chip_state = { | |
| "chip_id": chip_id, | |
| "num_sms": num_sms, | |
| "vram_gb": vram_gb, | |
| "pcie_state": { | |
| "active_transfers": {}, | |
| "bandwidth_usage": 0 | |
| }, | |
| "power_state": { | |
| "total_watts": 0, | |
| "sm_power": [0] * num_sms, | |
| "vram_power": 0 | |
| }, | |
| "memory_controller": { | |
| "active_requests": {}, | |
| "bandwidth_usage": 0 | |
| } | |
| } | |
| self.store_chip_state() | |
| def store_chip_state(self): | |
| """Store chip state in WebSocket storage""" | |
| self.storage.store_state(f"chip_{self.chip_id}", "state", self.chip_state) | |
| def allocate_memory(self, size: int, virtual_addr: Optional[str] = None) -> str: | |
| """Allocate memory through VRAM""" | |
| block_id = self.vram.allocate_block(size) | |
| if virtual_addr: | |
| self.vram.map_address(virtual_addr, block_id) | |
| # Update memory controller state | |
| self.chip_state["memory_controller"]["active_requests"][block_id] = { | |
| "type": "allocation", | |
| "size": size, | |
| "timestamp": time.time_ns() | |
| } | |
| self.store_chip_state() | |
| return block_id | |
| def transfer_to_device(self, data: bytes, virtual_addr: Optional[str] = None) -> str: | |
| """Transfer data to device through PCIe""" | |
| # Simulate PCIe transfer | |
| transfer_id = f"transfer_{time.time_ns()}" | |
| self.chip_state["pcie_state"]["active_transfers"][transfer_id] = { | |
| "direction": "to_device", | |
| "size": len(data), | |
| "timestamp": time.time_ns() | |
| } | |
| self.store_chip_state() | |
| # Allocate and store in VRAM | |
| block_id = self.allocate_memory(len(data), virtual_addr) | |
| self.storage.store_tensor(block_id, data) | |
| # Update transfer state | |
| self.chip_state["pcie_state"]["active_transfers"][transfer_id]["completed"] = True | |
| self.store_chip_state() | |
| return block_id | |
| def schedule_compute(self, sm_index: int, warp_state: Dict[str, Any]) -> str: | |
| """Schedule computation on an SM""" | |
| if 0 <= sm_index < len(self.sms): | |
| warp_id = f"warp_{time.time_ns()}" | |
| self.sms[sm_index].schedule_warp(warp_id, warp_state) | |
| # Update power state | |
| self.chip_state["power_state"]["sm_power"][sm_index] += 10 # Simulate power increase | |
| self.chip_state["power_state"]["total_watts"] = sum(self.chip_state["power_state"]["sm_power"]) | |
| self.store_chip_state() | |
| return warp_id | |
| raise ValueError(f"Invalid SM index: {sm_index}") | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get comprehensive chip statistics""" | |
| stats = { | |
| "chip_id": self.chip_id, | |
| "vram": self.vram.get_stats(), | |
| "sms": [sm.get_stats() for sm in self.sms], | |
| "pcie": { | |
| "active_transfers": len(self.chip_state["pcie_state"]["active_transfers"]), | |
| "bandwidth_usage": self.chip_state["pcie_state"]["bandwidth_usage"] | |
| }, | |
| "power": { | |
| "total_watts": self.chip_state["power_state"]["total_watts"], | |
| "vram_watts": self.chip_state["power_state"]["vram_power"] | |
| }, | |
| "memory_controller": { | |
| "active_requests": len(self.chip_state["memory_controller"]["active_requests"]), | |
| "bandwidth_usage": self.chip_state["memory_controller"]["bandwidth_usage"] | |
| } | |
| } | |
| return stats | |