TeeA commited on
Commit
eca560c
·
1 Parent(s): 182cad3

integrate ai agents

Browse files
Files changed (2) hide show
  1. app.py +252 -20
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import platform
3
  import re
4
  import subprocess # used to connect to FreeCAD via terminal sub process
@@ -12,6 +13,8 @@ import numpy as np
12
  import torch
13
  import torchvision.transforms.functional as TF
14
  import trimesh
 
 
15
  from llama_index.embeddings.clip import ClipEmbedding
16
  from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode
17
  from loguru import logger
@@ -207,39 +210,260 @@ def search_3D_similarity(filepath: str, embedding_dict: dict, top_k: int = 4):
207
  ####################################################################################################################
208
  # Text-based Query
209
  ####################################################################################################################
210
-
211
-
212
- def query_3D_object(query: str, embedding_dict: dict, top_k: int = 4):
 
 
 
 
 
 
 
 
 
213
  if query == "":
214
  raise gr.Error("Query cannot be empty!")
215
  if len(embedding_dict) < 4:
216
  raise gr.Error("Require at least 4 3D files to query by features")
 
 
 
 
 
 
 
 
 
 
217
 
 
 
 
 
218
  features1 = np.array(text_embedding_model.get_text_embedding(text=query)).reshape(
219
  1, -1
220
  )
221
 
222
- # List to store (path, similarity)
223
  valid_items = [
224
  (fp, data["text_embedding"])
225
  for fp, data in embedding_dict.items()
226
  if "text_embedding" in data
227
  ]
228
  filepaths = [fp for fp, _ in valid_items]
229
- feature_matrix = np.array([feat for _, feat in valid_items]) # shape: (N, D)
230
- similarities = cosine_similarity(features1, feature_matrix)[0] # shape: (N,)
231
  scores = list(zip(filepaths, similarities))
232
-
233
- # Sort by similarity in descending order
234
  scores.sort(key=lambda x: x[1], reverse=True)
235
 
236
- if len(scores) < 4:
237
- scores.append(("", 0.0))
238
 
