from websocket_storage import WebSocketGPUStorage import numpy as np from typing import Dict, Any, Optional import time class VirtualVRAM: def __init__(self, size_gb: int = None): """Initialize virtual VRAM with unlimited storage capability""" self.storage = WebSocketGPUStorage() if not self.storage.wait_for_connection(): raise RuntimeError("Could not connect to GPU storage server") # Initialize VRAM state with unlimited capacity self.vram_state = { "total_size": float('inf'), # Unlimited size "allocated": 0, "blocks": {}, "memory_map": {}, "is_unlimited": True } self.store_vram_state() def store_vram_state(self): """Store VRAM state in WebSocket storage""" self.storage.store_state("vram", "state", self.vram_state) def allocate_block(self, size: int, block_id: Optional[str] = None) -> str: """Allocate a block of VRAM""" if self.vram_state["allocated"] + size > self.vram_state["total_size"]: raise MemoryError("Not enough VRAM available") if block_id is None: block_id = f"block_{time.time_ns()}" self.vram_state["blocks"][block_id] = { "size": size, "allocated_at": time.time_ns(), "last_accessed": time.time_ns() } self.vram_state["allocated"] += size # Store updated state self.store_vram_state() return block_id def free_block(self, block_id: str): """Free a block of VRAM""" if block_id in self.vram_state["blocks"]: self.vram_state["allocated"] -= self.vram_state["blocks"][block_id]["size"] del self.vram_state["blocks"][block_id] self.store_vram_state() # Remove block data from storage self.storage.store_tensor(block_id, None) def write_block(self, block_id: str, data: np.ndarray): """Write data to a VRAM block""" if block_id not in self.vram_state["blocks"]: raise ValueError(f"Block {block_id} not allocated") self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns() self.store_vram_state() return self.storage.store_tensor(block_id, data) def read_block(self, block_id: str) -> Optional[np.ndarray]: """Read data from a VRAM block""" if block_id not in self.vram_state["blocks"]: raise ValueError(f"Block {block_id} not allocated") self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns() self.store_vram_state() return self.storage.load_tensor(block_id) def map_address(self, virtual_addr: str, block_id: str): """Map virtual address to VRAM block""" self.vram_state["memory_map"][virtual_addr] = block_id self.store_vram_state() def get_block_from_address(self, virtual_addr: str) -> Optional[str]: """Get block ID from virtual address""" return self.vram_state["memory_map"].get(virtual_addr) def get_stats(self) -> Dict[str, Any]: """Get VRAM statistics""" return { "total_gb": self.size_gb, "used_gb": self.vram_state["allocated"] / (1024 * 1024 * 1024), "free_gb": (self.vram_state["total_size"] - self.vram_state["allocated"]) / (1024 * 1024 * 1024), "num_blocks": len(self.vram_state["blocks"]), "mappings": len(self.vram_state["memory_map"]) }