subashpoudel commited on
Commit
c398d2b
·
1 Parent(s): 34b6a10

retrieval updated in ideation

Browse files
src/genai/ideation_agent/utils/tools.py CHANGED
@@ -7,7 +7,7 @@ import ast
7
  import faiss
8
  import tiktoken
9
  from src.genai.utils.models_loader import embedding_model
10
- from src.genai.utils.load_embeddings import caption_embeddings , caption_index , caption_df
11
  from src.genai.utils.utils import clean_text
12
 
13
  @tool("influencers_data_retrieval_tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
@@ -55,3 +55,30 @@ def retrieve_tool(business_details):
55
  return encoding.decode(trimmed_response)
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import faiss
8
  import tiktoken
9
  from src.genai.utils.models_loader import embedding_model
10
+ from src.genai.utils.load_embeddings import caption_index , caption_df, ideas_index , ideas_df
11
  from src.genai.utils.utils import clean_text
12
 
13
  @tool("influencers_data_retrieval_tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
 
55
  return encoding.decode(trimmed_response)
56
 
57
 
58
+ @tool("imdb_movies_ideas_retrieval_tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve imdb movies-related idea for a given query.")
59
+ def retrieve_tool(business_details):
60
+ '''
61
+ Always invoke this tool.
62
+ Retrieve the ideas of imdb_movies by semantic search of **business details**.
63
+ '''
64
+ query_embedding = np.array(embedding_model.embed_query(str(business_details))).reshape(1, -1).astype('float32')
65
+ faiss.normalize_L2(query_embedding)
66
+
67
+ top_k = 5
68
+ distances, indices = ideas_index.search(query_embedding, top_k)
69
+
70
+ outer_list = []
71
+ for rank, (idx, sim) in enumerate(indices[0], 1):
72
+ row = ideas_df.iloc[idx]
73
+ res = {
74
+ 'rank': rank,
75
+ 'idea': row['idea'],
76
+ }
77
+
78
+ inner_list = [
79
+ f"[{res['rank']}]. The retrieved idea is: **{res['idea']}\n**",
80
+ ]
81
+ outer_list.append(inner_list)
82
+
83
+ cleaned_response = clean_text(str(outer_list))
84
+ return str(cleaned_response)