from langgraph.graph import StateGraph, END from langgraph.graph.state import CompiledStateGraph from typing import Dict, Any from app.rag_chain import get_qa_chain from app.models.model import LLM as llm from app.summary_chain import get_summary_chain from app.translate_chain import get_translate_chain from app.rewrite_chain import get_rewrite_chain # --- OUTPUT COLLECTOR --- output_collector = { "answer": None, "summary": None, "translation": None, "formal_email": None } # --- TOOLS --- def answer_with_rag_tool(input_text: str, input_lang: str, **kwargs) -> str: chain = get_qa_chain(language=input_lang) result = chain.invoke({"input": input_text}) output_collector["answer"] = result return result def summarize_tool(input_text: str, input_lang: str, **kwargs) -> str: llm, prompt = get_summary_chain(input_lang) result = llm.invoke(prompt.format_messages(input=input_text)) output_collector["summary"] = result.content return result def translate_tool(input_text: str, input_lang: str, target_lang: str) -> str: if not target_lang: target_lang = input_lang # Default to input language if no target language provided llm, prompt = get_translate_chain(input_lang) result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang)) output_collector["translation"] = result.content return result def rewrite_email_tool(input_text: str, input_lang: str, target_lang: str) -> str: llm, prompt = get_rewrite_chain(input_lang) result = llm.invoke(prompt.format_messages(input=input_text, target_lang=target_lang)) output_collector["formal_email"] = result.content return result # --- LANGGRAPH STATE & NODES --- class AgentState(Dict[str, Any]): user_input: str input_lang: str target_lang: str answer: str = None summary: str = None translation: str = None formal_email: str = None def node_answer_by_RAG(state: AgentState) -> AgentState: answer = answer_with_rag_tool(state["user_input"], state["input_lang"]) state["answer"] = answer return state def node_summarize(state: AgentState) -> AgentState: summary = summarize_tool(state["answer"], state["input_lang"]) state["summary"] = summary return state def node_translate(state: AgentState) -> AgentState: # Use summary if exists, else answer text_to_translate = state.get("summary") or state["answer"] translation = translate_tool(text_to_translate, state["input_lang"], state["target_lang"]) state["translation"] = translation return state def node_rewrite(state: AgentState) -> AgentState: # Use translation if exists, else summary, else answer text_to_rewrite = state.get("translation") or state["answer"] formal_email = rewrite_email_tool(text_to_rewrite, state["input_lang"], state["target_lang"]) state["formal_email"] = formal_email return state # --- LANGGRAPH GRAPH DEFINITION --- graph = StateGraph(AgentState) graph.add_node("answer_by_RAG", node_answer_by_RAG) graph.add_node("summarize", node_summarize) graph.add_node("translate", node_translate) graph.add_node("rewrite", node_rewrite) graph.set_entry_point("answer_by_RAG") # Remove all conditional edges and decision functions, just allow direct edges for sequential execution if needed # (But for the new execution plan, we do not need to set edges between nodes, as we step through them manually) # Only keep END edge for completeness # Add END edges for all possible last nodes for node in ["answer_by_RAG", "summarize", "translate", "rewrite"]: graph.add_edge(node, END) compiled_graph = graph.compile() # --- MAIN API CALL --- def run_agent(user_input: str, input_lang: str = "Deutsch", target_lang: str = None, do_summarize: bool = False, do_translate: bool = False, do_email: bool = False): print(user_input, input_lang, target_lang, do_summarize, do_translate, do_email) for k in output_collector: output_collector[k] = None state = { "user_input": user_input, "input_lang": input_lang, "target_lang": target_lang } execution_plan = ["answer_by_RAG"] if do_summarize: execution_plan.append("summarize") if do_translate: execution_plan.append("translate") if do_email: execution_plan.append("rewrite") for node in execution_plan: state = globals()[f"node_{node}"](state) # Yield the current outputs after each node yield {k: v for k, v in output_collector.items() if v is not None}