File size: 15,020 Bytes
5cdc76b
 
 
 
 
e64ebad
5cdc76b
319082b
5cdc76b
 
e64ebad
 
 
 
 
5cdc76b
 
 
 
 
e64ebad
962c8c7
 
e64ebad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962c8c7
e64ebad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962c8c7
e64ebad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962c8c7
e64ebad
 
 
 
962c8c7
 
e64ebad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962c8c7
5cdc76b
 
 
e64ebad
5cdc76b
962c8c7
 
 
5cdc76b
 
e64ebad
 
5cdc76b
e64ebad
 
 
 
 
 
 
 
 
5cdc76b
e64ebad
 
 
 
 
 
 
5cdc76b
962c8c7
e64ebad
962c8c7
e64ebad
 
962c8c7
 
 
e64ebad
 
 
 
962c8c7
 
e64ebad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cdc76b
 
 
 
 
 
e64ebad
5cdc76b
e64ebad
5cdc76b
e64ebad
 
962c8c7
e64ebad
962c8c7
5cdc76b
e64ebad
 
 
0e61fb3
319082b
e64ebad
 
319082b
e64ebad
 
319082b
e64ebad
962c8c7
319082b
e64ebad
 
 
5cdc76b
e64ebad
 
962c8c7
e64ebad
 
5cdc76b
e64ebad
 
 
5cdc76b
e64ebad
 
 
 
 
 
 
5cdc76b
e64ebad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962c8c7
e64ebad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962c8c7
e64ebad
962c8c7
 
5cdc76b
e64ebad
 
 
 
 
 
 
0e61fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962c8c7
 
 
e64ebad
 
 
962c8c7
e64ebad
 
 
 
 
 
962c8c7
 
 
 
e64ebad
 
 
 
 
 
 
 
 
 
0e61fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e64ebad
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
import torch
from torch.library import Library, impl
from typing import Optional, Union, Tuple
import numpy as np
from virtual_vram import VirtualVRAM
import warnings

# Global flag for backend initialization
VGPU_BACKEND_INITIALIZED = False

def get_pytorch_version():
    """Get PyTorch version as tuple for comparison"""
    version = torch.__version__.split('.')
    return tuple(int(x.split('+')[0]) for x in version[:2])

