Spaces:
Sleeping
Sleeping
| from typing import TypedDict, Optional | |
| from langchain_core.messages import AnyMessage, ToolMessage | |
| from langgraph.graph.message import add_messages | |
| from typing import Sequence, Annotated | |
| from src.agents.custom_chatbot.prompt import get_custom_chatbot_chains | |
| from src.utils.logger import logger | |
| import re | |
| from src.config.mongo import bot_crud | |
| class State(TypedDict): | |
| messages: Annotated[Sequence[AnyMessage], add_messages] | |
| remaining_steps: int | |
| prompt: Optional[str] | |
| name: Optional[str] | |
| model_name: Optional[str] | |
| def get_info_collection(messages): | |
| for idx, message in enumerate(messages): | |
| if isinstance(message, ToolMessage): | |
| break | |
| info = messages[idx - 1].tool_calls[0].get("args", {}).get("info", "") | |
| name = messages[idx - 1].tool_calls[0].get("args", {}).get("name", "") | |
| return name, info | |
| async def collection_info_agent(state: State): | |
| model_name = state.get("model_name") | |
| _, collection_info_agent = get_custom_chatbot_chains(model_name) | |
| return await collection_info_agent.ainvoke(state) | |
| async def create_prompt(state: State): | |
| messages = state.get("messages") | |
| name, info = get_info_collection(messages) | |
| logger.info(f"create_prompt {info}") | |
| model_name = state.get("model_name") | |
| create_system_chain, _ = get_custom_chatbot_chains(model_name) | |
| res = await create_system_chain.ainvoke({"info": info}) | |
| return {"prompt": res.content, "name": name} | |
| async def save_prompt(state: State): | |
| prompt = state["prompt"] | |
| matches = re.findall(r"```(.*?)```", prompt, re.DOTALL) | |
| if matches: | |
| prompt = matches[0] | |
| name = state["name"] | |
| await bot_crud.create({"name": name, "prompt": prompt, "tools": []}) | |