|
|
"""
|
|
|
Physics-inspired digital core model for virtual GPU v2.
|
|
|
Contains ThreadedCore class for massive parallel computation.
|
|
|
"""
|
|
|
|
|
|
from logic_gates import ControlUnit, ALU2Bit, RegisterFile2x2, SimpleMMU
|
|
|
import threading
|
|
|
from typing import List, Dict, Any
|
|
|
import numpy as np
|
|
|
from queue import Queue
|
|
|
import time
|
|
|
|
|
|
class ThreadState:
|
|
|
"""Represents the state of a single thread within a core"""
|
|
|
def __init__(self, thread_id: int, num_registers: int = 2, bits: int = 2):
|
|
|
self.thread_id = thread_id
|
|
|
self.regfile = RegisterFile2x2()
|
|
|
self.active = True
|
|
|
self.barrier_count = 0
|
|
|
self.result_queue = Queue()
|
|
|
|
|
|
class ThreadBlock:
|
|
|
"""Manages a group of threads that can be synchronized together"""
|
|
|
def __init__(self, block_id: int, num_threads: int = 32):
|
|
|
self.block_id = block_id
|
|
|
self.threads: List[ThreadState] = []
|
|
|
self.barrier = threading.Barrier(num_threads)
|
|
|
self.shared_memory = {}
|
|
|
|
|
|
def synchronize(self):
|
|
|
"""Synchronize all threads in the block"""
|
|
|
self.barrier.wait()
|
|
|
|
|
|
class ThreadedCore:
|
|
|
"""
|
|
|
Simulates a massively parallel core with:
|
|
|
- 700K hardware threads
|
|
|
- Shared control unit
|
|
|
- Thread-local register files
|
|
|
- Shared ALU with time-multiplexing
|
|
|
- Thread synchronization capabilities
|
|
|
"""
|
|
|
def __init__(self, num_threads: int = 700000, threads_per_block: int = 32, bits: int = 2, num_registers: int = 2):
|
|
|
self.control = ControlUnit()
|
|
|
self.alu = ALU2Bit()
|
|
|
self.mmu = SimpleMMU(num_registers=num_registers, bits=bits)
|
|
|
self.clk = 0.7
|
|
|
self.bits = bits
|
|
|
self.num_registers = num_registers
|
|
|
|
|
|
|
|
|
self.num_threads = num_threads
|
|
|
self.threads_per_block = threads_per_block
|
|
|
self.num_blocks = (num_threads + threads_per_block - 1) // threads_per_block
|
|
|
|
|
|
|
|
|
self.blocks: List[ThreadBlock] = []
|
|
|
self.thread_states: Dict[int, ThreadState] = {}
|
|
|
self._initialize_threads()
|
|
|
|
|
|
|
|
|
self.scheduler_lock = threading.Lock()
|
|
|
self.active_threads = set(range(num_threads))
|
|
|
self.thread_pool = []
|
|
|
|
|
|
def _initialize_threads(self):
|
|
|
"""Initialize thread blocks and states"""
|
|
|
for block_id in range(self.num_blocks):
|
|
|
block = ThreadBlock(block_id, self.threads_per_block)
|
|
|
threads_in_block = min(
|
|
|
self.threads_per_block,
|
|
|
self.num_threads - block_id * self.threads_per_block
|
|
|
)
|
|
|
|
|
|
for i in range(threads_in_block):
|
|
|
thread_id = block_id * self.threads_per_block + i
|
|
|
thread_state = ThreadState(thread_id, num_registers=self.num_registers, bits=self.bits)
|
|
|
block.threads.append(thread_state)
|
|
|
self.thread_states[thread_id] = thread_state
|
|
|
|
|
|
self.blocks.append(block)
|
|
|
|
|
|
def _execute_thread(self, thread_id: int, a, b, cin, opcode, reg_sel):
|
|
|
"""Execute operation for a single thread"""
|
|
|
thread_state = self.thread_states[thread_id]
|
|
|
if not thread_state.active:
|
|
|
return None
|
|
|
|
|
|
|
|
|
block_id = thread_id // self.threads_per_block
|
|
|
block = self.blocks[block_id]
|
|
|
|
|
|
|
|
|
with self.scheduler_lock:
|
|
|
|
|
|
self.control.set_opcode(opcode)
|
|
|
ctrl = self.control.get_control_signals()
|
|
|
|
|
|
|
|
|
(r0, r1), cout = self.alu.operate(a[0], a[1], b[0], b[1], cin, ctrl['alu_op'])
|
|
|
|
|
|
|
|
|
thread_state.regfile.write(r0, r1, self.clk, reg_sel)
|
|
|
|
|
|
|
|
|
result = {
|
|
|
'thread_id': thread_id,
|
|
|
'alu_result': (r0, r1),
|
|
|
'carry_out': cout,
|
|
|
'regfile_out': thread_state.regfile.read(reg_sel),
|
|
|
'control': ctrl
|
|
|
}
|
|
|
thread_state.result_queue.put(result)
|
|
|
return result
|
|
|
|
|
|
def execute_parallel(self, inputs: List[Dict[str, Any]]):
|
|
|
"""
|
|
|
Execute operations across all threads in parallel
|
|
|
inputs: List of operation inputs for each thread
|
|
|
"""
|
|
|
threads = []
|
|
|
results = []
|
|
|
|
|
|
|
|
|
for thread_id, inp in enumerate(inputs):
|
|
|
if thread_id >= self.num_threads:
|
|
|
break
|
|
|
|
|
|
thread = threading.Thread(
|
|
|
target=self._execute_thread,
|
|
|
args=(thread_id, inp['a'], inp['b'], inp['cin'], inp['opcode'], inp['reg_sel'])
|
|
|
)
|
|
|
threads.append(thread)
|
|
|
thread.start()
|
|
|
|
|
|
|
|
|
for thread in threads:
|
|
|
thread.join()
|
|
|
|
|
|
|
|
|
for thread_id in range(min(len(inputs), self.num_threads)):
|
|
|
if thread_id in self.thread_states:
|
|
|
try:
|
|
|
result = self.thread_states[thread_id].result_queue.get_nowait()
|
|
|
results.append(result)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
return results
|
|
|
|
|
|
def synchronize_block(self, block_id: int):
|
|
|
"""Synchronize all threads in a block"""
|
|
|
if 0 <= block_id < len(self.blocks):
|
|
|
self.blocks[block_id].synchronize()
|
|
|
|
|
|
def barrier_all_threads(self):
|
|
|
"""Global barrier synchronization across all threads"""
|
|
|
for block in self.blocks:
|
|
|
block.synchronize()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("\n--- Threaded Core Simulation ---")
|
|
|
core = ThreadedCore(num_threads=700000, threads_per_block=32)
|
|
|
|
|
|
|
|
|
inputs = [
|
|
|
{'a': [0.7, 0.0], 'b': [0.7, 0.7], 'cin': 0.0, 'opcode': 0b10, 'reg_sel': 0}
|
|
|
for _ in range(1000)
|
|
|
]
|
|
|
|
|
|
start_time = time.time()
|
|
|
results = core.execute_parallel(inputs)
|
|
|
end_time = time.time()
|
|
|
|
|
|
print(f"Executed {len(results)} thread operations")
|
|
|
print(f"First thread result: {results[0]}")
|
|
|
print(f"Execution time: {end_time - start_time:.4f} seconds")
|
|
|
|