import structlog from langchain.chat_models import init_chat_model from langchain_core.messages import SystemMessage, AIMessage, HumanMessage from langchain_core.messages.tool import ToolCall from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ( ChatPromptTemplate, MessagesPlaceholder, ) from langchain_core.runnables import RunnableParallel from langgraph.graph import MessagesState from langfuse import Langfuse from pydantic import BaseModel from tools.langfuse_client import get_langfuse_handler, get_langfuse_client from conversation.citation_utils import CitedAnswer, format_artifacts_to_string, embed_references from config import app_settings logger = structlog.get_logger(__name__) llm = init_chat_model( app_settings.llm_model, model_provider=app_settings.model_provider, region_name=app_settings.llm_region, aws_access_key_id=app_settings.aws_access_key_id, aws_secret_access_key=app_settings.aws_secret_access_key, ) structured_llm = llm.with_structured_output(CitedAnswer) # Get Langfuse Callback handler langfuse_handler = get_langfuse_handler() def get_rag_prompt_from_langfuse( prompt_name: str, prompt_label: str ) -> ChatPromptTemplate: """ Get the prompt for the RAG-system via the Langfuse prompt management system and convert it to a Langchain prompt. Args: prompt_name (str): The name of the Langfuse prompt. prompt_label (str): The label of the Langfuse prompt. Returns: ChatPromptTemplate: Prompt template for chat model to use. """ # Get Langfuse client langfuse = get_langfuse_client() # Get current production version of prompt via Langfuse langfuse_prompt = langfuse.get_prompt(prompt_name, label=prompt_label) # Print loaded Langfuse prompt into logs logger.info("This is the loaded prompt from Langfuse: %s", langfuse_prompt.prompt) # Convert Langfuse prompt to Langchain prompt langchain_prompt = ChatPromptTemplate.from_messages( langfuse_prompt.get_langchain_prompt(), ) langchain_prompt.metadata = {"langfuse_prompt": langfuse_prompt} return langchain_prompt # User input class ChatHistory(BaseModel): chat_history: list[AIMessage | HumanMessage] question: str context: str _inputs = RunnableParallel( { "question": lambda x: x["question"], "chat_history": lambda x: x["chat_history"], "context": lambda x: x["context"] } ).with_types(input_type=ChatHistory) # Get current production version of RAG prompt via Langfuse langchain_prompt = get_rag_prompt_from_langfuse( prompt_name="answer-question-with-context-and-msg-history-copy", prompt_label="production" ) chain = _inputs | langchain_prompt | structured_llm def generate(state: MessagesState): """Generate answer.""" # Get generated ToolMessages recent_tool_messages = [] for message in reversed(state["messages"]): if message.type == "tool": recent_tool_messages.append(message) else: break tool_messages = recent_tool_messages[::-1] # Format into prompt all_artifacts = [] for message in tool_messages: if message.artifact: all_artifacts.extend(message.artifact) docs_content = format_artifacts_to_string(all_artifacts) logger.info("Tool messages", context=docs_content) conversation_messages = [ message for message in state["messages"] if message.type in ("human", "system") or (message.type == "ai" and not message.tool_calls) ] structured_response = chain.invoke({ "question": conversation_messages[-1].content, "chat_history": conversation_messages, "context": docs_content, }, config={"callbacks": [langfuse_handler]}) if structured_response.sources: formatted_answer = embed_references(structured_response) main_answer = {"role": "assistant", "content": formatted_answer} citations = f"{structured_response.sources}" else: main_answer = {"role": "assistant", "content": structured_response.answer} citations = [] return { "messages": main_answer, "llm-answer": structured_response.answer, "sources": citations } def trigger_ai_message_with_tool_call(state: MessagesState) -> AIMessage: """ Takes the last user message from the state and returns an AIMessage with example tool_calls populated. Args: state (dict): A dictionary with a 'messages' key containing a list of LangChain messages. Returns: AIMessage: An AIMessage with tool_calls based on the last user message. """ # Filter for user messages user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)] if not user_messages: raise ValueError("No user messages found in the previous messages.") last_user_msg = user_messages[-1] tool_call = ToolCall( name="retrieve", args={"query": last_user_msg.content}, id="tool_call_1" ) # Construct the AIMessage with tool_calls ai_message = AIMessage( content="Calling the retrieve function...", tool_calls=[tool_call] ) return {"messages": [ai_message]}