File size: 3,356 Bytes
7a0c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""

Register file implementation for virtual GPU memory subsystem

"""
from typing import Dict, Any, List, Optional
import numpy as np
import threading

class RegisterFile:
    def __init__(self, num_registers: int = 32, register_width: int = 32):
        self.num_registers = num_registers
        self.register_width = register_width
        self.registers = np.zeros(num_registers, dtype=np.int32)
        self.register_lock = threading.Lock()
        self.usage_stats = {
            "reads": 0,
            "writes": 0,
            "conflicts": 0
        }
        
    def read_register(self, reg_id: int) -> int:
        """Read value from specified register"""
        if not 0 <= reg_id < self.num_registers:
            raise ValueError(f"Invalid register ID: {reg_id}")
            
        with self.register_lock:
            self.usage_stats["reads"] += 1
            return int(self.registers[reg_id])
            
    def write_register(self, reg_id: int, value: int) -> None:
        """Write value to specified register"""
        if not 0 <= reg_id < self.num_registers:
            raise ValueError(f"Invalid register ID: {reg_id}")
            
        # Ensure value fits in register width
        max_value = (1 << self.register_width) - 1
        value = value & max_value
            
        with self.register_lock:
            self.usage_stats["writes"] += 1
            self.registers[reg_id] = value
            
    def bulk_read(self, reg_ids: List[int]) -> List[int]:
        """Read multiple registers at once"""
        if not all(0 <= rid < self.num_registers for rid in reg_ids):
            raise ValueError("Invalid register ID in bulk read")
            
        with self.register_lock:
            self.usage_stats["reads"] += len(reg_ids)
            return [int(self.registers[rid]) for rid in reg_ids]
            
    def bulk_write(self, reg_data: Dict[int, int]) -> None:
        """Write to multiple registers at once"""
        if not all(0 <= rid < self.num_registers for rid in reg_data.keys()):
            raise ValueError("Invalid register ID in bulk write")
            
        max_value = (1 << self.register_width) - 1
        values = {rid: val & max_value for rid, val in reg_data.items()}
            
        with self.register_lock:
            self.usage_stats["writes"] += len(values)
            for rid, val in values.items():
                self.registers[rid] = val
                
    def clear_registers(self) -> None:
        """Reset all registers to zero"""
        with self.register_lock:
            self.registers.fill(0)
            
    def get_usage_stats(self) -> Dict[str, int]:
        """Get register file usage statistics"""
        with self.register_lock:
            return dict(self.usage_stats)
            
    def check_conflicts(self, read_regs: List[int], write_regs: List[int]) -> bool:
        """Check for read/write conflicts between register sets"""
        read_set = set(read_regs)
        write_set = set(write_regs)
        
        # Conflict if any register is both read and written
        conflicts = read_set.intersection(write_set)
        if conflicts:
            self.usage_stats["conflicts"] += 1
            return True
            
        return False