import base64 import requests import json from PIL import Image import numpy as np import io import sys import matplotlib.pyplot as plt # Config SERVER_URL = "http://localhost:8000" IMAGE_PATH = "test_images/plate1.png" HOUSE_ID='c8c5fdea-7138-44ea-9f02-7fdcd47ff8cf' LOCATION_HIERARCHY = ["Kitchen", "Left Upper Cabinet", "Middle Shelf"] LOCATION_IMAGES = ["test_images/kitchen.png", "test_images/cabinets.png", "test_images/shelf.png"] def encode_image_to_base64(image_path): with open(image_path, "rb") as f: return base64.b64encode(f.read()).decode('utf-8') def decode_base64_to_image(base64_str): return Image.open(io.BytesIO(base64.b64decode(base64_str))) def display_image(title, img_pil): plt.figure() plt.title(title) plt.imshow(img_pil) plt.axis('off') plt.show() def decode_base64_to_np(base64_str: str) -> np.ndarray: image_bytes = base64.b64decode(base64_str) image = Image.open(io.BytesIO(image_bytes)).convert("RGB") return np.array(image) def check_and_set_primary_location(house_id, object_id): # 1. Query current location loc_payload = {"house_id": house_id, "object_id": object_id, "include_images": True} loc_response = requests.post(f"{SERVER_URL}/object/get_primary_location", json=loc_payload) if loc_response.status_code == 200 and loc_response.json().get("locations"): locations = loc_response.json()["locations"] print(f"Object {object_id} already has a primary location:") for loc in locations: print(f"- {loc['name']}") # Optional: Display stitched image preview images = [] for loc in locations: image_data = loc.get("image_base64") if image_data: # Decode base64 or load image depending on your API’s format # Assuming numpy array is directly returned or reloaded here # You might need to adjust if using base64 strings #img_np = np.array(image_data) # Adjust if data is base64 or bytes img_np = decode_base64_to_np(image_data) images.append(Image.fromarray(img_np)) if images: # Stitch images horizontally widths, heights = zip(*(img.size for img in images)) total_width = sum(widths) max_height = max(heights) stitched = Image.new('RGB', (total_width, max_height)) x_offset = 0 for img in images: stitched.paste(img, (x_offset, 0)) x_offset += img.width # Show stitched image using matplotlib plt.imshow(stitched) plt.axis('off') plt.show() else: print("No images available to display.") return # 2. Prompt user for location hierarchy print(f"No location found for object {object_id}.") location_hierarchy = LOCATION_HIERARCHY #location_list = [s.strip() for s in location_hierarchy.split('>')] location_list = location_hierarchy # Already a list, no need to split # 3. Set primary location set_payload = { "house_id": house_id, "object_id": object_id, "location_hierarchy": location_list } set_response = requests.post(f"{SERVER_URL}/object/set_primary_location", json=set_payload) if set_response.status_code == 200: print(f"Primary location set for object {object_id}: {location_list}") else: print(f"Failed to set primary location: {set_response.text}") # 1. Encode image image_base64 = encode_image_to_base64(IMAGE_PATH) # 2. Prepare mask request with dummy points and prompt mask_payload = { "image_base64": image_base64, "points": [{"x": 1000, "y": 1000}], # example click "labels": [1], "prompt": "plate", # example prompt "return_raw_mask": True, "return_rgb_mask": True, "return_embeddings": True } # 3. Call Mask API response = requests.post(f"{SERVER_URL}/object/get_mask", json=mask_payload) response.raise_for_status() mask_response = response.json() # 4. Display returned masks if mask_response.get("raw_mask_base64"): mask_img = decode_base64_to_image(mask_response["raw_mask_base64"]) display_image("Raw Mask", mask_img) if mask_response.get("rgb_mask_base64"): rgb_mask_img = decode_base64_to_image(mask_response["rgb_mask_base64"]) display_image("RGB Mask", rgb_mask_img) # 5. Query vector DB using embedding embedding = mask_response.get("embedding") if not embedding: print("No embedding returned. Skipping vector DB query.") sys.exit(0) # 0 for normal, or 1 for error query_payload = { "embedding": embedding, "k": 5 } query_response = requests.post(f"{SERVER_URL}/object/query_by_embedding", json=query_payload) query_response.raise_for_status() query_results = query_response.json() # 6. Display results print("Top Matches:") for result in query_results: print(f"- Object ID: {result['object_id']}") print(f" Aggregated Similarity: {result['aggregated_similarity']:.4f}") print(f" Probability: {result['probability']:.4f}") print(f" Descriptions: {result['descriptions']}\n") # 7. Lookup object info for top match top_result = query_results[0] if query_results else None if not top_result: sys.exit(1) check_and_set_primary_location(HOUSE_ID, top_result["object_id"])