INFER / gpu_arch.py
Factor Studios
Upload 27 files
2ff82ee verified
from multicore import MultiCoreSystem
from vram.ram_controller import RAMController
import os
from gpu_state_db import GPUStateDB
from custom_vram import CustomVRAM
from ai import AIAccelerator
class TensorCoreDB:
def __init__(self, tensor_core_id, sm_id, db):
self.tensor_core_id = tensor_core_id
self.sm_id = sm_id
self.db = db
def load_state(self):
state = self.db.load_state("tensor_core", "tensor_core_id", self.tensor_core_id)
return state or {}
def save_state(self, state):
self.db.save_state("tensor_core", "tensor_core_id", self.tensor_core_id, state)
def matmul(self, A, B):
state = self.load_state()
# Simulate a matrix multiply (for demo, just sum all elements)
result = sum(sum(row) for row in A) * sum(sum(row) for row in B)
state["last_result"] = result
self.save_state(state)
return result
class OpticalInterconnect:
def __init__(self, bandwidth_tbps=800, latency_ns=1):
self.bandwidth_tbps = bandwidth_tbps # TB/s
self.latency_ns = latency_ns # nanoseconds
def transfer_time(self, data_size_bytes):
# Time = latency + (data_size / bandwidth)
bandwidth_bytes_per_s = self.bandwidth_tbps * 1e12
transfer_time_s = self.latency_ns * 1e-9 + (data_size_bytes / bandwidth_bytes_per_s)
return transfer_time_s
class Thread:
def __init__(self, thread_id, core):
self.thread_id = thread_id
self.core = core
self.active = True
self.result = None
def run(self, a, b, cin, opcode, reg_sel):
if self.active:
self.result = self.core.step(a, b, cin, opcode, reg_sel)
return self.result
class Warp:
def __init__(self, warp_id, threads):
self.warp_id = warp_id
self.threads = threads # List of Thread objects
self.active = True
def run(self, a, b, cin, opcode, reg_sel):
# All threads in a warp execute in lockstep (SIMT)
return [thread.run(a, b, cin, opcode, reg_sel) for thread in self.threads if thread.active]
class WarpScheduler:
def __init__(self, warps):
self.warps = warps # List of Warp objects
self.schedule_ptr = 0
def schedule(self):
# Simple round-robin scheduler
if not self.warps:
return None
warp = self.warps[self.schedule_ptr]
self.schedule_ptr = (self.schedule_ptr + 1) % len(self.warps)
return warp
class SharedMemory:
def __init__(self, size):
self.size = size
self.mem = [0] * size
def read(self, addr):
return self.mem[addr % self.size]
def write(self, addr, value):
self.mem[addr % self.size] = value
def read_matrix(self, addr, n, m):
# Simulate reading an n x m matrix from shared memory
# For simplicity, treat addr as row offset
return [
[self.mem[(addr + i * m + j) % self.size] for j in range(m)]
for i in range(n)
]
class L1Cache:
def __init__(self, size):
self.size = size
self.cache = [None] * size
def read(self, addr):
return self.cache[addr % self.size]
def write(self, addr, value):
self.cache[addr % self.size] = value
# GlobalMemory now uses RAMController and persists to .db
class GlobalMemory:
def __init__(self, size_bytes, db_path=None):
if db_path is None:
import uuid
db_path = os.path.join(os.path.dirname(__file__), f"global_mem_{uuid.uuid4().hex}.db")
self.size_bytes = size_bytes
self.ram = RAMController(size_bytes, db_path=db_path)
self.allocated_address = 0 # Simple allocation pointer
def read(self, addr, length=1):
data = self.ram.read(addr, length)
# Return as int for compatibility (simulate voltage)
if length == 1:
return int(data[0]) if data else 0
return [int(b) for b in data]
def write(self, addr, value):
# Accepts int, float, or list/bytes
if isinstance(value, (int, float)):
data = bytes([int(value) & 0xFF])
elif isinstance(value, (bytes, bytearray)):
data = value
elif isinstance(value, list):
# Convert list of integers to bytes, assuming each integer is a byte value (0-255)
data = bytes(value)
else:
raise TypeError("Unsupported value type for write")
self.ram.write(addr, data)
def read_matrix(self, addr, n, m):
# Read n*m bytes and reshape
data = self.ram.read(addr, n * m)
return [list(data[i*m:(i+1)*m]) for i in range(n)]
def allocate_space(self, size_bytes: int) -> int:
"""Simulates allocating space in global memory."""
if self.allocated_address + size_bytes > self.size_bytes:
raise MemoryError("Out of global memory space")
allocated_addr = self.allocated_address
self.allocated_address += size_bytes
return allocated_addr
# StreamingMultiprocessor now only loads state from DB as needed
class StreamingMultiprocessor:
def __init__(self, sm_id, chip_id, db: GPUStateDB, num_cores_per_sm=128, warps_per_sm=164, threads_per_warp=700, num_tensor_cores=8):
self.sm_id = sm_id
self.chip_id = chip_id
self.db = db
self.num_cores_per_sm = num_cores_per_sm
self.warps_per_sm = warps_per_sm
self.threads_per_warp = threads_per_warp
self.num_tensor_cores = num_tensor_cores
self.global_mem = None # Will be set by GPUMemoryHierarchy
def load_state(self):
state = self.db.load_state("sm", "sm_id", self.sm_id)
return state or {}
def save_state(self, state):
self.db.save_state("sm", "sm_id", self.sm_id, state)
def attach_global_mem(self, global_mem):
self.global_mem = global_mem
def get_core(self, core_id):
return Core(core_id, self.sm_id, self.db)
def get_warp(self, warp_id):
return WarpDB(warp_id, self.sm_id, self.db)
def get_tensor_core(self, tensor_core_id):
return TensorCoreDB(tensor_core_id, self.sm_id, self.db)
def run_next_warp(self, a, b, cin, opcode, reg_sel):
# Example: load warp 0, run, save
warp = self.get_warp(0)
result = warp.run(a, b, cin, opcode, reg_sel)
return result
def tensor_core_matmul(self, A, B, tensor_core_id=0):
tensor_core = self.get_tensor_core(tensor_core_id)
return tensor_core.matmul(A, B)
class Core:
def __init__(self, core_id, sm_id, db: GPUStateDB):
self.core_id = core_id
self.sm_id = sm_id
self.db = db
def load_state(self):
state = self.db.load_state("core", "core_id", self.core_id)
return state or {}
def save_state(self, state):
self.db.save_state("core", "core_id", self.core_id, state)
def step(self, a, b, cin, opcode, reg_sel):
state = self.load_state()
# Simulate a simple operation
state["last_result"] = (a[0] + b[0] + cin) if opcode == 0b10 else 0.0
self.save_state(state)
return state["last_result"]
class WarpDB:
def __init__(self, warp_id, sm_id, db: GPUStateDB, threads_per_warp=700):
self.warp_id = warp_id
self.sm_id = sm_id
self.db = db
self.threads_per_warp = threads_per_warp
def load_state(self):
state = self.db.load_state("warp", "warp_id", self.warp_id)
return state or {}
def save_state(self, state):
self.db.save_state("warp", "warp_id", self.warp_id, state)
def get_thread(self, thread_id):
return ThreadDB(thread_id, self.warp_id, self.db)
def run(self, a, b, cin, opcode, reg_sel):
# For demo, run only first thread
thread = self.get_thread(0)
result = thread.run(a, b, cin, opcode, reg_sel)
return [result]
class ThreadDB:
def __init__(self, thread_id, warp_id, db: GPUStateDB):
self.thread_id = thread_id
self.warp_id = warp_id
self.db = db
def load_state(self):
state = self.db.load_state("thread", "thread_id", self.thread_id)
return state or {}
def save_state(self, state):
self.db.save_state("thread", "thread_id", self.thread_id, state)
def run(self, a, b, cin, opcode, reg_sel):
state = self.load_state()
# Simulate a simple operation
state["result"] = (a[0] + b[0] + cin) if opcode == 0b10 else 0.0
self.save_state(state)
return state["result"]
def attach_global_mem(self, global_mem):
self.global_mem = global_mem
def run_next_warp(self, a, b, cin, opcode, reg_sel):
warp = self.scheduler.schedule()
if warp:
return warp.run(a, b, cin, opcode, reg_sel)
return None
def tensor_core_matmul(self, A, B):
return self.tensor_cores.matmul(A, B)
def tensor_core_matmul_from_memory(self, srcA, addrA, srcB, addrB, shapeA, shapeB):
return self.tensor_cores.matmul_from_memory(srcA, addrA, srcB, addrB, shapeA, shapeB)
def read_register_matrix(self, addr, n, m):
# Simulate reading an n x m matrix from registers
# For simplicity, treat addr as row offset
return [
[self.register_file[(addr + i) % len(self.register_file)][(j) % len(self.register_file[0])] for j in range(m)]
for i in range(n)
]
class GPUMemoryHierarchy:
def __init__(self, num_sms, global_mem_size_bytes, chip_id, db: GPUStateDB):
self.global_mem = GlobalMemory(global_mem_size_bytes)
self.sm_ids = list(range(num_sms))
self.chip_id = chip_id
self.db = db
self.num_sms = num_sms
def add_sm(self, sm):
sm.attach_global_mem(self.global_mem)
def read_global(self, addr):
return self.global_mem.read(addr)
def write_global(self, addr, value):
self.global_mem.write(addr, value)
class Chip:
def __init__(self, chip_id, num_sms=1500, vram_size_gb=16, db_path="gpu_state.db"):
self.chip_id = chip_id
self.db = GPUStateDB(db_path)
global_mem_size_bytes = vram_size_gb * 1024 * 1024 * 1024
self.gpu_mem = GPUMemoryHierarchy(num_sms=num_sms, global_mem_size_bytes=global_mem_size_bytes, chip_id=chip_id, db=self.db)
self.sm_ids = list(range(num_sms))
self.connected_chips = []
self.ai_accelerator = AIAccelerator() # Instantiate AIAccelerator
self.custom_vram = CustomVRAM(self.gpu_mem.global_mem) # Create CustomVRAM instance
self.ai_accelerator.set_vram(self.custom_vram) # Set VRAM for AIAccelerator
def get_sm(self, sm_id):
return StreamingMultiprocessor(sm_id, self.chip_id, self.db)
def connect_chip(self, other_chip, interconnect):
self.connected_chips.append((other_chip, interconnect))
def close(self):
if hasattr(self, "db") and self.db:
self.db.close()
if hasattr(self, "gpu_mem") and hasattr(self.gpu_mem, "global_mem") and hasattr(self.gpu_mem.global_mem, "ram"):
self.gpu_mem.global_mem.ram.close()
if __name__ == "__main__":
print("\n--- Multi-Chip GPU Simulation (DB-backed) ---")
num_chips = 10
vram_size_gb = 16
chips = [Chip(
chip_id=i,
num_sms=100,
vram_size_gb=vram_size_gb,
db_path=f"gpu_state_chip_{i}.db"
) for i in range(num_chips)]
print(f"Total chips: {len(chips)}")
optical_link = OpticalInterconnect(bandwidth_tbps=800, latency_ns=1)
for i in range(num_chips):
chips[i].connect_chip(chips[(i+1)%num_chips], optical_link)
for chip in chips:
sm = chip.get_sm(0)
results = sm.run_next_warp([0.7, 0.0], [0.7, 0.7], 0.0, 0b10, 0)
print(f"Chip {chip.chip_id} SM 0 first thread result: {results[0] if results else None}")
# Example tensor core usage: matrix multiply on SM 0, tensor core 0
A = [[1.0, 2.0], [3.0, 4.0]]
B = [[5.0, 6.0], [7.0, 8.0]]
tc_result = sm.tensor_core_matmul(A, B, tensor_core_id=0)
print(f"Chip {chip.chip_id} SM 0 tensor core 0 matmul result: {tc_result}")
print(f"Total SMs in first chip: {len(chips[0].sm_ids)}")
print(f"Global memory size in first chip: {chips[0].gpu_mem.global_mem.size_bytes} bytes (backed by .db)")
chips[0].send_data(chips[1], optical_link, 1024*1024*1024*10)