File size: 1,716 Bytes
031378e
744b763
16d5a75
 
744b763
16d5a75
 
 
744b763
 
16d5a75
 
 
 
 
744b763
 
16d5a75
 
 
 
 
744b763
 
 
 
 
 
031378e
744b763
 
16d5a75
 
744b763
16d5a75
 
 
031378e
744b763
 
 
16d5a75
 
 
744b763
16d5a75
 
 
744b763
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import TypedDict, Optional
from langchain_core.messages import AnyMessage, ToolMessage
from langgraph.graph.message import add_messages
from typing import Sequence, Annotated
from src.agents.custom_chatbot.prompt import get_custom_chatbot_chains
from src.utils.logger import logger
import re
from src.config.mongo import bot_crud


class State(TypedDict):
    messages: Annotated[Sequence[AnyMessage], add_messages]
    remaining_steps: int
    prompt: Optional[str]
    name: Optional[str]
    model_name: Optional[str]


def get_info_collection(messages):
    for idx, message in enumerate(messages):
        if isinstance(message, ToolMessage):
            break
    info = messages[idx - 1].tool_calls[0].get("args", {}).get("info", "")
    name = messages[idx - 1].tool_calls[0].get("args", {}).get("name", "")
    return name, info


async def collection_info_agent(state: State):
    model_name = state.get("model_name")
    _, collection_info_agent = get_custom_chatbot_chains(model_name)
    return await collection_info_agent.ainvoke(state)


async def create_prompt(state: State):
    messages = state.get("messages")
    name, info = get_info_collection(messages)
    logger.info(f"create_prompt {info}")
    model_name = state.get("model_name")
    create_system_chain, _ = get_custom_chatbot_chains(model_name)
    res = await create_system_chain.ainvoke({"info": info})
    return {"prompt": res.content, "name": name}


async def save_prompt(state: State):
    prompt = state["prompt"]
    matches = re.findall(r"```(.*?)```", prompt, re.DOTALL)
    if matches:
        prompt = matches[0]
    name = state["name"]
    await bot_crud.create({"name": name, "prompt": prompt, "tools": []})