subashpoudel's picture
Code refactoring
38cf703
raw
history blame
1.61 kB
from .prompts import tool_return_prompt , extract_user_reference_prompt
from langchain_core.messages import SystemMessage, HumanMessage
from src.genai.utils.models_loader import llm_gpt
from .state import ToolResponseFormatter, UserReferenceResponseFormatter
class ToolReturnNode:
"""Node for determining which tools to use based on user messages."""
def __init__(self, llm=llm_gpt):
self.llm = llm
def run(self, state):
if len(state["messages"]) > 23:
state["messages"] = state["messages"][-18:]
template = [SystemMessage(content=tool_return_prompt)] + state["messages"]
response = self.llm.with_structured_output(ToolResponseFormatter).invoke(template)
return {"messages": [{'role': 'assistant', 'content': f"The exact name of the tool is: {response}"}]}
class ExtractUserReferenceNode:
"""Node for extracting video idea and story from user's messages."""
def __init__(self, llm=llm_gpt):
self.llm = llm
def run(self, state):
latest_human_message = next(
(msg for msg in reversed(state['messages']) if isinstance(msg, HumanMessage)),
None
)
template = [SystemMessage(content=extract_user_reference_prompt),
HumanMessage(content=latest_human_message.content)]
response = self.llm.with_structured_output(UserReferenceResponseFormatter).invoke(template)
return {'messages': [{
'role': 'assistant',
'content': f"The video idea is: {response.video_idea} and the video story is: {response.video_story}"
}]}