daniel-crawford-dunedain's picture
Initial Space upload
40e9b28 verified
# backend/utils/terrain_analyzer/road_detection_model/test.py
import json
import os
import sys
import time
from datetime import datetime
# Add the current directory to Python path to import local modules
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from dotenv import load_dotenv
load_dotenv()
HF_TOKEN = os.environ.get("HUGGING_FACE_INFERENCE_TOKEN")
def test_gradio_client():
"""Test the model via Gradio client (if deployed)."""
print("\n🌐 Testing via Gradio Client")
print("=" * 50)
try:
from gradio_client import Client
space_url = "https://dunedain-ai-road-detection-model.hf.space"
print(f"πŸ”— Connecting to: {space_url}")
client = Client(space_url, hf_token=HF_TOKEN)
# Test with example file
example_path = (
"utils/terrain_analyzer/road_detection_model/examples/example.png"
)
if not os.path.exists(example_path):
print(f"❌ Example file not found: {example_path}")
return None
print("πŸ“€ Uploading image and running prediction...")
start_time = time.time()
# Use the predict function (adjust based on your Gradio interface)
# Load the image and convert to base64-encoded string
import base64
with open(example_path, "rb") as f:
img_bytes = f.read()
image_b64 = base64.b64encode(img_bytes).decode("utf-8")
payload = json.dumps({"image_b64": image_b64})
result = client.predict(
payload,
api_name="/predict", # API endpoint
)
end_time = time.time()
inference_time = end_time - start_time
print(f"⏱️ Remote inference time: {inference_time:.3f} seconds")
# INSERT_YOUR_CODE
# The result is a JSON string with a key "mask_base64" containing the b64-encoded PNG mask.
result_json = json.loads(result)
print(result_json.keys())
mask_b64 = result_json.get("mask_base64")
# INSERT_YOUR_CODE
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if mask_b64:
output_dir = "utils/terrain_analyzer/road_detection_model/test_results"
os.makedirs(output_dir, exist_ok=True)
mask_path = os.path.join(output_dir, f"remote_pred_mask_{timestamp}.png")
with open(mask_path, "wb") as f:
f.write(base64.b64decode(mask_b64))
print(f"πŸ–ΌοΈ Remote predicted mask saved: {mask_path}")
else:
print("❌ No mask_base64 found in result.")
# Save remote test results
save_remote_test_results(result, inference_time, space_url)
return result
except ImportError:
print("❌ gradio_client not installed. Install with: pip install gradio-client")
return None
except Exception as e:
print(f"❌ Remote test failed: {e}")
return None
def save_remote_test_results(result, inference_time, space_url):
"""Save remote test results."""
output_dir = "utils/terrain_analyzer/road_detection_model/test_results"
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Save remote test results
remote_stats = {
"timestamp": timestamp,
"test_type": "remote_gradio",
"space_url": space_url,
"inference_time_seconds": round(inference_time, 3),
"result": str(result),
}
remote_path = os.path.join(output_dir, f"remote_test_{timestamp}.json")
with open(remote_path, "w") as f:
json.dump(remote_stats, f, indent=2)
print(f"🌐 Remote test results saved: {remote_path}")
def main():
"""Run comprehensive tests."""
print("πŸš€ Road Detection Model Test Suite")
print("=" * 60)
# Test Remote Gradio client
print("\n" + "=" * 60)
remote_result = test_gradio_client()
# Summary
print("\n" + "=" * 60)
print("🎯 Test Summary")
print("=" * 60)
if remote_result:
print("βœ… Remote test: PASSED")
else:
print("❌ Remote test: FAILED")
print("🏁 Test suite completed!")
if __name__ == "__main__":
main()