subashpoudel commited on
Commit
563ce7c
·
1 Parent(s): 3e87e76

Added manual retrieval in utilities

Browse files
Files changed (1) hide show
  1. my_agent/utils/utils.py +99 -6
my_agent/utils/utils.py CHANGED
@@ -10,7 +10,13 @@ from huggingface_hub import InferenceClient
10
  from .prompts import story_to_prompt , final_story_prompt
11
  import os
12
  from langgraph.prebuilt import create_react_agent
13
-
 
 
 
 
 
 
14
 
15
 
16
 
@@ -78,9 +84,96 @@ def generate_image(final_story):
78
  image.save('image.png')
79
  print('*****************Image Saved*******************')
80
  return "Image Created"
81
- # try:
82
- # return image
83
- # except:
84
- # return 'Image created'
85
 
86
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from .prompts import story_to_prompt , final_story_prompt
11
  import os
12
  from langgraph.prebuilt import create_react_agent
13
+ import pandas as pd
14
+ from datasets import load_dataset
15
+ import ast
16
+ import faiss
17
+ import re
18
+ import numpy as np
19
+ from .models_loader import ST
20
 
21
 
22
 
 
84
  image.save('image.png')
85
  print('*****************Image Saved*******************')
86
  return "Image Created"
 
 
 
 
87
 
88
+
89
+ def save_to_db(business_details):
90
+ dataset = load_dataset("subashdvorak/tiktok-agentic-story")['train']
91
+ # dataset = load_influencer_data()
92
+ df = pd.DataFrame(dataset)
93
+
94
+ # 2. Flatten all business detail values to a set of lowercase strings
95
+ all_values = set()
96
+ for v in business_details.values():
97
+ if isinstance(v, str):
98
+ all_values.add(v.lower())
99
+ elif isinstance(v, list):
100
+ all_values.update(map(str.lower, map(str, v)))
101
+
102
+ # 3. Match rows where ANY column contains ANY of the values
103
+ def row_matches(row):
104
+ return any(
105
+ str(cell).lower().find(val) != -1
106
+ for cell in row
107
+ for val in all_values
108
+ )
109
+
110
+ # 4. Apply row-wise matching
111
+ matched_df = df[df.apply(row_matches, axis=1)]
112
+ matched_df.to_csv('extracted_data.csv')
113
+
114
+ def manual_retrieval(messages, business_details):
115
+ # === Load CSV ===
116
+ csv_path = 'extracted_data.csv'
117
+ df = pd.read_csv(csv_path)
118
+
119
+ # === Parse stored embeddings ===
120
+ df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
121
+ embeddings = np.vstack(df['embeddings'].values).astype('float32')
122
+
123
+ # === Build FAISS index ===
124
+ dimension = embeddings.shape[1]
125
+ index = faiss.IndexFlatL2(dimension)
126
+ index.add(embeddings)
127
+
128
+ # === Load SentenceTransformer model ===
129
+
130
+ # === Encode the query and search ===
131
+ query_embedding = ST.encode(str(messages)+str(business_details)).reshape(1, -1).astype('float32')
132
+ top_k=3
133
+ distances, indices = index.search(query_embedding, top_k)
134
+
135
+ # === Function to extract sections 1 and 6 ===
136
+ def extract_story_and_branding(full_story):
137
+ full_story = full_story.replace('**6. Visible Texts or Brandings**', '**6. Visible Texts or Brandings:**')
138
+ full_story = full_story.replace('**1. Story**', '**1. Story:**')
139
+
140
+ pattern = (
141
+ r"\*\*1\. Story:\*\*(.*?)(?=\*\*\d+\.\s)"
142
+ r".*?"
143
+ r"\*\*6\. Visible Texts or Brandings:\*\*(.*?)(?=\*\*\d+\.\s|$)"
144
+ )
145
+ match = re.search(pattern, full_story, re.DOTALL)
146
+ if match:
147
+ story_section = match.group(1).strip()
148
+ branding_section = match.group(2).strip()
149
+ return f"Story:\n{story_section}\n\nVisible Texts or Brandings:\n{branding_section}"
150
+ else:
151
+ return "Requested sections not found."
152
+
153
+ # === Format results ===
154
+ outer_list = []
155
+ for i, idx in enumerate(indices[0]):
156
+ res = {
157
+ 'rank': i + 1,
158
+ 'username': df.iloc[idx]['username'],
159
+ 'agentic_story': df.iloc[idx]['agentic_story'],
160
+ 'likesCount': df.iloc[idx]['likesCount'],
161
+ 'commentCount': df.iloc[idx]['commentCount'],
162
+ 'distance': distances[0][i]
163
+ }
164
+
165
+ inner_list = []
166
+ inner_list.append(f"[{res['rank']}]. The influencer name is: **{res['username']}** — Likes: **{res['likesCount']}**, Comments: **{res['commentCount']}**")
167
+ inner_list.append(f"The story of that particular video is:\n{extract_story_and_branding(res['agentic_story'])}")
168
+ inner_list.append(f"Distance: {res['distance']:.4f}")
169
+ outer_list.append(inner_list)
170
+
171
+ return str(outer_list)
172
+
173
+
174
+
175
+
176
+
177
+
178
+
179
+