""" Register file implementation for virtual GPU memory subsystem """ from typing import Dict, Any, List, Optional import numpy as np import threading class RegisterFile: def __init__(self, num_registers: int = 32, register_width: int = 32): self.num_registers = num_registers self.register_width = register_width self.registers = np.zeros(num_registers, dtype=np.int32) self.register_lock = threading.Lock() self.usage_stats = { "reads": 0, "writes": 0, "conflicts": 0 } def read_register(self, reg_id: int) -> int: """Read value from specified register""" if not 0 <= reg_id < self.num_registers: raise ValueError(f"Invalid register ID: {reg_id}") with self.register_lock: self.usage_stats["reads"] += 1 return int(self.registers[reg_id]) def write_register(self, reg_id: int, value: int) -> None: """Write value to specified register""" if not 0 <= reg_id < self.num_registers: raise ValueError(f"Invalid register ID: {reg_id}") # Ensure value fits in register width max_value = (1 << self.register_width) - 1 value = value & max_value with self.register_lock: self.usage_stats["writes"] += 1 self.registers[reg_id] = value def bulk_read(self, reg_ids: List[int]) -> List[int]: """Read multiple registers at once""" if not all(0 <= rid < self.num_registers for rid in reg_ids): raise ValueError("Invalid register ID in bulk read") with self.register_lock: self.usage_stats["reads"] += len(reg_ids) return [int(self.registers[rid]) for rid in reg_ids] def bulk_write(self, reg_data: Dict[int, int]) -> None: """Write to multiple registers at once""" if not all(0 <= rid < self.num_registers for rid in reg_data.keys()): raise ValueError("Invalid register ID in bulk write") max_value = (1 << self.register_width) - 1 values = {rid: val & max_value for rid, val in reg_data.items()} with self.register_lock: self.usage_stats["writes"] += len(values) for rid, val in values.items(): self.registers[rid] = val def clear_registers(self) -> None: """Reset all registers to zero""" with self.register_lock: self.registers.fill(0) def get_usage_stats(self) -> Dict[str, int]: """Get register file usage statistics""" with self.register_lock: return dict(self.usage_stats) def check_conflicts(self, read_regs: List[int], write_regs: List[int]) -> bool: """Check for read/write conflicts between register sets""" read_set = set(read_regs) write_set = set(write_regs) # Conflict if any register is both read and written conflicts = read_set.intersection(write_set) if conflicts: self.usage_stats["conflicts"] += 1 return True return False