mrmtaeb commited on
Commit
d3d1a7c
·
verified ·
1 Parent(s): e0fa626

Update src/state.py

Browse files
Files changed (1) hide show
  1. src/state.py +29 -1
src/state.py CHANGED
@@ -1 +1,29 @@
1
- # Placeholder for LangGraph state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import START, StateGraph
2
+ from langgraph.checkpoint.memory import MemorySaver
3
+ from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
4
+ from typing import Sequence
5
+ from typing_extensions import TypedDict, Annotated
6
+
7
+ class State(TypedDict):
8
+ input: str
9
+ chat_history: Annotated[Sequence[BaseMessage], "add_messages"]
10
+ context: str
11
+ answer: str
12
+
13
+ def build_graph(rag_chain):
14
+ def call_model(state: State):
15
+ response = rag_chain.invoke(state)
16
+ return {
17
+ "chat_history": [
18
+ HumanMessage(state["input"]),
19
+ AIMessage(response["answer"]),
20
+ ],
21
+ "context": response.get("context", ""),
22
+ "answer": response.get("answer", "")
23
+ }
24
+
25
+ workflow = StateGraph(state_schema=State)
26
+ workflow.add_node("model", call_model)
27
+ workflow.add_edge(START, "model")
28
+ memory = MemorySaver()
29
+ return workflow.compile(checkpointer=memory)