Spaces:
Sleeping
Sleeping
| from langgraph.graph import START, StateGraph | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_core.messages import HumanMessage, AIMessage, BaseMessage | |
| from typing import Sequence | |
| from typing_extensions import TypedDict, Annotated | |
| class State(TypedDict): | |
| input: str | |
| chat_history: Annotated[Sequence[BaseMessage], "add_messages"] | |
| context: str | |
| answer: str | |
| # For direct RAG chains | |
| def build_graph(rag_chain): | |
| def call_model(state: State): | |
| response = rag_chain.invoke(state) | |
| return { | |
| "chat_history": [ | |
| HumanMessage(state["input"]), | |
| AIMessage(response["answer"]), | |
| ], | |
| "context": response.get("context", ""), | |
| "answer": response.get("answer", "") | |
| } | |
| workflow = StateGraph(state_schema=State) | |
| workflow.add_node("model", call_model) | |
| workflow.add_edge(START, "model") | |
| #memory = MemorySaver() | |
| #return workflow.compile(checkpointer=memory) | |
| return workflow.compile() # Stateless; relies on session memory only | |
| # For agent_chain.invoke | |
| def build_graph_with_callable(call_fn): | |
| def call_model(state: State): | |
| response = call_fn({"input": state["input"]}) | |
| return { | |
| "chat_history": [ | |
| HumanMessage(state["input"]), | |
| AIMessage(response.get("output", response.get("answer", ""))), | |
| ], | |
| "context": response.get("context", ""), | |
| "answer": response.get("output", response.get("answer", "")) | |
| } | |
| workflow = StateGraph(state_schema=State) | |
| workflow.add_node("model", call_model) | |
| workflow.add_edge(START, "model") | |
| return workflow.compile() | |