|
|
|
|
|
""" |
|
|
SAM3 API Usage Example |
|
|
|
|
|
This example shows how to use the SAM3 text-prompted segmentation API |
|
|
for road defect detection. |
|
|
""" |
|
|
import requests |
|
|
import base64 |
|
|
from PIL import Image |
|
|
import io |
|
|
import os |
|
|
|
|
|
|
|
|
ENDPOINT_URL = "https://p6irm2x7y9mwp4l4.us-east-1.aws.endpoints.huggingface.cloud" |
|
|
|
|
|
def segment_image(image_path, classes): |
|
|
""" |
|
|
Segment objects in an image using text prompts |
|
|
|
|
|
Args: |
|
|
image_path: Path to the image file |
|
|
classes: List of object classes to segment (e.g., ["pothole", "crack"]) |
|
|
|
|
|
Returns: |
|
|
List of dictionaries with 'label', 'mask' (base64), and 'score' |
|
|
""" |
|
|
|
|
|
with open(image_path, "rb") as f: |
|
|
image_b64 = base64.b64encode(f.read()).decode() |
|
|
|
|
|
|
|
|
response = requests.post( |
|
|
ENDPOINT_URL, |
|
|
json={ |
|
|
"inputs": image_b64, |
|
|
"parameters": { |
|
|
"classes": classes |
|
|
} |
|
|
}, |
|
|
timeout=30 |
|
|
) |
|
|
|
|
|
response.raise_for_status() |
|
|
return response.json() |
|
|
|
|
|
def save_masks(results, output_dir="output"): |
|
|
""" |
|
|
Save segmentation masks as PNG files |
|
|
|
|
|
Args: |
|
|
results: API response (list of dictionaries) |
|
|
output_dir: Directory to save masks |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
for result in results: |
|
|
label = result["label"] |
|
|
score = result["score"] |
|
|
mask_b64 = result["mask"] |
|
|
|
|
|
|
|
|
mask_bytes = base64.b64decode(mask_b64) |
|
|
mask_image = Image.open(io.BytesIO(mask_bytes)) |
|
|
|
|
|
|
|
|
output_path = os.path.join(output_dir, f"mask_{label}.png") |
|
|
mask_image.save(output_path) |
|
|
|
|
|
print(f"✓ Saved {label} mask: {output_path} (score: {score:.2f})") |
|
|
|
|
|
def main(): |
|
|
"""Example: Road defect detection""" |
|
|
|
|
|
|
|
|
print("Example 1: Road Defect Detection") |
|
|
print("=" * 50) |
|
|
|
|
|
image_path = "../test_images/test.jpg" |
|
|
classes = ["pothole", "crack", "patch", "debris"] |
|
|
|
|
|
print(f"Image: {image_path}") |
|
|
print(f"Classes: {classes}") |
|
|
print() |
|
|
|
|
|
try: |
|
|
results = segment_image(image_path, classes) |
|
|
print(f"Found {len(results)} segmentation masks") |
|
|
print() |
|
|
|
|
|
save_masks(results, output_dir="defects_output") |
|
|
print() |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
print("\nExample 2: Specific Object Segmentation") |
|
|
print("=" * 50) |
|
|
|
|
|
classes = ["asphalt", "yellow line"] |
|
|
|
|
|
print(f"Classes: {classes}") |
|
|
print() |
|
|
|
|
|
try: |
|
|
results = segment_image(image_path, classes) |
|
|
print(f"Found {len(results)} segmentation masks") |
|
|
print() |
|
|
|
|
|
save_masks(results, output_dir="objects_output") |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|