Spaces:
Runtime error
Runtime error
Factor Studios
commited on
Update test_ai_integration_http.py
Browse files- test_ai_integration_http.py +9 -15
test_ai_integration_http.py
CHANGED
|
@@ -12,12 +12,11 @@ from typing import Any, Optional
|
|
| 12 |
import torch
|
| 13 |
from torch import nn
|
| 14 |
import torch.nn.functional as F
|
| 15 |
-
from torch.
|
| 16 |
from PIL import Image
|
| 17 |
from transformers import (
|
| 18 |
-
AutoTokenizer,
|
| 19 |
-
AutoModelForCausalLM,
|
| 20 |
AutoProcessor,
|
|
|
|
| 21 |
AutoConfig
|
| 22 |
)
|
| 23 |
from virtual_vram import VirtualVRAM
|
|
@@ -114,26 +113,21 @@ def test_ai_integration_http():
|
|
| 114 |
logger.info(f"Loading {model_name}")
|
| 115 |
|
| 116 |
try:
|
| 117 |
-
# Load processor
|
| 118 |
processor = AutoProcessor.from_pretrained(
|
| 119 |
model_name,
|
| 120 |
-
trust_remote_code=True
|
| 121 |
-
return_tensors="pt"
|
| 122 |
)
|
| 123 |
status['processor_loaded'] = True
|
| 124 |
|
| 125 |
-
#
|
| 126 |
-
|
| 127 |
-
config = AutoConfig.from_pretrained(
|
| 128 |
-
model_name,
|
| 129 |
-
trust_remote_code=True,
|
| 130 |
-
torch_dtype=torch.float32 # Use float32 for better compatibility
|
| 131 |
-
)
|
| 132 |
|
| 133 |
-
model
|
|
|
|
| 134 |
model_name,
|
| 135 |
-
config=config,
|
| 136 |
trust_remote_code=True,
|
|
|
|
| 137 |
device_map=None # Don't auto-map devices
|
| 138 |
)
|
| 139 |
status['model_loaded'] = True
|
|
|
|
| 12 |
import torch
|
| 13 |
from torch import nn
|
| 14 |
import torch.nn.functional as F
|
| 15 |
+
from torch.utils._python_dispatch import TorchFunctionMode
|
| 16 |
from PIL import Image
|
| 17 |
from transformers import (
|
|
|
|
|
|
|
| 18 |
AutoProcessor,
|
| 19 |
+
AutoModel,
|
| 20 |
AutoConfig
|
| 21 |
)
|
| 22 |
from virtual_vram import VirtualVRAM
|
|
|
|
| 113 |
logger.info(f"Loading {model_name}")
|
| 114 |
|
| 115 |
try:
|
| 116 |
+
# Load processor first
|
| 117 |
processor = AutoProcessor.from_pretrained(
|
| 118 |
model_name,
|
| 119 |
+
trust_remote_code=True
|
|
|
|
| 120 |
)
|
| 121 |
status['processor_loaded'] = True
|
| 122 |
|
| 123 |
+
# Get config and model class from the processor
|
| 124 |
+
config = processor.config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
# Load the model as a general vision model
|
| 127 |
+
model = AutoModel.from_pretrained(
|
| 128 |
model_name,
|
|
|
|
| 129 |
trust_remote_code=True,
|
| 130 |
+
torch_dtype=torch.float32, # Use float32 for better compatibility
|
| 131 |
device_map=None # Don't auto-map devices
|
| 132 |
)
|
| 133 |
status['model_loaded'] = True
|