from websocket_storage import WebSocketGPUStorage import numpy as np from typing import Dict, Any, Optional, List import time class StreamingMultiprocessor: def __init__(self, sm_id: int, num_cores: int = 128, storage=None): self.sm_id = sm_id self.num_cores = num_cores self.storage = storage if self.storage is None: from websocket_storage import WebSocketGPUStorage self.storage = WebSocketGPUStorage() if not self.storage.wait_for_connection(): raise RuntimeError("Could not connect to GPU storage server") # Initialize SM state self.sm_state = { "sm_id": sm_id, "num_cores": num_cores, "active_warps": {}, "shared_memory": {}, "register_file": {}, "l1_cache": {}, "warp_scheduler_state": { "active_warps": [], "pending_warps": [], "completed_warps": [] } } self.store_sm_state() def store_sm_state(self): """Store SM state in WebSocket storage""" self.storage.store_state(f"sm_{self.sm_id}", "state", self.sm_state) def allocate_shared_memory(self, size: int, block_id: str) -> str: """Allocate shared memory for a block""" shared_id = f"shared_{block_id}_{time.time_ns()}" self.sm_state["shared_memory"][shared_id] = { "size": size, "block_id": block_id, "allocated_at": time.time_ns() } self.store_sm_state() return shared_id def write_shared_memory(self, shared_id: str, data: np.ndarray): """Write to shared memory""" if shared_id not in self.sm_state["shared_memory"]: raise ValueError(f"Shared memory block {shared_id} not allocated") return self.storage.store_tensor(shared_id, data) def read_shared_memory(self, shared_id: str) -> Optional[np.ndarray]: """Read from shared memory""" if shared_id not in self.sm_state["shared_memory"]: raise ValueError(f"Shared memory block {shared_id} not allocated") return self.storage.load_tensor(shared_id) def schedule_warp(self, warp_id: str, warp_state: Dict[str, Any]): """Schedule a warp for execution""" self.sm_state["warp_scheduler_state"]["active_warps"].append(warp_id) self.sm_state["active_warps"][warp_id] = warp_state self.store_sm_state() # Store warp state self.storage.store_state(f"warp_{warp_id}", "state", warp_state) def complete_warp(self, warp_id: str): """Mark a warp as completed""" if warp_id in self.sm_state["active_warps"]: self.sm_state["warp_scheduler_state"]["active_warps"].remove(warp_id) self.sm_state["warp_scheduler_state"]["completed_warps"].append(warp_id) warp_state = self.sm_state["active_warps"].pop(warp_id) self.store_sm_state() # Store completed state self.storage.store_state(f"warp_{warp_id}", "completed", warp_state) def write_register(self, warp_id: str, reg_id: str, data: np.ndarray): """Write to register file""" reg_key = f"reg_{warp_id}_{reg_id}" self.sm_state["register_file"][reg_key] = { "warp_id": warp_id, "reg_id": reg_id, "last_accessed": time.time_ns() } self.store_sm_state() return self.storage.store_tensor(reg_key, data) def read_register(self, warp_id: str, reg_id: str) -> Optional[np.ndarray]: """Read from register file""" reg_key = f"reg_{warp_id}_{reg_id}" if reg_key in self.sm_state["register_file"]: self.sm_state["register_file"][reg_key]["last_accessed"] = time.time_ns() self.store_sm_state() return self.storage.load_tensor(reg_key) return None def get_stats(self) -> Dict[str, Any]: """Get SM statistics""" return { "sm_id": self.sm_id, "num_cores": self.num_cores, "active_warps": len(self.sm_state["active_warps"]), "shared_memory_blocks": len(self.sm_state["shared_memory"]), "register_file_entries": len(self.sm_state["register_file"]), "completed_warps": len(self.sm_state["warp_scheduler_state"]["completed_warps"]) }