Factor Studios commited on
Commit
67791a9
·
verified ·
1 Parent(s): a1380c0

Update test_ai_integration_http.py

Browse files
Files changed (1) hide show
  1. test_ai_integration_http.py +58 -12
test_ai_integration_http.py CHANGED
@@ -7,10 +7,12 @@ import os
7
  import time
8
  from contextlib import contextmanager
9
  from io import BytesIO
 
10
 
11
  import torch
12
  from torch import nn
13
  import torch.nn.functional as F
 
14
  from PIL import Image
15
  from transformers import (
16
  AutoTokenizer,
@@ -21,19 +23,60 @@ from virtual_vram import VirtualVRAM
21
  from http_storage import HTTPGPUStorage
22
  from torch_vgpu import VGPUDevice, to_vgpu
23
 
24
- # Register vGPU device type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def register_vgpu_device():
26
  """Register vGPU as a custom device type"""
27
  try:
28
- if hasattr(torch.backends, 'register_custom_device'):
 
29
  torch.backends.register_custom_device("vgpu", VGPUDevice)
30
- else:
31
- # Fallback: Add device type to torch._C
32
- if not hasattr(torch._C, "_vgpu_device"):
33
- torch._C._vgpu_device = VGPUDevice
34
- logger.info("Using fallback vGPU device registration")
 
 
 
 
35
  except Exception as e:
36
- logger.error(f"vGPU device registration failed: {str(e)}")
37
  raise
38
 
39
  # Register vGPU device
@@ -103,8 +146,11 @@ def test_ai_integration_http():
103
  # Initialize vRAM with monitoring
104
  initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
105
  vram = VirtualVRAM(size_gb=None, storage=storage)
106
- device = VGPUDevice(vram=vram)
107
- logger.info("vGPU device initialized with HTTP storage backend")
 
 
 
108
 
109
  # Load Florence model and processor
110
  model_name = "microsoft/florence-2-large"
@@ -136,9 +182,9 @@ def test_ai_integration_http():
136
  model.eval()
137
  status['model_on_vgpu'] = True
138
 
139
- # Verify model location
140
  for param in model.parameters():
141
- if not hasattr(param, 'device') or param.device != device:
142
  raise RuntimeError("Model not properly moved to vGPU")
143
 
144
  current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
 
7
  import time
8
  from contextlib import contextmanager
9
  from io import BytesIO
10
+ from typing import Any, Optional
11
 
12
  import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
+ from torch.utils._python_dispatch import TorchFunctionMode
16
  from PIL import Image
17
  from transformers import (
18
  AutoTokenizer,
 
23
  from http_storage import HTTPGPUStorage
24
  from torch_vgpu import VGPUDevice, to_vgpu
25
 
26
+ class VGPUMode(TorchFunctionMode):
27
+ """Custom device mode for vGPU operations"""
28
+
29
+ def __init__(self, vram):
30
+ self.vram = vram
31
+ self.device = VGPUDevice(vram)
32
+
33
+ @torch.override
34
+ def __torch_function__(
35
+ self,
36
+ func: Any,
37
+ types: Any,
38
+ args: Any = (),
39
+ kwargs: Optional[dict] = None
40
+ ):
41
+ """Override torch functions to handle vGPU device operations"""
42
+ kwargs = kwargs or {}
43
+
44
+ # Handle device placement
45
+ if 'device' in kwargs and kwargs['device'] == 'vgpu':
46
+ kwargs['device'] = self.device
47
+
48
+ # Convert any tensor inputs to vGPU
49
+ new_args = []
50
+ for arg in args:
51
+ if isinstance(arg, torch.Tensor) and not hasattr(arg, 'device_type'):
52
+ arg = to_vgpu(arg, self.vram)
53
+ new_args.append(arg)
54
+
55
+ return func(*new_args, **kwargs)
56
+
57
+ def __enter__(self):
58
+ return self
59
+
60
+ def __exit__(self, exc_type, exc_val, exc_tb):
61
+ pass
62
+
63
  def register_vgpu_device():
64
  """Register vGPU as a custom device type"""
65
  try:
66
+ # Initialize vGPU device type if not already registered
67
+ if not hasattr(torch._C, "_vgpu_device"):
68
  torch.backends.register_custom_device("vgpu", VGPUDevice)
69
+
70
+ # Create and enable vGPU mode
71
+ def init_vgpu_mode(vram):
72
+ mode = VGPUMode(vram)
73
+ torch.set_mode(mode)
74
+ return mode
75
+
76
+ return init_vgpu_mode
77
+
78
  except Exception as e:
79
+ logging.error(f"vGPU device registration failed: {str(e)}")
80
  raise
81
 
82
  # Register vGPU device
 
146
  # Initialize vRAM with monitoring
147
  initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
148
  vram = VirtualVRAM(size_gb=None, storage=storage)
149
+
150
+ # Initialize vGPU mode and register device
151
+ init_vgpu_mode = register_vgpu_device()
152
+ vgpu_mode = init_vgpu_mode(vram)
153
+ logger.info("vGPU mode initialized with HTTP storage backend")
154
 
155
  # Load Florence model and processor
156
  model_name = "microsoft/florence-2-large"
 
182
  model.eval()
183
  status['model_on_vgpu'] = True
184
 
185
+ # Verify model location and device mode
186
  for param in model.parameters():
187
+ if not hasattr(param, 'device') or not isinstance(param.device, VGPUDevice):
188
  raise RuntimeError("Model not properly moved to vGPU")
189
 
190
  current_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0