Spaces:
Sleeping
Sleeping
File size: 6,540 Bytes
872d043 3002e1b a6a0614 6f57d05 a6a0614 076ac50 a6a0614 076ac50 a6a0614 3002e1b 6f57d05 3002e1b 872d043 3002e1b 05626fe 872d043 6f57d05 3002e1b 8ce97f0 3002e1b 872d043 6f57d05 872d043 8ce97f0 872d043 8ce97f0 a6a0614 8ce97f0 05626fe 8ce97f0 5acb96a 05626fe 8ce97f0 872d043 8ce97f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import requests
from langchain_core.messages import SystemMessage , HumanMessage , FunctionMessage
from .state import State
from .schemas import ResponseFormatter , CompareBodyFormatter, LatestMessageFormatter, ParameterFormatter, EndpointFormatter
from .prompts import chatbot_prompt , get_body_prompt , fetch_last_message_prompt , fetch_parameters_prompt, fetch_endpoint_prompt
from .utils import generate_api_knowledge , process_query, get_endpoint_info
from src.genai.utils.models_loader import llm_gpt
import numpy as np
from src.genai.utils.data_loader import api_knowledge_df, api_index
from src.genai.utils.models_loader import embedding_model
class FetchLastMessage:
def __init__(self):
self.llm = llm_gpt
def run (self, state:State):
print('Message:',state['messages'])
template = fetch_last_message_prompt
messages=[SystemMessage(content=template)]+state['messages']
result = self.llm.with_structured_output(LatestMessageFormatter, method='function_calling').invoke(messages)
print('Latest Message:', process_query(result.latest_message))
if len(state['messages'])>11:
state["messages"] = state["messages"][-9:]
return {
'latest_message': process_query(result.latest_message)
}
class RetrievePossibleEndpoints:
def __init__(self):
self.df = api_knowledge_df
self.index = api_index
self.results = []
def run(self,state:State):
query_embedding = np.array(embedding_model.embed_query(state['latest_message'])).reshape(1, -1).astype('float32')
distances, indices = self.index.search(query_embedding, 5)
for idx in indices[0]:
row = self.df.iloc[idx]
print('Endpoint:',row['endpoint'])
self.results.append(row['endpoint'])
print('The possible endpoints are:', self.results)
return {
"possible_endpoints": self.results,
}
class RetrieveExactEndpoint:
def __init__(self):
self.llm = llm_gpt
def run(self,state:State):
messages = [SystemMessage(content=fetch_endpoint_prompt),
FunctionMessage(name='possible_endpoints',content=f'''The possible endpoints are: {state['possible_endpoints']}'''),
HumanMessage(content=f'''The user query is: {state['latest_message']}''')]
result = self.llm.with_structured_output(EndpointFormatter, method='function_calling').invoke(messages)
print('The exact endpoint is:', result.endpoint)
endpoint_info=get_endpoint_info(result.endpoint)
print('The endpoint info is:', endpoint_info)
return {
"messages":[{"role": "assistant", "content": f'''The endpoint is: {result.endpoint}'''}],
"endpoint": result.endpoint,
"method": endpoint_info['method'],
"needed_parameters": endpoint_info["parameters"]
}
class FetchParametersNode:
def __init__(self):
self.llm = llm_gpt
def run(self , state:State):
print('Entered to fetch parameters')
print(state['method'])
if state['method'] == 'GET':
print('Condition satisfied')
template = fetch_parameters_prompt
messages=[SystemMessage(content=template),
HumanMessage(content=f'''The query is: {state['latest_message']}\n. The needed parameters: {str(state['needed_parameters'])}''')
]
print('messages:', messages)
result = self.llm.with_structured_output(ParameterFormatter, method='function_calling').invoke(messages)
parameters_values={key: process_query(value) for key, value in result.parameters_values.items()}
print('The parameter values:', parameters_values)
return {
'parameters_values':parameters_values
}
else:
return{
'parameters_values': {}
}
class ChatbotNode:
def __init__(self):
self.llm = llm_gpt
def run(self, state:State):
print('Message:',state['messages'])
template = chatbot_prompt()
knowledge_base = generate_api_knowledge('https://reveltrends.vercel.app')
print('The knowledge base is:', knowledge_base)
messages = [SystemMessage(content=template),
FunctionMessage(name='analytics_chatbot',content=str(knowledge_base)),
] + state["messages"]
if len(state['messages'])>11:
state["messages"] = state["messages"][-9:]
print('Messages:', state['messages'])
print(len(state['messages']))
result = self.llm.with_structured_output(ResponseFormatter, method='function_calling').invoke(messages)
print('The result is:',result)
return {
"messages": [{"role": "assistant", "content": f'''The endpoint is: {result.endpoint}. The parameters are: {result.parameters}'''}],
"endpoint": result.endpoint,
"method": result.method,
"parameters": result.parameters,
}
class FetchDataNode:
def __init__(self):
self.llm = llm_gpt
self.base_url = 'https://reveltrends.vercel.app'
self.headers = {
"Authorization": "Bearer YOUR_API_KEY", # replace with your API key if needed
"Content-Type": "application/json"
}
def run(self, state:State):
print('Entered to fetch data')
url = f'''{self.base_url}{state['endpoint']}'''
if state['method'] == 'GET':
response = requests.get(url, params=state['parameters_values'],headers=self.headers)
elif state['endpoint'] == '/api/v1/compare/':
print('Condition satisfied')
messages = [SystemMessage(content=get_body_prompt()),
HumanMessage(content=str(state['messages']))]
response=llm_gpt.with_structured_output(CompareBodyFormatter , method='function_calling').invoke(messages)
print('INF names response:', response)
payload = {
"usernames": list(map(process_query,response.names)),
"freq": response.frequency
}
print('The payload is:',payload)
headers = {
"Content-Type": "application/json"
}
response = requests.post(url, json=payload, headers=headers)
return {'response':response.json()}
|