Spaces:
Sleeping
Sleeping
Factor Studios
commited on
Upload test_ai_integration.py
Browse files- test_ai_integration.py +46 -15
test_ai_integration.py
CHANGED
|
@@ -184,15 +184,28 @@ def test_ai_integration():
|
|
| 184 |
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
|
| 185 |
print(f"Model size: {model_size / (1024**3):.2f} GB")
|
| 186 |
|
| 187 |
-
#
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
assert ai_accelerator_for_loading.has_model(model_id), "Model not found in WebSocket storage after loading."
|
| 197 |
|
| 198 |
# Store model parameters in components dict
|
|
@@ -371,13 +384,31 @@ def test_ai_integration():
|
|
| 371 |
# Load image section from WebSocket storage
|
| 372 |
tensor_id = f"input_image/{img_name}"
|
| 373 |
|
| 374 |
-
#
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
elapsed = time.time() - start_time
|
| 383 |
|
|
|
|
| 184 |
model_size = sum(p.numel() * p.element_size() for p in model.parameters())
|
| 185 |
print(f"Model size: {model_size / (1024**3):.2f} GB")
|
| 186 |
|
| 187 |
+
# Upload model weights directly to WebSocket storage
|
| 188 |
+
print("Uploading model weights to WebSocket storage...")
|
| 189 |
+
for name, param in model.state_dict().items():
|
| 190 |
+
# Convert tensor to numpy and upload
|
| 191 |
+
weight_data = param.cpu().numpy()
|
| 192 |
+
storage.store_tensor(f"model_weights/{model_id}/{name}", weight_data)
|
| 193 |
+
|
| 194 |
+
# Store minimal model info without serializing the config
|
| 195 |
+
storage.store_state(f"models/{model_id}", "info", {
|
| 196 |
+
"name": model_id,
|
| 197 |
+
"size_bytes": model_size,
|
| 198 |
+
"num_parameters": sum(p.numel() for p in model.parameters()),
|
| 199 |
+
"weight_keys": list(model.state_dict().keys())
|
| 200 |
+
})
|
| 201 |
|
| 202 |
+
# Set model reference without serializing the full model
|
| 203 |
+
ai_accelerator_for_loading.model_refs[model_id] = {
|
| 204 |
+
"weight_prefix": f"model_weights/{model_id}",
|
| 205 |
+
"size": model_size
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
print(f"Model weights uploaded successfully to WebSocket storage")
|
| 209 |
assert ai_accelerator_for_loading.has_model(model_id), "Model not found in WebSocket storage after loading."
|
| 210 |
|
| 211 |
# Store model parameters in components dict
|
|
|
|
| 384 |
# Load image section from WebSocket storage
|
| 385 |
tensor_id = f"input_image/{img_name}"
|
| 386 |
|
| 387 |
+
# Load weights from WebSocket storage and run inference
|
| 388 |
+
try:
|
| 389 |
+
# Get model info
|
| 390 |
+
model_info = accelerator.storage.load_state(f"models/{model_id}", "info")
|
| 391 |
+
weight_prefix = f"model_weights/{model_id}"
|
| 392 |
+
|
| 393 |
+
# Load input tensor
|
| 394 |
+
input_tensor = accelerator.storage.load_tensor(tensor_id)
|
| 395 |
+
|
| 396 |
+
# Run inference with direct weight access
|
| 397 |
+
result = accelerator.inference_with_ws_weights(
|
| 398 |
+
model_id=model_id,
|
| 399 |
+
input_tensor=input_tensor,
|
| 400 |
+
weight_prefix=weight_prefix
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# Store result in WebSocket storage
|
| 404 |
+
if result is not None:
|
| 405 |
+
storage.store_tensor(f"results/chip_{i}/{img_name}", result)
|
| 406 |
+
results.append(result)
|
| 407 |
+
else:
|
| 408 |
+
logging.error(f"Inference returned None for chip {i}")
|
| 409 |
+
except Exception as e:
|
| 410 |
+
logging.error(f"Inference failed on chip {i}: {str(e)}")
|
| 411 |
+
raise
|
| 412 |
|
| 413 |
elapsed = time.time() - start_time
|
| 414 |
|