239
- # Return top_k results
240
- return [x[0] for x in scores[:top_k]] + [
241
- os.path.basename(x[0]) for x in scores[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
 
245
  ####################################################################################################################
@@ -489,10 +713,14 @@ async def embedding_3d_object(obj_path: str) -> Dict[str, Any]:
489
 
490
  BASE_SAMPLE_DIR = "/Users/tridoan/Spartan/Datum/service-ai/poc/3D/gradio_cache/"
491
  sample_files = [
492
- # BASE_SAMPLE_DIR + "C5 Knuckle Object.obj",
493
- # BASE_SAMPLE_DIR + "NEMA 17 Stepper Motor 23mm-NEMA 17 Stepper Motor 23mm.obj",
494
- # BASE_SAMPLE_DIR + "TS6-THT_H-5.0.obj",
495
- # BASE_SAMPLE_DIR + "TS6-THT_H-11.0.obj"
 
 
 
 
496
  ]
497
 
498
 
@@ -546,10 +774,13 @@ async def accumulate_and_embedding(input_files, file_list, embedding_dict):
546
  + f".\n {'n' * 20}\nMetadata: "
547
  + metadata
548
  )
 
 
 
549
  # store embeddings and metadata
550
  embedding_dict[obj_path]["metadata"] = metadata
551
  embedding_dict[obj_path]["metadata_dictionary"] = normalize_metadata(
552
- metadata_aggregation.update(metadata_extraction) # type: ignore
553
  )
554
  embedding_dict[obj_path]["description"] = embeddings["description"]
555
  embedding_dict[obj_path]["image_embedding"] = embeddings["image_embedding"]
@@ -658,7 +889,7 @@ with gr.Blocks() as demo:
658
  # query button
659
  query_button.click(
660
  query_3D_object,
661
- [query_input, embedding_store],
662
  [
663
  model_q_1,
664
  model_q_2,
@@ -723,4 +954,5 @@ with gr.Blocks() as demo:
723
  )
724
 
725
  if __name__ == "__main__":
726
- demo.launch(share=True, debug=True)
 
 
1
  import os
2
+ from enum import Enum
3
  import platform
4
  import re
5
  import subprocess # used to connect to FreeCAD via terminal sub process
 
13
  import torch
14
  import torchvision.transforms.functional as TF
15
  import trimesh
16
+ import ast
17
+ from agents import Agent, Runner, function_tool
18
  from llama_index.embeddings.clip import ClipEmbedding
19
  from llama_index.embeddings.openai import OpenAIEmbedding, OpenAIEmbeddingMode
20
  from loguru import logger
 
210
  ####################################################################################################################
211
  # Text-based Query
212
  ####################################################################################################################
213
+ class Query3DObjectMethod(Enum):
214
+ HYBRID_SEARCH = "hybrid_search" # using multiple agents to query 3D objects
215
+ SEMANTIC_SEARCH = "semantic_search"
216
+
217
+
218
+ async def query_3D_object(
219
+ query: str,
220
+ current_obj_path: str,
221
+ embedding_dict: dict,
222
+ top_k: int = 4,
223
+ method: Query3DObjectMethod = Query3DObjectMethod.SEMANTIC_SEARCH,
224
+ ) -> List:
225
  if query == "":
226
  raise gr.Error("Query cannot be empty!")
227
  if len(embedding_dict) < 4:
228
  raise gr.Error("Require at least 4 3D files to query by features")
229
+ if method == Query3DObjectMethod.HYBRID_SEARCH:
230
+ logger.info("Running query_3D_object_by_hybrid_search_method")
231
+ return await query_3D_object_by_hybrid_search_method(
232
+ query, current_obj_path, embedding_dict, top_k
233
+ )
234
+ elif method == Query3DObjectMethod.SEMANTIC_SEARCH:
235
+ logger.info("Running query_3D_object_by_semantic_search_method")
236
+ return query_3D_object_by_semantic_search_method(
237
+ query, current_obj_path, embedding_dict, top_k
238
+ )
239
 
240
+
241
+ def query_3D_object_by_semantic_search_method(
242
+ query: str, current_obj_path: str, embedding_dict: dict, top_k: int = 4
243
+ ) -> List:
244
  features1 = np.array(text_embedding_model.get_text_embedding(text=query)).reshape(
245
  1, -1
246
  )
247
 
 
248
  valid_items = [
249
  (fp, data["text_embedding"])
250
  for fp, data in embedding_dict.items()
251
  if "text_embedding" in data
252
  ]
253
  filepaths = [fp for fp, _ in valid_items]
254
+ feature_matrix = np.array([feat for _, feat in valid_items])
255
+ similarities = cosine_similarity(features1, feature_matrix)[0]
256
  scores = list(zip(filepaths, similarities))
 
 
257
  scores.sort(key=lambda x: x[1], reverse=True)
258
 
259
+ if len(scores) < top_k:
260
+ scores.extend([("", 0.0)] * (top_k - len(scores)))
261
 
262
+ top_files = [x[0] for x in scores[:top_k]]
263
+ return top_files + [os.path.basename(x) for x in top_files]
264
+
265
+
266
+ async def query_3D_object_by_hybrid_search_method(
267
+ query: str, current_obj_path: str, embedding_dict: dict, top_k: int = 4
268
+ ) -> List:
269
+ # Keyword Search Agent
270
+ @function_tool
271
+ def query_3D_object_by_keyword_search(query: str, match_code: str, top_k: int = 4):
272
+ logger.info("Datum Agent is running query_3D_object_by_keyword_search")
273
+ logger.info(f"The 'match' function is:\n```{match_code}```")
274
+
275
+ def match(metadata: dict) -> bool:
276
+ """
277
+ This function should be generated by the match_code provided.
278
+ It will check if the metadata matches the query.
279
+ """
280
+ return True
281
+
282
+ try:
283
+ exec(match_code, globals())
284
+ assert (
285
+ "def match(metadata: dict) -> bool:" in match_code
286
+ ), "The match function is not defined correctly."
287
+ except Exception:
288
+ raise gr.Error("Your query did not generate a valid match function.")
289
+ matched_obj_paths = list(
290
+ filter(
291
+ lambda obj_path: match(embedding_dict[obj_path]["metadata_dictionary"]),
292
+ embedding_dict,
293
+ )
294
+ )
295
+
296
+ top_files = [x for x in matched_obj_paths[:top_k]]
297
+ return top_files + [os.path.basename(x) for x in top_files]
298
+
299
+ METADATA_SCHEMA = """Schema of metadata_dictionary:
300
+ - Volume: float
301
+ - Surface_Area: float
302
+ - Width: float
303
+ - Height: float
304
+ - Depth: float
305
+ - Description: str
306
+ - Description_Level: int
307
+ - FileName: str
308
+ - Created: str
309
+ - Authors: str
310
+ - Organizations: str
311
+ - Preprocessor: str
312
+ - OriginatingSystem: str
313
+ - Authorization: str
314
+ - Schema: str
315
+ """
316
+
317
+ QUERY_EXAMPLES = """Examples of natural language queries and their intended matching logic:
318
+
319
+ ### Example 1: "width greater than 7"
320
+ ```python
321
+ def match(metadata: dict) -> bool:
322
+ try:
323
+ return float(metadata.get("Width", 0)) > 7
324
+ except:
325
+ return False
326
+ ````
327
+
328
+ ### Example 2: "description contains STEP"
329
+
330
+ ```python
331
+ def match(metadata: dict) -> bool:
332
+ return "step" in str(metadata.get("Description", "")).lower()
333
+ ```
334
+
335
+ ### Example 3: "originating system is ASCON Math Kernel"
336
+
337
+ ```python
338
+ def match(metadata: dict) -> bool:
339
+ return str(metadata.get("OriginatingSystem", "")).lower() == "ascon math kernel"
340
+ ```
341
+
342
+ ### Example 4: "volume < 200 and surface area > 300"
343
+
344
+ ```python
345
+ def match(metadata: dict) -> bool:
346
+ try:
347
+ return float(metadata.get("Volume", 0)) < 200 and float(metadata.get("Surface_Area", 0)) > 300
348
+ except:
349
+ return False
350
+ ```
351
+
352
+ ### Example 5: "schema contains 214"
353
+
354
+ ```python
355
+ def match(metadata: dict) -> bool:
356
+ return "214" in str(metadata.get("Schema", ""))
357
+ ```
358
+ """
359
+
360
+ MATCH_GEN_INSTRUCTION = """You are a Python code generator. Your job is to translate a natural language query into a function named `match(metadata: dict) -> bool`.
361
+
362
+ Requirements:
363
+ - Only use keys present in the schema.
364
+ - Match strings case-insensitively.
365
+ - For numerical comparisons, cast to float.
366
+ - Combine conditions using logical `and`, `or` as inferred from natural language.
367
+ - Handle missing keys by returning False.
368
+ Return only the function code, nothing else.
369
+ """
370
+
371
+ @function_tool
372
+ def get_prompt_to_generate_match_code(query: str) -> str:
373
+ """
374
+ Generate a prompt to create a match function based on the user's query.
375
+ """
376
+ return (
377
+ METADATA_SCHEMA
378
+ + QUERY_EXAMPLES
379
+ + MATCH_GEN_INSTRUCTION
380
+ + f"\nQuery: {query}\n"
381
+ )
382
+
383
+ KEYWORD_SEARCH_AGENT_INSTRUCTIONS = """You are a Keyword Search Agent specialized in metadata-based filtering.
384
+ Given a natural language query from the user, you will automatically generate an executable `match` function based on the prompt provided by `get_prompt_to_generate_match_code`.
385
+ Combine the `match` function with `query_3D_object_by_keyword_search` to filter the top-K matching 3D object paths."""
386
+
387
+ keyword_search_agent = Agent(
388
+ name="Keyword Search Agent",
389
+ instructions=KEYWORD_SEARCH_AGENT_INSTRUCTIONS,
390
+ tools=[get_prompt_to_generate_match_code, query_3D_object_by_keyword_search],
391
+ )
392
+
393
+ @function_tool
394
+ def query_3D_object_by_semantic_search(query: str, top_k: int = 4):
395
+ logger.info("Datum Agent is running query_3D_object_by_semantic_search")
396
+ return query_3D_object_by_semantic_search_method(
397
+ query, current_obj_path, embedding_dict, top_k
398
+ )
399
+
400
+ @function_tool
401
+ def search_3D_similarity_factory(selected_filepath: str, top_k: int = 4):
402
+ logger.info("Datum Agent is running search_3D_similarity_factory")
403
+ return search_3D_similarity(selected_filepath, embedding_dict, top_k)
404
+
405
+ DATUM_AGENT_INSTRUCTIONS = """You are the Datum Agent: you retrieve the top-K most relevant 3D objects using three strategies.
406
+ * Use `query_3D_object_by_semantic_search` for abstract or descriptive queries.
407
+ * Use `search_3D_similarity_factory` when the query mentions the object currently displayed on the screen and aims to find similar objects.
408
+ * Use **Keyword Search Agent** for precise metadata constraints or comparative/filtering information in the query.
409
+ Return only the final tuple of file paths and display names.
410
+ """
411
+
412
+ HANDOFF_DESCRIPTION = """Handing off to Datum Agent: you can perform semantic search, keyword-based filtering, or visual similarity search.
413
+ If metadata filtering is required, delegate to the **Keyword Search Agent** by calling `get_prompt_to_generate_match_code`.
414
+ """
415
+
416
+ datum_agent = Agent(
417
+ name="Datum Agent",
418
+ handoff_description=HANDOFF_DESCRIPTION,
419
+ instructions=DATUM_AGENT_INSTRUCTIONS,
420
+ tools=[
421
+ query_3D_object_by_semantic_search,
422
+ search_3D_similarity_factory,
423
+ ],
424
+ handoffs=[keyword_search_agent],
425
+ )
426
+
427
+ # Prepare the prompt for the Datum Agent
428
+ prompt_input = f"""An user is watching a 3D object and wants to query it.
429
+ The query is: `{query}`.
430
+ The current 3D object is `{current_obj_path}`.
431
+ You need to find the most relevant 3D objects based on the query and return the top-k results.
432
+ """
433
+ ######################################################################
434
+ # Run the agent to get the results
435
+ ######################################################################
436
+ # result = Runner.run_streamed(starting_agent=datum_agent, input=prompt_input)
437
+ # in_memory_response = []
438
+ # async for event in result.stream_events():
439
+ # if event.type == "run_item_stream_event":
440
+ # item = event.item
441
+ # if item.type == "tool_call_output_item":
442
+ # in_memory_response += [item.output]
443
+ # logger.info(f"Datum Agent response: {in_memory_response}")
444
+
445
+ response = await Runner.run(datum_agent, prompt_input) # agent's final output
446
+
447
+ # Filter the lastest output with `function_call_output` type
448
+ function_call_output_list = [
449
+ item
450
+ for item in response.to_input_list()
451
+ if item.get("type") == "function_call_output"
452
  ]
453
+ files_result = function_call_output_list[-1]
454
+ logger.info(f"Datum Agent raw response: {files_result}")
455
+ try:
456
+ result = ast.literal_eval(files_result.get("output", "[]")) # type:ignore
457
+ except Exception as e:
458
+ logger.error(
459
+ f"Datum Agent did not return a valid list of file paths due to {e}"
460
+ )
461
+ raise gr.Error("Datum Agent did not return a valid list of file paths.")
462
+ if not isinstance(result, list):
463
+ raise gr.Error("Datum Agent did not return a valid list of file paths.")
464
+ if len(result) < 8:
465
+ raise gr.Error("Datum Agent did not return enough results. Please try again.")
466
+ return result
467
 
468
 
469
  ####################################################################################################################
 
713
 
714
  BASE_SAMPLE_DIR = "/Users/tridoan/Spartan/Datum/service-ai/poc/3D/gradio_cache/"
715
  sample_files = [
716
+ # BASE_SAMPLE_DIR + "C5 Knuckle Object.STEP",
717
+ # BASE_SAMPLE_DIR + "NEMA 17 Stepper Motor 23mm-NEMA 17 Stepper Motor 23mm.obj",
718
+ # BASE_SAMPLE_DIR + "TS6-THT_H-4.3.STEP",
719
+ # BASE_SAMPLE_DIR + "TS6-THT_H-5.0.STEP",
720
+ # BASE_SAMPLE_DIR + "TS6-THT_H-7.0.STEP",
721
+ # BASE_SAMPLE_DIR + "TS6-THT_H-7.3.STEP",
722
+ # BASE_SAMPLE_DIR + "TS6-THT_H-7.5.STEP",
723
+ # BASE_SAMPLE_DIR + "TS6-THT_H-11.0.STEP",
724
  ]
725
 
726
 
 
774
  + f".\n {'n' * 20}\nMetadata: "
775
  + metadata
776
  )
777
+ metadata_aggregation.update(
778
+ metadata_extraction
779
+ ) # !!! in-place function, return None
780
  # store embeddings and metadata
781
  embedding_dict[obj_path]["metadata"] = metadata
782
  embedding_dict[obj_path]["metadata_dictionary"] = normalize_metadata(
783
+ metadata_aggregation
784
  )
785
  embedding_dict[obj_path]["description"] = embeddings["description"]
786
  embedding_dict[obj_path]["image_embedding"] = embeddings["image_embedding"]
 
889
  # query button
890
  query_button.click(
891
  query_3D_object,
892
+ [query_input, model_render, embedding_store],
893
  [
894
  model_q_1,
895
  model_q_2,
 
954
  )
955
 
956
  if __name__ == "__main__":
957
+ _env = os.environ.get("ENVIRONMENT", "dev")
958
+ demo.launch(share=True if _env in ["dev", "prod"] else False)
requirements.txt CHANGED
@@ -16,4 +16,5 @@ numpy>=1.26.4,<2.0.0
16
  openai
17
  python-dotenv
18
  opencv-python
19
- Pillow
 
 
16
  openai
17
  python-dotenv
18
  opencv-python
19
+ Pillow
20
+ openai-agents