Factor Studios commited on
Commit
c90a803
·
verified ·
1 Parent(s): c4ad204

Update torch_vgpu.py

Browse files
Files changed (1) hide show
  1. torch_vgpu.py +104 -90
torch_vgpu.py CHANGED
@@ -1,90 +1,104 @@
1
- """
2
- Custom PyTorch device implementation that routes operations through our virtual GPU.
3
- """
4
- import torch
5
- from typing import Optional, Union, Tuple
6
- import numpy as np
7
- from virtual_vram import VirtualVRAM
8
-
9
- class VGPUStorage(torch.Storage):
10
- """Custom storage class that uses our virtual VRAM"""
11
-
12
- def __init__(self, *args, **kwargs):
13
- super().__init__(*args, **kwargs)
14
- self.vram = kwargs.get('vram')
15
- if not self.vram:
16
- from virtual_vram import VirtualVRAM
17
- self.vram = VirtualVRAM()
18
- self.tensor_id = kwargs.get('tensor_id', f"tensor_{id(self)}")
19
-
20
- def _new_shared(self, size):
21
- return VGPUStorage(size, vram=self.vram)
22
-
23
- class VGPUTensor:
24
- """Tensor implementation that uses vGPU for computations"""
25
- @staticmethod
26
- def __new__(cls, elem):
27
- return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
28
-
29
- class VGPUDevice:
30
- """
31
- Custom PyTorch device implementation that routes operations through vGPU.
32
- Usage:
33
- vgpu = VGPUDevice()
34
- tensor = torch.randn(2, 3).to(vgpu)
35
- """
36
- def __init__(self, vram: Optional[VirtualVRAM] = None):
37
- self.vram = vram or VirtualVRAM()
38
- self.tensor_cores = None # Will be initialized when needed
39
-
40
- @property
41
- def type(self):
42
- return 'vgpu'
43
-
44
- def __str__(self):
45
- return self.type
46
-
47
- def __repr__(self):
48
- return self.type
49
-
50
- def _init_tensor_cores(self):
51
- if self.tensor_cores is None:
52
- from tensor_core import TensorCoreArray
53
- self.tensor_cores = TensorCoreArray()
54
-
55
- def _to_vram(self, tensor: torch.Tensor) -> str:
56
- """Store tensor data in virtual VRAM"""
57
- tensor_id = f"tensor_{id(tensor)}"
58
- data = tensor.detach().cpu().numpy()
59
- self.vram.storage.store_tensor(tensor_id, data)
60
- return tensor_id
61
-
62
- def _from_vram(self, tensor_id: str) -> torch.Tensor:
63
- """Retrieve tensor data from virtual VRAM"""
64
- data = self.vram.storage.load_tensor(tensor_id)
65
- return torch.from_numpy(data)
66
-
67
- def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
68
- """Matrix multiplication using tensor cores"""
69
- self._init_tensor_cores()
70
-
71
- # Store inputs in VRAM
72
- a_id = self._to_vram(a)
73
- b_id = self._to_vram(b)
74
-
75
- # Perform matmul using tensor cores
76
- result = self.tensor_cores.matmul(
77
- self.vram.storage.load_tensor(a_id),
78
- self.vram.storage.load_tensor(b_id)
79
- )
80
-
81
- # Create new tensor with result
82
- return torch.from_numpy(result)
83
-
84
- # Register vGPU device type with PyTorch
85
- torch.backends.register_custom_device("vgpu", VGPUDevice)
86
-
87
- def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
88
- """Helper function to move tensors to vGPU"""
89
- device = VGPUDevice(vram)
90
- return tensor.to(device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom PyTorch device implementation that routes operations through our virtual GPU.
3
+ """
4
+ import torch
5
+ from typing import Optional, Union, Tuple
6
+ import numpy as np
7
+ from virtual_vram import VirtualVRAM
8
+
9
+ class VGPUStorage(torch.Storage):
10
+ """Custom storage class that uses our virtual VRAM"""
11
+
12
+ def __init__(self, *args, **kwargs):
13
+ super().__init__(*args, **kwargs)
14
+ self.vram = kwargs.get('vram')
15
+ if not self.vram:
16
+ from virtual_vram import VirtualVRAM
17
+ self.vram = VirtualVRAM()
18
+ self.tensor_id = kwargs.get('tensor_id', f"tensor_{id(self)}")
19
+
20
+ def _new_shared(self, size):
21
+ return VGPUStorage(size, vram=self.vram)
22
+
23
+ class VGPUTensor:
24
+ """Tensor implementation that uses vGPU for computations"""
25
+ @staticmethod
26
+ def __new__(cls, elem):
27
+ return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
28
+
29
+ class VGPUDevice:
30
+ """
31
+ Custom PyTorch device implementation that routes operations through vGPU.
32
+ Usage:
33
+ vgpu = VGPUDevice()
34
+ with vgpu.mode():
35
+ tensor = torch.randn(2, 3) # Will be on vGPU
36
+ """
37
+ def __init__(self, vram: Optional[VirtualVRAM] = None):
38
+ self.vram = vram or VirtualVRAM()
39
+ self.tensor_cores = None # Will be initialized when needed
40
+ self.device_name = "vgpu"
41
+ self._register_device()
42
+
43
+ def _register_device(self):
44
+ """Register vGPU device using privateuse1 backend"""
45
+ try:
46
+ torch._C._dispatch._rename_privateuse1_backend(self.device_name)
47
+ except Exception as e:
48
+ raise RuntimeError(f"Failed to register vGPU device: {str(e)}")
49
+
50
+ @property
51
+ def type(self):
52
+ return self.device_name
53
+
54
+ def __str__(self):
55
+ return f"{self.device_name}:0"
56
+
57
+ def __repr__(self):
58
+ return f"{self.device_name}:0"
59
+
60
+ def device(self):
61
+ """Get the PyTorch device object"""
62
+ return torch.device(str(self))
63
+
64
+ def _init_tensor_cores(self):
65
+ if self.tensor_cores is None:
66
+ from tensor_core import TensorCoreArray
67
+ self.tensor_cores = TensorCoreArray()
68
+
69
+ def _to_vram(self, tensor: torch.Tensor) -> str:
70
+ """Store tensor data in virtual VRAM"""
71
+ tensor_id = f"tensor_{id(tensor)}"
72
+ data = tensor.detach().cpu().numpy()
73
+ self.vram.storage.store_tensor(tensor_id, data)
74
+ return tensor_id
75
+
76
+ def _from_vram(self, tensor_id: str) -> torch.Tensor:
77
+ """Retrieve tensor data from virtual VRAM"""
78
+ data = self.vram.storage.load_tensor(tensor_id)
79
+ return torch.from_numpy(data)
80
+
81
+ def matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
82
+ """Matrix multiplication using tensor cores"""
83
+ self._init_tensor_cores()
84
+
85
+ # Store inputs in VRAM
86
+ a_id = self._to_vram(a)
87
+ b_id = self._to_vram(b)
88
+
89
+ # Perform matmul using tensor cores
90
+ result = self.tensor_cores.matmul(
91
+ self.vram.storage.load_tensor(a_id),
92
+ self.vram.storage.load_tensor(b_id)
93
+ )
94
+
95
+ # Create new tensor with result
96
+ return torch.from_numpy(result)
97
+
98
+ # Register vGPU device type with PyTorch
99
+ torch.backends.register_custom_device("vgpu", VGPUDevice)
100
+
101
+ def to_vgpu(tensor: torch.Tensor, vram: Optional[VirtualVRAM] = None) -> torch.Tensor:
102
+ """Helper function to move tensors to vGPU"""
103
+ device = VGPUDevice(vram)
104
+ return tensor.to(device=device)