Factor Studios commited on
Commit
fcd6b69
·
verified ·
1 Parent(s): 7274a2c

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +32 -2
torch_vgpu.py CHANGED
@@ -61,6 +61,33 @@ class VGPUDevice:
61
  """Get the PyTorch device object"""
62
  return torch.device(str(self))
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def _init_tensor_cores(self):
65
  if self.tensor_cores is None:
66
  from tensor_core import TensorCoreArray
@@ -95,8 +122,11 @@ class VGPUDevice:
95
  # Create new tensor with result
96
  return torch.from_numpy(result)
97
 
98
- # Register vGPU device type with PyTorch
99
- torch.backends.register_custom_device("vgpu", VGPUDevice)
 
 
 
100
 
101
  def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
102
  """Helper function to move tensors to vGPU"""
 
61
  """Get the PyTorch device object"""
62
  return torch.device(str(self))
63
 
64
+ def mode(self):
65
+ """Get a context manager for vGPU operations"""
66
+ from torch.utils._python_dispatch import TorchFunctionMode
67
+
68
+ class _VGPUMode(TorchFunctionMode):
69
+ def __init__(self, device):
70
+ self.device = device
71
+
72
+ def __torch_function__(self, func, types, args=(), kwargs=None):
73
+ kwargs = kwargs or {}
74
+
75
+ # Handle tensor creation and device placement
76
+ if func is torch.tensor or 'device' in kwargs:
77
+ kwargs['device'] = str(self.device)
78
+
79
+ # Handle tensor operations
80
+ new_args = []
81
+ for arg in args:
82
+ if isinstance(arg, torch.Tensor):
83
+ if not str(arg.device).startswith(self.device.device_name):
84
+ arg = to_vgpu(arg, self.device.vram)
85
+ new_args.append(arg)
86
+
87
+ return func(*new_args, **kwargs)
88
+
89
+ return _VGPUMode(self)
90
+
91
  def _init_tensor_cores(self):
92
  if self.tensor_cores is None:
93
  from tensor_core import TensorCoreArray
 
122
  # Create new tensor with result
123
  return torch.from_numpy(result)
124
 
125
+ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
126
+ """Move a tensor to vGPU device"""
127
+ device = VGPUDevice(vram)
128
+ tensor_id = device._to_vram(tensor)
129
+ return VGPUTensor(device._from_vram(tensor_id))
130
 
131
  def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
132
  """Helper function to move tensors to vGPU"""