backend / src /agents /bg_task_agent /bg_task_agent.py
anujjoshi3105's picture
initial
22dcdfd
import asyncio
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.types import StreamWriter
from agents.bg_task_agent.task import Task
from core import get_model, settings
class AgentState(MessagesState, total=False):
"""`total=False` is PEP589 specs.
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
"""
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
preprocessor = RunnableLambda(
lambda state: state["messages"],
name="StateModifier",
)
return preprocessor | model # type: ignore[return-value]
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL))
model_runnable = wrap_model(m)
response = await model_runnable.ainvoke(state, config)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
async def bg_task(state: AgentState, writer: StreamWriter) -> AgentState:
task1 = Task("Simple task 1...", writer)
task2 = Task("Simple task 2...", writer)
task1.start()
await asyncio.sleep(2)
task2.start()
await asyncio.sleep(2)
task1.write_data(data={"status": "Still running..."})
await asyncio.sleep(2)
task2.finish(result="error", data={"output": 42})
await asyncio.sleep(2)
task1.finish(result="success", data={"output": 42})
return {"messages": []}
# Define the graph
agent = StateGraph(AgentState)
agent.add_node("model", acall_model)
agent.add_node("bg_task", bg_task)
agent.set_entry_point("bg_task")
agent.add_edge("bg_task", "model")
agent.add_edge("model", END)
bg_task_agent = agent.compile()