Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
| # backend/utils/terrain_analyzer/road_detection_model/test.py | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from datetime import datetime | |
| # Add the current directory to Python path to import local modules | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| HF_TOKEN = os.environ.get("HUGGING_FACE_INFERENCE_TOKEN") | |
| def test_gradio_client(): | |
| """Test the model via Gradio client (if deployed).""" | |
| print("\nπ Testing via Gradio Client") | |
| print("=" * 50) | |
| try: | |
| from gradio_client import Client | |
| space_url = "https://dunedain-ai-road-detection-model.hf.space" | |
| print(f"π Connecting to: {space_url}") | |
| client = Client(space_url, hf_token=HF_TOKEN) | |
| # Test with example file | |
| example_path = ( | |
| "utils/terrain_analyzer/road_detection_model/examples/example.png" | |
| ) | |
| if not os.path.exists(example_path): | |
| print(f"β Example file not found: {example_path}") | |
| return None | |
| print("π€ Uploading image and running prediction...") | |
| start_time = time.time() | |
| # Use the predict function (adjust based on your Gradio interface) | |
| # Load the image and convert to base64-encoded string | |
| import base64 | |
| with open(example_path, "rb") as f: | |
| img_bytes = f.read() | |
| image_b64 = base64.b64encode(img_bytes).decode("utf-8") | |
| payload = json.dumps({"image_b64": image_b64}) | |
| result = client.predict( | |
| payload, | |
| api_name="/predict", # API endpoint | |
| ) | |
| end_time = time.time() | |
| inference_time = end_time - start_time | |
| print(f"β±οΈ Remote inference time: {inference_time:.3f} seconds") | |
| # INSERT_YOUR_CODE | |
| # The result is a JSON string with a key "mask_base64" containing the b64-encoded PNG mask. | |
| result_json = json.loads(result) | |
| print(result_json.keys()) | |
| mask_b64 = result_json.get("mask_base64") | |
| # INSERT_YOUR_CODE | |
| from datetime import datetime | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| if mask_b64: | |
| output_dir = "utils/terrain_analyzer/road_detection_model/test_results" | |
| os.makedirs(output_dir, exist_ok=True) | |
| mask_path = os.path.join(output_dir, f"remote_pred_mask_{timestamp}.png") | |
| with open(mask_path, "wb") as f: | |
| f.write(base64.b64decode(mask_b64)) | |
| print(f"πΌοΈ Remote predicted mask saved: {mask_path}") | |
| else: | |
| print("β No mask_base64 found in result.") | |
| # Save remote test results | |
| save_remote_test_results(result, inference_time, space_url) | |
| return result | |
| except ImportError: | |
| print("β gradio_client not installed. Install with: pip install gradio-client") | |
| return None | |
| except Exception as e: | |
| print(f"β Remote test failed: {e}") | |
| return None | |
| def save_remote_test_results(result, inference_time, space_url): | |
| """Save remote test results.""" | |
| output_dir = "utils/terrain_analyzer/road_detection_model/test_results" | |
| os.makedirs(output_dir, exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Save remote test results | |
| remote_stats = { | |
| "timestamp": timestamp, | |
| "test_type": "remote_gradio", | |
| "space_url": space_url, | |
| "inference_time_seconds": round(inference_time, 3), | |
| "result": str(result), | |
| } | |
| remote_path = os.path.join(output_dir, f"remote_test_{timestamp}.json") | |
| with open(remote_path, "w") as f: | |
| json.dump(remote_stats, f, indent=2) | |
| print(f"π Remote test results saved: {remote_path}") | |
| def main(): | |
| """Run comprehensive tests.""" | |
| print("π Road Detection Model Test Suite") | |
| print("=" * 60) | |
| # Test Remote Gradio client | |
| print("\n" + "=" * 60) | |
| remote_result = test_gradio_client() | |
| # Summary | |
| print("\n" + "=" * 60) | |
| print("π― Test Summary") | |
| print("=" * 60) | |
| if remote_result: | |
| print("β Remote test: PASSED") | |
| else: | |
| print("β Remote test: FAILED") | |
| print("π Test suite completed!") | |
| if __name__ == "__main__": | |
| main() | |