ABAO77's picture
Upload 72 files
031378e verified
from typing import TypedDict, Optional
from langchain_core.messages import AnyMessage, ToolMessage
from langgraph.graph.message import add_messages
from typing import Sequence, Annotated
from src.agents.custom_chatbot.prompt import get_custom_chatbot_chains
from src.utils.logger import logger
import re
from src.config.mongo import bot_crud
class State(TypedDict):
messages: Annotated[Sequence[AnyMessage], add_messages]
remaining_steps: int
prompt: Optional[str]
name: Optional[str]
model_name: Optional[str]
def get_info_collection(messages):
for idx, message in enumerate(messages):
if isinstance(message, ToolMessage):
break
info = messages[idx - 1].tool_calls[0].get("args", {}).get("info", "")
name = messages[idx - 1].tool_calls[0].get("args", {}).get("name", "")
return name, info
async def collection_info_agent(state: State):
model_name = state.get("model_name")
_, collection_info_agent = get_custom_chatbot_chains(model_name)
return await collection_info_agent.ainvoke(state)
async def create_prompt(state: State):
messages = state.get("messages")
name, info = get_info_collection(messages)
logger.info(f"create_prompt {info}")
model_name = state.get("model_name")
create_system_chain, _ = get_custom_chatbot_chains(model_name)
res = await create_system_chain.ainvoke({"info": info})
return {"prompt": res.content, "name": name}
async def save_prompt(state: State):
prompt = state["prompt"]
matches = re.findall(r"```(.*?)```", prompt, re.DOTALL)
if matches:
prompt = matches[0]
name = state["name"]
await bot_crud.create({"name": name, "prompt": prompt, "tools": []})