Factor Studios commited on
Commit
d7bc79c
·
verified ·
1 Parent(s): 1f02017

Update test_ai_integration_http.py

Browse files
Files changed (1) hide show
  1. test_ai_integration_http.py +28 -25
test_ai_integration_http.py CHANGED
@@ -12,7 +12,7 @@ from typing import Any, Optional
12
  import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
- from torch.overrides import TorchFunctionMode
16
  from PIL import Image
17
  from transformers import (
18
  AutoTokenizer,
@@ -26,11 +26,11 @@ from torch_vgpu import VGPUDevice, to_vgpu
26
  class VGPUMode(TorchFunctionMode):
27
  """Custom device mode for vGPU operations"""
28
 
29
- def __init__(self, vram):
30
  self.vram = vram
 
31
  self.device = VGPUDevice(vram)
32
 
33
- @torch.override
34
  def __torch_function__(
35
  self,
36
  func: Any,
@@ -41,15 +41,16 @@ class VGPUMode(TorchFunctionMode):
41
  """Override torch functions to handle vGPU device operations"""
42
  kwargs = kwargs or {}
43
 
44
- # Handle device placement
45
- if 'device' in kwargs and kwargs['device'] == 'vgpu':
46
- kwargs['device'] = self.device
47
 
48
- # Convert any tensor inputs to vGPU
49
  new_args = []
50
  for arg in args:
51
- if isinstance(arg, torch.Tensor) and not hasattr(arg, 'device_type'):
52
- arg = to_vgpu(arg, self.vram)
 
53
  new_args.append(arg)
54
 
55
  return func(*new_args, **kwargs)
@@ -61,17 +62,18 @@ class VGPUMode(TorchFunctionMode):
61
  pass
62
 
63
  def register_vgpu_device():
64
- """Register vGPU as a custom device type"""
65
  try:
66
- # Initialize vGPU device type if not already registered
67
- if not hasattr(torch._C, "_vgpu_device"):
68
- torch.backends.register_custom_device("vgpu", VGPUDevice)
69
-
70
- # Create and enable vGPU mode
71
  def init_vgpu_mode(vram):
72
- mode = VGPUMode(vram)
 
73
  torch.set_mode(mode)
74
- return mode
75
 
76
  return init_vgpu_mode
77
 
@@ -149,8 +151,8 @@ def test_ai_integration_http():
149
 
150
  # Initialize vGPU mode and register device
151
  init_vgpu_mode = register_vgpu_device()
152
- vgpu_mode = init_vgpu_mode(vram)
153
- logger.info("vGPU mode initialized with HTTP storage backend")
154
 
155
  # Load Florence model and processor
156
  model_name = "microsoft/florence-2-large"
@@ -183,12 +185,13 @@ def test_ai_integration_http():
183
  status['model_on_vgpu'] = True
184
 
185
  # Verify model location and device mode
186
- for param in model.parameters():
187
- if not hasattr(param, 'device') or not isinstance(param.device, VGPUDevice):
188
- raise RuntimeError("Model not properly moved to vGPU")
189
-
190
- current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
191
- logger.info(f"Model memory usage: {(current_mem - initial_mem)/1e9:.2f} GB")
 
192
  except Exception as e:
193
  logger.error(f"Model transfer to vGPU failed: {str(e)}")
194
  raise
 
12
  import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
+ from torch.utils._python_dispatch import TorchFunctionMode
16
  from PIL import Image
17
  from transformers import (
18
  AutoTokenizer,
 
26
  class VGPUMode(TorchFunctionMode):
27
  """Custom device mode for vGPU operations"""
28
 
29
+ def __init__(self, vram, device_name="vgpu"):
30
  self.vram = vram
31
+ self.device_name = device_name
32
  self.device = VGPUDevice(vram)
33
 
 
34
  def __torch_function__(
35
  self,
36
  func: Any,
 
41
  """Override torch functions to handle vGPU device operations"""
42
  kwargs = kwargs or {}
43
 
44
+ # Handle tensor creation and device placement
45
+ if func is torch.tensor or 'device' in kwargs:
46
+ kwargs['device'] = f"{self.device_name}:0"
47
 
48
+ # Handle tensor operations
49
  new_args = []
50
  for arg in args:
51
+ if isinstance(arg, torch.Tensor):
52
+ if not hasattr(arg, 'device') or not str(arg.device).startswith(self.device_name):
53
+ arg = to_vgpu(arg, self.vram)
54
  new_args.append(arg)
55
 
56
  return func(*new_args, **kwargs)
 
62
  pass
63
 
64
  def register_vgpu_device():
65
+ """Register vGPU as a custom device type using privateuse1 backend"""
66
  try:
67
+ device_name = "vgpu"
68
+
69
+ # Register device using privateuse1 backend
70
+ torch._C._dispatch._rename_privateuse1_backend(device_name)
71
+
72
  def init_vgpu_mode(vram):
73
+ # Create device mode with the registered device name
74
+ mode = VGPUMode(vram, device_name)
75
  torch.set_mode(mode)
76
+ return mode, torch.device(f"{device_name}:0")
77
 
78
  return init_vgpu_mode
79
 
 
151
 
152
  # Initialize vGPU mode and register device
153
  init_vgpu_mode = register_vgpu_device()
154
+ vgpu_mode, vgpu_device = init_vgpu_mode(vram)
155
+ logger.info(f"vGPU mode initialized with device {vgpu_device}")
156
 
157
  # Load Florence model and processor
158
  model_name = "microsoft/florence-2-large"
 
185
  status['model_on_vgpu'] = True
186
 
187
  # Verify model location and device mode
188
+ with vgpu_mode:
189
+ for param in model.parameters():
190
+ if not str(param.device).startswith('vgpu'):
191
+ raise RuntimeError(f"Model parameter not on vGPU device. Found device: {param.device}")
192
+
193
+ current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
194
+ logger.info(f"Model memory usage: {(current_mem - initial_mem)/1e9:.2f} GB")
195
  except Exception as e:
196
  logger.error(f"Model transfer to vGPU failed: {str(e)}")
197
  raise