Spaces:
Configuration error
Configuration error
| import base64 | |
| import requests | |
| import json | |
| from PIL import Image | |
| import numpy as np | |
| import io | |
| import sys | |
| import matplotlib.pyplot as plt | |
| # Config | |
| SERVER_URL = "http://localhost:8000" | |
| IMAGE_PATH = "test_images/plate1.png" | |
| HOUSE_ID='c8c5fdea-7138-44ea-9f02-7fdcd47ff8cf' | |
| LOCATION_HIERARCHY = ["Kitchen", "Left Upper Cabinet", "Middle Shelf"] | |
| LOCATION_IMAGES = ["test_images/kitchen.png", "test_images/cabinets.png", "test_images/shelf.png"] | |
| def encode_image_to_base64(image_path): | |
| with open(image_path, "rb") as f: | |
| return base64.b64encode(f.read()).decode('utf-8') | |
| def decode_base64_to_image(base64_str): | |
| return Image.open(io.BytesIO(base64.b64decode(base64_str))) | |
| def display_image(title, img_pil): | |
| plt.figure() | |
| plt.title(title) | |
| plt.imshow(img_pil) | |
| plt.axis('off') | |
| plt.show() | |
| def decode_base64_to_np(base64_str: str) -> np.ndarray: | |
| image_bytes = base64.b64decode(base64_str) | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| return np.array(image) | |
| def check_and_set_primary_location(house_id, object_id): | |
| # 1. Query current location | |
| loc_payload = {"house_id": house_id, "object_id": object_id, "include_images": True} | |
| loc_response = requests.post(f"{SERVER_URL}/object/get_primary_location", json=loc_payload) | |
| if loc_response.status_code == 200 and loc_response.json().get("locations"): | |
| locations = loc_response.json()["locations"] | |
| print(f"Object {object_id} already has a primary location:") | |
| for loc in locations: | |
| print(f"- {loc['name']}") | |
| # Optional: Display stitched image preview | |
| images = [] | |
| for loc in locations: | |
| image_data = loc.get("image_base64") | |
| if image_data: | |
| # Decode base64 or load image depending on your API’s format | |
| # Assuming numpy array is directly returned or reloaded here | |
| # You might need to adjust if using base64 strings | |
| #img_np = np.array(image_data) # Adjust if data is base64 or bytes | |
| img_np = decode_base64_to_np(image_data) | |
| images.append(Image.fromarray(img_np)) | |
| if images: | |
| # Stitch images horizontally | |
| widths, heights = zip(*(img.size for img in images)) | |
| total_width = sum(widths) | |
| max_height = max(heights) | |
| stitched = Image.new('RGB', (total_width, max_height)) | |
| x_offset = 0 | |
| for img in images: | |
| stitched.paste(img, (x_offset, 0)) | |
| x_offset += img.width | |
| # Show stitched image using matplotlib | |
| plt.imshow(stitched) | |
| plt.axis('off') | |
| plt.show() | |
| else: | |
| print("No images available to display.") | |
| return | |
| # 2. Prompt user for location hierarchy | |
| print(f"No location found for object {object_id}.") | |
| location_hierarchy = LOCATION_HIERARCHY | |
| #location_list = [s.strip() for s in location_hierarchy.split('>')] | |
| location_list = location_hierarchy # Already a list, no need to split | |
| # 3. Set primary location | |
| set_payload = { | |
| "house_id": house_id, | |
| "object_id": object_id, | |
| "location_hierarchy": location_list | |
| } | |
| set_response = requests.post(f"{SERVER_URL}/object/set_primary_location", json=set_payload) | |
| if set_response.status_code == 200: | |
| print(f"Primary location set for object {object_id}: {location_list}") | |
| else: | |
| print(f"Failed to set primary location: {set_response.text}") | |
| # 1. Encode image | |
| image_base64 = encode_image_to_base64(IMAGE_PATH) | |
| # 2. Prepare mask request with dummy points and prompt | |
| mask_payload = { | |
| "image_base64": image_base64, | |
| "points": [{"x": 1000, "y": 1000}], # example click | |
| "labels": [1], | |
| "prompt": "plate", # example prompt | |
| "return_raw_mask": True, | |
| "return_rgb_mask": True, | |
| "return_embeddings": True | |
| } | |
| # 3. Call Mask API | |
| response = requests.post(f"{SERVER_URL}/object/get_mask", json=mask_payload) | |
| response.raise_for_status() | |
| mask_response = response.json() | |
| # 4. Display returned masks | |
| if mask_response.get("raw_mask_base64"): | |
| mask_img = decode_base64_to_image(mask_response["raw_mask_base64"]) | |
| display_image("Raw Mask", mask_img) | |
| if mask_response.get("rgb_mask_base64"): | |
| rgb_mask_img = decode_base64_to_image(mask_response["rgb_mask_base64"]) | |
| display_image("RGB Mask", rgb_mask_img) | |
| # 5. Query vector DB using embedding | |
| embedding = mask_response.get("embedding") | |
| if not embedding: | |
| print("No embedding returned. Skipping vector DB query.") | |
| sys.exit(0) # 0 for normal, or 1 for error | |
| query_payload = { | |
| "embedding": embedding, | |
| "k": 5 | |
| } | |
| query_response = requests.post(f"{SERVER_URL}/object/query_by_embedding", json=query_payload) | |
| query_response.raise_for_status() | |
| query_results = query_response.json() | |
| # 6. Display results | |
| print("Top Matches:") | |
| for result in query_results: | |
| print(f"- Object ID: {result['object_id']}") | |
| print(f" Aggregated Similarity: {result['aggregated_similarity']:.4f}") | |
| print(f" Probability: {result['probability']:.4f}") | |
| print(f" Descriptions: {result['descriptions']}\n") | |
| # 7. Lookup object info for top match | |
| top_result = query_results[0] if query_results else None | |
| if not top_result: | |
| sys.exit(1) | |
| check_and_set_primary_location(HOUSE_ID, top_result["object_id"]) | |