subashpoudel's picture
next commit
076ac50
raw
history blame
6.54 kB
import requests
from langchain_core.messages import SystemMessage , HumanMessage , FunctionMessage
from .state import State
from .schemas import ResponseFormatter , CompareBodyFormatter, LatestMessageFormatter, ParameterFormatter, EndpointFormatter
from .prompts import chatbot_prompt , get_body_prompt , fetch_last_message_prompt , fetch_parameters_prompt, fetch_endpoint_prompt
from .utils import generate_api_knowledge , process_query, get_endpoint_info
from src.genai.utils.models_loader import llm_gpt
import numpy as np
from src.genai.utils.data_loader import api_knowledge_df, api_index
from src.genai.utils.models_loader import embedding_model
class FetchLastMessage:
def __init__(self):
self.llm = llm_gpt
def run (self, state:State):
print('Message:',state['messages'])
template = fetch_last_message_prompt
messages=[SystemMessage(content=template)]+state['messages']
result = self.llm.with_structured_output(LatestMessageFormatter, method='function_calling').invoke(messages)
print('Latest Message:', process_query(result.latest_message))
if len(state['messages'])>11:
state["messages"] = state["messages"][-9:]
return {
'latest_message': process_query(result.latest_message)
}
class RetrievePossibleEndpoints:
def __init__(self):
self.df = api_knowledge_df
self.index = api_index
self.results = []
def run(self,state:State):
query_embedding = np.array(embedding_model.embed_query(state['latest_message'])).reshape(1, -1).astype('float32')
distances, indices = self.index.search(query_embedding, 5)
for idx in indices[0]:
row = self.df.iloc[idx]
print('Endpoint:',row['endpoint'])
self.results.append(row['endpoint'])
print('The possible endpoints are:', self.results)
return {
"possible_endpoints": self.results,
}
class RetrieveExactEndpoint:
def __init__(self):
self.llm = llm_gpt
def run(self,state:State):
messages = [SystemMessage(content=fetch_endpoint_prompt),
FunctionMessage(name='possible_endpoints',content=f'''The possible endpoints are: {state['possible_endpoints']}'''),
HumanMessage(content=f'''The user query is: {state['latest_message']}''')]
result = self.llm.with_structured_output(EndpointFormatter, method='function_calling').invoke(messages)
print('The exact endpoint is:', result.endpoint)
endpoint_info=get_endpoint_info(result.endpoint)
print('The endpoint info is:', endpoint_info)
return {
"messages":[{"role": "assistant", "content": f'''The endpoint is: {result.endpoint}'''}],
"endpoint": result.endpoint,
"method": endpoint_info['method'],
"needed_parameters": endpoint_info["parameters"]
}
class FetchParametersNode:
def __init__(self):
self.llm = llm_gpt
def run(self , state:State):
print('Entered to fetch parameters')
print(state['method'])
if state['method'] == 'GET':
print('Condition satisfied')
template = fetch_parameters_prompt
messages=[SystemMessage(content=template),
HumanMessage(content=f'''The query is: {state['latest_message']}\n. The needed parameters: {str(state['needed_parameters'])}''')
]
print('messages:', messages)
result = self.llm.with_structured_output(ParameterFormatter, method='function_calling').invoke(messages)
parameters_values={key: process_query(value) for key, value in result.parameters_values.items()}
print('The parameter values:', parameters_values)
return {
'parameters_values':parameters_values
}
else:
return{
'parameters_values': {}
}
class ChatbotNode:
def __init__(self):
self.llm = llm_gpt
def run(self, state:State):
print('Message:',state['messages'])
template = chatbot_prompt()
knowledge_base = generate_api_knowledge('https://reveltrends.vercel.app')
print('The knowledge base is:', knowledge_base)
messages = [SystemMessage(content=template),
FunctionMessage(name='analytics_chatbot',content=str(knowledge_base)),
] + state["messages"]
if len(state['messages'])>11:
state["messages"] = state["messages"][-9:]
print('Messages:', state['messages'])
print(len(state['messages']))
result = self.llm.with_structured_output(ResponseFormatter, method='function_calling').invoke(messages)
print('The result is:',result)
return {
"messages": [{"role": "assistant", "content": f'''The endpoint is: {result.endpoint}. The parameters are: {result.parameters}'''}],
"endpoint": result.endpoint,
"method": result.method,
"parameters": result.parameters,
}
class FetchDataNode:
def __init__(self):
self.llm = llm_gpt
self.base_url = 'https://reveltrends.vercel.app'
self.headers = {
"Authorization": "Bearer YOUR_API_KEY", # replace with your API key if needed
"Content-Type": "application/json"
}
def run(self, state:State):
print('Entered to fetch data')
url = f'''{self.base_url}{state['endpoint']}'''
if state['method'] == 'GET':
response = requests.get(url, params=state['parameters_values'],headers=self.headers)
elif state['endpoint'] == '/api/v1/compare/':
print('Condition satisfied')
messages = [SystemMessage(content=get_body_prompt()),
HumanMessage(content=str(state['messages']))]
response=llm_gpt.with_structured_output(CompareBodyFormatter , method='function_calling').invoke(messages)
print('INF names response:', response)
payload = {
"usernames": list(map(process_query,response.names)),
"freq": response.frequency
}
print('The payload is:',payload)
headers = {
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
return {'response':response.json()}