Spaces:
Runtime error
Runtime error
File size: 2,355 Bytes
2ff82ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
"""
Test for hyperrealistic multi-chip GPU system with full SM and tensor core realism.
"""
import time
from gpu_arch import Chip, OpticalInterconnect
def test_multi_chip_gpu():
print("\n=== Multi-Chip GPU System Full Test ===")
num_chips = 2 # Use 2 for realism, scale up as needed
num_sms = 4 # Use 4 for realism, scale up as needed
chips = [Chip(
chip_id=i,
num_sms=num_sms
) for i in range(num_chips)]
print(f"Created {num_chips} chips, each with {num_sms} SMs.")
# Connect chips in a ring topology
optical_link = OpticalInterconnect(bandwidth_tbps=800, latency_ns=1)
for i in range(num_chips):
chips[i].connect_chip(chips[(i+1)%num_chips], optical_link)
# Run tensor core matmul from all SMs on all chips
for chip in chips:
print(f"\n--- Chip {chip.chip_id} ---")
for sm in chip.sms:
# Fill registers, shared, and global memory for realism
for i in range(len(sm.register_file)):
for j in range(len(sm.register_file[0])):
sm.register_file[i][j] = float(i + j)
for addr in range(sm.shared_mem.size):
sm.shared_mem.write(addr, float(addr % 10))
for addr in range(sm.global_mem.size_bytes if sm.global_mem else 0):
sm.global_mem.write(addr, float(addr % 100))
# Test tensor core matmul from registers
reg_result = sm.tensor_core_matmul_from_memory('register', 0, 'register', 0, (2,2), (2,2))
print(f"SM {sm.sm_id} tensor core matmul from registers: {reg_result}")
# Test tensor core matmul from shared memory
shared_result = sm.tensor_core_matmul_from_memory('shared', 0, 'shared', 0, (2,2), (2,2))
print(f"SM {sm.sm_id} tensor core matmul from shared memory: {shared_result}")
# Test tensor core matmul from global memory
global_result = sm.tensor_core_matmul_from_memory('global', 0, 'global', 0, (2,2), (2,2))
print(f"SM {sm.sm_id} tensor core matmul from global memory: {global_result}")
print("\n=== Multi-Chip GPU System Test Complete ===")
if __name__ == "__main__":
start = time.time()
test_multi_chip_gpu()
print(f"Test runtime: {time.time()-start:.3f} seconds")
|