File size: 5,380 Bytes
24f3fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import base64
import requests
import json
from PIL import Image
import numpy as np
import io
import sys
import matplotlib.pyplot as plt

# Config
SERVER_URL = "http://localhost:8000"
IMAGE_PATH = "test_images/plate1.png"
HOUSE_ID='c8c5fdea-7138-44ea-9f02-7fdcd47ff8cf'


LOCATION_HIERARCHY = ["Kitchen", "Left Upper Cabinet", "Middle Shelf"]
LOCATION_IMAGES = ["test_images/kitchen.png", "test_images/cabinets.png", "test_images/shelf.png"]

def encode_image_to_base64(image_path):
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode('utf-8')

def decode_base64_to_image(base64_str):
    return Image.open(io.BytesIO(base64.b64decode(base64_str)))

def display_image(title, img_pil):
    plt.figure()
    plt.title(title)
    plt.imshow(img_pil)
    plt.axis('off')
    plt.show()

def decode_base64_to_np(base64_str: str) -> np.ndarray:
    image_bytes = base64.b64decode(base64_str)
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    return np.array(image)

def check_and_set_primary_location(house_id, object_id):
    # 1. Query current location
    loc_payload = {"house_id": house_id, "object_id": object_id, "include_images": True}
    loc_response = requests.post(f"{SERVER_URL}/object/get_primary_location", json=loc_payload)
    
    if loc_response.status_code == 200 and loc_response.json().get("locations"):
        locations = loc_response.json()["locations"]
        print(f"Object {object_id} already has a primary location:")
        for loc in locations:
            print(f"- {loc['name']}")

# Optional: Display stitched image preview
        images = []
        for loc in locations:
            image_data = loc.get("image_base64")
            if image_data:
                # Decode base64 or load image depending on your API’s format
                # Assuming numpy array is directly returned or reloaded here
                # You might need to adjust if using base64 strings
                #img_np = np.array(image_data)  # Adjust if data is base64 or bytes
                img_np = decode_base64_to_np(image_data)
                images.append(Image.fromarray(img_np))
        
        if images:
            # Stitch images horizontally
            widths, heights = zip(*(img.size for img in images))
            total_width = sum(widths)
            max_height = max(heights)

            stitched = Image.new('RGB', (total_width, max_height))
            x_offset = 0
            for img in images:
                stitched.paste(img, (x_offset, 0))
                x_offset += img.width

            # Show stitched image using matplotlib
            plt.imshow(stitched)
            plt.axis('off')
            plt.show()
        else:
            print("No images available to display.")
        return
    
    # 2. Prompt user for location hierarchy
    print(f"No location found for object {object_id}.")
    location_hierarchy = LOCATION_HIERARCHY
    #location_list = [s.strip() for s in location_hierarchy.split('>')]
    location_list = location_hierarchy  # Already a list, no need to split


    # 3. Set primary location
    set_payload = {
        "house_id": house_id,
        "object_id": object_id,
        "location_hierarchy": location_list
    }
    set_response = requests.post(f"{SERVER_URL}/object/set_primary_location", json=set_payload)
    if set_response.status_code == 200:
        print(f"Primary location set for object {object_id}: {location_list}")
    else:
        print(f"Failed to set primary location: {set_response.text}")


# 1. Encode image
image_base64 = encode_image_to_base64(IMAGE_PATH)

# 2. Prepare mask request with dummy points and prompt
mask_payload = {
    "image_base64": image_base64,
    "points": [{"x": 1000, "y": 1000}],  # example click
    "labels": [1],
    "prompt": "plate",  # example prompt
    "return_raw_mask": True,
    "return_rgb_mask": True,
    "return_embeddings": True
}

# 3. Call Mask API
response = requests.post(f"{SERVER_URL}/object/get_mask", json=mask_payload)
response.raise_for_status()
mask_response = response.json()

# 4. Display returned masks
if mask_response.get("raw_mask_base64"):
    mask_img = decode_base64_to_image(mask_response["raw_mask_base64"])
    display_image("Raw Mask", mask_img)

if mask_response.get("rgb_mask_base64"):
    rgb_mask_img = decode_base64_to_image(mask_response["rgb_mask_base64"])
    display_image("RGB Mask", rgb_mask_img)

# 5. Query vector DB using embedding
embedding = mask_response.get("embedding")
if not embedding:
    print("No embedding returned. Skipping vector DB query.")
    sys.exit(0)  # 0 for normal, or 1 for error

query_payload = {
    "embedding": embedding,
    "k": 5
}

query_response = requests.post(f"{SERVER_URL}/object/query_by_embedding", json=query_payload)
query_response.raise_for_status()
query_results = query_response.json()

# 6. Display results
print("Top Matches:")
for result in query_results:
    print(f"- Object ID: {result['object_id']}")
    print(f"  Aggregated Similarity: {result['aggregated_similarity']:.4f}")
    print(f"  Probability: {result['probability']:.4f}")
    print(f"  Descriptions: {result['descriptions']}\n")

# 7. Lookup object info for top match
top_result = query_results[0] if query_results else None
if not top_result:
    sys.exit(1)

check_and_set_primary_location(HOUSE_ID, top_result["object_id"])