Factor Studios commited on
Commit
1a63d3c
·
verified ·
1 Parent(s): 6da4ca2

Update test_ai_integration_http.py

Browse files
Files changed (1) hide show
  1. test_ai_integration_http.py +9 -15
test_ai_integration_http.py CHANGED
@@ -12,12 +12,11 @@ from typing import Any, Optional
12
  import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
- from torch.overrides import TorchFunctionMode
16
  from PIL import Image
17
  from transformers import (
18
- AutoTokenizer,
19
- AutoModelForCausalLM,
20
  AutoProcessor,
 
21
  AutoConfig
22
  )
23
  from virtual_vram import VirtualVRAM
@@ -114,26 +113,21 @@ def test_ai_integration_http():
114
  logger.info(f"Loading {model_name}")
115
 
116
  try:
117
- # Load processor with direct configuration
118
  processor = AutoProcessor.from_pretrained(
119
  model_name,
120
- trust_remote_code=True,
121
- return_tensors="pt"
122
  )
123
  status['processor_loaded'] = True
124
 
125
- # Load model with vision config
126
- from transformers import AutoConfig
127
- config = AutoConfig.from_pretrained(
128
- model_name,
129
- trust_remote_code=True,
130
- torch_dtype=torch.float32 # Use float32 for better compatibility
131
- )
132
 
133
- model = AutoModelForCausalLM.from_pretrained(
 
134
  model_name,
135
- config=config,
136
  trust_remote_code=True,
 
137
  device_map=None # Don't auto-map devices
138
  )
139
  status['model_loaded'] = True
 
12
  import torch
13
  from torch import nn
14
  import torch.nn.functional as F
15
+ from torch.utils._python_dispatch import TorchFunctionMode
16
  from PIL import Image
17
  from transformers import (
 
 
18
  AutoProcessor,
19
+ AutoModel,
20
  AutoConfig
21
  )
22
  from virtual_vram import VirtualVRAM
 
113
  logger.info(f"Loading {model_name}")
114
 
115
  try:
116
+ # Load processor first
117
  processor = AutoProcessor.from_pretrained(
118
  model_name,
119
+ trust_remote_code=True
 
120
  )
121
  status['processor_loaded'] = True
122
 
123
+ # Get config and model class from the processor
124
+ config = processor.config
 
 
 
 
 
125
 
126
+ # Load the model as a general vision model
127
+ model = AutoModel.from_pretrained(
128
  model_name,
 
129
  trust_remote_code=True,
130
+ torch_dtype=torch.float32, # Use float32 for better compatibility
131
  device_map=None # Don't auto-map devices
132
  )
133
  status['model_loaded'] = True