Spaces:
Sleeping
Sleeping
File size: 7,969 Bytes
02b5c87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
from imports import *
from hal_check import *
from prompt_templates import *
from api_json_to_doc import *
from mem_check import *
def dynamic_k(query):
word_count = len(query.split())
if word_count <= 7:
k = 1
elif word_count <= 15:
k = 2
else:
k = 3
return k
def pipeline(query, API_LIST, available_tools, available_arguments, arg_allowed_values_dict, args_in_list_dict, vector_db):
print(f"PreviousQuery: {st.session_state.PREV_QUERY}")
print(f"PreviousResponse: {st.session_state.PREV_RESPONSE}")
print(f"PastQuery: {st.session_state.PAST_QUERY}")
print(f"PastResponse: {st.session_state.PAST_RESPONSE}")
API_LIST = convert_json_to_doc(API_LIST)
done = False
max_reprompts = 1
cntr = 1
num_examples = dynamic_k(query)
docs = vector_db.max_marginal_relevance_search(query, k = num_examples)
RAG_examples = ''
for i in range(num_examples):
RAG_examples += f'{docs[i].page_content}' + '\n'
classification = False
if(st.session_state.PREV_QUERY!=""):
if(st.session_state.PAST_QUERY=="NO PAST QUERIES"):
mem_resp = mem_chain.run(QUERY = query, PREV_QUERY = st.session_state.PREV_QUERY)
else:
mem_resp = mem_chain.run(QUERY = query, PREV_QUERY = st.session_state.PAST_QUERY)
print(f"0. Memory output- {cntr} ##################")
print(mem_resp)
print(f"##################")
classification = verify_follow_up_query(mem_resp)
if(classification==False):
st.session_state.PAST_QUERY = "NO PAST QUERIES"
st.session_state.PAST_RESPONSE = "NO PAST RESPONSES"
try:
resp1 = query_chain.run(QUERY = query , API_LIST = API_LIST, RAG = RAG_examples)
except:
return []
print(f"1. Pseudo code output- {cntr} ##################")
print(resp1)
print(type(resp1))
print(len(resp1))
print(f"##################")
else:
if(st.session_state.PAST_RESPONSE =="NO PAST RESPONSES"):
st.session_state.PAST_QUERY = st.session_state.PREV_QUERY
st.session_state.PAST_RESPONSE = st.session_state.PREV_RESPONSE
try:
resp2 = query_chain_memory.run(QUERY = query , API_LIST = API_LIST, RAG = RAG_examples, PAST_QUERY= st.session_state.PAST_QUERY, PAST_RESPONSE = st.session_state.PAST_RESPONSE)
except:
return resp1
print(f"1. Pseudo code output- {cntr} ##################")
print(resp2)
print(type(resp2))
print(len(resp2))
print(f"##################")
else:
try:
resp3 = query_chain.run(QUERY = query , API_LIST = API_LIST, RAG = RAG_examples)
except:
return []
print(f"1. Pseudo code output- {cntr} ##################")
print(resp3)
print(type(resp3))
print(len(resp3))
print(f"##################")
# json_response = []
# try:
# # Extract json via python code
# pass
# except:
try:
# print(memory.load_memory_variables({})['chat_history'])
# chat_history_temp = memory.load_memory_variables({})['chat_history'][1]
# print(chat_history_temp)
# memory.load_memory_variables({})['chat_history'] = chat_history_temp
resp_formatted= format_chain.run(QUERY = "")
except:
try:
return resp3
except:
return resp2
print(f"2. JSON string output- {cntr} ##################")
print(resp_formatted)
print(type(resp_formatted))
print(len(resp_formatted))
print(f"##################")
print(f"Reprompt Number: {cntr} #################")
print(f"Response formatted:{resp_formatted}")
try:
json_response = ast.literal_eval(resp_formatted)
except Exception as e:
Correction_prompt = correction_if_wrong_schema(e, resp_formatted)
resp_formatted = reprompt_chain.run(QUERY=query, API_LIST=API_LIST, CORRECTION_PROMPT=Correction_prompt)
try:
json_response = ast.literal_eval(resp_formatted)
except:
return []
print(f"3. JSON decoded output- {cntr} ##################")
print(json_response)
print(type(json_response))
print(len(json_response))
print(f"##################")
json_response_init = json_response
try:
while not done:
# hall = True
try:
hallucinated_args, hallucinated_tools, hallucinated_args_values, hallucinated_args_values_prev = find_hallucinations(json_response, arg_allowed_values_dict, available_tools, available_arguments, args_in_list_dict)
except:
return json_response_init
print('##############')
print(f'wrong stuff : {hallucinated_args}, {hallucinated_tools}, {hallucinated_args_values}, {hallucinated_args_values_prev}')
print('#############')
if ((len(hallucinated_args) + len(hallucinated_tools) + len(hallucinated_args_values)) + len(hallucinated_args_values_prev) is 0 ):
if(st.session_state.PREV_QUERY ==""):
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = str(json_response)
else:
if(classification==True):
st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query
st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + str(json_response)
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = str(json_response)
return json_response
if cntr>max_reprompts:
done=True
try:
Correction_prompt = correction(hallucinated_args, hallucinated_args_values, hallucinated_tools, hallucinated_args_values_prev, json_response)
except:
return json_response_init
print(f"4. Correction prompt- {cntr} ##################")
print(Correction_prompt)
print(type(Correction_prompt))
print(len(Correction_prompt))
print(f"##################")
json_response = reprompt_chain.run(QUERY = query, API_LIST = API_LIST, CORRECTION_PROMPT = Correction_prompt)
try:
json_response = ast.literal_eval(json_response)
except Exception as e:
Correction_prompt = correction_if_wrong_schema(e, json_response)
json_response = reprompt_chain.run(QUERY=query, API_LIST=API_LIST, CORRECTION_PROMPT=Correction_prompt)
try:
json_response = ast.literal_eval(json_response)
except:
return []
cntr+=1
print(f"4. JSON decoded output- {cntr} ##################")
print(json_response)
print(type(json_response))
print(len(json_response))
print(f"##################")
# json_response = structure_check(json_response)
if placeholder_check(json_response):
if(classification==True):
st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query
st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + '[]'
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = '[]'
return []
if unsolvable_check(json_response):
if(classification==True):
st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query
st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + '[]'
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = '[]'
return []
if(st.session_state.PREV_QUERY ==""):
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = str(json_response)
else:
if(classification==True):
st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query
st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + str(json_response)
st.session_state.PREV_QUERY = query
st.session_state.PREV_RESPONSE = str(json_response)
except:
return json_response_init
return json_response
|