File size: 4,763 Bytes
e9bc512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from websocket_storage import WebSocketGPUStorage
from virtual_vram import VirtualVRAM
from streaming_multiprocessor import StreamingMultiprocessor
from typing import Dict, Any, List, Optional
import time

class GPUChip:
    def __init__(self, chip_id: int, num_sms: int = 108, vram_gb: int = 24, storage=None):
        self.chip_id = chip_id
        self.storage = storage
        if self.storage is None:
            from websocket_storage import WebSocketGPUStorage
            self.storage = WebSocketGPUStorage()
            if not self.storage.wait_for_connection():
                raise RuntimeError("Could not connect to GPU storage server")
            
        # Initialize components with shared storage
        self.vram = VirtualVRAM(vram_gb, storage=self.storage)
        self.sms = [StreamingMultiprocessor(i, storage=self.storage) for i in range(num_sms)]
        
        # Initialize chip state
        self.chip_state = {
            "chip_id": chip_id,
            "num_sms": num_sms,
            "vram_gb": vram_gb,
            "pcie_state": {
                "active_transfers": {},
                "bandwidth_usage": 0
            },
            "power_state": {
                "total_watts": 0,
                "sm_power": [0] * num_sms,
                "vram_power": 0
            },
            "memory_controller": {
                "active_requests": {},
                "bandwidth_usage": 0
            }
        }
        self.store_chip_state()
        
    def store_chip_state(self):
        """Store chip state in WebSocket storage"""
        self.storage.store_state(f"chip_{self.chip_id}", "state", self.chip_state)
        
    def allocate_memory(self, size: int, virtual_addr: Optional[str] = None) -> str:
        """Allocate memory through VRAM"""
        block_id = self.vram.allocate_block(size)
        if virtual_addr:
            self.vram.map_address(virtual_addr, block_id)
        
        # Update memory controller state
        self.chip_state["memory_controller"]["active_requests"][block_id] = {
            "type": "allocation",
            "size": size,
            "timestamp": time.time_ns()
        }
        self.store_chip_state()
        
        return block_id
        
    def transfer_to_device(self, data: bytes, virtual_addr: Optional[str] = None) -> str:
        """Transfer data to device through PCIe"""
        # Simulate PCIe transfer
        transfer_id = f"transfer_{time.time_ns()}"
        self.chip_state["pcie_state"]["active_transfers"][transfer_id] = {
            "direction": "to_device",
            "size": len(data),
            "timestamp": time.time_ns()
        }
        self.store_chip_state()
        
        # Allocate and store in VRAM
        block_id = self.allocate_memory(len(data), virtual_addr)
        self.storage.store_tensor(block_id, data)
        
        # Update transfer state
        self.chip_state["pcie_state"]["active_transfers"][transfer_id]["completed"] = True
        self.store_chip_state()
        
        return block_id
        
    def schedule_compute(self, sm_index: int, warp_state: Dict[str, Any]) -> str:
        """Schedule computation on an SM"""
        if 0 <= sm_index < len(self.sms):
            warp_id = f"warp_{time.time_ns()}"
            self.sms[sm_index].schedule_warp(warp_id, warp_state)
            
            # Update power state
            self.chip_state["power_state"]["sm_power"][sm_index] += 10  # Simulate power increase
            self.chip_state["power_state"]["total_watts"] = sum(self.chip_state["power_state"]["sm_power"])
            self.store_chip_state()
            
            return warp_id
        raise ValueError(f"Invalid SM index: {sm_index}")
        
    def get_stats(self) -> Dict[str, Any]:
        """Get comprehensive chip statistics"""
        stats = {
            "chip_id": self.chip_id,
            "vram": self.vram.get_stats(),
            "sms": [sm.get_stats() for sm in self.sms],
            "pcie": {
                "active_transfers": len(self.chip_state["pcie_state"]["active_transfers"]),
                "bandwidth_usage": self.chip_state["pcie_state"]["bandwidth_usage"]
            },
            "power": {
                "total_watts": self.chip_state["power_state"]["total_watts"],
                "vram_watts": self.chip_state["power_state"]["vram_power"]
            },
            "memory_controller": {
                "active_requests": len(self.chip_state["memory_controller"]["active_requests"]),
                "bandwidth_usage": self.chip_state["memory_controller"]["bandwidth_usage"]
            }
        }
        return stats