Spaces:
Configuration error
Configuration error
File size: 5,380 Bytes
24f3fb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import base64
import requests
import json
from PIL import Image
import numpy as np
import io
import sys
import matplotlib.pyplot as plt
# Config
SERVER_URL = "http://localhost:8000"
IMAGE_PATH = "test_images/plate1.png"
HOUSE_ID='c8c5fdea-7138-44ea-9f02-7fdcd47ff8cf'
LOCATION_HIERARCHY = ["Kitchen", "Left Upper Cabinet", "Middle Shelf"]
LOCATION_IMAGES = ["test_images/kitchen.png", "test_images/cabinets.png", "test_images/shelf.png"]
def encode_image_to_base64(image_path):
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode('utf-8')
def decode_base64_to_image(base64_str):
return Image.open(io.BytesIO(base64.b64decode(base64_str)))
def display_image(title, img_pil):
plt.figure()
plt.title(title)
plt.imshow(img_pil)
plt.axis('off')
plt.show()
def decode_base64_to_np(base64_str: str) -> np.ndarray:
image_bytes = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return np.array(image)
def check_and_set_primary_location(house_id, object_id):
# 1. Query current location
loc_payload = {"house_id": house_id, "object_id": object_id, "include_images": True}
loc_response = requests.post(f"{SERVER_URL}/object/get_primary_location", json=loc_payload)
if loc_response.status_code == 200 and loc_response.json().get("locations"):
locations = loc_response.json()["locations"]
print(f"Object {object_id} already has a primary location:")
for loc in locations:
print(f"- {loc['name']}")
# Optional: Display stitched image preview
images = []
for loc in locations:
image_data = loc.get("image_base64")
if image_data:
# Decode base64 or load image depending on your API’s format
# Assuming numpy array is directly returned or reloaded here
# You might need to adjust if using base64 strings
#img_np = np.array(image_data) # Adjust if data is base64 or bytes
img_np = decode_base64_to_np(image_data)
images.append(Image.fromarray(img_np))
if images:
# Stitch images horizontally
widths, heights = zip(*(img.size for img in images))
total_width = sum(widths)
max_height = max(heights)
stitched = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in images:
stitched.paste(img, (x_offset, 0))
x_offset += img.width
# Show stitched image using matplotlib
plt.imshow(stitched)
plt.axis('off')
plt.show()
else:
print("No images available to display.")
return
# 2. Prompt user for location hierarchy
print(f"No location found for object {object_id}.")
location_hierarchy = LOCATION_HIERARCHY
#location_list = [s.strip() for s in location_hierarchy.split('>')]
location_list = location_hierarchy # Already a list, no need to split
# 3. Set primary location
set_payload = {
"house_id": house_id,
"object_id": object_id,
"location_hierarchy": location_list
}
set_response = requests.post(f"{SERVER_URL}/object/set_primary_location", json=set_payload)
if set_response.status_code == 200:
print(f"Primary location set for object {object_id}: {location_list}")
else:
print(f"Failed to set primary location: {set_response.text}")
# 1. Encode image
image_base64 = encode_image_to_base64(IMAGE_PATH)
# 2. Prepare mask request with dummy points and prompt
mask_payload = {
"image_base64": image_base64,
"points": [{"x": 1000, "y": 1000}], # example click
"labels": [1],
"prompt": "plate", # example prompt
"return_raw_mask": True,
"return_rgb_mask": True,
"return_embeddings": True
}
# 3. Call Mask API
response = requests.post(f"{SERVER_URL}/object/get_mask", json=mask_payload)
response.raise_for_status()
mask_response = response.json()
# 4. Display returned masks
if mask_response.get("raw_mask_base64"):
mask_img = decode_base64_to_image(mask_response["raw_mask_base64"])
display_image("Raw Mask", mask_img)
if mask_response.get("rgb_mask_base64"):
rgb_mask_img = decode_base64_to_image(mask_response["rgb_mask_base64"])
display_image("RGB Mask", rgb_mask_img)
# 5. Query vector DB using embedding
embedding = mask_response.get("embedding")
if not embedding:
print("No embedding returned. Skipping vector DB query.")
sys.exit(0) # 0 for normal, or 1 for error
query_payload = {
"embedding": embedding,
"k": 5
}
query_response = requests.post(f"{SERVER_URL}/object/query_by_embedding", json=query_payload)
query_response.raise_for_status()
query_results = query_response.json()
# 6. Display results
print("Top Matches:")
for result in query_results:
print(f"- Object ID: {result['object_id']}")
print(f" Aggregated Similarity: {result['aggregated_similarity']:.4f}")
print(f" Probability: {result['probability']:.4f}")
print(f" Descriptions: {result['descriptions']}\n")
# 7. Lookup object info for top match
top_result = query_results[0] if query_results else None
if not top_result:
sys.exit(1)
check_and_set_primary_location(HOUSE_ID, top_result["object_id"])
|