Spaces:
Runtime error
Runtime error
Factor Studios commited on
Update torch_vgpu.py
Browse files- torch_vgpu.py +32 -2
torch_vgpu.py
CHANGED
|
@@ -61,6 +61,33 @@ class VGPUDevice:
|
|
| 61 |
"""Get the PyTorch device object"""
|
| 62 |
return torch.device(str(self))
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def _init_tensor_cores(self):
|
| 65 |
if self.tensor_cores is None:
|
| 66 |
from tensor_core import TensorCoreArray
|
|
@@ -95,8 +122,11 @@ class VGPUDevice:
|
|
| 95 |
# Create new tensor with result
|
| 96 |
return torch.from_numpy(result)
|
| 97 |
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
|
| 102 |
"""Helper function to move tensors to vGPU"""
|
|
|
|
| 61 |
"""Get the PyTorch device object"""
|
| 62 |
return torch.device(str(self))
|
| 63 |
|
| 64 |
+
def mode(self):
|
| 65 |
+
"""Get a context manager for vGPU operations"""
|
| 66 |
+
from torch.utils._python_dispatch import TorchFunctionMode
|
| 67 |
+
|
| 68 |
+
class _VGPUMode(TorchFunctionMode):
|
| 69 |
+
def __init__(self, device):
|
| 70 |
+
self.device = device
|
| 71 |
+
|
| 72 |
+
def __torch_function__(self, func, types, args=(), kwargs=None):
|
| 73 |
+
kwargs = kwargs or {}
|
| 74 |
+
|
| 75 |
+
# Handle tensor creation and device placement
|
| 76 |
+
if func is torch.tensor or 'device' in kwargs:
|
| 77 |
+
kwargs['device'] = str(self.device)
|
| 78 |
+
|
| 79 |
+
# Handle tensor operations
|
| 80 |
+
new_args = []
|
| 81 |
+
for arg in args:
|
| 82 |
+
if isinstance(arg, torch.Tensor):
|
| 83 |
+
if not str(arg.device).startswith(self.device.device_name):
|
| 84 |
+
arg = to_vgpu(arg, self.device.vram)
|
| 85 |
+
new_args.append(arg)
|
| 86 |
+
|
| 87 |
+
return func(*new_args, **kwargs)
|
| 88 |
+
|
| 89 |
+
return _VGPUMode(self)
|
| 90 |
+
|
| 91 |
def _init_tensor_cores(self):
|
| 92 |
if self.tensor_cores is None:
|
| 93 |
from tensor_core import TensorCoreArray
|
|
|
|
| 122 |
# Create new tensor with result
|
| 123 |
return torch.from_numpy(result)
|
| 124 |
|
| 125 |
+
def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
|
| 126 |
+
"""Move a tensor to vGPU device"""
|
| 127 |
+
device = VGPUDevice(vram)
|
| 128 |
+
tensor_id = device._to_vram(tensor)
|
| 129 |
+
return VGPUTensor(device._from_vram(tensor_id))
|
| 130 |
|
| 131 |
def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
|
| 132 |
"""Helper function to move tensors to vGPU"""
|