INV / core.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Physics-inspired digital core model for virtual GPU v2.
Contains ThreadedCore class for massive parallel computation.
"""
from logic_gates import ControlUnit, ALU2Bit, RegisterFile2x2, SimpleMMU
import threading
from typing import List, Dict, Any
import numpy as np
from queue import Queue
import time
class ThreadState:
"""Represents the state of a single thread within a core"""
def __init__(self, thread_id: int, num_registers: int = 2, bits: int = 2):
self.thread_id = thread_id
self.regfile = RegisterFile2x2() # Each thread gets its own registers
self.active = True
self.barrier_count = 0
self.result_queue = Queue()
class ThreadBlock:
"""Manages a group of threads that can be synchronized together"""
def __init__(self, block_id: int, num_threads: int = 32):
self.block_id = block_id
self.threads: List[ThreadState] = []
self.barrier = threading.Barrier(num_threads)
self.shared_memory = {}
def synchronize(self):
"""Synchronize all threads in the block"""
self.barrier.wait()
class ThreadedCore:
"""
Simulates a massively parallel core with:
- 700K hardware threads
- Shared control unit
- Thread-local register files
- Shared ALU with time-multiplexing
- Thread synchronization capabilities
"""
def __init__(self, num_threads: int = 700000, threads_per_block: int = 32, bits: int = 2, num_registers: int = 2):
self.control = ControlUnit()
self.alu = ALU2Bit() # Shared ALU
self.mmu = SimpleMMU(num_registers=num_registers, bits=bits)
self.clk = 0.7 # High voltage for clock
self.bits = bits
self.num_registers = num_registers # Store num_registers as instance variable
# Thread management
self.num_threads = num_threads
self.threads_per_block = threads_per_block
self.num_blocks = (num_threads + threads_per_block - 1) // threads_per_block
# Initialize thread blocks and states
self.blocks: List[ThreadBlock] = []
self.thread_states: Dict[int, ThreadState] = {}
self._initialize_threads()
# Thread scheduling
self.scheduler_lock = threading.Lock()
self.active_threads = set(range(num_threads))
self.thread_pool = [] # Will hold thread objects
def _initialize_threads(self):
"""Initialize thread blocks and states"""
for block_id in range(self.num_blocks):
block = ThreadBlock(block_id, self.threads_per_block)
threads_in_block = min(
self.threads_per_block,
self.num_threads - block_id * self.threads_per_block
)
for i in range(threads_in_block):
thread_id = block_id * self.threads_per_block + i
thread_state = ThreadState(thread_id, num_registers=self.num_registers, bits=self.bits)
block.threads.append(thread_state)
self.thread_states[thread_id] = thread_state
self.blocks.append(block)
def _execute_thread(self, thread_id: int, a, b, cin, opcode, reg_sel):
"""Execute operation for a single thread"""
thread_state = self.thread_states[thread_id]
if not thread_state.active:
return None
# Get block for this thread
block_id = thread_id // self.threads_per_block
block = self.blocks[block_id]
# Acquire scheduler lock for ALU access
with self.scheduler_lock:
# Set control signals
self.control.set_opcode(opcode)
ctrl = self.control.get_control_signals()
# ALU operation (shared resource)
(r0, r1), cout = self.alu.operate(a[0], a[1], b[0], b[1], cin, ctrl['alu_op'])
# Write to thread-local register file
thread_state.regfile.write(r0, r1, self.clk, reg_sel)
# Store result in thread's queue
result = {
'thread_id': thread_id,
'alu_result': (r0, r1),
'carry_out': cout,
'regfile_out': thread_state.regfile.read(reg_sel),
'control': ctrl
}
thread_state.result_queue.put(result)
return result
def execute_parallel(self, inputs: List[Dict[str, Any]]):
"""
Execute operations across all threads in parallel
inputs: List of operation inputs for each thread
"""
threads = []
results = []
# Create and start threads
for thread_id, inp in enumerate(inputs):
if thread_id >= self.num_threads:
break
thread = threading.Thread(
target=self._execute_thread,
args=(thread_id, inp['a'], inp['b'], inp['cin'], inp['opcode'], inp['reg_sel'])
)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join()
# Collect results
for thread_id in range(min(len(inputs), self.num_threads)):
if thread_id in self.thread_states:
try:
result = self.thread_states[thread_id].result_queue.get_nowait()
results.append(result)
except Exception:
pass
return results
def synchronize_block(self, block_id: int):
"""Synchronize all threads in a block"""
if 0 <= block_id < len(self.blocks):
self.blocks[block_id].synchronize()
def barrier_all_threads(self):
"""Global barrier synchronization across all threads"""
for block in self.blocks:
block.synchronize()
if __name__ == "__main__":
print("\n--- Threaded Core Simulation ---")
core = ThreadedCore(num_threads=700000, threads_per_block=32)
# Example: Execute same operation across many threads
inputs = [
{'a': [0.7, 0.0], 'b': [0.7, 0.7], 'cin': 0.0, 'opcode': 0b10, 'reg_sel': 0}
for _ in range(1000) # Test with 1000 threads
]
start_time = time.time()
results = core.execute_parallel(inputs)
end_time = time.time()
print(f"Executed {len(results)} thread operations")
print(f"First thread result: {results[0]}")
print(f"Execution time: {end_time - start_time:.4f} seconds")