Ai_Agent / pipelines.py
HighRoller's picture
Upload 14 files
02b5c87 verified
from imports import *
from hal_check import *
from prompt_templates import *
from api_json_to_doc import *
from mem_check import *
def dynamic_k(query):
word_count = len(query.split())
if word_count <= 7:
k = 1
elif word_count <= 15:
k = 2
else:
k = 3
return k
def pipeline(query, API_LIST, available_tools, available_arguments, arg_allowed_values_dict, args_in_list_dict, vector_db):
print(f"PreviousQuery: {st.session_state.PREV_QUERY}")
print(f"PreviousResponse: {st.session_state.PREV_RESPONSE}")
print(f"PastQuery: {st.session_state.PAST_QUERY}")
print(f"PastResponse: {st.session_state.PAST_RESPONSE}")
API_LIST = convert_json_to_doc(API_LIST)
done = False
max_reprompts = 1
cntr = 1
num_examples = dynamic_k(query)
docs = vector_db.max_marginal_relevance_search(query, k = num_examples)
RAG_examples = ''
for i in range(num_examples):
RAG_examples += f'{docs[i].page_content}' + '\n'
classification = False
if(st.session_state.PREV_QUERY!=""):
if(st.session_state.PAST_QUERY=="NO PAST QUERIES"):
mem_resp = mem_chain.run(QUERY = query, PREV_QUERY = st.session_state.PREV_QUERY)
else:
mem_resp = mem_chain.run(QUERY = query, PREV_QUERY = st.session_state.PAST_QUERY)
print(f"0. Memory output- {cntr} ##################")
print(mem_resp)
print(f"##################")
classification = verify_follow_up_query(mem_resp)
if(classification==False):
st.session_state.PAST_QUERY = "NO PAST QUERIES"
st.session_state.PAST_RESPONSE = "NO PAST RESPONSES"
try:
resp1 = query_chain.run(QUERY = query , API_LIST = API_LIST, RAG = RAG_examples)
except:
return []
print(f"1. Pseudo code output- {cntr} ##################")
print(resp1)
print(type(resp1))
print(len(resp1))
print(f"##################")
else:
if(st.session_state.PAST_RESPONSE =="NO PAST RESPONSES"):
st.session_state.PAST_QUERY = st.session_state.PREV_QUERY
st.session_state.PAST_RESPONSE = st.session_state.PREV_RESPONSE
try:
resp2 = query_chain_memory.run(QUERY = query , API_LIST = API_LIST, RAG = RAG_examples, PAST_QUERY= st.session_state.PAST_QUERY, PAST_RESPONSE = st.session_state.PAST_RESPONSE)
except:
return resp1
print(f"1. Pseudo code output- {cntr} ##################")
print(resp2)
print(type(resp2))
print(len(resp2))
print(f"##################")
else:
try:
resp3 = query_chain.run(QUERY = query , API_LIST = API_LIST, RAG = RAG_examples)
except:
return []
print(f"1. Pseudo code output- {cntr} ##################")
print(resp3)
print(type(resp3))
print(len(resp3))
print(f"##################")
# json_response = []
# try:
# # Extract json via python code
# pass
# except:
try:
# print(memory.load_memory_variables({})['chat_history'])
# chat_history_temp = memory.load_memory_variables({})['chat_history'][1]
# print(chat_history_temp)
# memory.load_memory_variables({})['chat_history'] = chat_history_temp
resp_formatted= format_chain.run(QUERY = "")
except:
try:
return resp3
except:
return resp2
print(f"2. JSON string output- {cntr} ##################")
print(resp_formatted)
print(type(resp_formatted))
print(len(resp_formatted))
print(f"##################")
print(f"Reprompt Number: {cntr} #################")
print(f"Response formatted:{resp_formatted}")
try:
json_response = ast.literal_eval(resp_formatted)
except Exception as e:
Correction_prompt = correction_if_wrong_schema(e, resp_formatted)
resp_formatted = reprompt_chain.run(QUERY=query, API_LIST=API_LIST, CORRECTION_PROMPT=Correction_prompt)
try:
json_response = ast.literal_eval(resp_formatted)
except:
return []
print(f"3. JSON decoded output- {cntr} ##################")
print(json_response)
print(type(json_response))
print(len(json_response))
print(f"##################")
json_response_init = json_response
try:
while not done:
# hall = True
try:
hallucinated_args, hallucinated_tools, hallucinated_args_values, hallucinated_args_values_prev = find_hallucinations(json_response, arg_allowed_values_dict, available_tools, available_arguments, args_in_list_dict)
except:
return json_response_init
print('##############')
print(f'wrong stuff : {hallucinated_args}, {hallucinated_tools}, {hallucinated_args_values}, {hallucinated_args_values_prev}')
print('#############')
if ((len(hallucinated_args) + len(hallucinated_tools) + len(hallucinated_args_values)) + len(hallucinated_args_values_prev) is 0 ):
if(st.session_state.PREV_QUERY ==""):
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = str(json_response)
else:
if(classification==True):
st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query
st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + str(json_response)
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = str(json_response)
return json_response
if cntr>max_reprompts:
done=True
try:
Correction_prompt = correction(hallucinated_args, hallucinated_args_values, hallucinated_tools, hallucinated_args_values_prev, json_response)
except:
return json_response_init
print(f"4. Correction prompt- {cntr} ##################")
print(Correction_prompt)
print(type(Correction_prompt))
print(len(Correction_prompt))
print(f"##################")
json_response = reprompt_chain.run(QUERY = query, API_LIST = API_LIST, CORRECTION_PROMPT = Correction_prompt)
try:
json_response = ast.literal_eval(json_response)
except Exception as e:
Correction_prompt = correction_if_wrong_schema(e, json_response)
json_response = reprompt_chain.run(QUERY=query, API_LIST=API_LIST, CORRECTION_PROMPT=Correction_prompt)
try:
json_response = ast.literal_eval(json_response)
except:
return []
cntr+=1
print(f"4. JSON decoded output- {cntr} ##################")
print(json_response)
print(type(json_response))
print(len(json_response))
print(f"##################")
# json_response = structure_check(json_response)
if placeholder_check(json_response):
if(classification==True):
st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query
st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + '[]'
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = '[]'
return []
if unsolvable_check(json_response):
if(classification==True):
st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query
st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + '[]'
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = '[]'
return []
if(st.session_state.PREV_QUERY ==""):
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = str(json_response)
else:
if(classification==True):
st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query
st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + str(json_response)
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = str(json_response)
except:
return json_response_init
return json_response