subashpoudel commited on
Commit
ca75c57
·
1 Parent(s): e842da1

Added the retrieval tool node

Browse files
Files changed (1) hide show
  1. my_agent/utils/tools.py +21 -9
my_agent/utils/tools.py CHANGED
@@ -4,17 +4,13 @@ from dotenv import load_dotenv
4
  load_dotenv()
5
  import os
6
  import numpy as np
 
 
 
7
 
8
- os.environ['GROQ_API_KEY']=os.getenv('GROQ_API_KEY')
9
 
10
- llm = ChatGroq(
11
- model="llama3-8b-8192",
12
- temperature=0,
13
- max_tokens=None,
14
- timeout=None,
15
- max_retries=2,
16
 
17
- )
18
 
19
  class StoryFormatter(BaseModel):
20
  """Always use this tool to structure your response to the user."""
@@ -30,4 +26,20 @@ class BrainstromTopicFormatter(BaseModel):
30
  topic1:str=Field(description="First brainstorming topic of the story")
31
  topic2:str=Field(description="Second brainstorming topic of the story")
32
  topic3:str=Field(description="Third brainstorming topic of the story")
33
- topic4:str=Field(description="Fourth brainstorming topic of the story")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  load_dotenv()
5
  import os
6
  import numpy as np
7
+ from langchain_core.tools import tool
8
+ from .data_loader import load_influencer_data
9
+ from .models_loader import ST
10
 
 
11
 
12
+ os.environ['GROQ_API_KEY']=os.getenv('GROQ_API_KEY')
 
 
 
 
 
13
 
 
14
 
15
  class StoryFormatter(BaseModel):
16
  """Always use this tool to structure your response to the user."""
 
26
  topic1:str=Field(description="First brainstorming topic of the story")
27
  topic2:str=Field(description="Second brainstorming topic of the story")
28
  topic3:str=Field(description="Third brainstorming topic of the story")
29
+ topic4:str=Field(description="Fourth brainstorming topic of the story")
30
+
31
+ class QueryFormatter(BaseModel):
32
+ idea:str = Field(description="The video idea which the user wants to create.")
33
+ business_details: str = Field(description="The details of the business of that user.")
34
+
35
+ @tool("influencer's data-retrieval-tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
36
+ def retrieve_tool(idea, business_details):
37
+
38
+ """This tool is responsible for the retrieval of the influencer's data using semantic search by reading the video idea and the business details of the user. """
39
+ embedded_query = ST.encode(str(idea)+str(business_details)) # Embed each topic
40
+ data = load_influencer_data()
41
+ scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=1)
42
+
43
+ # Construct a list of dictionaries for this topic
44
+ result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
45
+ return result