Factor Studios commited on
Commit
7274a2c
·
verified ·
1 Parent(s): c90a803

Update test_ai_integration_http.py

Browse files
Files changed (1) hide show
  1. test_ai_integration_http.py +11 -60
test_ai_integration_http.py CHANGED
@@ -12,7 +12,7 @@ from typing import Any, Optional
12
  import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
- from torch.overrides import TorchFunctionMode
16
  from PIL import Image
17
  from transformers import (
18
  AutoTokenizer,
@@ -23,67 +23,19 @@ from virtual_vram import VirtualVRAM
23
  from http_storage import HTTPGPUStorage
24
  from torch_vgpu import VGPUDevice, to_vgpu
25
 
26
- class VGPUMode(TorchFunctionMode):
27
- """Custom device mode for vGPU operations"""
28
-
29
- def __init__(self, vram, device_name="vgpu"):
30
- self.vram = vram
31
- self.device_name = device_name
32
- self.device = VGPUDevice(vram)
33
-
34
- def __torch_function__(
35
- self,
36
- func: Any,
37
- types: Any,
38
- args: Any = (),
39
- kwargs: Optional[dict] = None
40
- ):
41
- """Override torch functions to handle vGPU device operations"""
42
- kwargs = kwargs or {}
43
-
44
- # Handle tensor creation and device placement
45
- if func is torch.tensor or 'device' in kwargs:
46
- kwargs['device'] = f"{self.device_name}:0"
47
-
48
- # Handle tensor operations
49
- new_args = []
50
- for arg in args:
51
- if isinstance(arg, torch.Tensor):
52
- if not hasattr(arg, 'device') or not str(arg.device).startswith(self.device_name):
53
- arg = to_vgpu(arg, self.vram)
54
- new_args.append(arg)
55
-
56
- return func(*new_args, **kwargs)
57
-
58
- def __enter__(self):
59
- return self
60
-
61
- def __exit__(self, exc_type, exc_val, exc_tb):
62
- pass
63
-
64
- def register_vgpu_device():
65
- """Register vGPU as a custom device type using privateuse1 backend"""
66
  try:
67
- device_name = "vgpu"
68
-
69
- # Register device using privateuse1 backend
70
- torch._C._dispatch._rename_privateuse1_backend(device_name)
71
 
72
- def init_vgpu_mode(vram):
73
- # Create device mode with the registered device name
74
- mode = VGPUMode(vram, device_name)
75
- torch.set_mode(mode)
76
- return mode, torch.device(f"{device_name}:0")
77
-
78
- return init_vgpu_mode
79
 
80
  except Exception as e:
81
- logging.error(f"vGPU device registration failed: {str(e)}")
82
  raise
83
 
84
- # Register vGPU device
85
- register_vgpu_device()
86
-
87
  # Configure logging
88
  logging.basicConfig(
89
  level=logging.INFO,
@@ -149,10 +101,9 @@ def test_ai_integration_http():
149
  initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
150
  vram = VirtualVRAM(size_gb=None, storage=storage)
151
 
152
- # Initialize vGPU mode and register device
153
- init_vgpu_mode = register_vgpu_device()
154
- vgpu_mode, vgpu_device = init_vgpu_mode(vram)
155
- logger.info(f"vGPU mode initialized with device {vgpu_device}")
156
 
157
  # Load Florence model and processor
158
  model_name = "microsoft/florence-2-large"
 
12
  import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
+ from torch.utils._python_dispatch import TorchFunctionMode
16
  from PIL import Image
17
  from transformers import (
18
  AutoTokenizer,
 
23
  from http_storage import HTTPGPUStorage
24
  from torch_vgpu import VGPUDevice, to_vgpu
25
 
26
+ def setup_vgpu():
27
+ """Setup vGPU device and mode"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  try:
29
+ # Create device and get its mode
30
+ device = VGPUDevice()
31
+ mode = device.mode()
 
32
 
33
+ return mode, device
 
 
 
 
 
 
34
 
35
  except Exception as e:
36
+ logging.error(f"vGPU setup failed: {str(e)}")
37
  raise
38
 
 
 
 
39
  # Configure logging
40
  logging.basicConfig(
41
  level=logging.INFO,
 
101
  initial_mem = storage.get_used_memory() if hasattr(storage, 'get_used_memory') else 0
102
  vram = VirtualVRAM(size_gb=None, storage=storage)
103
 
104
+ # Initialize vGPU device and mode
105
+ vgpu_mode, vgpu_device = setup_vgpu()
106
+ logger.info(f"vGPU initialized with device {vgpu_device}")
 
107
 
108
  # Load Florence model and processor
109
  model_name = "microsoft/florence-2-large"