Spaces:
Sleeping
Sleeping
File size: 4,763 Bytes
e9bc512 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from websocket_storage import WebSocketGPUStorage
from virtual_vram import VirtualVRAM
from streaming_multiprocessor import StreamingMultiprocessor
from typing import Dict, Any, List, Optional
import time
class GPUChip:
def __init__(self, chip_id: int, num_sms: int = 108, vram_gb: int = 24, storage=None):
self.chip_id = chip_id
self.storage = storage
if self.storage is None:
from websocket_storage import WebSocketGPUStorage
self.storage = WebSocketGPUStorage()
if not self.storage.wait_for_connection():
raise RuntimeError("Could not connect to GPU storage server")
# Initialize components with shared storage
self.vram = VirtualVRAM(vram_gb, storage=self.storage)
self.sms = [StreamingMultiprocessor(i, storage=self.storage) for i in range(num_sms)]
# Initialize chip state
self.chip_state = {
"chip_id": chip_id,
"num_sms": num_sms,
"vram_gb": vram_gb,
"pcie_state": {
"active_transfers": {},
"bandwidth_usage": 0
},
"power_state": {
"total_watts": 0,
"sm_power": [0] * num_sms,
"vram_power": 0
},
"memory_controller": {
"active_requests": {},
"bandwidth_usage": 0
}
}
self.store_chip_state()
def store_chip_state(self):
"""Store chip state in WebSocket storage"""
self.storage.store_state(f"chip_{self.chip_id}", "state", self.chip_state)
def allocate_memory(self, size: int, virtual_addr: Optional[str] = None) -> str:
"""Allocate memory through VRAM"""
block_id = self.vram.allocate_block(size)
if virtual_addr:
self.vram.map_address(virtual_addr, block_id)
# Update memory controller state
self.chip_state["memory_controller"]["active_requests"][block_id] = {
"type": "allocation",
"size": size,
"timestamp": time.time_ns()
}
self.store_chip_state()
return block_id
def transfer_to_device(self, data: bytes, virtual_addr: Optional[str] = None) -> str:
"""Transfer data to device through PCIe"""
# Simulate PCIe transfer
transfer_id = f"transfer_{time.time_ns()}"
self.chip_state["pcie_state"]["active_transfers"][transfer_id] = {
"direction": "to_device",
"size": len(data),
"timestamp": time.time_ns()
}
self.store_chip_state()
# Allocate and store in VRAM
block_id = self.allocate_memory(len(data), virtual_addr)
self.storage.store_tensor(block_id, data)
# Update transfer state
self.chip_state["pcie_state"]["active_transfers"][transfer_id]["completed"] = True
self.store_chip_state()
return block_id
def schedule_compute(self, sm_index: int, warp_state: Dict[str, Any]) -> str:
"""Schedule computation on an SM"""
if 0 <= sm_index < len(self.sms):
warp_id = f"warp_{time.time_ns()}"
self.sms[sm_index].schedule_warp(warp_id, warp_state)
# Update power state
self.chip_state["power_state"]["sm_power"][sm_index] += 10 # Simulate power increase
self.chip_state["power_state"]["total_watts"] = sum(self.chip_state["power_state"]["sm_power"])
self.store_chip_state()
return warp_id
raise ValueError(f"Invalid SM index: {sm_index}")
def get_stats(self) -> Dict[str, Any]:
"""Get comprehensive chip statistics"""
stats = {
"chip_id": self.chip_id,
"vram": self.vram.get_stats(),
"sms": [sm.get_stats() for sm in self.sms],
"pcie": {
"active_transfers": len(self.chip_state["pcie_state"]["active_transfers"]),
"bandwidth_usage": self.chip_state["pcie_state"]["bandwidth_usage"]
},
"power": {
"total_watts": self.chip_state["power_state"]["total_watts"],
"vram_watts": self.chip_state["power_state"]["vram_power"]
},
"memory_controller": {
"active_requests": len(self.chip_state["memory_controller"]["active_requests"]),
"bandwidth_usage": self.chip_state["memory_controller"]["bandwidth_usage"]
}
}
return stats
|