subashpoudel commited on
Commit
1ce8b88
·
1 Parent(s): 8d2224f

Updated the tools

Browse files
Files changed (1) hide show
  1. my_agent/utils/tools.py +16 -15
my_agent/utils/tools.py CHANGED
@@ -7,6 +7,10 @@ import numpy as np
7
  from langchain_core.tools import tool
8
  from .data_loader import load_influencer_data
9
  from .models_loader import ST , llm
 
 
 
 
10
 
11
 
12
  os.environ['GROQ_API_KEY']=os.getenv('GROQ_API_KEY')
@@ -30,26 +34,23 @@ class BrainstromTopicFormatter(BaseModel):
30
  topic4:str=Field(description="Fourth brainstorming topic of the story")
31
 
32
  class QueryFormatter(BaseModel):
33
- idea:str = Field(description="Any idea or query about the business.")
34
  business_details: str = Field(description="The details of the business of that user.")
35
 
36
  @tool("influencer's data-retrieval-tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
37
- def retrieve_tool(idea, business_details):
38
-
39
- # """This tool is responsible for the retrieval of the influencer's data using semantic search by reading any **idea or query about the business** and the **business details of the user.**
40
- # But remember, the idea have to be valid first. Don't retrieve anything if the idea is invalid or it is like General Question Answering or follow up questions.
41
- # If you find the idea as invalid, write the value as "None" in the idea so that i can process it."""
42
-
43
- """This tool is responsible for the retrieval of the influencer's data using semantic search by reading any **idea or query about the business** and the **business details of the user.**
44
- ."""
45
-
46
-
47
- embedded_query = ST.encode(str(idea)+str(business_details)) # Embed each topic
48
  data = load_influencer_data()
49
- scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=3)
50
 
51
  # Construct a list of dictionaries for this topic
52
  result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
53
- # result = [{u: {"story": s, "likes": l, "comments": c}} for u, s, l, c in zip(retrieved_examples['username'], retrieved_examples['agentic_story'], retrieved_examples['likes'], retrieved_examples['comments'])]
54
- print('The tool response:',result)
 
55
  return result
 
 
 
 
7
  from langchain_core.tools import tool
8
  from .data_loader import load_influencer_data
9
  from .models_loader import ST , llm
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
+ import numpy as np
12
+ from langchain_core.messages import SystemMessage
13
+ import re
14
 
15
 
16
  os.environ['GROQ_API_KEY']=os.getenv('GROQ_API_KEY')
 
34
  topic4:str=Field(description="Fourth brainstorming topic of the story")
35
 
36
  class QueryFormatter(BaseModel):
37
+ messages:str = Field(description="The user query")
38
  business_details: str = Field(description="The details of the business of that user.")
39
 
40
  @tool("influencer's data-retrieval-tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
41
+ def retrieve_tool(messages, business_details):
42
+ '''Always invoke this tool once.'''
43
+ print('The query for retrieval is:',messages)
44
+ embedded_query = ST.encode(str(messages)+str(business_details)) # Embed each topic
 
 
 
 
 
 
 
45
  data = load_influencer_data()
46
+ scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=2)
47
 
48
  # Construct a list of dictionaries for this topic
49
  result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
50
+ print('Tool response:',result)
51
+
52
+
53
  return result
54
+
55
+
56
+