object-memory / test_scripts /test_hud_api.py
russ4stall
fresh history
24f3fb6
import base64
import requests
import json
from PIL import Image
import numpy as np
import io
import sys
import matplotlib.pyplot as plt
# Config
SERVER_URL = "http://localhost:8000"
IMAGE_PATH = "test_images/plate1.png"
HOUSE_ID='c8c5fdea-7138-44ea-9f02-7fdcd47ff8cf'
LOCATION_HIERARCHY = ["Kitchen", "Left Upper Cabinet", "Middle Shelf"]
LOCATION_IMAGES = ["test_images/kitchen.png", "test_images/cabinets.png", "test_images/shelf.png"]
def encode_image_to_base64(image_path):
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode('utf-8')
def decode_base64_to_image(base64_str):
return Image.open(io.BytesIO(base64.b64decode(base64_str)))
def display_image(title, img_pil):
plt.figure()
plt.title(title)
plt.imshow(img_pil)
plt.axis('off')
plt.show()
def decode_base64_to_np(base64_str: str) -> np.ndarray:
image_bytes = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
return np.array(image)
def check_and_set_primary_location(house_id, object_id):
# 1. Query current location
loc_payload = {"house_id": house_id, "object_id": object_id, "include_images": True}
loc_response = requests.post(f"{SERVER_URL}/object/get_primary_location", json=loc_payload)
if loc_response.status_code == 200 and loc_response.json().get("locations"):
locations = loc_response.json()["locations"]
print(f"Object {object_id} already has a primary location:")
for loc in locations:
print(f"- {loc['name']}")
# Optional: Display stitched image preview
images = []
for loc in locations:
image_data = loc.get("image_base64")
if image_data:
# Decode base64 or load image depending on your API’s format
# Assuming numpy array is directly returned or reloaded here
# You might need to adjust if using base64 strings
#img_np = np.array(image_data) # Adjust if data is base64 or bytes
img_np = decode_base64_to_np(image_data)
images.append(Image.fromarray(img_np))
if images:
# Stitch images horizontally
widths, heights = zip(*(img.size for img in images))
total_width = sum(widths)
max_height = max(heights)
stitched = Image.new('RGB', (total_width, max_height))
x_offset = 0
for img in images:
stitched.paste(img, (x_offset, 0))
x_offset += img.width
# Show stitched image using matplotlib
plt.imshow(stitched)
plt.axis('off')
plt.show()
else:
print("No images available to display.")
return
# 2. Prompt user for location hierarchy
print(f"No location found for object {object_id}.")
location_hierarchy = LOCATION_HIERARCHY
#location_list = [s.strip() for s in location_hierarchy.split('>')]
location_list = location_hierarchy # Already a list, no need to split
# 3. Set primary location
set_payload = {
"house_id": house_id,
"object_id": object_id,
"location_hierarchy": location_list
}
set_response = requests.post(f"{SERVER_URL}/object/set_primary_location", json=set_payload)
if set_response.status_code == 200:
print(f"Primary location set for object {object_id}: {location_list}")
else:
print(f"Failed to set primary location: {set_response.text}")
# 1. Encode image
image_base64 = encode_image_to_base64(IMAGE_PATH)
# 2. Prepare mask request with dummy points and prompt
mask_payload = {
"image_base64": image_base64,
"points": [{"x": 1000, "y": 1000}], # example click
"labels": [1],
"prompt": "plate", # example prompt
"return_raw_mask": True,
"return_rgb_mask": True,
"return_embeddings": True
}
# 3. Call Mask API
response = requests.post(f"{SERVER_URL}/object/get_mask", json=mask_payload)
response.raise_for_status()
mask_response = response.json()
# 4. Display returned masks
if mask_response.get("raw_mask_base64"):
mask_img = decode_base64_to_image(mask_response["raw_mask_base64"])
display_image("Raw Mask", mask_img)
if mask_response.get("rgb_mask_base64"):
rgb_mask_img = decode_base64_to_image(mask_response["rgb_mask_base64"])
display_image("RGB Mask", rgb_mask_img)
# 5. Query vector DB using embedding
embedding = mask_response.get("embedding")
if not embedding:
print("No embedding returned. Skipping vector DB query.")
sys.exit(0) # 0 for normal, or 1 for error
query_payload = {
"embedding": embedding,
"k": 5
}
query_response = requests.post(f"{SERVER_URL}/object/query_by_embedding", json=query_payload)
query_response.raise_for_status()
query_results = query_response.json()
# 6. Display results
print("Top Matches:")
for result in query_results:
print(f"- Object ID: {result['object_id']}")
print(f" Aggregated Similarity: {result['aggregated_similarity']:.4f}")
print(f" Probability: {result['probability']:.4f}")
print(f" Descriptions: {result['descriptions']}\n")
# 7. Lookup object info for top match
top_result = query_results[0] if query_results else None
if not top_result:
sys.exit(1)
check_and_set_primary_location(HOUSE_ID, top_result["object_id"])