Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Test script for Wall Color Visualizer API | |
| """ | |
| import requests | |
| import base64 | |
| import json | |
| from pathlib import Path | |
| # Configuration | |
| BASE_URL = "http://localhost:8000" | |
| def test_health(): | |
| """Test health endpoint""" | |
| print("Testing health endpoint...") | |
| try: | |
| response = requests.get(f"{BASE_URL}/health") | |
| print(f"Status: {response.status_code}") | |
| print(f"Response: {json.dumps(response.json(), indent=2)}") | |
| return response.status_code == 200 | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return False | |
| def test_simple_segment(image_path): | |
| """Test simple segmentation endpoint""" | |
| print(f"\nTesting simple segmentation with {image_path}...") | |
| if not Path(image_path).exists(): | |
| print(f"Error: Image file not found: {image_path}") | |
| return False | |
| try: | |
| with open(image_path, 'rb') as f: | |
| files = {'file': f} | |
| response = requests.post( | |
| f"{BASE_URL}/simple-segment", | |
| files=files, | |
| timeout=60 | |
| ) | |
| print(f"Status: {response.status_code}") | |
| if response.status_code == 200: | |
| data = response.json() | |
| print(f"Success: {data['success']}") | |
| print(f"Number of masks: {data['num_masks']}") | |
| print(f"Method: {data.get('method', 'N/A')}") | |
| return True | |
| else: | |
| print(f"Error: {response.text}") | |
| return False | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return False | |
| def test_segment_automatic(image_path): | |
| """Test automatic segmentation endpoint (requires SAM)""" | |
| print(f"\nTesting automatic segmentation with {image_path}...") | |
| if not Path(image_path).exists(): | |
| print(f"Error: Image file not found: {image_path}") | |
| return False | |
| try: | |
| with open(image_path, 'rb') as f: | |
| files = {'file': f} | |
| response = requests.post( | |
| f"{BASE_URL}/segment-automatic", | |
| files=files, | |
| timeout=60 | |
| ) | |
| print(f"Status: {response.status_code}") | |
| if response.status_code == 200: | |
| data = response.json() | |
| print(f"Success: {data['success']}") | |
| print(f"Number of masks: {data['num_masks']}") | |
| return True | |
| else: | |
| print(f"Error: {response.text}") | |
| return False | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return False | |
| def test_apply_color(image_path): | |
| """Test color application (requires existing segmentation)""" | |
| print(f"\nTesting color application...") | |
| # First, get a segmentation | |
| if not Path(image_path).exists(): | |
| print(f"Error: Image file not found: {image_path}") | |
| return False | |
| try: | |
| # Get segmentation | |
| with open(image_path, 'rb') as f: | |
| files = {'file': f} | |
| seg_response = requests.post( | |
| f"{BASE_URL}/simple-segment", | |
| files=files, | |
| timeout=60 | |
| ) | |
| if seg_response.status_code != 200: | |
| print("Failed to get segmentation") | |
| return False | |
| seg_data = seg_response.json() | |
| if not seg_data['masks']: | |
| print("No masks found") | |
| return False | |
| # Apply color to first mask | |
| image_base64 = seg_data['image_base64'] | |
| mask_base64 = seg_data['masks'][0]['mask_base64'] | |
| color_request = { | |
| 'image_base64': image_base64, | |
| 'mask_base64': mask_base64, | |
| 'color_hex': '#FF5733', # Orange-red color | |
| 'opacity': 0.8 | |
| } | |
| response = requests.post( | |
| f"{BASE_URL}/apply-color", | |
| json=color_request, | |
| timeout=60 | |
| ) | |
| print(f"Status: {response.status_code}") | |
| if response.status_code == 200: | |
| data = response.json() | |
| print(f"Success: {data['success']}") | |
| print("Color applied successfully!") | |
| # Optionally save result | |
| if data.get('result_base64'): | |
| result_bytes = base64.b64decode(data['result_base64']) | |
| output_path = 'result_colored.png' | |
| with open(output_path, 'wb') as f: | |
| f.write(result_bytes) | |
| print(f"Result saved to: {output_path}") | |
| return True | |
| else: | |
| print(f"Error: {response.text}") | |
| return False | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return False | |
| def main(): | |
| """Run all tests""" | |
| print("=" * 60) | |
| print("Wall Color Visualizer API Test Suite") | |
| print("=" * 60) | |
| results = {} | |
| # Test 1: Health check | |
| results['health'] = test_health() | |
| # Ask for test image | |
| print("\n" + "=" * 60) | |
| image_path = input("Enter path to test image (or press Enter to skip): ").strip() | |
| if image_path and Path(image_path).exists(): | |
| # Test 2: Simple segmentation | |
| results['simple_segment'] = test_simple_segment(image_path) | |
| # Test 3: Automatic segmentation (SAM) | |
| results['auto_segment'] = test_segment_automatic(image_path) | |
| # Test 4: Color application | |
| results['apply_color'] = test_apply_color(image_path) | |
| else: | |
| print("Skipping image-based tests...") | |
| # Summary | |
| print("\n" + "=" * 60) | |
| print("Test Results Summary") | |
| print("=" * 60) | |
| for test_name, passed in results.items(): | |
| status = "✓ PASSED" if passed else "✗ FAILED" | |
| print(f"{test_name:20} : {status}") | |
| total = len(results) | |
| passed = sum(results.values()) | |
| print(f"\nTotal: {passed}/{total} tests passed") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |