from websocket_storage import WebSocketGPUStorage import numpy as np from typing import Dict, Any, Optional import time class VirtualVRAM: def __init__(self, size_gb: int = None, storage=None): """Initialize virtual VRAM with unlimited storage capability""" self.storage = storage if self.storage is None: from websocket_storage import WebSocketGPUStorage self.storage = WebSocketGPUStorage() if not self.storage.wait_for_connection(): raise RuntimeError("Could not connect to GPU storage server") # Initialize VRAM state with unlimited capacity self.vram_state = { "total_size": size_gb * 1024 * 1024 * 1024 if size_gb is not None else float('inf'), # Convert GB to bytes or use unlimited "allocated": 0, "blocks": {}, "memory_map": {}, "is_unlimited": size_gb is None } self.store_vram_state() @property def total_size(self) -> int: """Get total VRAM size in bytes""" return self.vram_state["total_size"] @property def available_size(self) -> int: """Get available VRAM size in bytes""" return self.vram_state["total_size"] - self.vram_state["allocated"] if not self.vram_state["is_unlimited"] else float('inf') def store_vram_state(self, max_retries=3): """Store VRAM state in WebSocket storage with retry logic""" for attempt in range(max_retries): try: # Wait for connection if needed if not self.storage.wait_for_connection(timeout=5): print(f"Waiting for WebSocket connection (attempt {attempt + 1}/{max_retries})") time.sleep(1) continue # Ensure state is JSON serializable safe_state = { "total_size": str(self.vram_state["total_size"]) if isinstance(self.vram_state["total_size"], float) and self.vram_state["total_size"] == float('inf') else self.vram_state["total_size"], "allocated": self.vram_state["allocated"], "blocks": self.vram_state["blocks"], "memory_map": self.vram_state["memory_map"], "is_unlimited": self.vram_state["is_unlimited"] } success = self.storage.store_state("vram", "state", safe_state) if success: return True print(f"Failed to store VRAM state (attempt {attempt + 1}/{max_retries})") time.sleep(1) except Exception as e: print(f"Error storing VRAM state (attempt {attempt + 1}/{max_retries}): {str(e)}") time.sleep(1) raise RuntimeError("Failed to store VRAM state after multiple attempts") def allocate_block(self, size: int, block_id: Optional[str] = None) -> str: """Allocate a block of VRAM""" if self.vram_state["allocated"] + size > self.vram_state["total_size"]: raise MemoryError("Not enough VRAM available") if block_id is None: block_id = f"block_{time.time_ns()}" self.vram_state["blocks"][block_id] = { "size": size, "allocated_at": time.time_ns(), "last_accessed": time.time_ns() } self.vram_state["allocated"] += size # Store updated state self.store_vram_state() return block_id def free_block(self, block_id: str): """Free a block of VRAM""" if block_id in self.vram_state["blocks"]: self.vram_state["allocated"] -= self.vram_state["blocks"][block_id]["size"] del self.vram_state["blocks"][block_id] self.store_vram_state() # Remove block data from storage self.storage.store_tensor(block_id, None) def write_block(self, block_id: str, data: np.ndarray): """Write data to a VRAM block""" if block_id not in self.vram_state["blocks"]: raise ValueError(f"Block {block_id} not allocated") self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns() self.store_vram_state() return self.storage.store_tensor(block_id, data) def read_block(self, block_id: str) -> Optional[np.ndarray]: """Read data from a VRAM block""" if block_id not in self.vram_state["blocks"]: raise ValueError(f"Block {block_id} not allocated") self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns() self.store_vram_state() return self.storage.load_tensor(block_id) def map_address(self, virtual_addr: str, block_id: str): """Map virtual address to VRAM block""" self.vram_state["memory_map"][virtual_addr] = block_id self.store_vram_state() def get_block_from_address(self, virtual_addr: str) -> Optional[str]: """Get block ID from virtual address""" return self.vram_state["memory_map"].get(virtual_addr) def get_stats(self) -> Dict[str, Any]: """Get VRAM statistics""" return { "total_gb": self.size_gb, "used_gb": self.vram_state["allocated"] / (1024 * 1024 * 1024), "free_gb": (self.vram_state["total_size"] - self.vram_state["allocated"]) / (1024 * 1024 * 1024), "num_blocks": len(self.vram_state["blocks"]), "mappings": len(self.vram_state["memory_map"]) }