Spaces:
Sleeping
Sleeping
| from websocket_storage import WebSocketGPUStorage | |
| import numpy as np | |
| from typing import Dict, Any, Optional, List | |
| import time | |
| class StreamingMultiprocessor: | |
| def __init__(self, sm_id: int, num_cores: int = 128, storage=None): | |
| self.sm_id = sm_id | |
| self.num_cores = num_cores | |
| self.storage = storage | |
| if self.storage is None: | |
| from websocket_storage import WebSocketGPUStorage | |
| self.storage = WebSocketGPUStorage() | |
| if not self.storage.wait_for_connection(): | |
| raise RuntimeError("Could not connect to GPU storage server") | |
| # Initialize SM state | |
| self.sm_state = { | |
| "sm_id": sm_id, | |
| "num_cores": num_cores, | |
| "active_warps": {}, | |
| "shared_memory": {}, | |
| "register_file": {}, | |
| "l1_cache": {}, | |
| "warp_scheduler_state": { | |
| "active_warps": [], | |
| "pending_warps": [], | |
| "completed_warps": [] | |
| } | |
| } | |
| self.store_sm_state() | |
| def store_sm_state(self): | |
| """Store SM state in WebSocket storage""" | |
| self.storage.store_state(f"sm_{self.sm_id}", "state", self.sm_state) | |
| def allocate_shared_memory(self, size: int, block_id: str) -> str: | |
| """Allocate shared memory for a block""" | |
| shared_id = f"shared_{block_id}_{time.time_ns()}" | |
| self.sm_state["shared_memory"][shared_id] = { | |
| "size": size, | |
| "block_id": block_id, | |
| "allocated_at": time.time_ns() | |
| } | |
| self.store_sm_state() | |
| return shared_id | |
| def write_shared_memory(self, shared_id: str, data: np.ndarray): | |
| """Write to shared memory""" | |
| if shared_id not in self.sm_state["shared_memory"]: | |
| raise ValueError(f"Shared memory block {shared_id} not allocated") | |
| return self.storage.store_tensor(shared_id, data) | |
| def read_shared_memory(self, shared_id: str) -> Optional[np.ndarray]: | |
| """Read from shared memory""" | |
| if shared_id not in self.sm_state["shared_memory"]: | |
| raise ValueError(f"Shared memory block {shared_id} not allocated") | |
| return self.storage.load_tensor(shared_id) | |
| def schedule_warp(self, warp_id: str, warp_state: Dict[str, Any]): | |
| """Schedule a warp for execution""" | |
| self.sm_state["warp_scheduler_state"]["active_warps"].append(warp_id) | |
| self.sm_state["active_warps"][warp_id] = warp_state | |
| self.store_sm_state() | |
| # Store warp state | |
| self.storage.store_state(f"warp_{warp_id}", "state", warp_state) | |
| def complete_warp(self, warp_id: str): | |
| """Mark a warp as completed""" | |
| if warp_id in self.sm_state["active_warps"]: | |
| self.sm_state["warp_scheduler_state"]["active_warps"].remove(warp_id) | |
| self.sm_state["warp_scheduler_state"]["completed_warps"].append(warp_id) | |
| warp_state = self.sm_state["active_warps"].pop(warp_id) | |
| self.store_sm_state() | |
| # Store completed state | |
| self.storage.store_state(f"warp_{warp_id}", "completed", warp_state) | |
| def write_register(self, warp_id: str, reg_id: str, data: np.ndarray): | |
| """Write to register file""" | |
| reg_key = f"reg_{warp_id}_{reg_id}" | |
| self.sm_state["register_file"][reg_key] = { | |
| "warp_id": warp_id, | |
| "reg_id": reg_id, | |
| "last_accessed": time.time_ns() | |
| } | |
| self.store_sm_state() | |
| return self.storage.store_tensor(reg_key, data) | |
| def read_register(self, warp_id: str, reg_id: str) -> Optional[np.ndarray]: | |
| """Read from register file""" | |
| reg_key = f"reg_{warp_id}_{reg_id}" | |
| if reg_key in self.sm_state["register_file"]: | |
| self.sm_state["register_file"][reg_key]["last_accessed"] = time.time_ns() | |
| self.store_sm_state() | |
| return self.storage.load_tensor(reg_key) | |
| return None | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get SM statistics""" | |
| return { | |
| "sm_id": self.sm_id, | |
| "num_cores": self.num_cores, | |
| "active_warps": len(self.sm_state["active_warps"]), | |
| "shared_memory_blocks": len(self.sm_state["shared_memory"]), | |
| "register_file_entries": len(self.sm_state["register_file"]), | |
| "completed_warps": len(self.sm_state["warp_scheduler_state"]["completed_warps"]) | |
| } | |