TeeA commited on
Commit
0d65f8a
·
1 Parent(s): eca560c

apply agent succesfully

Browse files
Files changed (1) hide show
  1. app.py +39 -14
app.py CHANGED
@@ -220,7 +220,7 @@ async def query_3D_object(
220
  current_obj_path: str,
221
  embedding_dict: dict,
222
  top_k: int = 4,
223
- method: Query3DObjectMethod = Query3DObjectMethod.SEMANTIC_SEARCH,
224
  ) -> List:
225
  if query == "":
226
  raise gr.Error("Query cannot be empty!")
@@ -270,28 +270,31 @@ async def query_3D_object_by_hybrid_search_method(
270
  @function_tool
271
  def query_3D_object_by_keyword_search(query: str, match_code: str, top_k: int = 4):
272
  logger.info("Datum Agent is running query_3D_object_by_keyword_search")
273
- logger.info(f"The 'match' function is:\n```{match_code}```")
274
-
275
- def match(metadata: dict) -> bool:
276
- """
277
- This function should be generated by the match_code provided.
278
- It will check if the metadata matches the query.
279
- """
280
- return True
281
 
 
 
282
  try:
283
- exec(match_code, globals())
 
 
 
284
  assert (
285
  "def match(metadata: dict) -> bool:" in match_code
286
  ), "The match function is not defined correctly."
287
  except Exception:
288
- raise gr.Error("Your query did not generate a valid match function.")
 
 
289
  matched_obj_paths = list(
290
  filter(
291
  lambda obj_path: match(embedding_dict[obj_path]["metadata_dictionary"]),
292
  embedding_dict,
293
  )
294
  )
 
 
 
295
 
296
  top_files = [x for x in matched_obj_paths[:top_k]]
297
  return top_files + [os.path.basename(x) for x in top_files]
@@ -393,14 +396,24 @@ Combine the `match` function with `query_3D_object_by_keyword_search` to filter
393
  @function_tool
394
  def query_3D_object_by_semantic_search(query: str, top_k: int = 4):
395
  logger.info("Datum Agent is running query_3D_object_by_semantic_search")
396
- return query_3D_object_by_semantic_search_method(
397
  query, current_obj_path, embedding_dict, top_k
398
  )
 
 
 
 
399
 
400
  @function_tool
401
- def search_3D_similarity_factory(selected_filepath: str, top_k: int = 4):
 
 
402
  logger.info("Datum Agent is running search_3D_similarity_factory")
403
- return search_3D_similarity(selected_filepath, embedding_dict, top_k)
 
 
 
 
404
 
405
  DATUM_AGENT_INSTRUCTIONS = """You are the Datum Agent: you retrieve the top-K most relevant 3D objects using three strategies.
406
  * Use `query_3D_object_by_semantic_search` for abstract or descriptive queries.
@@ -751,6 +764,13 @@ async def accumulate_and_embedding(input_files, file_list, embedding_dict):
751
  all_files = input_files
752
  new_files = input_files[len(file_list) :]
753
 
 
 
 
 
 
 
 
754
  # embedding
755
  for file_path in new_files:
756
  logger.info("Processing new upload file:", file_path)
@@ -786,6 +806,11 @@ async def accumulate_and_embedding(input_files, file_list, embedding_dict):
786
  embedding_dict[obj_path]["image_embedding"] = embeddings["image_embedding"]
787
  embedding_dict[obj_path]["text_embedding"] = text_embedding
788
 
 
 
 
 
 
789
  return all_files, gr.update(choices=all_files), embedding_dict
790
 
791
 
 
220
  current_obj_path: str,
221
  embedding_dict: dict,
222
  top_k: int = 4,
223
+ method: Query3DObjectMethod = Query3DObjectMethod.HYBRID_SEARCH,
224
  ) -> List:
225
  if query == "":
226
  raise gr.Error("Query cannot be empty!")
 
270
  @function_tool
271
  def query_3D_object_by_keyword_search(query: str, match_code: str, top_k: int = 4):
272
  logger.info("Datum Agent is running query_3D_object_by_keyword_search")
273
+ logger.info(f"The 'match' function is:\n```\n{match_code}\n```")
 
 
 
 
 
 
 
274
 
275
+ # !!!IMPORTANT, create a new individual execution context for the match function
276
+ exec_globals = {}
277
  try:
278
+ exec(match_code, exec_globals)
279
+ match = exec_globals[
280
+ "match"
281
+ ] # get the match function from the execution context
282
  assert (
283
  "def match(metadata: dict) -> bool:" in match_code
284
  ), "The match function is not defined correctly."
285
  except Exception:
286
+ raise gr.Error(
287
+ "Your query did not generate a valid match function. Try your query again."
288
+ )
289
  matched_obj_paths = list(
290
  filter(
291
  lambda obj_path: match(embedding_dict[obj_path]["metadata_dictionary"]),
292
  embedding_dict,
293
  )
294
  )
295
+ logger.info(
296
+ f"Found {len(matched_obj_paths)} matching objects for the query `{query}`:\n```{matched_obj_paths}```"
297
+ )
298
 
299
  top_files = [x for x in matched_obj_paths[:top_k]]
300
  return top_files + [os.path.basename(x) for x in top_files]
 
396
  @function_tool
397
  def query_3D_object_by_semantic_search(query: str, top_k: int = 4):
398
  logger.info("Datum Agent is running query_3D_object_by_semantic_search")
399
+ response = query_3D_object_by_semantic_search_method(
400
  query, current_obj_path, embedding_dict, top_k
401
  )
402
+ logger.info(
403
+ f"Found {len(response) // 2} matching objects for the query `{query}`:\n```{response[: len(response) // 2]}```"
404
+ )
405
+ return response
406
 
407
  @function_tool
408
+ def search_3D_similarity_factory(
409
+ query: str, selected_filepath: str, top_k: int = 4
410
+ ):
411
  logger.info("Datum Agent is running search_3D_similarity_factory")
412
+ response = search_3D_similarity(selected_filepath, embedding_dict, top_k)
413
+ logger.info(
414
+ f"Found {len(response) // 2} similar objects for the query `{query}`:\n```{response[: len(response) // 2]}```"
415
+ )
416
+ return response
417
 
418
  DATUM_AGENT_INSTRUCTIONS = """You are the Datum Agent: you retrieve the top-K most relevant 3D objects using three strategies.
419
  * Use `query_3D_object_by_semantic_search` for abstract or descriptive queries.
 
764
  all_files = input_files
765
  new_files = input_files[len(file_list) :]
766
 
767
+ # # forwarding
768
+ # if os.environ.get("ENVIRONMENT") == "local" and os.path.exists("embedding_dict.pt"):
769
+ # embedding_dict = torch.load(
770
+ # "embedding_dict.pt", map_location=torch.device("cpu")
771
+ # ) # load from local file
772
+ # return all_files, gr.update(choices=all_files), embedding_dict
773
+
774
  # embedding
775
  for file_path in new_files:
776
  logger.info("Processing new upload file:", file_path)
 
806
  embedding_dict[obj_path]["image_embedding"] = embeddings["image_embedding"]
807
  embedding_dict[obj_path]["text_embedding"] = text_embedding
808
 
809
+ # if os.environ.get("ENVIRONMENT") == "local":
810
+ # # save to local file
811
+ # torch.save(embedding_dict, "embedding_dict.pt")
812
+ # logger.info("Saved embedding_dict to local file.")
813
+
814
  return all_files, gr.update(choices=all_files), embedding_dict
815
 
816