from imports import * from hal_check import * from prompt_templates import * from api_json_to_doc import * from mem_check import * def dynamic_k(query): word_count = len(query.split()) if word_count <= 7: k = 1 elif word_count <= 15: k = 2 else: k = 3 return k def pipeline(query, API_LIST, available_tools, available_arguments, arg_allowed_values_dict, args_in_list_dict, vector_db): print(f"PreviousQuery: {st.session_state.PREV_QUERY}") print(f"PreviousResponse: {st.session_state.PREV_RESPONSE}") print(f"PastQuery: {st.session_state.PAST_QUERY}") print(f"PastResponse: {st.session_state.PAST_RESPONSE}") API_LIST = convert_json_to_doc(API_LIST) done = False max_reprompts = 1 cntr = 1 num_examples = dynamic_k(query) docs = vector_db.max_marginal_relevance_search(query, k = num_examples) RAG_examples = '' for i in range(num_examples): RAG_examples += f'{docs[i].page_content}' + '\n' classification = False if(st.session_state.PREV_QUERY!=""): if(st.session_state.PAST_QUERY=="NO PAST QUERIES"): mem_resp = mem_chain.run(QUERY = query, PREV_QUERY = st.session_state.PREV_QUERY) else: mem_resp = mem_chain.run(QUERY = query, PREV_QUERY = st.session_state.PAST_QUERY) print(f"0. Memory output- {cntr} ##################") print(mem_resp) print(f"##################") classification = verify_follow_up_query(mem_resp) if(classification==False): st.session_state.PAST_QUERY = "NO PAST QUERIES" st.session_state.PAST_RESPONSE = "NO PAST RESPONSES" try: resp1 = query_chain.run(QUERY = query , API_LIST = API_LIST, RAG = RAG_examples) except: return [] print(f"1. Pseudo code output- {cntr} ##################") print(resp1) print(type(resp1)) print(len(resp1)) print(f"##################") else: if(st.session_state.PAST_RESPONSE =="NO PAST RESPONSES"): st.session_state.PAST_QUERY = st.session_state.PREV_QUERY st.session_state.PAST_RESPONSE = st.session_state.PREV_RESPONSE try: resp2 = query_chain_memory.run(QUERY = query , API_LIST = API_LIST, RAG = RAG_examples, PAST_QUERY= st.session_state.PAST_QUERY, PAST_RESPONSE = st.session_state.PAST_RESPONSE) except: return resp1 print(f"1. Pseudo code output- {cntr} ##################") print(resp2) print(type(resp2)) print(len(resp2)) print(f"##################") else: try: resp3 = query_chain.run(QUERY = query , API_LIST = API_LIST, RAG = RAG_examples) except: return [] print(f"1. Pseudo code output- {cntr} ##################") print(resp3) print(type(resp3)) print(len(resp3)) print(f"##################") # json_response = [] # try: # # Extract json via python code # pass # except: try: # print(memory.load_memory_variables({})['chat_history']) # chat_history_temp = memory.load_memory_variables({})['chat_history'][1] # print(chat_history_temp) # memory.load_memory_variables({})['chat_history'] = chat_history_temp resp_formatted= format_chain.run(QUERY = "") except: try: return resp3 except: return resp2 print(f"2. JSON string output- {cntr} ##################") print(resp_formatted) print(type(resp_formatted)) print(len(resp_formatted)) print(f"##################") print(f"Reprompt Number: {cntr} #################") print(f"Response formatted:{resp_formatted}") try: json_response = ast.literal_eval(resp_formatted) except Exception as e: Correction_prompt = correction_if_wrong_schema(e, resp_formatted) resp_formatted = reprompt_chain.run(QUERY=query, API_LIST=API_LIST, CORRECTION_PROMPT=Correction_prompt) try: json_response = ast.literal_eval(resp_formatted) except: return [] print(f"3. JSON decoded output- {cntr} ##################") print(json_response) print(type(json_response)) print(len(json_response)) print(f"##################") json_response_init = json_response try: while not done: # hall = True try: hallucinated_args, hallucinated_tools, hallucinated_args_values, hallucinated_args_values_prev = find_hallucinations(json_response, arg_allowed_values_dict, available_tools, available_arguments, args_in_list_dict) except: return json_response_init print('##############') print(f'wrong stuff : {hallucinated_args}, {hallucinated_tools}, {hallucinated_args_values}, {hallucinated_args_values_prev}') print('#############') if ((len(hallucinated_args) + len(hallucinated_tools) + len(hallucinated_args_values)) + len(hallucinated_args_values_prev) is 0 ): if(st.session_state.PREV_QUERY ==""): st.session_state.PREV_QUERY = query st.session_state.PREV_RESPONSE = str(json_response) else: if(classification==True): st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + str(json_response) st.session_state.PREV_QUERY = query st.session_state.PREV_RESPONSE = str(json_response) return json_response if cntr>max_reprompts: done=True try: Correction_prompt = correction(hallucinated_args, hallucinated_args_values, hallucinated_tools, hallucinated_args_values_prev, json_response) except: return json_response_init print(f"4. Correction prompt- {cntr} ##################") print(Correction_prompt) print(type(Correction_prompt)) print(len(Correction_prompt)) print(f"##################") json_response = reprompt_chain.run(QUERY = query, API_LIST = API_LIST, CORRECTION_PROMPT = Correction_prompt) try: json_response = ast.literal_eval(json_response) except Exception as e: Correction_prompt = correction_if_wrong_schema(e, json_response) json_response = reprompt_chain.run(QUERY=query, API_LIST=API_LIST, CORRECTION_PROMPT=Correction_prompt) try: json_response = ast.literal_eval(json_response) except: return [] cntr+=1 print(f"4. JSON decoded output- {cntr} ##################") print(json_response) print(type(json_response)) print(len(json_response)) print(f"##################") # json_response = structure_check(json_response) if placeholder_check(json_response): if(classification==True): st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + '[]' st.session_state.PREV_QUERY = query st.session_state.PREV_RESPONSE = '[]' return [] if unsolvable_check(json_response): if(classification==True): st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + '[]' st.session_state.PREV_QUERY = query st.session_state.PREV_RESPONSE = '[]' return [] if(st.session_state.PREV_QUERY ==""): st.session_state.PREV_QUERY = query st.session_state.PREV_RESPONSE = str(json_response) else: if(classification==True): st.session_state.PAST_QUERY = st.session_state.PAST_QUERY + '.\n' + query st.session_state.PAST_RESPONSE = st.session_state.PAST_RESPONSE + '.\n' + str(json_response) st.session_state.PREV_QUERY = query st.session_state.PREV_RESPONSE = str(json_response) except: return json_response_init return json_response