subashpoudel commited on
Commit
3e87e76
·
1 Parent(s): 53ffc0f

Enhanced retrieval tool

Browse files
Files changed (1) hide show
  1. my_agent/utils/tools.py +58 -30
my_agent/utils/tools.py CHANGED
@@ -11,46 +11,74 @@ from sklearn.metrics.pairwise import cosine_similarity
11
  import numpy as np
12
  from langchain_core.messages import SystemMessage
13
  import re
14
-
 
 
 
15
 
16
  os.environ['GROQ_API_KEY']=os.getenv('GROQ_API_KEY')
17
 
 
 
 
 
 
 
 
 
 
18
 
 
 
 
19
 
20
- class StoryFormatter(BaseModel):
21
- """Always use this tool to structure your response to the user."""
22
- story: str=Field(description="How to introduce the scene and set the tone. What is happening in the scene? Describe key visuals and actions")
23
- narration:str=Field(description="Suggestions for narration or voiceover that complements the visuals." )
24
- text_in_the_Video:str=Field(description="Propose important text overlays for key moments.")
25
- transitions:str=Field(description="Smooth transitions between scenes to maintain flow.")
26
- emotional_tone:str=Field(description="The mood and energy of the scenes (e.g., excitement, calm, tension, joy")
27
- key_visuals:str=Field(description="Important props, locations, sound effects, or background music to enhance the video.")
28
-
29
-
30
- class BrainstromTopicFormatter(BaseModel):
31
- topic1:str=Field(description="First brainstorming topic of the story")
32
- topic2:str=Field(description="Second brainstorming topic of the story")
33
- topic3:str=Field(description="Third brainstorming topic of the story")
34
- topic4:str=Field(description="Fourth brainstorming topic of the story")
35
 
36
- class QueryFormatter(BaseModel):
37
- messages:str = Field(description="The user query")
38
- business_details: str = Field(description="The details of the business of that user.")
39
 
40
- @tool("influencer's data-retrieval-tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
41
- def retrieve_tool(messages, business_details):
42
- '''Always invoke this tool once.'''
43
- print('The query for retrieval is:',messages)
44
- embedded_query = ST.encode(str(messages)+str(business_details)) # Embed each topic
45
- data = load_influencer_data()
46
- scores, retrieved_examples = data.get_nearest_examples("embeddings", embedded_query, k=2)
47
 
48
- # Construct a list of dictionaries for this topic
49
- result = [{user: story} for user, story in zip(retrieved_examples['username'], retrieved_examples['agentic_story'])]
50
- print('Tool response:',result)
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- return result
 
 
 
 
 
 
 
 
 
 
54
 
 
 
 
 
 
55
 
 
56
 
 
11
  import numpy as np
12
  from langchain_core.messages import SystemMessage
13
  import re
14
+ import faiss
15
+ import ast
16
+ import pandas as pd
17
+ from .validators import QueryFormatter
18
 
19
  os.environ['GROQ_API_KEY']=os.getenv('GROQ_API_KEY')
20
 
21
+ @tool("influencer's data-retrieval-tool", args_schema=QueryFormatter, return_direct=False,description="Retrieve influencer-related data for a given query.")
22
+ def retrieve_tool(messages, business_details):
23
+ '''
24
+ Always invoke this tool.
25
+ Retrieve influencer's data by semantic search of **user messages** and the **business details**.
26
+ '''
27
+ # === Load CSV ===
28
+ csv_path = 'extracted_data.csv'
29
+ df = pd.read_csv(csv_path)
30
 
31
+ # === Parse stored embeddings ===
32
+ df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
33
+ embeddings = np.vstack(df['embeddings'].values).astype('float32')
34
 
35
+ # === Build FAISS index ===
36
+ dimension = embeddings.shape[1]
37
+ index = faiss.IndexFlatL2(dimension)
38
+ index.add(embeddings)
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # === Load SentenceTransformer model ===
 
 
41
 
42
+ # === Encode the query and search ===
43
+ query_embedding = ST.encode(str(messages)+str(business_details)).reshape(1, -1).astype('float32')
44
+ top_k=3
45
+ distances, indices = index.search(query_embedding, top_k)
 
 
 
46
 
47
+ # === Function to extract sections 1 and 6 ===
48
+ def extract_story_and_branding(full_story):
49
+ full_story = full_story.replace('**6. Visible Texts or Brandings**', '**6. Visible Texts or Brandings:**')
50
+ full_story = full_story.replace('**1. Story**', '**1. Story:**')
51
 
52
+ pattern = (
53
+ r"\*\*1\. Story:\*\*(.*?)(?=\*\*\d+\.\s)"
54
+ r".*?"
55
+ r"\*\*6\. Visible Texts or Brandings:\*\*(.*?)(?=\*\*\d+\.\s|$)"
56
+ )
57
+ match = re.search(pattern, full_story, re.DOTALL)
58
+ if match:
59
+ story_section = match.group(1).strip()
60
+ branding_section = match.group(2).strip()
61
+ return f"Story:\n{story_section}\n\nVisible Texts or Brandings:\n{branding_section}"
62
+ else:
63
+ return "Requested sections not found."
64
 
65
+ # === Format results ===
66
+ outer_list = []
67
+ for i, idx in enumerate(indices[0]):
68
+ res = {
69
+ 'rank': i + 1,
70
+ 'username': df.iloc[idx]['username'],
71
+ 'agentic_story': df.iloc[idx]['agentic_story'],
72
+ 'likesCount': df.iloc[idx]['likesCount'],
73
+ 'commentCount': df.iloc[idx]['commentCount'],
74
+ 'distance': distances[0][i]
75
+ }
76
 
77
+ inner_list = []
78
+ inner_list.append(f"[{res['rank']}]. The influencer name is: **{res['username']}** — Likes: **{res['likesCount']}**, Comments: **{res['commentCount']}**")
79
+ inner_list.append(f"The story of that particular video is:\n{extract_story_and_branding(res['agentic_story'])}")
80
+ inner_list.append(f"Distance: {res['distance']:.4f}")
81
+ outer_list.append(inner_list)
82
 
83
+ return str(outer_list)
84