apply agent succesfully
Browse files
app.py
CHANGED
|
@@ -220,7 +220,7 @@ async def query_3D_object(
|
|
| 220 |
current_obj_path: str,
|
| 221 |
embedding_dict: dict,
|
| 222 |
top_k: int = 4,
|
| 223 |
-
method: Query3DObjectMethod = Query3DObjectMethod.
|
| 224 |
) -> List:
|
| 225 |
if query == "":
|
| 226 |
raise gr.Error("Query cannot be empty!")
|
|
@@ -270,28 +270,31 @@ async def query_3D_object_by_hybrid_search_method(
|
|
| 270 |
@function_tool
|
| 271 |
def query_3D_object_by_keyword_search(query: str, match_code: str, top_k: int = 4):
|
| 272 |
logger.info("Datum Agent is running query_3D_object_by_keyword_search")
|
| 273 |
-
logger.info(f"The 'match' function is:\n```{match_code}```")
|
| 274 |
-
|
| 275 |
-
def match(metadata: dict) -> bool:
|
| 276 |
-
"""
|
| 277 |
-
This function should be generated by the match_code provided.
|
| 278 |
-
It will check if the metadata matches the query.
|
| 279 |
-
"""
|
| 280 |
-
return True
|
| 281 |
|
|
|
|
|
|
|
| 282 |
try:
|
| 283 |
-
exec(match_code,
|
|
|
|
|
|
|
|
|
|
| 284 |
assert (
|
| 285 |
"def match(metadata: dict) -> bool:" in match_code
|
| 286 |
), "The match function is not defined correctly."
|
| 287 |
except Exception:
|
| 288 |
-
raise gr.Error(
|
|
|
|
|
|
|
| 289 |
matched_obj_paths = list(
|
| 290 |
filter(
|
| 291 |
lambda obj_path: match(embedding_dict[obj_path]["metadata_dictionary"]),
|
| 292 |
embedding_dict,
|
| 293 |
)
|
| 294 |
)
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
top_files = [x for x in matched_obj_paths[:top_k]]
|
| 297 |
return top_files + [os.path.basename(x) for x in top_files]
|
|
@@ -393,14 +396,24 @@ Combine the `match` function with `query_3D_object_by_keyword_search` to filter
|
|
| 393 |
@function_tool
|
| 394 |
def query_3D_object_by_semantic_search(query: str, top_k: int = 4):
|
| 395 |
logger.info("Datum Agent is running query_3D_object_by_semantic_search")
|
| 396 |
-
|
| 397 |
query, current_obj_path, embedding_dict, top_k
|
| 398 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
@function_tool
|
| 401 |
-
def search_3D_similarity_factory(
|
|
|
|
|
|
|
| 402 |
logger.info("Datum Agent is running search_3D_similarity_factory")
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
DATUM_AGENT_INSTRUCTIONS = """You are the Datum Agent: you retrieve the top-K most relevant 3D objects using three strategies.
|
| 406 |
* Use `query_3D_object_by_semantic_search` for abstract or descriptive queries.
|
|
@@ -751,6 +764,13 @@ async def accumulate_and_embedding(input_files, file_list, embedding_dict):
|
|
| 751 |
all_files = input_files
|
| 752 |
new_files = input_files[len(file_list) :]
|
| 753 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
# embedding
|
| 755 |
for file_path in new_files:
|
| 756 |
logger.info("Processing new upload file:", file_path)
|
|
@@ -786,6 +806,11 @@ async def accumulate_and_embedding(input_files, file_list, embedding_dict):
|
|
| 786 |
embedding_dict[obj_path]["image_embedding"] = embeddings["image_embedding"]
|
| 787 |
embedding_dict[obj_path]["text_embedding"] = text_embedding
|
| 788 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
return all_files, gr.update(choices=all_files), embedding_dict
|
| 790 |
|
| 791 |
|
|
|
|
| 220 |
current_obj_path: str,
|
| 221 |
embedding_dict: dict,
|
| 222 |
top_k: int = 4,
|
| 223 |
+
method: Query3DObjectMethod = Query3DObjectMethod.HYBRID_SEARCH,
|
| 224 |
) -> List:
|
| 225 |
if query == "":
|
| 226 |
raise gr.Error("Query cannot be empty!")
|
|
|
|
| 270 |
@function_tool
|
| 271 |
def query_3D_object_by_keyword_search(query: str, match_code: str, top_k: int = 4):
|
| 272 |
logger.info("Datum Agent is running query_3D_object_by_keyword_search")
|
| 273 |
+
logger.info(f"The 'match' function is:\n```\n{match_code}\n```")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
+
# !!!IMPORTANT, create a new individual execution context for the match function
|
| 276 |
+
exec_globals = {}
|
| 277 |
try:
|
| 278 |
+
exec(match_code, exec_globals)
|
| 279 |
+
match = exec_globals[
|
| 280 |
+
"match"
|
| 281 |
+
] # get the match function from the execution context
|
| 282 |
assert (
|
| 283 |
"def match(metadata: dict) -> bool:" in match_code
|
| 284 |
), "The match function is not defined correctly."
|
| 285 |
except Exception:
|
| 286 |
+
raise gr.Error(
|
| 287 |
+
"Your query did not generate a valid match function. Try your query again."
|
| 288 |
+
)
|
| 289 |
matched_obj_paths = list(
|
| 290 |
filter(
|
| 291 |
lambda obj_path: match(embedding_dict[obj_path]["metadata_dictionary"]),
|
| 292 |
embedding_dict,
|
| 293 |
)
|
| 294 |
)
|
| 295 |
+
logger.info(
|
| 296 |
+
f"Found {len(matched_obj_paths)} matching objects for the query `{query}`:\n```{matched_obj_paths}```"
|
| 297 |
+
)
|
| 298 |
|
| 299 |
top_files = [x for x in matched_obj_paths[:top_k]]
|
| 300 |
return top_files + [os.path.basename(x) for x in top_files]
|
|
|
|
| 396 |
@function_tool
|
| 397 |
def query_3D_object_by_semantic_search(query: str, top_k: int = 4):
|
| 398 |
logger.info("Datum Agent is running query_3D_object_by_semantic_search")
|
| 399 |
+
response = query_3D_object_by_semantic_search_method(
|
| 400 |
query, current_obj_path, embedding_dict, top_k
|
| 401 |
)
|
| 402 |
+
logger.info(
|
| 403 |
+
f"Found {len(response) // 2} matching objects for the query `{query}`:\n```{response[: len(response) // 2]}```"
|
| 404 |
+
)
|
| 405 |
+
return response
|
| 406 |
|
| 407 |
@function_tool
|
| 408 |
+
def search_3D_similarity_factory(
|
| 409 |
+
query: str, selected_filepath: str, top_k: int = 4
|
| 410 |
+
):
|
| 411 |
logger.info("Datum Agent is running search_3D_similarity_factory")
|
| 412 |
+
response = search_3D_similarity(selected_filepath, embedding_dict, top_k)
|
| 413 |
+
logger.info(
|
| 414 |
+
f"Found {len(response) // 2} similar objects for the query `{query}`:\n```{response[: len(response) // 2]}```"
|
| 415 |
+
)
|
| 416 |
+
return response
|
| 417 |
|
| 418 |
DATUM_AGENT_INSTRUCTIONS = """You are the Datum Agent: you retrieve the top-K most relevant 3D objects using three strategies.
|
| 419 |
* Use `query_3D_object_by_semantic_search` for abstract or descriptive queries.
|
|
|
|
| 764 |
all_files = input_files
|
| 765 |
new_files = input_files[len(file_list) :]
|
| 766 |
|
| 767 |
+
# # forwarding
|
| 768 |
+
# if os.environ.get("ENVIRONMENT") == "local" and os.path.exists("embedding_dict.pt"):
|
| 769 |
+
# embedding_dict = torch.load(
|
| 770 |
+
# "embedding_dict.pt", map_location=torch.device("cpu")
|
| 771 |
+
# ) # load from local file
|
| 772 |
+
# return all_files, gr.update(choices=all_files), embedding_dict
|
| 773 |
+
|
| 774 |
# embedding
|
| 775 |
for file_path in new_files:
|
| 776 |
logger.info("Processing new upload file:", file_path)
|
|
|
|
| 806 |
embedding_dict[obj_path]["image_embedding"] = embeddings["image_embedding"]
|
| 807 |
embedding_dict[obj_path]["text_embedding"] = text_embedding
|
| 808 |
|
| 809 |
+
# if os.environ.get("ENVIRONMENT") == "local":
|
| 810 |
+
# # save to local file
|
| 811 |
+
# torch.save(embedding_dict, "embedding_dict.pt")
|
| 812 |
+
# logger.info("Saved embedding_dict to local file.")
|
| 813 |
+
|
| 814 |
return all_files, gr.update(choices=all_files), embedding_dict
|
| 815 |
|
| 816 |
|