File size: 4,681 Bytes
0a735c8
 
 
 
 
 
16d64f1
0a735c8
 
16d64f1
 
 
 
 
 
0a735c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from websocket_storage import WebSocketGPUStorage
import numpy as np
from typing import Dict, Any, Optional, List
import time

class StreamingMultiprocessor:
    def __init__(self, sm_id: int, num_cores: int = 128, storage=None):
        self.sm_id = sm_id
        self.num_cores = num_cores
        self.storage = storage
        if self.storage is None:
            from websocket_storage import WebSocketGPUStorage
            self.storage = WebSocketGPUStorage()
            if not self.storage.wait_for_connection():
                raise RuntimeError("Could not connect to GPU storage server")
            
        # Initialize SM state
        self.sm_state = {
            "sm_id": sm_id,
            "num_cores": num_cores,
            "active_warps": {},
            "shared_memory": {},
            "register_file": {},
            "l1_cache": {},
            "warp_scheduler_state": {
                "active_warps": [],
                "pending_warps": [],
                "completed_warps": []
            }
        }
        self.store_sm_state()
        
    def store_sm_state(self):
        """Store SM state in WebSocket storage"""
        self.storage.store_state(f"sm_{self.sm_id}", "state", self.sm_state)
        
    def allocate_shared_memory(self, size: int, block_id: str) -> str:
        """Allocate shared memory for a block"""
        shared_id = f"shared_{block_id}_{time.time_ns()}"
        self.sm_state["shared_memory"][shared_id] = {
            "size": size,
            "block_id": block_id,
            "allocated_at": time.time_ns()
        }
        self.store_sm_state()
        return shared_id
        
    def write_shared_memory(self, shared_id: str, data: np.ndarray):
        """Write to shared memory"""
        if shared_id not in self.sm_state["shared_memory"]:
            raise ValueError(f"Shared memory block {shared_id} not allocated")
            
        return self.storage.store_tensor(shared_id, data)
        
    def read_shared_memory(self, shared_id: str) -> Optional[np.ndarray]:
        """Read from shared memory"""
        if shared_id not in self.sm_state["shared_memory"]:
            raise ValueError(f"Shared memory block {shared_id} not allocated")
            
        return self.storage.load_tensor(shared_id)
        
    def schedule_warp(self, warp_id: str, warp_state: Dict[str, Any]):
        """Schedule a warp for execution"""
        self.sm_state["warp_scheduler_state"]["active_warps"].append(warp_id)
        self.sm_state["active_warps"][warp_id] = warp_state
        self.store_sm_state()
        
        # Store warp state
        self.storage.store_state(f"warp_{warp_id}", "state", warp_state)
        
    def complete_warp(self, warp_id: str):
        """Mark a warp as completed"""
        if warp_id in self.sm_state["active_warps"]:
            self.sm_state["warp_scheduler_state"]["active_warps"].remove(warp_id)
            self.sm_state["warp_scheduler_state"]["completed_warps"].append(warp_id)
            warp_state = self.sm_state["active_warps"].pop(warp_id)
            self.store_sm_state()
            
            # Store completed state
            self.storage.store_state(f"warp_{warp_id}", "completed", warp_state)
            
    def write_register(self, warp_id: str, reg_id: str, data: np.ndarray):
        """Write to register file"""
        reg_key = f"reg_{warp_id}_{reg_id}"
        self.sm_state["register_file"][reg_key] = {
            "warp_id": warp_id,
            "reg_id": reg_id,
            "last_accessed": time.time_ns()
        }
        self.store_sm_state()
        
        return self.storage.store_tensor(reg_key, data)
        
    def read_register(self, warp_id: str, reg_id: str) -> Optional[np.ndarray]:
        """Read from register file"""
        reg_key = f"reg_{warp_id}_{reg_id}"
        if reg_key in self.sm_state["register_file"]:
            self.sm_state["register_file"][reg_key]["last_accessed"] = time.time_ns()
            self.store_sm_state()
            return self.storage.load_tensor(reg_key)
        return None
        
    def get_stats(self) -> Dict[str, Any]:
        """Get SM statistics"""
        return {
            "sm_id": self.sm_id,
            "num_cores": self.num_cores,
            "active_warps": len(self.sm_state["active_warps"]),
            "shared_memory_blocks": len(self.sm_state["shared_memory"]),
            "register_file_entries": len(self.sm_state["register_file"]),
            "completed_warps": len(self.sm_state["warp_scheduler_state"]["completed_warps"])
        }