INV / gpu_chip.py
Fred808's picture
Upload 256 files
7a0c684 verified
from http_storage import LocalStorage
from virtual_vram import VirtualVRAM
from streaming_multiprocessor import StreamingMultiprocessor
from typing import Dict, Any, List, Optional, Tuple
import time
from config import get_db_url
class OpticalInterconnect:
def __init__(self, bandwidth_tbps=800, latency_ns=1):
self.bandwidth_tbps = bandwidth_tbps
self.latency_ns = latency_ns
def transfer_time(self, data_size_bytes: int) -> float:
"""Calculate data transfer time in seconds"""
bandwidth_bytes_per_s = self.bandwidth_tbps * 1e12
return self.latency_ns * 1e-9 + (data_size_bytes / bandwidth_bytes_per_s)
class GPUChip:
def __init__(self, chip_id: int, num_sms: int = 108, vram_gb: int = 214, storage=None):
self.chip_id = chip_id
self.storage = storage
if self.storage is None:
from http_storage import LocalStorage
self.storage = LocalStorage(db_url=get_db_url())
if not self.storage.is_connected():
raise RuntimeError("Could not connect to local storage")
# Initialize components with shared storage
self.vram = VirtualVRAM(vram_gb, storage=self.storage)
self.sms = [StreamingMultiprocessor(i, storage=self.storage) for i in range(num_sms)]
# Initialize chip state
self.chip_state = {
"chip_id": chip_id,
"num_sms": num_sms,
"vram_gb": vram_gb,
"pcie_state": {
"active_transfers": {},
"bandwidth_usage": 0
},
"power_state": {
"total_watts": 0,
"sm_power": [0] * num_sms,
"vram_power": 0
},
"memory_controller": {
"active_requests": {},
"bandwidth_usage": 0
}
}
self.store_chip_state()
def store_chip_state(self):
"""Store chip state in local storage"""
self.storage.store_state(f"chip_{self.chip_id}", "state", self.chip_state)
def connect_chip(self, other_chip: 'GPUChip', interconnect: OpticalInterconnect) -> None:
"""Connect to another GPU chip via optical interconnect"""
if not hasattr(self, 'connected_chips'):
self.connected_chips = []
self.chip_state['connected_chips'] = {}
self.connected_chips.append((other_chip, interconnect))
self.chip_state['connected_chips'][other_chip.chip_id] = {
'bandwidth_tbps': interconnect.bandwidth_tbps,
'latency_ns': interconnect.latency_ns,
'active': True
}
self.store_chip_state()
def transfer_data(self, target_chip: 'GPUChip', data_size: int) -> float:
"""Transfer data to another chip, returns transfer time in seconds"""
for chip, interconnect in self.connected_chips:
if chip.chip_id == target_chip.chip_id:
transfer_time = interconnect.transfer_time(data_size)
self.chip_state['pcie_state']['active_transfers'][str(time.time())] = {
'target_chip': target_chip.chip_id,
'size': data_size,
'estimated_time': transfer_time
}
self.store_chip_state()
return transfer_time
raise ValueError(f"No connection found to chip {target_chip.chip_id}")
def allocate_memory(self, size: int, virtual_addr: Optional[str] = None) -> str:
"""Allocate memory through VRAM"""
block_id = self.vram.allocate_block(size)
if virtual_addr:
self.vram.map_address(virtual_addr, block_id)
# Update memory controller state
self.chip_state["memory_controller"]["active_requests"][block_id] = {
"type": "allocation",
"size": size,
"timestamp": time.time_ns()
}
self.store_chip_state()
return block_id
def transfer_to_device(self, data: bytes, virtual_addr: Optional[str] = None) -> str:
"""Transfer data to device through PCIe"""
# Simulate PCIe transfer
transfer_id = f"transfer_{time.time_ns()}"
self.chip_state["pcie_state"]["active_transfers"][transfer_id] = {
"direction": "to_device",
"size": len(data),
"timestamp": time.time_ns()
}
self.store_chip_state()
# Allocate and store in VRAM
block_id = self.allocate_memory(len(data), virtual_addr)
self.storage.store_tensor(block_id, data)
# Update transfer state
self.chip_state["pcie_state"]["active_transfers"][transfer_id]["completed"] = True
self.store_chip_state()
return block_id
def schedule_compute(self, sm_index: int, warp_state: Dict[str, Any]) -> str:
"""Schedule computation on an SM"""
if 0 <= sm_index < len(self.sms):
warp_id = f"warp_{time.time_ns()}"
self.sms[sm_index].schedule_warp(warp_id, warp_state)
# Update power state
self.chip_state["power_state"]["sm_power"][sm_index] += 10 # Simulate power increase
self.chip_state["power_state"]["total_watts"] = sum(self.chip_state["power_state"]["sm_power"])
self.store_chip_state()
return warp_id
raise ValueError(f"Invalid SM index: {sm_index}")
def get_stats(self) -> Dict[str, Any]:
"""Get comprehensive chip statistics"""
stats = {
"chip_id": self.chip_id,
"vram": self.vram.get_stats(),
"sms": [sm.get_stats() for sm in self.sms],
"pcie": {
"active_transfers": len(self.chip_state["pcie_state"]["active_transfers"]),
"bandwidth_usage": self.chip_state["pcie_state"]["bandwidth_usage"]
},
"power": {
"total_watts": self.chip_state["power_state"]["total_watts"],
"vram_watts": self.chip_state["power_state"]["vram_power"]
},
"memory_controller": {
"active_requests": len(self.chip_state["memory_controller"]["active_requests"]),
"bandwidth_usage": self.chip_state["memory_controller"]["bandwidth_usage"]
}
}
return stats