File size: 3,141 Bytes
16d5a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from langchain_core.documents import Document
from typing import Union, TypedDict, Dict, Any
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from langchain_core.runnables import RunnableLambda
from langgraph.prebuilt import ToolNode
from langchain_core.messages import ToolMessage

# Use TypeVar instead of direct import to avoid circular dependency
from typing import TypeVar

State = TypeVar("State", bound=Dict[str, Any])


def fake_token_counter(messages: Union[list[BaseMessage], BaseMessage]) -> int:
    if isinstance(messages, list):
        return sum(len(message.content.split()) for message in messages)
    return len(messages.content.split())


def convert_list_context_source_to_str(contexts: list[Document]):
    formatted_str = ""
    for i, context in enumerate(contexts):
        formatted_str += f"Document index {i}:\nContent: {context.page_content}\n"
        formatted_str += "----------------------------------------------\n\n"
    return formatted_str


def convert_message(messages):
    list_message = []
    for message in messages:
        if message["type"] == "human":
            list_message.append(HumanMessage(content=message["content"]))
        else:
            list_message.append(AIMessage(content=message["content"]))
    return list_message


def create_tool_node_with_fallback(tools: list) -> dict:
    return ToolNode(tools).with_fallbacks(
        [RunnableLambda(handle_tool_error)], exception_key="error"
    )


def handle_tool_error(state: State) -> dict:
    error = state.get("error")
    tool_messages = state["build_lesson_plan_response"]
    return {
        "build_lesson_plan_response": [
            ToolMessage(
                content=f"Error: {repr(error)}\n please fix your mistakes.",
                tool_call_id=tc["id"],
            )
            for tc in tool_messages.tool_calls
        ]
    }


def filter_image_messages(messages):
    """
    Filters out messages containing images from a list of message dictionaries.

    Args:
        messages (list): A list of dictionaries, each representing a message with 'role' and 'content' keys.

    Returns:
        list: A new list of dictionaries with messages containing images removed.
    """
    filtered_messages = []

    for message in messages:
        # Check if 'content' is a list (indicating multiple parts)
        if isinstance(message["content"], list):
            # Filter out parts that are of type 'image'
            filtered_content = [
                part for part in message["content"] if part.get("type") != "image"
            ]
            # If there are remaining parts, add the message to the filtered list
            if filtered_content:
                print("filtered_content", filtered_content)
                filtered_messages.append(
                    {
                        "role": message["role"],
                        "content": filtered_content[0]["text"]
                    }
                )
        else:
            # If 'content' is not a list, simply add the message
            filtered_messages.append(message)

    return filtered_messages