INV / virtual_gpu_driver /src /memory /register_file.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Register file implementation for virtual GPU memory subsystem
"""
from typing import Dict, Any, List, Optional
import numpy as np
import threading
class RegisterFile:
def __init__(self, num_registers: int = 32, register_width: int = 32):
self.num_registers = num_registers
self.register_width = register_width
self.registers = np.zeros(num_registers, dtype=np.int32)
self.register_lock = threading.Lock()
self.usage_stats = {
"reads": 0,
"writes": 0,
"conflicts": 0
}
def read_register(self, reg_id: int) -> int:
"""Read value from specified register"""
if not 0 <= reg_id < self.num_registers:
raise ValueError(f"Invalid register ID: {reg_id}")
with self.register_lock:
self.usage_stats["reads"] += 1
return int(self.registers[reg_id])
def write_register(self, reg_id: int, value: int) -> None:
"""Write value to specified register"""
if not 0 <= reg_id < self.num_registers:
raise ValueError(f"Invalid register ID: {reg_id}")
# Ensure value fits in register width
max_value = (1 << self.register_width) - 1
value = value & max_value
with self.register_lock:
self.usage_stats["writes"] += 1
self.registers[reg_id] = value
def bulk_read(self, reg_ids: List[int]) -> List[int]:
"""Read multiple registers at once"""
if not all(0 <= rid < self.num_registers for rid in reg_ids):
raise ValueError("Invalid register ID in bulk read")
with self.register_lock:
self.usage_stats["reads"] += len(reg_ids)
return [int(self.registers[rid]) for rid in reg_ids]
def bulk_write(self, reg_data: Dict[int, int]) -> None:
"""Write to multiple registers at once"""
if not all(0 <= rid < self.num_registers for rid in reg_data.keys()):
raise ValueError("Invalid register ID in bulk write")
max_value = (1 << self.register_width) - 1
values = {rid: val & max_value for rid, val in reg_data.items()}
with self.register_lock:
self.usage_stats["writes"] += len(values)
for rid, val in values.items():
self.registers[rid] = val
def clear_registers(self) -> None:
"""Reset all registers to zero"""
with self.register_lock:
self.registers.fill(0)
def get_usage_stats(self) -> Dict[str, int]:
"""Get register file usage statistics"""
with self.register_lock:
return dict(self.usage_stats)
def check_conflicts(self, read_regs: List[int], write_regs: List[int]) -> bool:
"""Check for read/write conflicts between register sets"""
read_set = set(read_regs)
write_set = set(write_regs)
# Conflict if any register is both read and written
conflicts = read_set.intersection(write_set)
if conflicts:
self.usage_stats["conflicts"] += 1
return True
return False