File size: 3,938 Bytes
7a0c684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import Dict, Any, Optional
import numpy as np

class MemoryBlock:
    """Base class for GPU memory blocks"""
    def __init__(self, size_bytes: int):
        self.size = size_bytes
        self.data = bytearray(size_bytes)
        self.offset = 0

    def allocate(self, size_bytes: int) -> Optional[int]:
        """Allocate memory and return offset"""
        if self.offset + size_bytes > self.size:
            return None
        current_offset = self.offset
        self.offset += size_bytes
        return current_offset

    def write(self, offset: int, data: bytes):
        """Write data at specified offset"""
        if offset + len(data) > self.size:
            raise ValueError("Write operation exceeds memory block size")
        self.data[offset:offset + len(data)] = data

    def read(self, offset: int, size: int) -> bytes:
        """Read data from specified offset"""
        if offset + size > self.size:
            raise ValueError("Read operation exceeds memory block size")
        return bytes(self.data[offset:offset + size])

class SharedMemory(MemoryBlock):
    """Represents shared memory accessible by all threads in a block"""
    def __init__(self, size_bytes: int = 48*1024):  # Default 48KB
        super().__init__(size_bytes)
        self.locks: Dict[int, bool] = {}  # For synchronization

    def atomic_add(self, offset: int, value: int) -> int:
        """Perform atomic addition"""
        current = int.from_bytes(self.read(offset, 4), 'little')
        new_value = current + value
        self.write(offset, new_value.to_bytes(4, 'little'))
        return current

class L1Cache(MemoryBlock):
    """Represents L1 cache memory"""
    def __init__(self, size_bytes: int = 32*1024):  # Default 32KB
        super().__init__(size_bytes)
        self.cache_lines: Dict[int, bytes] = {}
        self.line_size = 128  # 128 bytes per cache line

    def load_line(self, address: int) -> bytes:
        """Load a cache line"""
        line_address = address - (address % self.line_size)
        if line_address not in self.cache_lines:
            # Simulate fetching from L2
            self.cache_lines[line_address] = bytes(self.line_size)
        return self.cache_lines[line_address]

class L2Cache(MemoryBlock):
    """Represents L2 cache memory"""
    def __init__(self, size_bytes: int = 1024*1024):  # Default 1MB
        super().__init__(size_bytes)
        self.cache_lines: Dict[int, bytes] = {}
        self.line_size = 256  # 256 bytes per cache line

    def load_line(self, address: int) -> bytes:
        """Load a cache line"""
        line_address = address - (address % self.line_size)
        if line_address not in self.cache_lines:
            # Simulate fetching from global memory
            self.cache_lines[line_address] = bytes(self.line_size)
        return self.cache_lines[line_address]

class RegisterFile:
    """Represents per-thread registers"""
    def __init__(self, num_registers: int = 255):  # Maximum registers per thread
        self.registers = [0] * num_registers
        self.used_registers = 0

    def allocate(self, num: int = 1) -> Optional[int]:
        """Allocate registers and return starting index"""
        if self.used_registers + num > len(self.registers):
            return None
        start = self.used_registers
        self.used_registers += num
        return start

    def read(self, index: int) -> int:
        """Read from register"""
        if 0 <= index < self.used_registers:
            return self.registers[index]
        raise IndexError("Register index out of bounds")

    def write(self, index: int, value: int):
        """Write to register"""
        if 0 <= index < self.used_registers:
            self.registers[index] = value
        else:
            raise IndexError("Register index out of bounds")