File size: 6,468 Bytes
16d5a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from math import log
import os
from typing import TypedDict, Optional, List, Literal
from langchain_core.documents import Document
from langchain_core.tools import tool
from src.utils.helper import (
    fake_token_counter,
    convert_list_context_source_to_str,
    convert_message,
)
from src.utils.logger import logger
from langchain_core.messages import trim_messages, AnyMessage
from .prompt import entry_chain, build_lesson_plan_chain
from .data import primary_level_format, high_school_level_format
from typing import Annotated, Sequence
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage


class State(TypedDict):
    user_query: str | AnyMessage
    messages_history: list
    document_id_selected: Optional[List]
    lesson_name: str
    subject_name: str
    class_number: int
    entry_response: str
    build_lesson_plan_response: list[AnyMessage]
    final_response: str
    messages: Annotated[Sequence[AnyMessage], add_messages]


def trim_history(state: State):
    history = (
        convert_message(state["messages_history"])
        if state.get("messages_history")
        else None
    )

    if not history:
        return {"messages_history": []}

    chat_message_history = trim_messages(
        history,
        strategy="last",
        token_counter=fake_token_counter,
        max_tokens=int(os.getenv("HISTORY_TOKEN_LIMIT", 4000)),
        start_on="human",
        end_on="ai",
        include_system=False,
        allow_partial=False,
    )
    return {"messages_history": chat_message_history}


async def entry(state: State):
    logger.info(f"Entry {state['messages']}")
    entry_response: AnyMessage = await entry_chain.ainvoke(state)
    logger.info(f"Entry response: {entry_response}")
    logger.info(f"Entry response tool_calls: {entry_response.content}")
    # Check if entry_response has tool_calls attribute and it's not empty
    if (
        hasattr(entry_response, "tool_calls")
        and entry_response.tool_calls
        and len(entry_response.tool_calls) > 0
        and any(
            tool_call["name"] == "EntryExtractor"
            for tool_call in entry_response.tool_calls
        )
        and "args" in entry_response.tool_calls[0]
        and "class_number" in entry_response.tool_calls[0]["args"]
        and "subject_name" in entry_response.tool_calls[0]["args"]
        and "lesson_name" in entry_response.tool_calls[0]["args"]
    ):
        logger.info("V么 膽芒y")
        class_number = str(int(entry_response.tool_calls[0]["args"]["class_number"]))
        subject_name = str(entry_response.tool_calls[0]["args"]["subject_name"])
        lesson_name = str(entry_response.tool_calls[0]["args"]["lesson_name"])
        return {
            "entry_response": entry_response.content,
            "class_number": class_number,
            "subject_name": subject_name,
            "lesson_name": lesson_name,
        }
    logger.info("kh么ng v么")
    return {
        "entry_response": entry_response.content,
        "class_number": None,
        "subject_name": None,
        "lesson_name": None,
        "final_response": entry_response.content,
    }


async def build_lesson_plan(state: State):
    logger.info(f"build_lesson_plan {state['messages']}")
    has_change_lesson = any(
        hasattr(message, "tool_calls")
        and any(tool_call["name"] == "ChangeLesson" for tool_call in message.tool_calls)
        for message in state["messages"]
    )
    has_extract_lesson = any(
        hasattr(message, "tool_calls")
        and any(
            tool_call["name"] == "extract_lesson_content"
            for tool_call in message.tool_calls
        )
        for message in state["messages"]
    )
    if has_change_lesson and not has_extract_lesson:
        state["messages"] = []
        state["messages_history"] = []
    if has_extract_lesson and has_change_lesson:
        state["messages"] = [
            message
            for message in state["messages"]
            if not hasattr(message, "tool_calls")
            or not any(
                tool_call["name"] == "ChangeLesson" for tool_call in message.tool_calls
            )
        ]
    class_number = state["class_number"]
    if int(class_number) > 5:
        lesson_plan_format = high_school_level_format
    else:
        lesson_plan_format = primary_level_format
    state["lesson_plan_format"] = lesson_plan_format
    build_lesson_plan_response = await build_lesson_plan_chain.ainvoke(state)

    return {
        "build_lesson_plan_response": [build_lesson_plan_response],
        "messages": build_lesson_plan_response,
        "final_response": build_lesson_plan_response.content,
    }


def change_lesson(state: State):
    build_lesson_plan_response = state["build_lesson_plan_response"][-1]
    print("Build lesson plan response: ", build_lesson_plan_response)
    logger.info(f"Build lesson plan response: {build_lesson_plan_response}")

    # Check if there are tool calls in the response
    if (
        hasattr(build_lesson_plan_response, "tool_calls")
        and build_lesson_plan_response.tool_calls
    ):
        # Extract values from tool_calls
        try:
            # Get the first tool call (should be ChangeLesson)
            tool_call = build_lesson_plan_response.tool_calls[0]

            # Extract class_number (handle float conversion if needed)
            class_number_value = tool_call["args"]["class_number"]
            if isinstance(class_number_value, float):
                class_number = str(int(class_number_value))
            else:
                class_number = str(class_number_value)

            # Extract subject and lesson name
            subject_name = str(tool_call["args"]["subject_name"])
            lesson_name = str(tool_call["args"]["lesson_name"])

            logger.info(
                f"Extracted lesson data: class={class_number}, subject={subject_name}, lesson={lesson_name}"
            )

            return {
                "class_number": class_number,
                "subject_name": subject_name,
                "lesson_name": lesson_name,
            }
        except (KeyError, IndexError, TypeError) as e:
            logger.error(f"Error extracting lesson data: {e}")

    # Default values if extraction fails
    logger.warning("Could not extract lesson data from response")
    return {
        "class_number": None,
        "subject_name": None,
        "lesson_name": None,
    }