Spaces:
Build error
Build error
| import os | |
| import re | |
| from typing import Annotated | |
| from typing_extensions import TypedDict | |
| from langchain_groq import ChatGroq | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from langchain_community.graphs import Neo4jGraph | |
| from langgraph.graph import StateGraph | |
| from langgraph.graph import add_messages | |
| from ki_gen.prompts import PLAN_GEN_PROMPT, PLAN_MODIFICATION_PROMPT | |
| from ki_gen.data_retriever import build_data_retriever_graph | |
| from ki_gen.data_processor import build_data_processor_graph | |
| from ki_gen.utils import ConfigSchema, State, HumanValidationState, DocProcessorState, DocRetrieverState | |
| from langgraph.checkpoint.sqlite import SqliteSaver | |
| ########################################################################## | |
| ###### NODES DEFINITION ###### | |
| ########################################################################## | |
| def validate_node(state: State): | |
| """ | |
| This node inserts the plan validation prompt. | |
| """ | |
| prompt = """System : You only need to focus on Key Issues, no need to focus on solutions or stakeholders yet and your plan should be concise. | |
| If needed, give me an updated plan to follow this instruction. If your plan already follows the instruction just say "My plan is correct".""" | |
| output = HumanMessage(content=prompt) | |
| return {"messages" : [output]} | |
| def error_chatbot_groq(error, model_name, query): # Pass model_name instead of llm_groq object | |
| # Switch API key logic... | |
| if os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key"): | |
| os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key2") | |
| elif os.environ["GROQ_API_KEY"] == os.getenv("groq_api_key2"): | |
| os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key3") | |
| else: | |
| os.environ["GROQ_API_KEY"] = os.getenv("groq_api_key") | |
| # Re-initialize the model *after* switching the key | |
| try: | |
| # Use the model_name passed in | |
| llm_groq_retry = ChatGroq(model=model_name) | |
| # Pass the original query messages | |
| return {"messages" : [llm_groq_retry.invoke(query)]} | |
| except Exception as retry_error: | |
| # Handle potential error during retry | |
| print(f"Error during retry: {retry_error}") | |
| # Decide what to return or raise here | |
| return {"messages": [SystemMessage(content=f"Failed to process after retry: {retry_error}")]} | |
| # Wrappers to call LLMs on the state messsages field | |
| def chatbot_llama(state: State): | |
| try: | |
| llm_llama = ChatGroq(model="llama3-70b-8192") | |
| return {"messages" : [llm_llama.invoke(state["messages"])]} | |
| except Exception as error: | |
| error_chatbot_groq(error,llm_llama,state["messages"]) | |
| def chatbot_mixtral(state: State): | |
| print(state) | |
| llm_mixtral = ChatGroq(model="deepseek-r1-distill-llama-70b") | |
| print(llm_mixtral) | |
| return {"messages" : [llm_mixtral.invoke(state["messages"])]} | |
| # except Exception as error: | |
| # error_chatbot_groq(error,llm_mixtral,state["messages"]) | |
| def chatbot_openai(state: State): | |
| llm_openai = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") | |
| return {"messages" : [llm_openai.invoke(state["messages"])]} | |
| chatbots = {"gpt-4o" : chatbot_openai, | |
| "deepseek-r1-distill-llama-70b" : chatbot_mixtral, | |
| "llama3-70b-8192" : chatbot_llama | |
| } | |
| def parse_plan(state: State): | |
| """ | |
| This node parses the generated plan and writes in the 'store_plan' field of the state | |
| """ | |
| plan = state["messages"][-3].content | |
| store_plan = re.split("\d\.", plan.split("Plan:\n")[1])[1:] | |
| try: | |
| store_plan[len(store_plan) - 1] = store_plan[len(store_plan) - 1].split("<END_OF_PLAN>")[0] | |
| except Exception as e: | |
| print(f"Error while removing <END_OF_PLAN> : {e}") | |
| return {"store_plan" : store_plan} | |
| def detail_step(state: State, config: ConfigSchema): | |
| """ | |
| This node updates the value of the 'current_plan_step' field and defines the query to be used for the data_retriever. | |
| """ | |
| print("test") | |
| print(state) | |
| if 'current_plan_step' in state.keys(): | |
| print("all good chief") | |
| else: | |
| state["current_plan_step"] = None | |
| current_plan_step = state["current_plan_step"] + 1 if state["current_plan_step"] is not None else 0 # We just began a new step so we will increase current_plan_step at the end | |
| if config["configurable"].get("use_detailed_query"): | |
| prompt = HumanMessage(f"""Specify what additional information you need to proceed with the next step of your plan : | |
| Step {current_plan_step + 1} : {state['store_plan'][current_plan_step]}""") | |
| query = get_detailed_query(context = state["messages"] + [prompt], model=config["configurable"].get("main_llm")) | |
| return {"messages" : [prompt, query], "current_plan_step": current_plan_step, 'query' : query} | |
| return {"current_plan_step": current_plan_step, 'query' : state["store_plan"][current_plan_step], "valid_docs" : []} | |
| def get_detailed_query(context : list, model : str = "deepseek-r1-distill-llama-70b"): | |
| """ | |
| Simple helper function for the detail_step node | |
| """ | |
| if model == 'gpt-4o': | |
| llm = ChatOpenAI(model=model, base_url="https://llm.synapse.thalescloud.io/") | |
| else: | |
| llm = ChatGroq(model=model) | |
| return llm.invoke(context) | |
| def concatenate_data(state: State): | |
| """ | |
| This node concatenates all the data that was processed by the data_processor and inserts it in the state's messages | |
| """ | |
| prompt = f"""#########TECHNICAL INFORMATION ############ | |
| {str(state["valid_docs"])} | |
| ########END OF TECHNICAL INFORMATION####### | |
| Using the information provided above, proceed with step {state['current_plan_step'] + 1} of your plan : | |
| {state['store_plan'][state['current_plan_step']]} | |
| """ | |
| return {"messages": [HumanMessage(content=prompt)]} | |
| def human_validation(state: HumanValidationState) -> HumanValidationState: | |
| """ | |
| Dummy node to interrupt before | |
| """ | |
| return {'process_steps' : []} | |
| def generate_ki(state: State): | |
| """ | |
| This node inserts the prompt to begin Key Issues generation | |
| """ | |
| print(f"THIS IS THE STATE FOR CURRENT PLAN STEP IN GENERATE_KI : {state}") | |
| prompt = f"""Using the information provided above, proceed with step 4 of your plan to provide the user with NEW and INNOVATIVE Key Issues : | |
| {state['store_plan'][state['current_plan_step'] + 1]}""" | |
| return {"messages" : [HumanMessage(content=prompt)]} | |
| def detail_ki(state: State): | |
| """ | |
| This node inserts the last prompt to detail the generated Key Issues | |
| """ | |
| prompt = f"""Using the information provided above, proceed with step 5 of your plan to provide the user with NEW and INNOVATIVE Key Issues : | |
| {state['store_plan'][state['current_plan_step'] + 2]}""" | |
| return {"messages" : [HumanMessage(content=prompt)]} | |
| ########################################################################## | |
| ###### CONDITIONAL EDGE FUNCTIONS ###### | |
| ########################################################################## | |
| def validate_plan(state: State): | |
| """ | |
| Whether to regenerate the plan or to parse it | |
| """ | |
| if "messages" in state and "My plan is correct" in state["messages"][-1].content: | |
| return "parse" | |
| return "validate" | |
| def next_plan_step(state: State, config: ConfigSchema): | |
| """ | |
| Proceed to next plan step (either generate KI or retrieve more data) | |
| """ | |
| if (state["current_plan_step"] == 2) and (config["configurable"].get('plan_method') == "modification"): | |
| return "generate_key_issues" | |
| if state["current_plan_step"] == len(state["store_plan"]) - 1: | |
| return "generate_key_issues" | |
| else: | |
| return "detail_step" | |
| def detail_or_data_retriever(state: State, config: ConfigSchema): | |
| """ | |
| Detail the query to use for data retrieval or not | |
| """ | |
| if config["configurable"].get("use_detailed_query"): | |
| return "chatbot_detail" | |
| else: | |
| return "data_retriever" | |
| def retrieve_or_process(state: State): | |
| """ | |
| Process the retrieved docs or keep retrieving | |
| """ | |
| if state['human_validated']: | |
| return "process" | |
| return "retrieve" | |
| # while True: | |
| # user_input = input(f"{len(state['valid_docs'])} were retreived. Do you want more documents (y/[n]) : ") | |
| # if user_input.lower() == "y": | |
| # return "retrieve" | |
| # if not user_input or user_input.lower() == "n": | |
| # return "process" | |
| # print("Please answer with 'y' or 'n'.\n") | |
| def build_planner_graph(memory, config): | |
| """ | |
| Builds the planner graph | |
| """ | |
| graph_builder = StateGraph(State) | |
| graph_doc_retriever = build_data_retriever_graph(memory) | |
| graph_doc_processor = build_data_processor_graph(memory) | |
| graph_builder.add_node("chatbot_planner", chatbots[config["main_llm"]]) | |
| graph_builder.add_node("validate", validate_node) | |
| graph_builder.add_node("chatbot_detail", chatbot_llama) | |
| graph_builder.add_node("parse", parse_plan) | |
| graph_builder.add_node("detail_step", detail_step) | |
| graph_builder.add_node("data_retriever", graph_doc_retriever, input=DocRetrieverState) | |
| graph_builder.add_node("human_validation", human_validation) | |
| graph_builder.add_node("data_processor", graph_doc_processor, input=DocProcessorState) | |
| graph_builder.add_node("concatenate_data", concatenate_data) | |
| graph_builder.add_node("chatbot_exec_step", chatbots[config["main_llm"]]) | |
| graph_builder.add_node("generate_ki", generate_ki) | |
| graph_builder.add_node("chatbot_ki", chatbots[config["main_llm"]]) | |
| graph_builder.add_node("detail_ki", detail_ki) | |
| graph_builder.add_node("chatbot_final", chatbots[config["main_llm"]]) | |
| graph_builder.add_edge("validate", "chatbot_planner") | |
| graph_builder.add_edge("parse", "detail_step") | |
| # graph_builder.add_edge("detail_step", "chatbot2") | |
| graph_builder.add_edge("chatbot_detail", "data_retriever") | |
| graph_builder.add_edge("data_retriever", "human_validation") | |
| graph_builder.add_edge("data_processor", "concatenate_data") | |
| graph_builder.add_edge("concatenate_data", "chatbot_exec_step") | |
| graph_builder.add_edge("generate_ki", "chatbot_ki") | |
| graph_builder.add_edge("chatbot_ki", "detail_ki") | |
| graph_builder.add_edge("detail_ki", "chatbot_final") | |
| graph_builder.add_edge("chatbot_final", "__end__") | |
| graph_builder.add_conditional_edges( | |
| "detail_step", | |
| detail_or_data_retriever, | |
| {"chatbot_detail": "chatbot_detail", "data_retriever": "data_retriever"} | |
| ) | |
| graph_builder.add_conditional_edges( | |
| "human_validation", | |
| retrieve_or_process, | |
| {"retrieve" : "data_retriever", "process" : "data_processor"} | |
| ) | |
| graph_builder.add_conditional_edges( | |
| "chatbot_planner", | |
| validate_plan, | |
| {"parse" : "parse", "validate": "validate"} | |
| ) | |
| graph_builder.add_conditional_edges( | |
| "chatbot_exec_step", | |
| next_plan_step, | |
| {"generate_key_issues" : "generate_ki", "detail_step": "detail_step"} | |
| ) | |
| graph_builder.set_entry_point("chatbot_planner") | |
| graph = graph_builder.compile( | |
| checkpointer=memory, | |
| interrupt_after=["parse", "chatbot_exec_step", "chatbot_final", "data_retriever"], | |
| ) | |
| return graph |