Factor Studios commited on
Commit
373ab21
·
verified ·
1 Parent(s): 5e36ec1

Update test_ai_integration.py

Browse files
Files changed (1) hide show
  1. test_ai_integration.py +63 -13
test_ai_integration.py CHANGED
@@ -111,19 +111,69 @@ def test_ai_integration():
111
  model_id = "microsoft/florence-2-large"
112
  print(f"Loading model {model_id} directly to WebSocket storage...")
113
 
114
- # Load model and processor directly to WebSocket storage
115
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
116
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
117
-
118
- # Store model in WebSocket storage without CPU intermediary
119
- ai_accelerator_for_loading.load_model(model_id, model, processor)
120
- print(f"Model '{model_id}' loaded successfully to WebSocket storage.")
121
- assert ai_accelerator_for_loading.has_model(model_id), "Model not found in WebSocket storage after loading."
122
-
123
- # Clear any CPU-side model data
124
- model = None
125
- import gc
126
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  except Exception as e:
129
  print(f"Model loading test failed: {e}")
 
111
  model_id = "microsoft/florence-2-large"
112
  print(f"Loading model {model_id} directly to WebSocket storage...")
113
 
114
+ try:
115
+ # Load model and processor with proper error handling
116
+ model = AutoModelForCausalLM.from_pretrained(
117
+ model_id,
118
+ trust_remote_code=True,
119
+ device_map="auto", # Allow automatic device mapping
120
+ torch_dtype="auto" # Use appropriate dtype
121
+ )
122
+
123
+ processor = AutoProcessor.from_pretrained(
124
+ model_id,
125
+ trust_remote_code=True
126
+ )
127
+
128
+ # Calculate model size for proper VRAM allocation
129
+ model_size = sum(p.numel() * p.element_size() for p in model.parameters())
130
+ print(f"Model size: {model_size / (1024**3):.2f} GB")
131
+
132
+ # Store model in WebSocket storage with size information
133
+ ai_accelerator_for_loading.load_model(
134
+ model_id=model_id,
135
+ model=model,
136
+ processor=processor,
137
+ model_config={
138
+ "size_bytes": model_size,
139
+ "unlimited_vram": True,
140
+ "allow_resize": True
141
+ }
142
+ )
143
+
144
+ print(f"Model '{model_id}' loaded successfully to WebSocket storage.")
145
+ assert ai_accelerator_for_loading.has_model(model_id), "Model not found in WebSocket storage after loading."
146
+
147
+ # Store model parameters in components dict
148
+ components['model_id'] = model_id
149
+ components['model_size'] = model_size
150
+
151
+ # Clear any CPU-side model data
152
+ model = None
153
+ processor = None
154
+ import gc
155
+ gc.collect()
156
+
157
+ except Exception as e:
158
+ print(f"Detailed model loading error: {str(e)}")
159
+ print("Falling back to zero-copy tensor mode...")
160
+ # Try loading with zero-copy tensor mode
161
+ try:
162
+ ai_accelerator_for_loading.load_model(
163
+ model_id=model_id,
164
+ model=None, # Use zero-copy mode
165
+ processor=None,
166
+ model_config={
167
+ "zero_copy": True,
168
+ "unlimited_vram": True,
169
+ "allow_resize": True
170
+ }
171
+ )
172
+ components['model_id'] = model_id
173
+ print("Successfully loaded model in zero-copy mode")
174
+ except Exception as e2:
175
+ print(f"Zero-copy fallback also failed: {str(e2)}")
176
+ raise
177
 
178
  except Exception as e:
179
  print(f"Model loading test failed: {e}")