File size: 5,926 Bytes
520d6cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e41bfbc
520d6cf
 
 
e41bfbc
520d6cf
 
 
e41bfbc
 
 
 
 
 
 
 
 
 
520d6cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from websocket_storage import WebSocketGPUStorage
import numpy as np
from typing import Dict, Any, Optional
import time

class VirtualVRAM:
    def __init__(self, size_gb: int = None, storage=None):
        """Initialize virtual VRAM with unlimited storage capability"""
        self.storage = storage
        if self.storage is None:
            from websocket_storage import WebSocketGPUStorage
            self.storage = WebSocketGPUStorage()
            if not self.storage.wait_for_connection():
                raise RuntimeError("Could not connect to GPU storage server")
        
        # Initialize VRAM state with unlimited capacity
        self.vram_state = {
            "total_size": size_gb * 1024 * 1024 * 1024 if size_gb is not None else float('inf'),  # Convert GB to bytes or use unlimited
            "allocated": 0,
            "blocks": {},
            "memory_map": {},
            "is_unlimited": size_gb is None
        }
        self.store_vram_state()
        
    @property
    def total_size(self) -> int:
        """Get total VRAM size in bytes"""
        return self.vram_state["total_size"]
        
    @property
    def available_size(self) -> int:
        """Get available VRAM size in bytes"""
        return self.vram_state["total_size"] - self.vram_state["allocated"] if not self.vram_state["is_unlimited"] else float('inf')
        
    def store_vram_state(self, max_retries=3):
        """Store VRAM state in WebSocket storage with retry logic"""
        for attempt in range(max_retries):
            try:
                # Wait for connection if needed
                if not self.storage.wait_for_connection(timeout=5):
                    print(f"Waiting for WebSocket connection (attempt {attempt + 1}/{max_retries})")
                    time.sleep(1)
                    continue
                    
                # Ensure state is JSON serializable
                safe_state = {
                    "total_size": str(self.vram_state["total_size"]) if isinstance(self.vram_state["total_size"], float) and self.vram_state["total_size"] == float('inf') else self.vram_state["total_size"],
                    "allocated": self.vram_state["allocated"],
                    "blocks": self.vram_state["blocks"],
                    "memory_map": self.vram_state["memory_map"],
                    "is_unlimited": self.vram_state["is_unlimited"]
                }
                
                success = self.storage.store_state("vram", "state", safe_state)
                if success:
                    return True
                    
                print(f"Failed to store VRAM state (attempt {attempt + 1}/{max_retries})")
                time.sleep(1)
                
            except Exception as e:
                print(f"Error storing VRAM state (attempt {attempt + 1}/{max_retries}): {str(e)}")
                time.sleep(1)
                
        raise RuntimeError("Failed to store VRAM state after multiple attempts")
        
    def allocate_block(self, size: int, block_id: Optional[str] = None) -> str:
        """Allocate a block of VRAM"""
        if self.vram_state["allocated"] + size > self.vram_state["total_size"]:
            raise MemoryError("Not enough VRAM available")
            
        if block_id is None:
            block_id = f"block_{time.time_ns()}"
            
        self.vram_state["blocks"][block_id] = {
            "size": size,
            "allocated_at": time.time_ns(),
            "last_accessed": time.time_ns()
        }
        self.vram_state["allocated"] += size
        
        # Store updated state
        self.store_vram_state()
        return block_id
        
    def free_block(self, block_id: str):
        """Free a block of VRAM"""
        if block_id in self.vram_state["blocks"]:
            self.vram_state["allocated"] -= self.vram_state["blocks"][block_id]["size"]
            del self.vram_state["blocks"][block_id]
            self.store_vram_state()
            
            # Remove block data from storage
            self.storage.store_tensor(block_id, None)
            
    def write_block(self, block_id: str, data: np.ndarray):
        """Write data to a VRAM block"""
        if block_id not in self.vram_state["blocks"]:
            raise ValueError(f"Block {block_id} not allocated")
            
        self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns()
        self.store_vram_state()
        
        return self.storage.store_tensor(block_id, data)
        
    def read_block(self, block_id: str) -> Optional[np.ndarray]:
        """Read data from a VRAM block"""
        if block_id not in self.vram_state["blocks"]:
            raise ValueError(f"Block {block_id} not allocated")
            
        self.vram_state["blocks"][block_id]["last_accessed"] = time.time_ns()
        self.store_vram_state()
        
        return self.storage.load_tensor(block_id)
        
    def map_address(self, virtual_addr: str, block_id: str):
        """Map virtual address to VRAM block"""
        self.vram_state["memory_map"][virtual_addr] = block_id
        self.store_vram_state()
        
    def get_block_from_address(self, virtual_addr: str) -> Optional[str]:
        """Get block ID from virtual address"""
        return self.vram_state["memory_map"].get(virtual_addr)
        
    def get_stats(self) -> Dict[str, Any]:
        """Get VRAM statistics"""
        return {
            "total_gb": self.size_gb,
            "used_gb": self.vram_state["allocated"] / (1024 * 1024 * 1024),
            "free_gb": (self.vram_state["total_size"] - self.vram_state["allocated"]) / (1024 * 1024 * 1024),
            "num_blocks": len(self.vram_state["blocks"]),
            "mappings": len(self.vram_state["memory_map"])
        }