|
|
|
|
|
""" |
|
|
Comprehensive inference test for SAM3 endpoint |
|
|
Tests multiple images and saves detailed results with visualizations |
|
|
""" |
|
|
|
|
|
import requests |
|
|
import base64 |
|
|
import json |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import io |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import sys |
|
|
|
|
|
|
|
|
ENDPOINT_URL = "https://p6irm2x7y9mwp4l4.us-east-1.aws.endpoints.huggingface.cloud" |
|
|
CLASSES = ["Pothole", "Road crack", "Road"] |
|
|
TEST_IMAGES_DIR = Path("assets/test_images") |
|
|
OUTPUT_DIR = Path(".cache/test/inference") |
|
|
|
|
|
|
|
|
COLORS = { |
|
|
"Pothole": (255, 0, 0, 128), |
|
|
"Road crack": (255, 255, 0, 128), |
|
|
"Road": (0, 0, 255, 128) |
|
|
} |
|
|
|
|
|
def ensure_output_dir(image_name): |
|
|
"""Create output directory for image results""" |
|
|
output_path = OUTPUT_DIR / image_name |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
return output_path |
|
|
|
|
|
def save_request_data(output_path, image_path, classes): |
|
|
"""Save request metadata""" |
|
|
request_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"endpoint": ENDPOINT_URL, |
|
|
"image_path": str(image_path), |
|
|
"image_name": image_path.name, |
|
|
"classes": classes |
|
|
} |
|
|
|
|
|
with open(output_path / "request.json", "w") as f: |
|
|
json.dump(request_data, f, indent=2) |
|
|
|
|
|
return request_data |
|
|
|
|
|
def save_response_data(output_path, results, status_code, elapsed_time): |
|
|
"""Save response data""" |
|
|
|
|
|
simplified_results = [] |
|
|
for result in results: |
|
|
simplified = { |
|
|
"label": result["label"], |
|
|
"score": result["score"], |
|
|
"mask_size_bytes": len(base64.b64decode(result["mask"])) if "mask" in result else 0 |
|
|
} |
|
|
simplified_results.append(simplified) |
|
|
|
|
|
response_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"status_code": status_code, |
|
|
"elapsed_time_seconds": elapsed_time, |
|
|
"results_count": len(results), |
|
|
"results": simplified_results |
|
|
} |
|
|
|
|
|
with open(output_path / "response.json", "w") as f: |
|
|
json.dump(response_data, f, indent=2) |
|
|
|
|
|
|
|
|
with open(output_path / "full_results.json", "w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
|
|
|
return response_data |
|
|
|
|
|
def create_visualization(original_img, results, output_path): |
|
|
"""Create and save visualization with masks overlay""" |
|
|
width, height = original_img.size |
|
|
|
|
|
|
|
|
overlay = Image.new('RGBA', original_img.size, (0, 0, 0, 0)) |
|
|
|
|
|
mask_stats = {} |
|
|
|
|
|
for result in results: |
|
|
label = result['label'] |
|
|
mask_b64 = result['mask'] |
|
|
mask_data = base64.b64decode(mask_b64) |
|
|
mask_img = Image.open(io.BytesIO(mask_data)).convert('L') |
|
|
|
|
|
|
|
|
mask_img.save(output_path / f"mask_{label.replace(' ', '_')}.png") |
|
|
|
|
|
|
|
|
pixels = np.array(mask_img) |
|
|
coverage = (pixels > 0).sum() / pixels.size * 100 |
|
|
mask_stats[label] = { |
|
|
"coverage_percent": round(coverage, 4), |
|
|
"non_zero_pixels": int((pixels > 0).sum()), |
|
|
"total_pixels": int(pixels.size) |
|
|
} |
|
|
|
|
|
|
|
|
color = COLORS.get(label, (128, 128, 128, 128)) |
|
|
colored_mask = Image.new('RGBA', mask_img.size, color) |
|
|
colored_mask.putalpha(mask_img) |
|
|
|
|
|
|
|
|
overlay = Image.alpha_composite(overlay, colored_mask) |
|
|
|
|
|
|
|
|
original_rgba = original_img.convert('RGBA') |
|
|
result_img = Image.alpha_composite(original_rgba, overlay) |
|
|
result_img.save(output_path / "visualization.png") |
|
|
|
|
|
|
|
|
original_img.save(output_path / "original.jpg") |
|
|
|
|
|
|
|
|
create_legend(output_path, mask_stats) |
|
|
|
|
|
return mask_stats |
|
|
|
|
|
def create_legend(output_path, mask_stats): |
|
|
"""Create legend with colors and statistics""" |
|
|
legend_height = 40 + len(COLORS) * 60 |
|
|
legend = Image.new('RGB', (500, legend_height), 'white') |
|
|
draw = ImageDraw.Draw(legend) |
|
|
|
|
|
|
|
|
draw.text([10, 10], "Segmentation Results", fill='black') |
|
|
|
|
|
y_offset = 40 |
|
|
for label, color in COLORS.items(): |
|
|
|
|
|
draw.rectangle([10, y_offset, 40, y_offset + 30], fill=color[:3]) |
|
|
|
|
|
|
|
|
stats = mask_stats.get(label, {"coverage_percent": 0}) |
|
|
text = f"{label}: {stats['coverage_percent']:.2f}% coverage" |
|
|
draw.text([50, y_offset + 5], text, fill='black') |
|
|
|
|
|
y_offset += 60 |
|
|
|
|
|
legend.save(output_path / "legend.png") |
|
|
|
|
|
def test_image(image_path): |
|
|
"""Test a single image""" |
|
|
print(f"\n{'='*80}") |
|
|
print(f"Testing: {image_path.name}") |
|
|
print('='*80) |
|
|
|
|
|
|
|
|
image_name = image_path.stem |
|
|
output_path = ensure_output_dir(image_name) |
|
|
|
|
|
|
|
|
with open(image_path, "rb") as f: |
|
|
image_data = f.read() |
|
|
image_b64 = base64.b64encode(image_data).decode() |
|
|
|
|
|
original_img = Image.open(io.BytesIO(image_data)) |
|
|
print(f"Image size: {original_img.size}") |
|
|
print(f"Image mode: {original_img.mode}") |
|
|
|
|
|
|
|
|
save_request_data(output_path, image_path, CLASSES) |
|
|
|
|
|
|
|
|
print(f"\nCalling endpoint...") |
|
|
try: |
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
response = requests.post( |
|
|
ENDPOINT_URL, |
|
|
json={ |
|
|
"inputs": image_b64, |
|
|
"parameters": { |
|
|
"classes": CLASSES |
|
|
} |
|
|
}, |
|
|
timeout=120 |
|
|
) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
|
|
|
print(f"Response status: {response.status_code}") |
|
|
print(f"Response time: {elapsed_time:.2f}s") |
|
|
|
|
|
if response.status_code == 200: |
|
|
results = response.json() |
|
|
print(f"✅ Got {len(results)} segmentation results") |
|
|
|
|
|
|
|
|
save_response_data(output_path, results, response.status_code, elapsed_time) |
|
|
|
|
|
|
|
|
mask_stats = create_visualization(original_img, results, output_path) |
|
|
|
|
|
|
|
|
print("\nSegmentation Coverage:") |
|
|
for label, stats in mask_stats.items(): |
|
|
print(f" • {label}: {stats['coverage_percent']:.2f}% ({stats['non_zero_pixels']:,} pixels)") |
|
|
|
|
|
print(f"\n✅ Results saved to: {output_path}") |
|
|
return True |
|
|
else: |
|
|
print(f"❌ Error: {response.status_code}") |
|
|
print(response.text) |
|
|
|
|
|
|
|
|
error_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"status_code": response.status_code, |
|
|
"error": response.text, |
|
|
"elapsed_time_seconds": elapsed_time |
|
|
} |
|
|
with open(output_path / "error.json", "w") as f: |
|
|
json.dump(error_data, f, indent=2) |
|
|
|
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Exception: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
error_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"exception": str(e), |
|
|
"traceback": traceback.format_exc() |
|
|
} |
|
|
with open(output_path / "error.json", "w") as f: |
|
|
json.dump(error_data, f, indent=2) |
|
|
|
|
|
return False |
|
|
|
|
|
def main(): |
|
|
"""Run comprehensive inference tests""" |
|
|
print("="*80) |
|
|
print("SAM3 Comprehensive Inference Test") |
|
|
print("="*80) |
|
|
print(f"Endpoint: {ENDPOINT_URL}") |
|
|
print(f"Classes: {', '.join(CLASSES)}") |
|
|
print(f"Test images directory: {TEST_IMAGES_DIR}") |
|
|
print(f"Output directory: {OUTPUT_DIR}") |
|
|
|
|
|
|
|
|
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp'] |
|
|
test_images = [] |
|
|
for ext in image_extensions: |
|
|
test_images.extend(TEST_IMAGES_DIR.glob(f"*{ext}")) |
|
|
test_images.extend(TEST_IMAGES_DIR.glob(f"*{ext.upper()}")) |
|
|
|
|
|
test_images = sorted(set(test_images)) |
|
|
|
|
|
if not test_images: |
|
|
print(f"\n❌ No test images found in {TEST_IMAGES_DIR}") |
|
|
sys.exit(1) |
|
|
|
|
|
print(f"\nFound {len(test_images)} test image(s)") |
|
|
|
|
|
|
|
|
results_summary = [] |
|
|
for image_path in test_images: |
|
|
success = test_image(image_path) |
|
|
results_summary.append({ |
|
|
"image": image_path.name, |
|
|
"success": success |
|
|
}) |
|
|
|
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("Test Summary") |
|
|
print("="*80) |
|
|
|
|
|
successful = sum(1 for r in results_summary if r["success"]) |
|
|
failed = len(results_summary) - successful |
|
|
|
|
|
print(f"Total: {len(results_summary)}") |
|
|
print(f"Successful: {successful}") |
|
|
print(f"Failed: {failed}") |
|
|
|
|
|
print("\nResults:") |
|
|
for result in results_summary: |
|
|
status = "✅" if result["success"] else "❌" |
|
|
print(f" {status} {result['image']}") |
|
|
|
|
|
|
|
|
summary_path = OUTPUT_DIR / "summary.json" |
|
|
with open(summary_path, "w") as f: |
|
|
json.dump({ |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"total": len(results_summary), |
|
|
"successful": successful, |
|
|
"failed": failed, |
|
|
"results": results_summary |
|
|
}, f, indent=2) |
|
|
|
|
|
print(f"\nSummary saved to: {summary_path}") |
|
|
|
|
|
if failed > 0: |
|
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|