File size: 5,442 Bytes
e9bc512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from websocket_storage import WebSocketGPUStorage
import numpy as np
from typing import Dict, Any, Optional
import time

class VirtualVRAM:
    def __init__(self, size_gb: int = None, storage=None):
        """Initialize virtual VRAM with unlimited storage capability"""
        self.storage = storage
        if self.storage is None:
            from websocket_storage import WebSocketGPUStorage
            self.storage = WebSocketGPUStorage()
            if not self.storage.wait_for_connection():
                raise RuntimeError("Could not connect to GPU storage server")
        
        # Initialize VRAM state with unlimited capacity
        self.vram_state = {
            "total_size": float('inf'),  # Unlimited size
            "allocated": 0,
            "blocks": {},
            "memory_map": {},
            "is_unlimited": True
        }
        self.store_vram_state()
        
    def store_vram_state(self, max_retries=3):
        """Store VRAM state in WebSocket storage with retry logic"""
        for attempt in range(max_retries):
            try:
                # Wait for connection if needed
                if not self.storage.wait_for_connection(timeout=5):
                    print(f"Waiting for WebSocket connection (attempt {attempt + 1}/{max_retries})")
                    time.sleep(1)
                    continue
                    
                # Ensure state is JSON serializable
                safe_state = {
                    "total_size": str(self.vram_state["total_size"]) if isinstance(self.vram_state["total_size"], float) and self.vram_state["total_size"] == float('inf') else self.vram_state["total_size"],
                    "allocated": self.vram_state["allocated"],
                    "blocks": self.vram_state["blocks"],
                    "memory_map": self.vram_state["memory_map"],
                    "is_unlimited": self.vram_state["is_unlimited"]
                }
                
                success = self.storage.store_state("vram", "state", safe_state)
                if success:
                    return True
                    
                print(f"Failed to store VRAM state (attempt {attempt + 1}/{max_retries})")
                time.sleep(1)
                
            except Exception as e:
                print(f"Error storing VRAM state (attempt {attempt + 1}/{max_retries}): {str(e)}")
                time.sleep(1)
                
        raise RuntimeError("Failed to store VRAM state after multiple attempts")
        
    def allocate_block(self, size: int, block_id: Optional[str] = None) -> str:
        """Allocate a block of VRAM"""
        if self.vram_state["allocated"] + size > self.vram_state["total_size"]:
            raise MemoryError("Not enough VRAM available")
            
        if block_id is None:
            block_id = f"block_{time.time_ns()}"
            
        self.vram_state["blocks"][block_id] = {
            "size": size,
            "allocated_at": time.time_ns(),
            "last_accessed": time.time_ns()
        }
        self.vram_state["allocated"] += size
        
        # Store updated state
        self.store_vram_state()
        return block_id
        
    def free_block(self, block_id: str):
        """Free a block of VRAM"""
        if block_id in self.vram_state["blocks"]:
            self.vram_state["allocated"] -= self.vram_state["blocks"][block_id]["size"]
            del self.vram_state["blocks"][block_id]
            self.store_vram_state()
            
            # Remove block data from storage
            self.storage.store_tensor(block_id, None)
            
    def write_block(self, block_id: str, data: np.ndarray):
        """Write data to a VRAM block"""
        if block_id not in self.vram_state["blocks"]:
            raise ValueError(f"Block {block_id} not allocated")
            
        self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns()
        self.store_vram_state()
        
        return self.storage.store_tensor(block_id, data)
        
    def read_block(self, block_id: str) -> Optional[np.ndarray]:
        """Read data from a VRAM block"""
        if block_id not in self.vram_state["blocks"]:
            raise ValueError(f"Block {block_id} not allocated")
            
        self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns()
        self.store_vram_state()
        
        return self.storage.load_tensor(block_id)
        
    def map_address(self, virtual_addr: str, block_id: str):
        """Map virtual address to VRAM block"""
        self.vram_state["memory_map"][virtual_addr] = block_id
        self.store_vram_state()
        
    def get_block_from_address(self, virtual_addr: str) -> Optional[str]:
        """Get block ID from virtual address"""
        return self.vram_state["memory_map"].get(virtual_addr)
        
    def get_stats(self) -> Dict[str, Any]:
        """Get VRAM statistics"""
        return {
            "total_gb": self.size_gb,
            "used_gb": self.vram_state["allocated"] / (1024 * 1024 * 1024),
            "free_gb": (self.vram_state["total_size"] - self.vram_state["allocated"]) / (1024 * 1024 * 1024),
            "num_blocks": len(self.vram_state["blocks"]),
            "mappings": len(self.vram_state["memory_map"])
        }