|
|
from http_storage import LocalStorage
|
|
|
from virtual_vram import VirtualVRAM
|
|
|
from streaming_multiprocessor import StreamingMultiprocessor
|
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
|
import time
|
|
|
from config import get_db_url
|
|
|
|
|
|
class OpticalInterconnect:
|
|
|
def __init__(self, bandwidth_tbps=800, latency_ns=1):
|
|
|
self.bandwidth_tbps = bandwidth_tbps
|
|
|
self.latency_ns = latency_ns
|
|
|
|
|
|
def transfer_time(self, data_size_bytes: int) -> float:
|
|
|
"""Calculate data transfer time in seconds"""
|
|
|
bandwidth_bytes_per_s = self.bandwidth_tbps * 1e12
|
|
|
return self.latency_ns * 1e-9 + (data_size_bytes / bandwidth_bytes_per_s)
|
|
|
|
|
|
class GPUChip:
|
|
|
def __init__(self, chip_id: int, num_sms: int = 108, vram_gb: int = 214, storage=None):
|
|
|
self.chip_id = chip_id
|
|
|
self.storage = storage
|
|
|
if self.storage is None:
|
|
|
from http_storage import LocalStorage
|
|
|
self.storage = LocalStorage(db_url=get_db_url())
|
|
|
if not self.storage.is_connected():
|
|
|
raise RuntimeError("Could not connect to local storage")
|
|
|
|
|
|
|
|
|
self.vram = VirtualVRAM(vram_gb, storage=self.storage)
|
|
|
self.sms = [StreamingMultiprocessor(i, storage=self.storage) for i in range(num_sms)]
|
|
|
|
|
|
|
|
|
self.chip_state = {
|
|
|
"chip_id": chip_id,
|
|
|
"num_sms": num_sms,
|
|
|
"vram_gb": vram_gb,
|
|
|
"pcie_state": {
|
|
|
"active_transfers": {},
|
|
|
"bandwidth_usage": 0
|
|
|
},
|
|
|
"power_state": {
|
|
|
"total_watts": 0,
|
|
|
"sm_power": [0] * num_sms,
|
|
|
"vram_power": 0
|
|
|
},
|
|
|
"memory_controller": {
|
|
|
"active_requests": {},
|
|
|
"bandwidth_usage": 0
|
|
|
}
|
|
|
}
|
|
|
self.store_chip_state()
|
|
|
|
|
|
def store_chip_state(self):
|
|
|
"""Store chip state in local storage"""
|
|
|
self.storage.store_state(f"chip_{self.chip_id}", "state", self.chip_state)
|
|
|
|
|
|
def connect_chip(self, other_chip: 'GPUChip', interconnect: OpticalInterconnect) -> None:
|
|
|
"""Connect to another GPU chip via optical interconnect"""
|
|
|
if not hasattr(self, 'connected_chips'):
|
|
|
self.connected_chips = []
|
|
|
self.chip_state['connected_chips'] = {}
|
|
|
|
|
|
self.connected_chips.append((other_chip, interconnect))
|
|
|
self.chip_state['connected_chips'][other_chip.chip_id] = {
|
|
|
'bandwidth_tbps': interconnect.bandwidth_tbps,
|
|
|
'latency_ns': interconnect.latency_ns,
|
|
|
'active': True
|
|
|
}
|
|
|
self.store_chip_state()
|
|
|
|
|
|
def transfer_data(self, target_chip: 'GPUChip', data_size: int) -> float:
|
|
|
"""Transfer data to another chip, returns transfer time in seconds"""
|
|
|
for chip, interconnect in self.connected_chips:
|
|
|
if chip.chip_id == target_chip.chip_id:
|
|
|
transfer_time = interconnect.transfer_time(data_size)
|
|
|
self.chip_state['pcie_state']['active_transfers'][str(time.time())] = {
|
|
|
'target_chip': target_chip.chip_id,
|
|
|
'size': data_size,
|
|
|
'estimated_time': transfer_time
|
|
|
}
|
|
|
self.store_chip_state()
|
|
|
return transfer_time
|
|
|
raise ValueError(f"No connection found to chip {target_chip.chip_id}")
|
|
|
|
|
|
def allocate_memory(self, size: int, virtual_addr: Optional[str] = None) -> str:
|
|
|
"""Allocate memory through VRAM"""
|
|
|
block_id = self.vram.allocate_block(size)
|
|
|
if virtual_addr:
|
|
|
self.vram.map_address(virtual_addr, block_id)
|
|
|
|
|
|
|
|
|
self.chip_state["memory_controller"]["active_requests"][block_id] = {
|
|
|
"type": "allocation",
|
|
|
"size": size,
|
|
|
"timestamp": time.time_ns()
|
|
|
}
|
|
|
self.store_chip_state()
|
|
|
|
|
|
return block_id
|
|
|
|
|
|
def transfer_to_device(self, data: bytes, virtual_addr: Optional[str] = None) -> str:
|
|
|
"""Transfer data to device through PCIe"""
|
|
|
|
|
|
transfer_id = f"transfer_{time.time_ns()}"
|
|
|
self.chip_state["pcie_state"]["active_transfers"][transfer_id] = {
|
|
|
"direction": "to_device",
|
|
|
"size": len(data),
|
|
|
"timestamp": time.time_ns()
|
|
|
}
|
|
|
self.store_chip_state()
|
|
|
|
|
|
|
|
|
block_id = self.allocate_memory(len(data), virtual_addr)
|
|
|
self.storage.store_tensor(block_id, data)
|
|
|
|
|
|
|
|
|
self.chip_state["pcie_state"]["active_transfers"][transfer_id]["completed"] = True
|
|
|
self.store_chip_state()
|
|
|
|
|
|
return block_id
|
|
|
|
|
|
def schedule_compute(self, sm_index: int, warp_state: Dict[str, Any]) -> str:
|
|
|
"""Schedule computation on an SM"""
|
|
|
if 0 <= sm_index < len(self.sms):
|
|
|
warp_id = f"warp_{time.time_ns()}"
|
|
|
self.sms[sm_index].schedule_warp(warp_id, warp_state)
|
|
|
|
|
|
|
|
|
self.chip_state["power_state"]["sm_power"][sm_index] += 10
|
|
|
self.chip_state["power_state"]["total_watts"] = sum(self.chip_state["power_state"]["sm_power"])
|
|
|
self.store_chip_state()
|
|
|
|
|
|
return warp_id
|
|
|
raise ValueError(f"Invalid SM index: {sm_index}")
|
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
|
"""Get comprehensive chip statistics"""
|
|
|
stats = {
|
|
|
"chip_id": self.chip_id,
|
|
|
"vram": self.vram.get_stats(),
|
|
|
"sms": [sm.get_stats() for sm in self.sms],
|
|
|
"pcie": {
|
|
|
"active_transfers": len(self.chip_state["pcie_state"]["active_transfers"]),
|
|
|
"bandwidth_usage": self.chip_state["pcie_state"]["bandwidth_usage"]
|
|
|
},
|
|
|
"power": {
|
|
|
"total_watts": self.chip_state["power_state"]["total_watts"],
|
|
|
"vram_watts": self.chip_state["power_state"]["vram_power"]
|
|
|
},
|
|
|
"memory_controller": {
|
|
|
"active_requests": len(self.chip_state["memory_controller"]["active_requests"]),
|
|
|
"bandwidth_usage": self.chip_state["memory_controller"]["bandwidth_usage"]
|
|
|
}
|
|
|
}
|
|
|
return stats
|
|
|
|