def init_vgpu_backend():
    """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
    global VGPU_BACKEND_INITIALIZED
    try:
        if not VGPU_BACKEND_INITIALIZED:
            pytorch_version = get_pytorch_version()
            backend_name = "vgpu"
            
            # Method 1: Try modern PyTorch approach (2.0+)
            if pytorch_version >= (2, 0):
                try:
                    # Try the new API first
                    if hasattr(torch._C, '_dispatch') and hasattr(torch._C._dispatch, '_rename_privateuse1_backend'):
                        torch._C._dispatch._rename_privateuse1_backend(backend_name)
                    elif hasattr(torch, '_register_privateuse1_backend'):
                        # Alternative API in some PyTorch versions
                        torch._register_privateuse1_backend(backend_name)
                    else:
                        # Fallback: use torch.utils approach
                        raise AttributeError("Modern API not available")
                        
                    # Generate methods for the backend
                    torch.utils.generate_methods_for_privateuse1_backend(
                        for_tensor=True,
                        for_module=True,
                        for_packed_sequence=True,
                        for_storage=True
                    )
                    backend_registered = True
                except (AttributeError, RuntimeError) as e:
                    print(f"Modern backend registration failed: {e}")
                    backend_registered = False
            else:
                backend_registered = False
            
            # Method 2: Fallback approach for older PyTorch or when modern approach fails
            if not backend_registered:
                print(f"Using fallback registration method for PyTorch {torch.__version__}")
                
                # Create a mock device type that behaves like a custom device
                class VGPUDeviceType:
                    def __init__(self, name):
                        self.name = name
                        self.index = 0
                    
                    def __str__(self):
                        return f"{self.name}:{self.index}"
                    
                    def __repr__(self):
                        return f"device(type='{self.name}', index={self.index})"
                
                # Register our device type manually
                backend_name = "vgpu"
            
            # Define core operations using Library
            try:
                lib = Library(backend_name, "DEF")
                impl_lib = Library(backend_name, "IMPL", "PrivateUse1")
                
                # Define essential operations
                lib.define("empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor")
                lib.define("copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)")
                lib.define("add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor")
                lib.define("mm(Tensor self, Tensor mat2) -> Tensor")
                
                @impl(impl_lib, "empty.memory_format")
                def empty_memory_format(size, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None):
                    dtype = dtype or torch.float32
                    # Create on CPU but track metadata for vGPU
                    result = torch.empty(size, dtype=dtype, device='cpu')
                    return result
                
                @impl(impl_lib, "copy_")
                def copy_impl(self, src, non_blocking=False):
                    if isinstance(src, torch.Tensor):
                        self.data.copy_(src.cpu().data if hasattr(src, 'cpu') else src.data)
                    return self
                
                @impl(impl_lib, "add.Tensor")
                def add_tensor(self, other, alpha=1):
                    # Perform add on CPU then return result
                    self_cpu = self.cpu() if hasattr(self, 'cpu') else self
                    other_cpu = other.cpu() if hasattr(other, 'cpu') else other
                    result = torch.add(self_cpu, other_cpu, alpha=alpha)
                    return result
                
                @impl(impl_lib, "mm")  
                def mm_impl(self, mat2):
                    # Perform matmul on CPU
                    self_cpu = self.cpu() if hasattr(self, 'cpu') else self
                    mat2_cpu = mat2.cpu() if hasattr(mat2, 'cpu') else mat2
                    result = torch.mm(self_cpu, mat2_cpu)
                    return result
                    
            except Exception as e:
                print(f"Library registration warning: {e}")
                # Continue without library registration
            
            VGPU_BACKEND_INITIALIZED = True
            
        return VGPU_BACKEND_INITIALIZED
        
    except Exception as e:
        print(f"Backend initialization error: {e}")
        import traceback
        traceback.print_exc()
        return False

class VGPUDeviceMock:
    """Mock device class that behaves like a PyTorch device"""
    
    def __init__(self, device_name="vgpu", index=0):
        self.type = device_name
        self.index = index
    
    def __str__(self):
        return f"{self.type}:{self.index}"
    
    def __repr__(self):
        return f"device(type='{self.type}', index={self.index})"
    
    def __eq__(self, other):
        if isinstance(other, (VGPUDeviceMock, torch.device)):
            return str(self) == str(other)
        return str(self) == str(other)
    
    def __hash__(self):
        return hash(str(self))

class VGPUTensor(torch.Tensor):
    """Custom tensor class that handles vGPU operations"""
    
    @staticmethod 
    def __new__(cls, data, device=None, requires_grad=False, vram=None):
        if not isinstance(data, torch.Tensor):
            data = torch.as_tensor(data)
        
        # Create tensor on CPU but track vGPU device
        r = torch.Tensor._make_subclass(cls, data.cpu(), requires_grad)
        r._vgpu_device = device
        r._vram = vram
        return r
    
    @property
    def device(self):
        """Return the vGPU device"""
        return self._vgpu_device or VGPUDeviceMock()
    
    def cpu(self):
        """Move tensor to CPU"""
        cpu_tensor = torch.Tensor(self.data)
        cpu_tensor.requires_grad = self.requires_grad
        return cpu_tensor
    
    def to(self, device, **kwargs):
        """Handle device transfers"""
        if isinstance(device, (VGPUDeviceMock, str)) and ('vgpu' in str(device)):
            # Stay on vGPU
            return self
        else:
            # Move to requested device
            return self.data.to(device, **kwargs)

class VGPUDevice:
    """
    Custom PyTorch device implementation that routes operations through vGPU.
    Usage:
        vgpu = VGPUDevice()
        tensor = vgpu.tensor([1, 2, 3])  # Create tensor on vGPU
    """
    _VGPU_INSTANCES = {}
    
    def __init__(self, vram: Optional[VirtualVRAM] = None, device_index: int = 0):
        # Initialize backend
        if not init_vgpu_backend():
            print("Warning: Backend initialization incomplete, using fallback mode")
            
        self.vram = vram or VirtualVRAM()
        self.tensor_cores = None
        self.device_name = "vgpu"
        self.device_index = device_index
        self._device = torch.device(f"{self.device_name}:{device_index}")
        
        # Store this instance
        VGPUDevice._VGPU_INSTANCES[f"{self.device_name}:{device_index}"] = self
        
        print(f"βœ“ vGPU device initialized: {self._device}")

    def device(self):
        """Get the device object"""
        return self._device
    
    def tensor(self, data, **kwargs):
        """Create a tensor on this vGPU device"""
        kwargs.pop('device', None)  # Remove device if specified
        
        if isinstance(data, torch.Tensor):
            result = VGPUTensor(data, device=self._device, vram=self.vram, **kwargs)
        else:
            cpu_tensor = torch.tensor(data, **kwargs)
            result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
        
        # Store in vRAM
        self._to_vram(result)
        return result
    
    def randn(self, *size, **kwargs):
        """Create random tensor on vGPU"""
        kwargs.pop('device', None)
        cpu_tensor = torch.randn(*size, **kwargs)
        result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
        self._to_vram(result)
        return result
    
    def zeros(self, *size, **kwargs):
        """Create zero tensor on vGPU"""
        kwargs.pop('device', None)
        cpu_tensor = torch.zeros(*size, **kwargs)
        result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
        self._to_vram(result)
        return result
    
    def ones(self, *size, **kwargs):
        """Create ones tensor on vGPU"""
        kwargs.pop('device', None)
        cpu_tensor = torch.ones(*size, **kwargs)
        result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
        self._to_vram(result)
        return result
    
    def empty(self, *size, **kwargs):
        """Create empty tensor on vGPU"""
        kwargs.pop('device', None)
        cpu_tensor = torch.empty(*size, **kwargs)
        result = VGPUTensor(cpu_tensor, device=self._device, vram=self.vram)
        self._to_vram(result)
        return result
    
    def _to_vram(self, tensor):
        """Store tensor in vRAM"""
        if hasattr(tensor, '_vram') and tensor._vram:
            tensor_id = f"tensor_{id(tensor)}"
            data = tensor.detach().cpu().numpy()
            tensor._vram.storage.store_tensor(tensor_id, data)
            tensor._vram_id = tensor_id
    
    def _from_vram(self, tensor):
        """Load tensor from vRAM"""
        if hasattr(tensor, '_vram_id') and hasattr(tensor, '_vram'):
            data = tensor._vram.storage.load_tensor(tensor._vram_id)
            return torch.from_numpy(data)
        return tensor.cpu()
    
    def __str__(self):
        return str(self._device)
    
    def __repr__(self):
        return f"VGPUDevice({self._device})"

# Convenience functions
def to_vgpu(tensor, vram=None):
    """Move tensor to vGPU"""
    if not VGPUDevice._VGPU_INSTANCES:
        device = VGPUDevice(vram)
    else:
        device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
    
    if isinstance(tensor, VGPUTensor):
        return tensor
    
    result = VGPUTensor(tensor, device=device.device(), vram=device.vram)
    device._to_vram(result)
    return result

# Create a proper device class that extends torch.device behavior
class VGPUDeviceWrapper(torch.device):
    """Extended device class that handles vGPU devices while maintaining torch.device compatibility"""
    
    def __new__(cls, device_spec):
        if isinstance(device_spec, str) and device_spec.startswith('vgpu'):
            # Create a CPU device internally but track vGPU info
            parts = device_spec.split(':')
            device_name = parts[0] 
            device_index = int(parts[1]) if len(parts) > 1 else 0
            
            # Create CPU device as base
            obj = super().__new__(cls, 'cpu')
            obj._vgpu_type = device_name
            obj._vgpu_index = device_index
            obj._is_vgpu = True
            return obj
        else:
            # Regular device creation
            return super().__new__(cls, device_spec)
    
    def __init__(self, device_spec):
        # Only initialize if not already done by __new__
        if not hasattr(self, '_is_vgpu'):
            super().__init__()
            self._is_vgpu = False
    
    @property
    def type(self):
        if hasattr(self, '_is_vgpu') and self._is_vgpu:
            return self._vgpu_type
        return super().type
    
    @property 
    def index(self):
        if hasattr(self, '_is_vgpu') and self._is_vgpu:
            return self._vgpu_index
        return super().index
    
    def __str__(self):
        if hasattr(self, '_is_vgpu') and self._is_vgpu:
            return f"{self._vgpu_type}:{self._vgpu_index}"
        return super().__str__()
    
    def __repr__(self):
        if hasattr(self, '_is_vgpu') and self._is_vgpu:
            return f"device(type='{self._vgpu_type}', index={self._vgpu_index})"
        return super().__repr__()

# Store original torch.device
_original_torch_device = torch.device

# Replace torch.device with our wrapper
torch.device = VGPUDeviceWrapper

# Example usage and testing
if __name__ == "__main__":
    print(f"PyTorch version: {torch.__version__}")
    
    # Test backend initialization
    if init_vgpu_backend():
        print("βœ“ vGPU backend initialized")
    else:
        print("! vGPU backend initialization incomplete, using fallback")
    
    # Create vGPU device
    try:
        vgpu = VGPUDevice()
        print(f"βœ“ vGPU device created: {vgpu}")
        
        # Test tensor creation
        x = vgpu.randn(2, 3)
        print(f"βœ“ Random tensor created on {x.device}: shape {x.shape}")
        
        y = vgpu.ones(3, 4) 
        print(f"βœ“ Ones tensor created on {y.device}: shape {y.shape}")
        
        # Test basic operations
        z = x.data @ y.data  # Matrix multiply on CPU data
        print(f"βœ“ Matrix multiplication result shape: {z.shape}")
        
        # Test device string parsing - use a safer approach
        try:
            device_str = torch.device("vgpu:0")
            print(f"βœ“ Device string parsing: {device_str}")
            print(f"βœ“ Device type check: isinstance(device_str, torch.device) = {isinstance(device_str, torch.device)}")
        except Exception as e:
            print(f"! Device string parsing issue: {e}")
        
        # Test compatibility with transformers-style isinstance checks
        cpu_device = torch.device("cpu")
        print(f"βœ“ CPU device isinstance check: {isinstance(cpu_device, torch.device)}")
        
        vgpu_device = torch.device("vgpu:0") 
        print(f"βœ“ vGPU device isinstance check: {isinstance(vgpu_device, torch.device)}")
        
        print(f"βœ“ Device compatibility tests passed")
        
    except Exception as e:
        print(f"βœ— Test failed: {e}")
        import traceback
        traceback.print_exc()