Julia Ostheimer
Move trigger_ai_message_with_tool_call to conversation/generate.py
bf0eea7
import structlog
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
)
from langchain_core.runnables import RunnableParallel
from langgraph.graph import MessagesState
from langfuse import Langfuse
from pydantic import BaseModel
from tools.langfuse_client import get_langfuse_handler, get_langfuse_client
from conversation.citation_utils import CitedAnswer, format_artifacts_to_string, embed_references
from config import app_settings
logger = structlog.get_logger(__name__)
llm = init_chat_model(
app_settings.llm_model,
model_provider=app_settings.model_provider,
region_name=app_settings.llm_region,
aws_access_key_id=app_settings.aws_access_key_id,
aws_secret_access_key=app_settings.aws_secret_access_key,
)
structured_llm = llm.with_structured_output(CitedAnswer)
# Get Langfuse Callback handler
langfuse_handler = get_langfuse_handler()
def get_rag_prompt_from_langfuse(
prompt_name: str,
prompt_label: str
) -> ChatPromptTemplate:
"""
Get the prompt for the RAG-system via the Langfuse prompt management system and convert it to a Langchain prompt.
Args:
prompt_name (str): The name of the Langfuse prompt.
prompt_label (str): The label of the Langfuse prompt.
Returns:
ChatPromptTemplate: Prompt template for chat model to use.
"""
# Get Langfuse client
langfuse = get_langfuse_client()
# Get current production version of prompt via Langfuse
langfuse_prompt = langfuse.get_prompt(prompt_name, label=prompt_label)
# Print loaded Langfuse prompt into logs
logger.info("This is the loaded prompt from Langfuse: %s", langfuse_prompt.prompt)
# Convert Langfuse prompt to Langchain prompt
langchain_prompt = ChatPromptTemplate.from_messages(
langfuse_prompt.get_langchain_prompt(),
)
langchain_prompt.metadata = {"langfuse_prompt": langfuse_prompt}
return langchain_prompt
# User input
class ChatHistory(BaseModel):
chat_history: list[AIMessage | HumanMessage]
question: str
context: str
_inputs = RunnableParallel(
{
"question": lambda x: x["question"],
"chat_history": lambda x: x["chat_history"],
"context": lambda x: x["context"]
}
).with_types(input_type=ChatHistory)
# Get current production version of RAG prompt via Langfuse
langchain_prompt = get_rag_prompt_from_langfuse(
prompt_name="answer-question-with-context-and-msg-history-copy",
prompt_label="production"
)
chain = _inputs | langchain_prompt | structured_llm
def generate(state: MessagesState):
"""Generate answer."""
# Get generated ToolMessages
recent_tool_messages = []
for message in reversed(state["messages"]):
if message.type == "tool":
recent_tool_messages.append(message)
else:
break
tool_messages = recent_tool_messages[::-1]
# Format into prompt
all_artifacts = []
for message in tool_messages:
if message.artifact:
all_artifacts.extend(message.artifact)
docs_content = format_artifacts_to_string(all_artifacts)
logger.info("Tool messages", context=docs_content)
conversation_messages = [
message
for message in state["messages"]
if message.type in ("human", "system")
or (message.type == "ai" and not message.tool_calls)
]
structured_response = chain.invoke({
"question": conversation_messages[-1].content,
"chat_history": conversation_messages,
"context": docs_content,
}, config={"callbacks": [langfuse_handler]})
if structured_response.sources:
formatted_answer = embed_references(structured_response)
main_answer = {"role": "assistant", "content": formatted_answer}
citations = f"{structured_response.sources}"
else:
main_answer = {"role": "assistant", "content": structured_response.answer}
citations = []
return {
"messages": main_answer,
"llm-answer": structured_response.answer,
"sources": citations
}
def trigger_ai_message_with_tool_call(state: MessagesState) -> AIMessage:
"""
Takes the last user message from the state and returns an AIMessage
with example tool_calls populated.
Args:
state (dict): A dictionary with a 'messages' key containing a list of LangChain messages.
Returns:
AIMessage: An AIMessage with tool_calls based on the last user message.
"""
# Filter for user messages
user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
if not user_messages:
raise ValueError("No user messages found in the previous messages.")
last_user_msg = user_messages[-1]
tool_call = ToolCall(
name="retrieve",
args={"query": last_user_msg.content},
id="tool_call_1"
)
# Construct the AIMessage with tool_calls
ai_message = AIMessage(
content="Calling the retrieve function...",
tool_calls=[tool_call]
)
return {"messages": [ai_message]}