ABAO77's picture
Upload 107 files
16d5a75 verified
from math import log
import os
from typing import TypedDict, Optional, List, Literal
from langchain_core.documents import Document
from langchain_core.tools import tool
from src.utils.helper import (
fake_token_counter,
convert_list_context_source_to_str,
convert_message,
)
from src.utils.logger import logger
from langchain_core.messages import trim_messages, AnyMessage
from .prompt import entry_chain, build_lesson_plan_chain
from .data import primary_level_format, high_school_level_format
from typing import Annotated, Sequence
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage
class State(TypedDict):
user_query: str | AnyMessage
messages_history: list
document_id_selected: Optional[List]
lesson_name: str
subject_name: str
class_number: int
entry_response: str
build_lesson_plan_response: list[AnyMessage]
final_response: str
messages: Annotated[Sequence[AnyMessage], add_messages]
def trim_history(state: State):
history = (
convert_message(state["messages_history"])
if state.get("messages_history")
else None
)
if not history:
return {"messages_history": []}
chat_message_history = trim_messages(
history,
strategy="last",
token_counter=fake_token_counter,
max_tokens=int(os.getenv("HISTORY_TOKEN_LIMIT", 4000)),
start_on="human",
end_on="ai",
include_system=False,
allow_partial=False,
)
return {"messages_history": chat_message_history}
async def entry(state: State):
logger.info(f"Entry {state['messages']}")
entry_response: AnyMessage = await entry_chain.ainvoke(state)
logger.info(f"Entry response: {entry_response}")
logger.info(f"Entry response tool_calls: {entry_response.content}")
# Check if entry_response has tool_calls attribute and it's not empty
if (
hasattr(entry_response, "tool_calls")
and entry_response.tool_calls
and len(entry_response.tool_calls) > 0
and any(
tool_call["name"] == "EntryExtractor"
for tool_call in entry_response.tool_calls
)
and "args" in entry_response.tool_calls[0]
and "class_number" in entry_response.tool_calls[0]["args"]
and "subject_name" in entry_response.tool_calls[0]["args"]
and "lesson_name" in entry_response.tool_calls[0]["args"]
):
logger.info("V么 膽芒y")
class_number = str(int(entry_response.tool_calls[0]["args"]["class_number"]))
subject_name = str(entry_response.tool_calls[0]["args"]["subject_name"])
lesson_name = str(entry_response.tool_calls[0]["args"]["lesson_name"])
return {
"entry_response": entry_response.content,
"class_number": class_number,
"subject_name": subject_name,
"lesson_name": lesson_name,
}
logger.info("kh么ng v么")
return {
"entry_response": entry_response.content,
"class_number": None,
"subject_name": None,
"lesson_name": None,
"final_response": entry_response.content,
}
async def build_lesson_plan(state: State):
logger.info(f"build_lesson_plan {state['messages']}")
has_change_lesson = any(
hasattr(message, "tool_calls")
and any(tool_call["name"] == "ChangeLesson" for tool_call in message.tool_calls)
for message in state["messages"]
)
has_extract_lesson = any(
hasattr(message, "tool_calls")
and any(
tool_call["name"] == "extract_lesson_content"
for tool_call in message.tool_calls
)
for message in state["messages"]
)
if has_change_lesson and not has_extract_lesson:
state["messages"] = []
state["messages_history"] = []
if has_extract_lesson and has_change_lesson:
state["messages"] = [
message
for message in state["messages"]
if not hasattr(message, "tool_calls")
or not any(
tool_call["name"] == "ChangeLesson" for tool_call in message.tool_calls
)
]
class_number = state["class_number"]
if int(class_number) > 5:
lesson_plan_format = high_school_level_format
else:
lesson_plan_format = primary_level_format
state["lesson_plan_format"] = lesson_plan_format
build_lesson_plan_response = await build_lesson_plan_chain.ainvoke(state)
return {
"build_lesson_plan_response": [build_lesson_plan_response],
"messages": build_lesson_plan_response,
"final_response": build_lesson_plan_response.content,
}
def change_lesson(state: State):
build_lesson_plan_response = state["build_lesson_plan_response"][-1]
print("Build lesson plan response: ", build_lesson_plan_response)
logger.info(f"Build lesson plan response: {build_lesson_plan_response}")
# Check if there are tool calls in the response
if (
hasattr(build_lesson_plan_response, "tool_calls")
and build_lesson_plan_response.tool_calls
):
# Extract values from tool_calls
try:
# Get the first tool call (should be ChangeLesson)
tool_call = build_lesson_plan_response.tool_calls[0]
# Extract class_number (handle float conversion if needed)
class_number_value = tool_call["args"]["class_number"]
if isinstance(class_number_value, float):
class_number = str(int(class_number_value))
else:
class_number = str(class_number_value)
# Extract subject and lesson name
subject_name = str(tool_call["args"]["subject_name"])
lesson_name = str(tool_call["args"]["lesson_name"])
logger.info(
f"Extracted lesson data: class={class_number}, subject={subject_name}, lesson={lesson_name}"
)
return {
"class_number": class_number,
"subject_name": subject_name,
"lesson_name": lesson_name,
}
except (KeyError, IndexError, TypeError) as e:
logger.error(f"Error extracting lesson data: {e}")
# Default values if extraction fails
logger.warning("Could not extract lesson data from response")
return {
"class_number": None,
"subject_name": None,
"lesson_name": None,
}