ARGOBot / src /state.py
mrmtaeb's picture
Update src/state.py
16edc8c verified
raw
history blame
1.7 kB
from langgraph.graph import START, StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from typing import Sequence
from typing_extensions import TypedDict, Annotated
class State(TypedDict):
input: str
chat_history: Annotated[Sequence[BaseMessage], "add_messages"]
context: str
answer: str
# For direct RAG chains
def build_graph(rag_chain):
def call_model(state: State):
response = rag_chain.invoke(state)
return {
"chat_history": [
HumanMessage(state["input"]),
AIMessage(response["answer"]),
],
"context": response.get("context", ""),
"answer": response.get("answer", "")
}
workflow = StateGraph(state_schema=State)
workflow.add_node("model", call_model)
workflow.add_edge(START, "model")
#memory = MemorySaver()
#return workflow.compile(checkpointer=memory)
return workflow.compile() # Stateless; relies on session memory only
# For agent_chain.invoke
def build_graph_with_callable(call_fn):
def call_model(state: State):
response = call_fn({"input": state["input"]})
return {
"chat_history": [
HumanMessage(state["input"]),
AIMessage(response.get("output", response.get("answer", ""))),
],
"context": response.get("context", ""),
"answer": response.get("output", response.get("answer", ""))
}
workflow = StateGraph(state_schema=State)
workflow.add_node("model", call_model)
workflow.add_edge(START, "model")
return workflow.compile()