File size: 3,356 Bytes
7a0c684 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
"""
Register file implementation for virtual GPU memory subsystem
"""
from typing import Dict, Any, List, Optional
import numpy as np
import threading
class RegisterFile:
def __init__(self, num_registers: int = 32, register_width: int = 32):
self.num_registers = num_registers
self.register_width = register_width
self.registers = np.zeros(num_registers, dtype=np.int32)
self.register_lock = threading.Lock()
self.usage_stats = {
"reads": 0,
"writes": 0,
"conflicts": 0
}
def read_register(self, reg_id: int) -> int:
"""Read value from specified register"""
if not 0 <= reg_id < self.num_registers:
raise ValueError(f"Invalid register ID: {reg_id}")
with self.register_lock:
self.usage_stats["reads"] += 1
return int(self.registers[reg_id])
def write_register(self, reg_id: int, value: int) -> None:
"""Write value to specified register"""
if not 0 <= reg_id < self.num_registers:
raise ValueError(f"Invalid register ID: {reg_id}")
# Ensure value fits in register width
max_value = (1 << self.register_width) - 1
value = value & max_value
with self.register_lock:
self.usage_stats["writes"] += 1
self.registers[reg_id] = value
def bulk_read(self, reg_ids: List[int]) -> List[int]:
"""Read multiple registers at once"""
if not all(0 <= rid < self.num_registers for rid in reg_ids):
raise ValueError("Invalid register ID in bulk read")
with self.register_lock:
self.usage_stats["reads"] += len(reg_ids)
return [int(self.registers[rid]) for rid in reg_ids]
def bulk_write(self, reg_data: Dict[int, int]) -> None:
"""Write to multiple registers at once"""
if not all(0 <= rid < self.num_registers for rid in reg_data.keys()):
raise ValueError("Invalid register ID in bulk write")
max_value = (1 << self.register_width) - 1
values = {rid: val & max_value for rid, val in reg_data.items()}
with self.register_lock:
self.usage_stats["writes"] += len(values)
for rid, val in values.items():
self.registers[rid] = val
def clear_registers(self) -> None:
"""Reset all registers to zero"""
with self.register_lock:
self.registers.fill(0)
def get_usage_stats(self) -> Dict[str, int]:
"""Get register file usage statistics"""
with self.register_lock:
return dict(self.usage_stats)
def check_conflicts(self, read_regs: List[int], write_regs: List[int]) -> bool:
"""Check for read/write conflicts between register sets"""
read_set = set(read_regs)
write_set = set(write_regs)
# Conflict if any register is both read and written
conflicts = read_set.intersection(write_set)
if conflicts:
self.usage_stats["conflicts"] += 1
return True
return False
|