subashpoudel's picture
next
e2721b4
import requests
from langchain_core.messages import SystemMessage , HumanMessage , FunctionMessage
from .state import State
from .tools import RetrieverBackup
from .schemas import ParameterFormatter, EndpointFormatter
from .prompts import query_check_prompt, fetch_last_message_prompt , fetch_parameters_prompt, fetch_endpoint_prompt
from .utils import process_query, get_endpoint_info
from src.genai.utils.models_loader import llm_gpt
import numpy as np
from src.genai.utils.data_loader import api_knowledge_df, api_index, caption_df , caption_index
from src.genai.utils.models_loader import embedding_model
from ..handlers import (
compare,
get_posting_time,
get_peak_comment_hour,
get_emoji_count,
get_comment_quality,
get_bot_and_diversity,
)
class FetchLastMessage:
def __init__(self):
self.llm = llm_gpt
def run (self, state:State):
print('Message:',state['messages'])
template = fetch_last_message_prompt
messages=[SystemMessage(content=template)]+state['messages']
result = self.llm.invoke(messages)
print('Latest Message:', process_query(result.content))
if len(state['messages'])>11:
state["messages"] = state["messages"][-9:]
return {
'latest_message': process_query(result.content)
}
class RetrievePossibleEndpoints:
def __init__(self):
self.df = api_knowledge_df
self.index = api_index
self.results = []
def run(self,state:State):
print('Gone to retrieve possible endpoints')
query_embedding = np.array(embedding_model.embed_query(state['latest_message'])).reshape(1, -1).astype('float32')
distances, indices = self.index.search(query_embedding,10)
for idx in indices[0]:
row = self.df.iloc[idx]
print('Endpoint:',row['endpoint'])
self.results.append(row['endpoint'])
print('The possible endpoints are:', self.results)
return {
"possible_endpoints": self.results,
}
class RetrieveExactEndpoint:
def __init__(self):
self.llm = llm_gpt
def run(self,state:State):
print('Gone to retrieve exact endpoint')
messages = [SystemMessage(content=fetch_endpoint_prompt),
FunctionMessage(name='possible_endpoints',content=f'''The possible endpoints are: {state['possible_endpoints']}'''),
HumanMessage(content=f'''The user query is: {state['latest_message']}''')]
result = self.llm.with_structured_output(EndpointFormatter, method='function_calling').invoke(messages)
print('The exact endpoint is:', result.endpoint)
endpoint_info=get_endpoint_info(result.endpoint)
print('The endpoint info is:', endpoint_info)
return {
"messages":[{"role": "assistant", "content": f'''The endpoint is: {result.endpoint}'''}],
"endpoint": result.endpoint,
"method": endpoint_info['method'],
"needed_parameters": endpoint_info["parameters"]
}
class QueryCheckNode:
def __init__(self):
self.llm = llm_gpt
def run(self, state:State):
try:
print('Entered to query checking')
messages = [SystemMessage(content=query_check_prompt),
HumanMessage(content=f'''The user query is: {state['latest_message']}''')]
result = self.llm.invoke(messages)
print(result.content)
return{'query_type': result.content}
except Exception as e:
print('Error occoured:', e)
return {'error_message': str(e)}
class FetchParametersNode:
def __init__(self):
self.llm = llm_gpt
self.complex_endpoints=['/api/v1/compare/','/api/v1/engagement/posting-time-analysis','/api/v1/audience/peak-comment-hour','/api/v1/audience/emoji-count','/api/v1/audience/comment-quality']
def run(self , state:State):
try:
print('Entered to fetch parameters')
if state['endpoint'] not in self.complex_endpoints:
template = fetch_parameters_prompt
messages=[SystemMessage(content=template),HumanMessage(content=f'''The query is: {state['latest_message']}\n. The needed parameters: {str(state['needed_parameters'])}''')]
# print('messages:', messages)
result = self.llm.with_structured_output(ParameterFormatter, method='function_calling').invoke(messages)
parameters_values = {k: (process_query(v) if isinstance(v, str) else v) for k, v in result.parameters_values.items()}
# if 'single_influencer_query' in state['query_type']:
# print('The parameter values:', parameters_values)
# return {
# 'parameters_values':parameters_values
# }
# elif 'aggregate_query' in state['query_type']:
# parameters_values['influencer_username'] = ['divyadhakal_','munachiya','mydarlingfood','_its.me.muskan_']
# print('The parameter values:', parameters_values)
# return{
# 'parameters_values': parameters_values
# }
print('The parameter values:', parameters_values)
return {'parameters_values': parameters_values}
except Exception as e:
print('Error occoured:', e)
return {'error_message': str(e)}
class FetchDataNode:
def __init__(self):
self.llm = llm_gpt
self.base_url = 'https://reveltrends.vercel.app'
self.headers = {
"Authorization": "Bearer YOUR_API_KEY", # replace with your API key if needed
"Content-Type": "application/json"
}
self.endpoint_handlers = {
'/api/v1/compare/': compare,
'/api/v1/engagement/posting-time-analysis': get_posting_time,
'/api/v1/audience/peak-comment-hour': get_peak_comment_hour,
'/api/v1/audience/emoji-count': get_emoji_count,
'/api/v1/audience/comment-quality': get_comment_quality,
'/api/v1/audience/bot-and-diversity': get_bot_and_diversity
}
def run(self, state:State):
try:
state['query_type']='single_influencer_query'
print('Entered to fetch data')
url = f'''{self.base_url}{state['endpoint']}'''
if state['endpoint'] in self.endpoint_handlers:
print('Entered to handler.')
handler = self.endpoint_handlers[state['endpoint']]
response = handler(state, llm_gpt, url)
print('Returned by handler.')
return {'response':response.json()}
elif 'single_influencer_query' in state['query_type']:
response = requests.get(url, params=state['parameters_values'],headers=self.headers)
print('Data from api:', response)
return {'response':response.json()}
# elif 'aggregate_query' in state['query_type']:
# print('Entered to aggregrated query execution')
# print(state['parameters_values'])
# params = state["parameters_values"]
# if "influencer_username" in params and isinstance(params["influencer_username"], list):
# results = {}
# # Iterate through each influencer username
# for username in params["influencer_username"]:
# current_params = params.copy()
# current_params["influencer_username"] = username
# response = requests.get(url, params=current_params, headers=self.headers)
# results[username] = response.json() # Store influencer-wise response
# print('Data from api:', response)
# return {"response": results}
except Exception as e:
print('Error occoured:', e)
return {'error_message': str(e), 'response': 'No response'}
class BackupRetrievalNode:
def __init__(self):
self.llm = llm_gpt
def run(self, state:State):
retrieval=RetrieverBackup().retrieve(state['latest_message'])
return {'backup_data': retrieval}
class BackupRoutingNode:
def __init__(self):
pass
def run(self,state:State):
if state.get('error_message') is not None:
return 'execute_backup'
else:
return 'go_on'