Spaces:
Runtime error
Runtime error
File size: 5,926 Bytes
520d6cf e41bfbc 520d6cf e41bfbc 520d6cf e41bfbc 520d6cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from websocket_storage import WebSocketGPUStorage
import numpy as np
from typing import Dict, Any, Optional
import time
class VirtualVRAM:
def __init__(self, size_gb: int = None, storage=None):
"""Initialize virtual VRAM with unlimited storage capability"""
self.storage = storage
if self.storage is None:
from websocket_storage import WebSocketGPUStorage
self.storage = WebSocketGPUStorage()
if not self.storage.wait_for_connection():
raise RuntimeError("Could not connect to GPU storage server")
# Initialize VRAM state with unlimited capacity
self.vram_state = {
"total_size": size_gb * 1024 * 1024 * 1024 if size_gb is not None else float('inf'), # Convert GB to bytes or use unlimited
"allocated": 0,
"blocks": {},
"memory_map": {},
"is_unlimited": size_gb is None
}
self.store_vram_state()
@property
def total_size(self) -> int:
"""Get total VRAM size in bytes"""
return self.vram_state["total_size"]
@property
def available_size(self) -> int:
"""Get available VRAM size in bytes"""
return self.vram_state["total_size"] - self.vram_state["allocated"] if not self.vram_state["is_unlimited"] else float('inf')
def store_vram_state(self, max_retries=3):
"""Store VRAM state in WebSocket storage with retry logic"""
for attempt in range(max_retries):
try:
# Wait for connection if needed
if not self.storage.wait_for_connection(timeout=5):
print(f"Waiting for WebSocket connection (attempt {attempt + 1}/{max_retries})")
time.sleep(1)
continue
# Ensure state is JSON serializable
safe_state = {
"total_size": str(self.vram_state["total_size"]) if isinstance(self.vram_state["total_size"], float) and self.vram_state["total_size"] == float('inf') else self.vram_state["total_size"],
"allocated": self.vram_state["allocated"],
"blocks": self.vram_state["blocks"],
"memory_map": self.vram_state["memory_map"],
"is_unlimited": self.vram_state["is_unlimited"]
}
success = self.storage.store_state("vram", "state", safe_state)
if success:
return True
print(f"Failed to store VRAM state (attempt {attempt + 1}/{max_retries})")
time.sleep(1)
except Exception as e:
print(f"Error storing VRAM state (attempt {attempt + 1}/{max_retries}): {str(e)}")
time.sleep(1)
raise RuntimeError("Failed to store VRAM state after multiple attempts")
def allocate_block(self, size: int, block_id: Optional[str] = None) -> str:
"""Allocate a block of VRAM"""
if self.vram_state["allocated"] + size > self.vram_state["total_size"]:
raise MemoryError("Not enough VRAM available")
if block_id is None:
block_id = f"block_{time.time_ns()}"
self.vram_state["blocks"][block_id] = {
"size": size,
"allocated_at": time.time_ns(),
"last_accessed": time.time_ns()
}
self.vram_state["allocated"] += size
# Store updated state
self.store_vram_state()
return block_id
def free_block(self, block_id: str):
"""Free a block of VRAM"""
if block_id in self.vram_state["blocks"]:
self.vram_state["allocated"] -= self.vram_state["blocks"][block_id]["size"]
del self.vram_state["blocks"][block_id]
self.store_vram_state()
# Remove block data from storage
self.storage.store_tensor(block_id, None)
def write_block(self, block_id: str, data: np.ndarray):
"""Write data to a VRAM block"""
if block_id not in self.vram_state["blocks"]:
raise ValueError(f"Block {block_id} not allocated")
self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns()
self.store_vram_state()
return self.storage.store_tensor(block_id, data)
def read_block(self, block_id: str) -> Optional[np.ndarray]:
"""Read data from a VRAM block"""
if block_id not in self.vram_state["blocks"]:
raise ValueError(f"Block {block_id} not allocated")
self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns()
self.store_vram_state()
return self.storage.load_tensor(block_id)
def map_address(self, virtual_addr: str, block_id: str):
"""Map virtual address to VRAM block"""
self.vram_state["memory_map"][virtual_addr] = block_id
self.store_vram_state()
def get_block_from_address(self, virtual_addr: str) -> Optional[str]:
"""Get block ID from virtual address"""
return self.vram_state["memory_map"].get(virtual_addr)
def get_stats(self) -> Dict[str, Any]:
"""Get VRAM statistics"""
return {
"total_gb": self.size_gb,
"used_gb": self.vram_state["allocated"] / (1024 * 1024 * 1024),
"free_gb": (self.vram_state["total_size"] - self.vram_state["allocated"]) / (1024 * 1024 * 1024),
"num_blocks": len(self.vram_state["blocks"]),
"mappings": len(self.vram_state["memory_map"])
}
|