Factor Studios commited on
Commit
9be71ce
·
verified ·
1 Parent(s): 172ea54

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +183 -182
torch_vgpu.py CHANGED
@@ -1,182 +1,183 @@
1
- """
2
- Custom PyTorch device implementation that routes operations through our virtual GPU.
3
- """
4
- import torch
5
- from torch.library import Library, impl
6
- from typing import Optional, Union, Tuple
7
- import numpy as np
8
- from virtual_vram import VirtualVRAM
9
-
10
- # Global flag for backend initialization
11
- VGPU_BACKEND_INITIALIZED = False
12
-
13
- def init_vgpu_backend():
14
- """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
15
- global VGPU_BACKEND_INITIALIZED
16
- try:
17
- if not VGPU_BACKEND_INITIALIZED:
18
- # First rename the backend
19
- torch.utils.rename_privateuse1_backend("vgpu")
20
-
21
- # Then generate all the necessary methods
22
- torch.utils.generate_methods_for_privateuse1_backend(
23
- for_tensor=True,
24
- for_module=True,
25
- for_packed_sequence=True,
26
- for_storage=True
27
- )
28
-
29
- # Register our custom library
30
- lib = Library("vgpu", "DEF")
31
- lib.define("custom_op(Tensor self) -> Tensor")
32
-
33
- @impl("vgpu", "custom_op", "Tensor")
34
- def custom_op_impl(tensor):
35
- return tensor.clone()
36
-
37
- VGPU_BACKEND_INITIALIZED = True
38
-
39
- return VGPU_BACKEND_INITIALIZED
40
- except Exception as e:
41
- print(f"Backend initialization warning: {e}")
42
- return False
43
-
44
- class VGPUStorage(torch.Storage):
45
- """Custom storage class that uses our virtual VRAM"""
46
-
47
- def __init__(self, *args, **kwargs):
48
- super().__init__(*args, **kwargs)
49
- self.vram = kwargs.get('vram')
50
- if not self.vram:
51
- from virtual_vram import VirtualVRAM
52
- self.vram = VirtualVRAM()
53
- self.tensor_id = kwargs.get('tensor_id', f"tensor_{id(self)}")
54
-
55
- def _new_shared(self, size):
56
- return VGPUStorage(size, vram=self.vram)
57
-
58
- class VGPUTensor:
59
- """Tensor implementation that uses vGPU for computations"""
60
- @staticmethod
61
- def __new__(cls, elem):
62
- return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
63
-
64
- class VGPUDevice:
65
- """
66
- Custom PyTorch device implementation that routes operations through vGPU.
67
- Usage:
68
- vgpu = VGPUDevice()
69
- with vgpu.mode():
70
- tensor = torch.randn(2, 3) # Will be on vGPU
71
- """
72
- _VGPU_INSTANCES = {} # Class-level dict to track instances
73
-
74
- def __init__(self, vram: Optional[VirtualVRAM] = None):
75
- self.vram = vram or VirtualVRAM()
76
- self.tensor_cores = None # Will be initialized when needed
77
- self.device_name = "vgpu" # Both internal and user-facing name
78
- self._register_device()
79
-
80
- def _register_device(self):
81
- """Register vGPU device using PyTorch's device system"""
82
- try:
83
- if not VGPU_BACKEND_INITIALIZED:
84
- raise RuntimeError("VGPU backend not properly initialized")
85
-
86
- # Create device with explicit index
87
- self._device = torch.device("vgpu")
88
-
89
- # Store this instance for reuse
90
- VGPUDevice._VGPU_INSTANCES[self.device_name] = self
91
-
92
- # Define custom operations for the device
93
- class VGPUAllocator:
94
- def __init__(self, vram, device):
95
- self.vram = vram
96
- self.device = device
97
-
98
- def __call__(self, size, dtype=None, device=None):
99
- # Create tensor on CPU first
100
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
101
- # Move to vGPU storage
102
- return to_vgpu(cpu_tensor, self.vram)
103
-
104
- # Set up allocator
105
- self._allocator = VGPUAllocator(self.vram, self._device)
106
-
107
- except Exception as e:
108
- raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
109
-
110
- @property
111
- def type(self):
112
- return self.internal_name
113
-
114
- def __str__(self):
115
- return f"{self.internal_name}:0"
116
-
117
- def __repr__(self):
118
- return f"vgpu(device='{self.internal_name}:0')"
119
-
120
- def device(self):
121
- """Get the PyTorch device object that maps to our vGPU"""
122
- return self._device # Return the already created device object
123
-
124
- def mode(self):
125
- """Get a context manager for vGPU operations"""
126
- return torch.device(self._device)
127
-
128
- def _init_tensor_cores(self):
129
- if self.tensor_cores is None:
130
- from tensor_core import TensorCoreArray
131
- self.tensor_cores = TensorCoreArray()
132
-
133
- def _to_vram(self, tensor: torch.Tensor) -> str:
134
- """Store tensor data in virtual VRAM"""
135
- tensor_id = f"tensor_{id(tensor)}"
136
- data = tensor.detach().cpu().numpy()
137
- self.vram.storage.store_tensor(tensor_id, data)
138
- return tensor_id
139
-
140
- def _from_vram(self, tensor_id: str) -> torch.Tensor:
141
- """Retrieve tensor data from virtual VRAM"""
142
- data = self.vram.storage.load_tensor(tensor_id)
143
- return torch.from_numpy(data)
144
-
145
- def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
146
- """Matrix multiplication using tensor cores"""
147
- self._init_tensor_cores()
148
-
149
- # Store inputs in VRAM
150
- a_id = self._to_vram(a)
151
- b_id = self._to_vram(b)
152
-
153
- # Perform matmul using tensor cores
154
- result = self.tensor_cores.matmul(
155
- self.vram.storage.load_tensor(a_id),
156
- self.vram.storage.load_tensor(b_id)
157
- )
158
-
159
- # Create new tensor with result
160
- return torch.from_numpy(result)
161
-
162
- def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
163
- """Move a tensor to vGPU device"""
164
- if not isinstance(tensor, torch.Tensor):
165
- tensor = torch.tensor(tensor)
166
-
167
- # Get or create vGPU device
168
- if not VGPUDevice._VGPU_INSTANCES:
169
- device = VGPUDevice(vram)
170
- else:
171
- device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
172
- if vram is not None:
173
- device.vram = vram
174
-
175
- # Move data to vRAM
176
- tensor_id = device._to_vram(tensor)
177
- result = device._from_vram(tensor_id)
178
- result.requires_grad = tensor.requires_grad
179
-
180
- # Set the device using the internal name
181
- result.data = result.data.to(device._device)
182
- return result
 
 
1
+ """
2
+ Custom PyTorch device implementation that routes operations through our virtual GPU.
3
+ """
4
+ import torch
5
+ from torch.library import Library, impl
6
+ from typing import Optional, Union, Tuple
7
+ import numpy as np
8
+ from virtual_vram import VirtualVRAM
9
+
10
+ # Global flag for backend initialization
11
+ VGPU_BACKEND_INITIALIZED = False
12
+
13
+ def init_vgpu_backend():
14
+ """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
15
+ global VGPU_BACKEND_INITIALIZED
16
+ try:
17
+ if not VGPU_BACKEND_INITIALIZED:
18
+ # First define the library with privateuseone
19
+ lib = Library("privateuseone", "DEF")
20
+ lib.define("custom_op(Tensor self) -> Tensor")
21
+
22
+ @impl("privateuseone", "custom_op", "Tensor")
23
+ def custom_op_impl(tensor):
24
+ return tensor.clone()
25
+
26
+ # Then rename the backend
27
+ torch.utils.rename_privateuse1_backend("vgpu")
28
+
29
+ # Finally generate all the necessary methods
30
+ torch.utils.generate_methods_for_privateuse1_backend(
31
+ for_tensor=True,
32
+ for_module=True,
33
+ for_packed_sequence=True,
34
+ for_storage=True
35
+ )
36
+
37
+ VGPU_BACKEND_INITIALIZED = True
38
+
39
+ return VGPU_BACKEND_INITIALIZED
40
+ except Exception as e:
41
+ print(f"Backend initialization warning: {e}")
42
+ return False
43
+
44
+ class VGPUStorage(torch.Storage):
45
+ """Custom storage class that uses our virtual VRAM"""
46
+
47
+ def __init__(self, *args, **kwargs):
48
+ super().__init__(*args, **kwargs)
49
+ self.vram = kwargs.get('vram')
50
+ if not self.vram:
51
+ from virtual_vram import VirtualVRAM
52
+ self.vram = VirtualVRAM()
53
+ self.tensor_id = kwargs.get('tensor_id', f"tensor_{id(self)}")
54
+
55
+ def _new_shared(self, size):
56
+ return VGPUStorage(size, vram=self.vram)
57
+
58
+ class VGPUTensor:
59
+ """Tensor implementation that uses vGPU for computations"""
60
+ @staticmethod
61
+ def __new__(cls, elem):
62
+ return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
63
+
64
+ class VGPUDevice:
65
+ """
66
+ Custom PyTorch device implementation that routes operations through vGPU.
67
+ Usage:
68
+ vgpu = VGPUDevice()
69
+ with vgpu.mode():
70
+ tensor = torch.randn(2, 3) # Will be on vGPU
71
+ """
72
+ _VGPU_INSTANCES = {} # Class-level dict to track instances
73
+
74
+ def __init__(self, vram: Optional[VirtualVRAM] = None):
75
+ self.vram = vram or VirtualVRAM()
76
+ self.tensor_cores = None # Will be initialized when needed
77
+ self.device_name = "privateuseone" # Use original name first
78
+ self._register_device()
79
+ self.device_name = "vgpu" # Then switch to renamed backend
80
+
81
+ def _register_device(self):
82
+ """Register vGPU device using PyTorch's device system"""
83
+ try:
84
+ if not VGPU_BACKEND_INITIALIZED:
85
+ raise RuntimeError("VGPU backend not properly initialized")
86
+
87
+ # Create device with explicit index
88
+ self._device = torch.device("vgpu")
89
+
90
+ # Store this instance for reuse
91
+ VGPUDevice._VGPU_INSTANCES[self.device_name] = self
92
+
93
+ # Define custom operations for the device
94
+ class VGPUAllocator:
95
+ def __init__(self, vram, device):
96
+ self.vram = vram
97
+ self.device = device
98
+
99
+ def __call__(self, size, dtype=None, device=None):
100
+ # Create tensor on CPU first
101
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
102
+ # Move to vGPU storage
103
+ return to_vgpu(cpu_tensor, self.vram)
104
+
105
+ # Set up allocator
106
+ self._allocator = VGPUAllocator(self.vram, self._device)
107
+
108
+ except Exception as e:
109
+ raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
110
+
111
+ @property
112
+ def type(self):
113
+ return self.internal_name
114
+
115
+ def __str__(self):
116
+ return f"{self.internal_name}:0"
117
+
118
+ def __repr__(self):
119
+ return f"vgpu(device='{self.internal_name}:0')"
120
+
121
+ def device(self):
122
+ """Get the PyTorch device object that maps to our vGPU"""
123
+ return self._device # Return the already created device object
124
+
125
+ def mode(self):
126
+ """Get a context manager for vGPU operations"""
127
+ return torch.device(self._device)
128
+
129
+ def _init_tensor_cores(self):
130
+ if self.tensor_cores is None:
131
+ from tensor_core import TensorCoreArray
132
+ self.tensor_cores = TensorCoreArray()
133
+
134
+ def _to_vram(self, tensor: torch.Tensor) -> str:
135
+ """Store tensor data in virtual VRAM"""
136
+ tensor_id = f"tensor_{id(tensor)}"
137
+ data = tensor.detach().cpu().numpy()
138
+ self.vram.storage.store_tensor(tensor_id, data)
139
+ return tensor_id
140
+
141
+ def _from_vram(self, tensor_id: str) -> torch.Tensor:
142
+ """Retrieve tensor data from virtual VRAM"""
143
+ data = self.vram.storage.load_tensor(tensor_id)
144
+ return torch.from_numpy(data)
145
+
146
+ def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
147
+ """Matrix multiplication using tensor cores"""
148
+ self._init_tensor_cores()
149
+
150
+ # Store inputs in VRAM
151
+ a_id = self._to_vram(a)
152
+ b_id = self._to_vram(b)
153
+
154
+ # Perform matmul using tensor cores
155
+ result = self.tensor_cores.matmul(
156
+ self.vram.storage.load_tensor(a_id),
157
+ self.vram.storage.load_tensor(b_id)
158
+ )
159
+
160
+ # Create new tensor with result
161
+ return torch.from_numpy(result)
162
+
163
+ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
164
+ """Move a tensor to vGPU device"""
165
+ if not isinstance(tensor, torch.Tensor):
166
+ tensor = torch.tensor(tensor)
167
+
168
+ # Get or create vGPU device
169
+ if not VGPUDevice._VGPU_INSTANCES:
170
+ device = VGPUDevice(vram)
171
+ else:
172
+ device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
173
+ if vram is not None:
174
+ device.vram = vram
175
+
176
+ # Move data to vRAM
177
+ tensor_id = device._to_vram(tensor)
178
+ result = device._from_vram(tensor_id)
179
+ result.requires_grad = tensor.requires_grad
180
+
181
+ # Set the device using the internal name
182
+ result.data = result.data.to(device._device)
183
+ return result