Factor Studios commited on
Commit
edcd91f
·
verified ·
1 Parent(s): e0b5f17

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +9 -2
torch_vgpu.py CHANGED
@@ -7,6 +7,13 @@ from typing import Optional, Union, Tuple
7
  import numpy as np
8
  from virtual_vram import VirtualVRAM
9
 
 
 
 
 
 
 
 
10
  class VGPUStorage(torch.Storage):
11
  """Custom storage class that uses our virtual VRAM"""
12
 
@@ -47,8 +54,8 @@ class VGPUDevice:
47
  def _register_device(self):
48
  """Register vGPU device using PyTorch's device system"""
49
  try:
50
- # Register device using privateuseone
51
- self._device = torch.device(f"{self.internal_name}:0")
52
 
53
  # Store this instance for reuse
54
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self
 
7
  import numpy as np
8
  from virtual_vram import VirtualVRAM
9
 
10
+ # Register and rename privateuse1 backend to vgpu
11
+ try:
12
+ torch.utils.rename_privateuse1_backend("vgpu")
13
+ except (AttributeError, RuntimeError) as e:
14
+ # Fallback for older PyTorch versions or if already renamed
15
+ pass
16
+
17
  class VGPUStorage(torch.Storage):
18
  """Custom storage class that uses our virtual VRAM"""
19
 
 
54
  def _register_device(self):
55
  """Register vGPU device using PyTorch's device system"""
56
  try:
57
+ # Register device using privateuse1
58
+ self._device = torch.device(self.internal_name)
59
 
60
  # Store this instance for reuse
61
  VGPUDevice._VGPU_INSTANCES[self.device_name] = self