Factor Studios commited on
Commit
5cdc76b
·
verified ·
1 Parent(s): e4541c8

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +201 -200
torch_vgpu.py CHANGED
@@ -1,200 +1,201 @@
1
- import torch
2
- from torch.library import Library, impl
3
- from typing import Optional, Union, Tuple
4
- import numpy as np
5
- from virtual_vram import VirtualVRAM
6
-
7
- # Global variables for backend state
8
- VGPU_BACKEND_INITIALIZED = False
9
- CURRENT_VRAM = None # Global reference to current vRAM manager
10
-
11
- def set_current_vram(vram):
12
- """Set the current vRAM manager globally"""
13
- global CURRENT_VRAM
14
- CURRENT_VRAM = vram
15
-
16
- def get_current_vram():
17
- """Get the current vRAM manager"""
18
- return CURRENT_VRAM
19
-
20
- def init_vgpu_backend():
21
- """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
22
- global VGPU_BACKEND_INITIALIZED
23
- try:
24
- if not VGPU_BACKEND_INITIALIZED:
25
- # First define our core library
26
- lib = Library("vgpu", "DEF")
27
- lib.define("custom_allocate(Device? device) -> Tensor")
28
- lib.define("custom_to_cpu(Tensor self) -> Tensor")
29
- lib.define("custom_from_cpu(Tensor self) -> Tensor")
30
-
31
- # Then implement the operations
32
- impl_lib = Library("vgpu", "IMPL", "PrivateUse1")
33
-
34
- @impl(impl_lib, "custom_allocate")
35
- def custom_allocate(device=None):
36
- return torch.empty((), device="cpu")
37
-
38
- @impl(impl_lib, "custom_to_cpu")
39
- def custom_to_cpu(tensor):
40
- return tensor.clone()
41
-
42
- @impl(impl_lib, "custom_from_cpu")
43
- def custom_from_cpu(tensor):
44
- return tensor.clone()
45
-
46
- # Generate all methods for our backend
47
- torch.utils.generate_methods_for_privateuse1_backend(
48
- for_tensor=True,
49
- for_module=True,
50
- for_packed_sequence=True,
51
- for_storage=True
52
- )
53
-
54
- VGPU_BACKEND_INITIALIZED = True
55
-
56
- return VGPU_BACKEND_INITIALIZED
57
- except Exception as e:
58
- print(f"Backend initialization warning: {e}")
59
- return False
60
-
61
- class VGPUStorage(torch.Storage):
62
- """Custom storage class that uses our virtual VRAM"""
63
-
64
- def __init__(self, *args, **kwargs):
65
- super().__init__(*args, **kwargs)
66
- self.vram = kwargs.get("vram")
67
- if not self.vram:
68
- from virtual_vram import VirtualVRAM
69
- self.vram = VirtualVRAM()
70
- self.tensor_id = kwargs.get("tensor_id", f"tensor_{id(self)}")
71
-
72
- def _new_shared(self, size):
73
- return VGPUStorage(size, vram=self.vram)
74
-
75
- class VGPUTensor:
76
- """Tensor implementation that uses vGPU for computations"""
77
- @staticmethod
78
- def __new__(cls, elem):
79
- return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
80
-
81
- class VGPUDevice:
82
- """
83
- Custom PyTorch device implementation that routes operations through vGPU.
84
- Usage:
85
- vgpu = VGPUDevice()
86
- with vgpu.mode():
87
- tensor = torch.randn(2, 3) # Will be on vGPU
88
- """
89
- _VGPU_INSTANCES = {} # Class-level dict to track instances
90
-
91
- def __init__(self, vram: Optional[VirtualVRAM] = None):
92
- self.vram = vram or VirtualVRAM()
93
- self.tensor_cores = None # Will be initialized when needed
94
- self.device_name = "privateuseone" # Our registered device type
95
- set_current_vram(self.vram) # Set up global vRAM reference
96
- self._register_device()
97
-
98
- def _register_device(self):
99
- """Register vGPU device using PyTorch's device system"""
100
- try:
101
- if not VGPU_BACKEND_INITIALIZED:
102
- raise RuntimeError("VGPU backend not properly initialized")
103
-
104
- # Create device using our registered device type
105
- self._device = torch.device(self.device_name)
106
-
107
- # Store this instance for reuse
108
- VGPUDevice._VGPU_INSTANCES[self.device_name] = self
109
-
110
- # Define custom operations for the device
111
- class VGPUAllocator:
112
- def __init__(self, vram, device):
113
- self.vram = vram
114
- self.device = device
115
-
116
- def __call__(self, size, dtype=None, device=None):
117
- # Create tensor on CPU first
118
- cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
119
- # Move to vGPU storage
120
- return to_vgpu(cpu_tensor, self.vram)
121
-
122
- # Set up allocator
123
- self._allocator = VGPUAllocator(self.vram, self._device)
124
-
125
- except Exception as e:
126
- raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
127
-
128
- @property
129
- def type(self):
130
- return self.internal_name
131
-
132
- def __str__(self):
133
- return f"{self.internal_name}:0"
134
-
135
- def __repr__(self):
136
- return f"vgpu(device='{self.internal_name}:0')"
137
-
138
- def device(self):
139
- """Get the PyTorch device object that maps to our vGPU"""
140
- return self._device # Return the already created device object
141
-
142
- def mode(self):
143
- """Get a context manager for vGPU operations"""
144
- return torch.device(self._device)
145
-
146
- def _init_tensor_cores(self):
147
- if self.tensor_cores is None:
148
- from tensor_core import TensorCoreArray
149
- self.tensor_cores = TensorCoreArray()
150
-
151
- def _to_vram(self, tensor: torch.Tensor) -> str:
152
- """Store tensor data in virtual VRAM"""
153
- tensor_id = f"tensor_{id(tensor)}"
154
- data = tensor.detach().cpu().numpy()
155
- self.vram.storage.store_tensor(tensor_id, data)
156
- return tensor_id
157
-
158
- def _from_vram(self, tensor_id: str) -> torch.Tensor:
159
- """Retrieve tensor data from virtual VRAM"""
160
- data = self.vram.storage.load_tensor(tensor_id)
161
- return torch.from_numpy(data)
162
-
163
- def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
164
- """Matrix multiplication using tensor cores"""
165
- self._init_tensor_cores()
166
-
167
- # Store inputs in VRAM
168
- a_id = self._to_vram(a)
169
- b_id = self._to_vram(b)
170
-
171
- # Perform matmul using tensor cores
172
- result = self.tensor_cores.matmul(
173
- self.vram.storage.load_tensor(a_id),
174
- self.vram.storage.load_tensor(b_id)
175
- )
176
-
177
- # Create new tensor with result
178
- return torch.from_numpy(result)
179
-
180
- def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
181
- """Move a tensor to vGPU device"""
182
- if not isinstance(tensor, torch.Tensor):
183
- tensor = torch.tensor(tensor)
184
-
185
- # Get or create vGPU device
186
- if not VGPUDevice._VGPU_INSTANCES:
187
- device = VGPUDevice(vram)
188
- else:
189
- device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
190
- if vram is not None:
191
- device.vram = vram
192
-
193
- # Move data to vRAM
194
- tensor_id = device._to_vram(tensor)
195
- result = device._from_vram(tensor_id)
196
- result.requires_grad = tensor.requires_grad
197
-
198
- # Set the device using the internal name
199
- result.data = result.data.to(device._device)
200
- return result
 
 
1
+ import torch
2
+ from torch.library import Library, impl
3
+ from typing import Optional, Union, Tuple
4
+ import numpy as np
5
+ from virtual_vram import VirtualVRAM
6
+
7
+ # Global variables for backend state
8
+ VGPU_BACKEND_INITIALIZED = False
9
+ CURRENT_VRAM = None # Global reference to current vRAM manager
10
+
11
+ def set_current_vram(vram):
12
+ """Set the current vRAM manager globally"""
13
+ global CURRENT_VRAM
14
+ CURRENT_VRAM = vram
15
+
16
+ def get_current_vram():
17
+ """Get the current vRAM manager"""
18
+ return CURRENT_VRAM
19
+
20
+ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
21
+ """Move a tensor to vGPU memory"""
22
+ if vram is None:
23
+ vram = get_current_vram()
24
+ if vram is None:
25
+ raise RuntimeError("No vRAM manager available. Initialize VGPUDevice first.")
26
+
27
+ # Get data and store in vRAM
28
+ cpu_data = tensor.detach().cpu().numpy()
29
+ tensor_id = f"tensor_{id(tensor)}"
30
+ vram.store(tensor_id, cpu_data)
31
+
32
+ # Create vGPU tensor
33
+ device = torch.device("privateuseone")
34
+ vgpu_storage = VGPUStorage(
35
+ cpu_data.size,
36
+ vram=vram,
37
+ tensor_id=tensor_id
38
+ )
39
+ vgpu_tensor = torch.tensor(
40
+ [],
41
+ device=device,
42
+ requires_grad=tensor.requires_grad
43
+ )
44
+ vgpu_tensor.set_(vgpu_storage)
45
+
46
+ return vgpu_tensor
47
+
48
+ def init_vgpu_backend():
49
+ """Initialize the vGPU backend. Must be called before creating any VGPUDevice instances."""
50
+ global VGPU_BACKEND_INITIALIZED
51
+ try:
52
+ if not VGPU_BACKEND_INITIALIZED:
53
+ # Create library for custom ops
54
+ lib = Library("vgpu", "DEF")
55
+ lib.define("custom_from_cpu(Tensor x) -> Tensor")
56
+
57
+ impl_lib = Library("vgpu", "IMPL")
58
+
59
+ @impl(impl_lib, "custom_from_cpu")
60
+ def custom_from_cpu(x):
61
+ """Copy tensor to our vGPU memory"""
62
+ return x.clone()
63
+
64
+ # Set initialization flag
65
+ VGPU_BACKEND_INITIALIZED = True
66
+
67
+ return VGPU_BACKEND_INITIALIZED
68
+ except Exception as e:
69
+ print(f"Backend initialization warning: {e}")
70
+ return False
71
+
72
+ class VGPUStorage(torch.Storage):
73
+ """Custom storage class that uses our virtual VRAM"""
74
+
75
+ def __init__(self, *args, **kwargs):
76
+ super().__init__(*args, **kwargs)
77
+ self.vram = kwargs.get("vram")
78
+ if not self.vram:
79
+ from virtual_vram import VirtualVRAM
80
+ self.vram = VirtualVRAM()
81
+ self.tensor_id = kwargs.get("tensor_id", f"tensor_{id(self)}")
82
+
83
+ def _new_shared(self, size):
84
+ return VGPUStorage(size, vram=self.vram)
85
+
86
+ class VGPUTensor:
87
+ """Tensor implementation that uses vGPU for computations"""
88
+ @staticmethod
89
+ def __new__(cls, elem):
90
+ return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
91
+
92
+ from contextlib import contextmanager
93
+
94
+ # Custom allocator for vGPU tensors
95
+ class VGPUAllocator:
96
+ def __init__(self, vram):
97
+ self.vram = vram
98
+
99
+ def __call__(self, size, dtype=None, device=None):
100
+ cpu_tensor = torch.empty(size, dtype=dtype, device='cpu')
101
+ return to_vgpu(cpu_tensor, self.vram)
102
+
103
+ class VGPUDevice:
104
+ """
105
+ Custom PyTorch device implementation that routes operations through vGPU.
106
+ Usage:
107
+ vgpu = VGPUDevice()
108
+ with vgpu.mode():
109
+ tensor = torch.randn(2, 3) # Will be on vGPU
110
+ """
111
+ _VGPU_INSTANCES = {} # Class-level dict to track instances
112
+
113
+ def __init__(self, vram: Optional[VirtualVRAM] = None):
114
+ """Initialize a vGPU device with optional vRAM manager"""
115
+ self.vram = vram or VirtualVRAM()
116
+ self.device_name = "privateuseone" # Our device type
117
+ self._init_device()
118
+
119
+ def _init_device(self):
120
+ """Initialize the device backend and settings"""
121
+ if not VGPU_BACKEND_INITIALIZED:
122
+ raise RuntimeError("VGPU backend not properly initialized")
123
+
124
+ # Setup device and global vRAM
125
+ self._device = torch.device(self.device_name)
126
+ set_current_vram(self.vram)
127
+
128
+ # Register instance
129
+ VGPUDevice._VGPU_INSTANCES[self.device_name] = self
130
+
131
+ # Setup allocator
132
+ self._allocator = VGPUAllocator(self.vram)
133
+
134
+ def device(self) -> torch.device:
135
+ """Get the PyTorch device object for this vGPU"""
136
+ return self._device
137
+
138
+ @contextmanager
139
+ def mode(self):
140
+ """Context manager for using this device as the default"""
141
+ prev_device = torch.device("cpu")
142
+ try:
143
+ prev_device = torch.cuda.current_device() if torch.cuda.is_available() else prev_device
144
+ torch.set_device(self._device)
145
+ yield
146
+ finally:
147
+ torch.set_device(prev_device)
148
+
149
+ def __str__(self):
150
+ """String representation of the device"""
151
+ return f"{self.device_name}:0"
152
+
153
+ def __repr__(self):
154
+ """Detailed string representation"""
155
+ return f"vgpu(device='{self.device_name}:0')"
156
+
157
+ return tensor_id
158
+
159
+ def _from_vram(self, tensor_id: str) -> torch.Tensor:
160
+ """Retrieve tensor data from virtual VRAM"""
161
+ data = self.vram.storage.load_tensor(tensor_id)
162
+ return torch.from_numpy(data)
163
+
164
+ def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
165
+ """Matrix multiplication using tensor cores"""
166
+ self._init_tensor_cores()
167
+
168
+ # Store inputs in VRAM
169
+ a_id = self._to_vram(a)
170
+ b_id = self._to_vram(b)
171
+
172
+ # Perform matmul using tensor cores
173
+ result = self.tensor_cores.matmul(
174
+ self.vram.storage.load_tensor(a_id),
175
+ self.vram.storage.load_tensor(b_id)
176
+ )
177
+
178
+ # Create new tensor with result
179
+ return torch.from_numpy(result)
180
+
181
+ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
182
+ """Move a tensor to vGPU device"""
183
+ if not isinstance(tensor, torch.Tensor):
184
+ tensor = torch.tensor(tensor)
185
+
186
+ # Get or create vGPU device
187
+ if not VGPUDevice._VGPU_INSTANCES:
188
+ device = VGPUDevice(vram)
189
+ else:
190
+ device = next(iter(VGPUDevice._VGPU_INSTANCES.values()))
191
+ if vram is not None:
192
+ device.vram = vram
193
+
194
+ # Move data to vRAM
195
+ tensor_id = device._to_vram(tensor)
196
+ result = device._from_vram(tensor_id)
197
+ result.requires_grad = tensor.requires_grad
198
+
199
+ # Set the device using the internal name
200
+ result.data = result.data.to(device._device)
201
+ return result