|
|
import asyncio |
|
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel |
|
|
from langchain_core.messages import AIMessage |
|
|
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable |
|
|
from langgraph.graph import END, MessagesState, StateGraph |
|
|
from langgraph.types import StreamWriter |
|
|
|
|
|
from agents.bg_task_agent.task import Task |
|
|
from core import get_model, settings |
|
|
|
|
|
|
|
|
class AgentState(MessagesState, total=False): |
|
|
"""`total=False` is PEP589 specs. |
|
|
|
|
|
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality |
|
|
""" |
|
|
|
|
|
|
|
|
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]: |
|
|
preprocessor = RunnableLambda( |
|
|
lambda state: state["messages"], |
|
|
name="StateModifier", |
|
|
) |
|
|
return preprocessor | model |
|
|
|
|
|
|
|
|
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: |
|
|
m = get_model(config["configurable"].get("model", settings.DEFAULT_MODEL)) |
|
|
model_runnable = wrap_model(m) |
|
|
response = await model_runnable.ainvoke(state, config) |
|
|
|
|
|
|
|
|
return {"messages": [response]} |
|
|
|
|
|
|
|
|
async def bg_task(state: AgentState, writer: StreamWriter) -> AgentState: |
|
|
task1 = Task("Simple task 1...", writer) |
|
|
task2 = Task("Simple task 2...", writer) |
|
|
|
|
|
task1.start() |
|
|
await asyncio.sleep(2) |
|
|
task2.start() |
|
|
await asyncio.sleep(2) |
|
|
task1.write_data(data={"status": "Still running..."}) |
|
|
await asyncio.sleep(2) |
|
|
task2.finish(result="error", data={"output": 42}) |
|
|
await asyncio.sleep(2) |
|
|
task1.finish(result="success", data={"output": 42}) |
|
|
return {"messages": []} |
|
|
|
|
|
|
|
|
|
|
|
agent = StateGraph(AgentState) |
|
|
agent.add_node("model", acall_model) |
|
|
agent.add_node("bg_task", bg_task) |
|
|
agent.set_entry_point("bg_task") |
|
|
|
|
|
agent.add_edge("bg_task", "model") |
|
|
agent.add_edge("model", END) |
|
|
|
|
|
bg_task_agent = agent.compile() |
|
|
|