object-memory / test_scripts /test_hud_api.py
russ4stall
fresh history
24f3fb6
raw
history blame
5.38 kB
import base64
import requests
import json
from PIL import Image
import numpy as np
import io
import sys
import matplotlib.pyplot as plt
# Config
SERVER_URL = "http://localhost:8000"
IMAGE_PATH = "test_images/plate1.png"
HOUSE_ID='c8c5fdea-7138-44ea-9f02-7fdcd47ff8cf'
LOCATION_HIERARCHY = ["Kitchen", "Left Upper Cabinet", "Middle Shelf"]
LOCATION_IMAGES = ["test_images/kitchen.png", "test_images/cabinets.png", "test_images/shelf.png"]
def encode_image_to_base64(image_path):
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode('utf-8')
def decode_base64_to_image(base64_str):
return Image.open(io.BytesIO(base64.b64decode(base64_str)))
def display_image(title, img_pil):
plt.figure()
plt.title(title)
plt.imshow(img_pil)
plt.axis('off')
plt.show()
def decode_base64_to_np(base64_str: str) -> np.ndarray:
image_bytes = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return np.array(image)
def check_and_set_primary_location(house_id, object_id):
# 1. Query current location
loc_payload = {"house_id": house_id, "object_id": object_id, "include_images": True}
loc_response = requests.post(f"{SERVER_URL}/object/get_primary_location", json=loc_payload)
if loc_response.status_code == 200 and loc_response.json().get("locations"):
locations = loc_response.json()["locations"]
print(f"Object {object_id} already has a primary location:")
for loc in locations:
print(f"- {loc['name']}")
# Optional: Display stitched image preview
images = []
for loc in locations:
image_data = loc.get("image_base64")
if image_data:
# Decode base64 or load image depending on your API’s format
# Assuming numpy array is directly returned or reloaded here
# You might need to adjust if using base64 strings
#img_np = np.array(image_data) # Adjust if data is base64 or bytes
img_np = decode_base64_to_np(image_data)
images.append(Image.fromarray(img_np))
if images:
# Stitch images horizontally
widths, heights = zip(*(img.size for img in images))
total_width = sum(widths)
max_height = max(heights)
stitched = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in images:
stitched.paste(img, (x_offset, 0))
x_offset += img.width
# Show stitched image using matplotlib
plt.imshow(stitched)
plt.axis('off')
plt.show()
else:
print("No images available to display.")
return
# 2. Prompt user for location hierarchy
print(f"No location found for object {object_id}.")
location_hierarchy = LOCATION_HIERARCHY
#location_list = [s.strip() for s in location_hierarchy.split('>')]
location_list = location_hierarchy # Already a list, no need to split
# 3. Set primary location
set_payload = {
"house_id": house_id,
"object_id": object_id,
"location_hierarchy": location_list
}
set_response = requests.post(f"{SERVER_URL}/object/set_primary_location", json=set_payload)
if set_response.status_code == 200:
print(f"Primary location set for object {object_id}: {location_list}")
else:
print(f"Failed to set primary location: {set_response.text}")
# 1. Encode image
image_base64 = encode_image_to_base64(IMAGE_PATH)
# 2. Prepare mask request with dummy points and prompt
mask_payload = {
"image_base64": image_base64,
"points": [{"x": 1000, "y": 1000}], # example click
"labels": [1],
"prompt": "plate", # example prompt
"return_raw_mask": True,
"return_rgb_mask": True,
"return_embeddings": True
}
# 3. Call Mask API
response = requests.post(f"{SERVER_URL}/object/get_mask", json=mask_payload)
response.raise_for_status()
mask_response = response.json()
# 4. Display returned masks
if mask_response.get("raw_mask_base64"):
mask_img = decode_base64_to_image(mask_response["raw_mask_base64"])
display_image("Raw Mask", mask_img)
if mask_response.get("rgb_mask_base64"):
rgb_mask_img = decode_base64_to_image(mask_response["rgb_mask_base64"])
display_image("RGB Mask", rgb_mask_img)
# 5. Query vector DB using embedding
embedding = mask_response.get("embedding")
if not embedding:
print("No embedding returned. Skipping vector DB query.")
sys.exit(0) # 0 for normal, or 1 for error
query_payload = {
"embedding": embedding,
"k": 5
}
query_response = requests.post(f"{SERVER_URL}/object/query_by_embedding", json=query_payload)
query_response.raise_for_status()
query_results = query_response.json()
# 6. Display results
print("Top Matches:")
for result in query_results:
print(f"- Object ID: {result['object_id']}")
print(f" Aggregated Similarity: {result['aggregated_similarity']:.4f}")
print(f" Probability: {result['probability']:.4f}")
print(f" Descriptions: {result['descriptions']}\n")
# 7. Lookup object info for top match
top_result = query_results[0] if query_results else None
if not top_result:
sys.exit(1)
check_and_set_primary_location(HOUSE_ID, top_result["object_id"])