Spaces:
Sleeping
Sleeping
File size: 8,738 Bytes
872d043 3002e1b 93d50e5 6b61df1 92cf6e2 6f57d05 a6a0614 93d50e5 a6a0614 acfddab a6a0614 93d50e5 a6a0614 ca63886 a6a0614 ca63886 a6a0614 3dbb35d a6a0614 3f5fe23 3dbb35d acfddab 3dbb35d a6a0614 3f5fe23 a6a0614 0b2c9fd a6a0614 6b61df1 a6a0614 93d50e5 0b2c9fd 6b61df1 93d50e5 0b2c9fd 03ef145 93d50e5 0b2c9fd b2982ed e2721b4 b2982ed 93d50e5 3002e1b 872d043 6f57d05 872d043 acfddab 872d043 93d50e5 b2982ed 6b61df1 93d50e5 0b2c9fd acfddab 3e28687 6b61df1 acfddab 0b2c9fd 92cf6e2 0b2c9fd b2982ed 03ef145 93d50e5 16279cd 8ce97f0 93d50e5 0b2c9fd 93d50e5 872d043 93d50e5 8ce97f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import requests
from langchain_core.messages import SystemMessage , HumanMessage , FunctionMessage
from .state import State
from .tools import RetrieverBackup
from .schemas import ParameterFormatter, EndpointFormatter
from .prompts import query_check_prompt, fetch_last_message_prompt , fetch_parameters_prompt, fetch_endpoint_prompt
from .utils import process_query, get_endpoint_info
from src.genai.utils.models_loader import llm_gpt
import numpy as np
from src.genai.utils.data_loader import api_knowledge_df, api_index, caption_df , caption_index
from src.genai.utils.models_loader import embedding_model
from ..handlers import (
compare,
get_posting_time,
get_peak_comment_hour,
get_emoji_count,
get_comment_quality,
get_bot_and_diversity,
)
class FetchLastMessage:
def __init__(self):
self.llm = llm_gpt
def run (self, state:State):
print('Message:',state['messages'])
template = fetch_last_message_prompt
messages=[SystemMessage(content=template)]+state['messages']
result = self.llm.invoke(messages)
print('Latest Message:', process_query(result.content))
if len(state['messages'])>11:
state["messages"] = state["messages"][-9:]
return {
'latest_message': process_query(result.content)
}
class RetrievePossibleEndpoints:
def __init__(self):
self.df = api_knowledge_df
self.index = api_index
self.results = []
def run(self,state:State):
print('Gone to retrieve possible endpoints')
query_embedding = np.array(embedding_model.embed_query(state['latest_message'])).reshape(1, -1).astype('float32')
distances, indices = self.index.search(query_embedding,10)
for idx in indices[0]:
row = self.df.iloc[idx]
print('Endpoint:',row['endpoint'])
self.results.append(row['endpoint'])
print('The possible endpoints are:', self.results)
return {
"possible_endpoints": self.results,
}
class RetrieveExactEndpoint:
def __init__(self):
self.llm = llm_gpt
def run(self,state:State):
print('Gone to retrieve exact endpoint')
messages = [SystemMessage(content=fetch_endpoint_prompt),
FunctionMessage(name='possible_endpoints',content=f'''The possible endpoints are: {state['possible_endpoints']}'''),
HumanMessage(content=f'''The user query is: {state['latest_message']}''')]
result = self.llm.with_structured_output(EndpointFormatter, method='function_calling').invoke(messages)
print('The exact endpoint is:', result.endpoint)
endpoint_info=get_endpoint_info(result.endpoint)
print('The endpoint info is:', endpoint_info)
return {
"messages":[{"role": "assistant", "content": f'''The endpoint is: {result.endpoint}'''}],
"endpoint": result.endpoint,
"method": endpoint_info['method'],
"needed_parameters": endpoint_info["parameters"]
}
class QueryCheckNode:
def __init__(self):
self.llm = llm_gpt
def run(self, state:State):
try:
print('Entered to query checking')
messages = [SystemMessage(content=query_check_prompt),
HumanMessage(content=f'''The user query is: {state['latest_message']}''')]
result = self.llm.invoke(messages)
print(result.content)
return{'query_type': result.content}
except Exception as e:
print('Error occoured:', e)
return {'error_message': str(e)}
class FetchParametersNode:
def __init__(self):
self.llm = llm_gpt
self.complex_endpoints=['/api/v1/compare/','/api/v1/engagement/posting-time-analysis','/api/v1/audience/peak-comment-hour','/api/v1/audience/emoji-count','/api/v1/audience/comment-quality']
def run(self , state:State):
try:
print('Entered to fetch parameters')
if state['endpoint'] not in self.complex_endpoints:
template = fetch_parameters_prompt
messages=[SystemMessage(content=template),HumanMessage(content=f'''The query is: {state['latest_message']}\n. The needed parameters: {str(state['needed_parameters'])}''')]
# print('messages:', messages)
result = self.llm.with_structured_output(ParameterFormatter, method='function_calling').invoke(messages)
parameters_values = {k: (process_query(v) if isinstance(v, str) else v) for k, v in result.parameters_values.items()}
# if 'single_influencer_query' in state['query_type']:
# print('The parameter values:', parameters_values)
# return {
# 'parameters_values':parameters_values
# }
# elif 'aggregate_query' in state['query_type']:
# parameters_values['influencer_username'] = ['divyadhakal_','munachiya','mydarlingfood','_its.me.muskan_']
# print('The parameter values:', parameters_values)
# return{
# 'parameters_values': parameters_values
# }
print('The parameter values:', parameters_values)
return {'parameters_values': parameters_values}
except Exception as e:
print('Error occoured:', e)
return {'error_message': str(e)}
class FetchDataNode:
def __init__(self):
self.llm = llm_gpt
self.base_url = 'https://reveltrends.vercel.app'
self.headers = {
"Authorization": "Bearer YOUR_API_KEY", # replace with your API key if needed
"Content-Type": "application/json"
}
self.endpoint_handlers = {
'/api/v1/compare/': compare,
'/api/v1/engagement/posting-time-analysis': get_posting_time,
'/api/v1/audience/peak-comment-hour': get_peak_comment_hour,
'/api/v1/audience/emoji-count': get_emoji_count,
'/api/v1/audience/comment-quality': get_comment_quality,
'/api/v1/audience/bot-and-diversity': get_bot_and_diversity
}
def run(self, state:State):
try:
state['query_type']='single_influencer_query'
print('Entered to fetch data')
url = f'''{self.base_url}{state['endpoint']}'''
if state['endpoint'] in self.endpoint_handlers:
print('Entered to handler.')
handler = self.endpoint_handlers[state['endpoint']]
response = handler(state, llm_gpt, url)
print('Returned by handler.')
return {'response':response.json()}
elif 'single_influencer_query' in state['query_type']:
response = requests.get(url, params=state['parameters_values'],headers=self.headers)
print('Data from api:', response)
return {'response':response.json()}
# elif 'aggregate_query' in state['query_type']:
# print('Entered to aggregrated query execution')
# print(state['parameters_values'])
# params = state["parameters_values"]
# if "influencer_username" in params and isinstance(params["influencer_username"], list):
# results = {}
# # Iterate through each influencer username
# for username in params["influencer_username"]:
# current_params = params.copy()
# current_params["influencer_username"] = username
# response = requests.get(url, params=current_params, headers=self.headers)
# results[username] = response.json() # Store influencer-wise response
# print('Data from api:', response)
# return {"response": results}
except Exception as e:
print('Error occoured:', e)
return {'error_message': str(e), 'response': 'No response'}
class BackupRetrievalNode:
def __init__(self):
self.llm = llm_gpt
def run(self, state:State):
retrieval=RetrieverBackup().retrieve(state['latest_message'])
return {'backup_data': retrieval}
class BackupRoutingNode:
def __init__(self):
pass
def run(self,state:State):
if state.get('error_message') is not None:
return 'execute_backup'
else:
return 'go_on'
|