|
|
"""
|
|
|
Register file implementation for virtual GPU memory subsystem
|
|
|
"""
|
|
|
from typing import Dict, Any, List, Optional
|
|
|
import numpy as np
|
|
|
import threading
|
|
|
|
|
|
class RegisterFile:
|
|
|
def __init__(self, num_registers: int = 32, register_width: int = 32):
|
|
|
self.num_registers = num_registers
|
|
|
self.register_width = register_width
|
|
|
self.registers = np.zeros(num_registers, dtype=np.int32)
|
|
|
self.register_lock = threading.Lock()
|
|
|
self.usage_stats = {
|
|
|
"reads": 0,
|
|
|
"writes": 0,
|
|
|
"conflicts": 0
|
|
|
}
|
|
|
|
|
|
def read_register(self, reg_id: int) -> int:
|
|
|
"""Read value from specified register"""
|
|
|
if not 0 <= reg_id < self.num_registers:
|
|
|
raise ValueError(f"Invalid register ID: {reg_id}")
|
|
|
|
|
|
with self.register_lock:
|
|
|
self.usage_stats["reads"] += 1
|
|
|
return int(self.registers[reg_id])
|
|
|
|
|
|
def write_register(self, reg_id: int, value: int) -> None:
|
|
|
"""Write value to specified register"""
|
|
|
if not 0 <= reg_id < self.num_registers:
|
|
|
raise ValueError(f"Invalid register ID: {reg_id}")
|
|
|
|
|
|
|
|
|
max_value = (1 << self.register_width) - 1
|
|
|
value = value & max_value
|
|
|
|
|
|
with self.register_lock:
|
|
|
self.usage_stats["writes"] += 1
|
|
|
self.registers[reg_id] = value
|
|
|
|
|
|
def bulk_read(self, reg_ids: List[int]) -> List[int]:
|
|
|
"""Read multiple registers at once"""
|
|
|
if not all(0 <= rid < self.num_registers for rid in reg_ids):
|
|
|
raise ValueError("Invalid register ID in bulk read")
|
|
|
|
|
|
with self.register_lock:
|
|
|
self.usage_stats["reads"] += len(reg_ids)
|
|
|
return [int(self.registers[rid]) for rid in reg_ids]
|
|
|
|
|
|
def bulk_write(self, reg_data: Dict[int, int]) -> None:
|
|
|
"""Write to multiple registers at once"""
|
|
|
if not all(0 <= rid < self.num_registers for rid in reg_data.keys()):
|
|
|
raise ValueError("Invalid register ID in bulk write")
|
|
|
|
|
|
max_value = (1 << self.register_width) - 1
|
|
|
values = {rid: val & max_value for rid, val in reg_data.items()}
|
|
|
|
|
|
with self.register_lock:
|
|
|
self.usage_stats["writes"] += len(values)
|
|
|
for rid, val in values.items():
|
|
|
self.registers[rid] = val
|
|
|
|
|
|
def clear_registers(self) -> None:
|
|
|
"""Reset all registers to zero"""
|
|
|
with self.register_lock:
|
|
|
self.registers.fill(0)
|
|
|
|
|
|
def get_usage_stats(self) -> Dict[str, int]:
|
|
|
"""Get register file usage statistics"""
|
|
|
with self.register_lock:
|
|
|
return dict(self.usage_stats)
|
|
|
|
|
|
def check_conflicts(self, read_regs: List[int], write_regs: List[int]) -> bool:
|
|
|
"""Check for read/write conflicts between register sets"""
|
|
|
read_set = set(read_regs)
|
|
|
write_set = set(write_regs)
|
|
|
|
|
|
|
|
|
conflicts = read_set.intersection(write_set)
|
|
|
if conflicts:
|
|
|
self.usage_stats["conflicts"] += 1
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|