| | import structlog |
| | from langchain.chat_models import init_chat_model |
| | from langchain_core.messages import SystemMessage, AIMessage, HumanMessage |
| | from langchain_core.messages.tool import ToolCall |
| | from langchain_core.output_parsers import StrOutputParser |
| | from langchain_core.prompts import ( |
| | ChatPromptTemplate, |
| | MessagesPlaceholder, |
| | ) |
| | from langchain_core.runnables import RunnableParallel |
| | from langgraph.graph import MessagesState |
| | from langfuse import Langfuse |
| |
|
| | from pydantic import BaseModel |
| |
|
| | from tools.langfuse_client import get_langfuse_handler, get_langfuse_client |
| | from conversation.citation_utils import CitedAnswer, format_artifacts_to_string, embed_references |
| |
|
| | from config import app_settings |
| |
|
| | logger = structlog.get_logger(__name__) |
| |
|
| | llm = init_chat_model( |
| | app_settings.llm_model, |
| | model_provider=app_settings.model_provider, |
| | region_name=app_settings.llm_region, |
| | aws_access_key_id=app_settings.aws_access_key_id, |
| | aws_secret_access_key=app_settings.aws_secret_access_key, |
| | ) |
| |
|
| | structured_llm = llm.with_structured_output(CitedAnswer) |
| |
|
| | |
| | langfuse_handler = get_langfuse_handler() |
| |
|
| | def get_rag_prompt_from_langfuse( |
| | prompt_name: str, |
| | prompt_label: str |
| | ) -> ChatPromptTemplate: |
| | """ |
| | Get the prompt for the RAG-system via the Langfuse prompt management system and convert it to a Langchain prompt. |
| | |
| | Args: |
| | prompt_name (str): The name of the Langfuse prompt. |
| | prompt_label (str): The label of the Langfuse prompt. |
| | |
| | Returns: |
| | ChatPromptTemplate: Prompt template for chat model to use. |
| | """ |
| |
|
| | |
| | langfuse = get_langfuse_client() |
| |
|
| | |
| | langfuse_prompt = langfuse.get_prompt(prompt_name, label=prompt_label) |
| |
|
| | |
| | logger.info("This is the loaded prompt from Langfuse: %s", langfuse_prompt.prompt) |
| |
|
| | |
| | langchain_prompt = ChatPromptTemplate.from_messages( |
| | langfuse_prompt.get_langchain_prompt(), |
| | ) |
| | langchain_prompt.metadata = {"langfuse_prompt": langfuse_prompt} |
| |
|
| | return langchain_prompt |
| |
|
| | |
| | class ChatHistory(BaseModel): |
| | chat_history: list[AIMessage | HumanMessage] |
| | question: str |
| | context: str |
| |
|
| |
|
| | _inputs = RunnableParallel( |
| | { |
| | "question": lambda x: x["question"], |
| | "chat_history": lambda x: x["chat_history"], |
| | "context": lambda x: x["context"] |
| | } |
| | ).with_types(input_type=ChatHistory) |
| |
|
| | |
| | langchain_prompt = get_rag_prompt_from_langfuse( |
| | prompt_name="answer-question-with-context-and-msg-history-copy", |
| | prompt_label="production" |
| | ) |
| |
|
| | chain = _inputs | langchain_prompt | structured_llm |
| |
|
| |
|
| | def generate(state: MessagesState): |
| | """Generate answer.""" |
| | |
| | recent_tool_messages = [] |
| | for message in reversed(state["messages"]): |
| | if message.type == "tool": |
| | recent_tool_messages.append(message) |
| | else: |
| | break |
| | tool_messages = recent_tool_messages[::-1] |
| | |
| | all_artifacts = [] |
| | for message in tool_messages: |
| | if message.artifact: |
| | all_artifacts.extend(message.artifact) |
| | |
| | docs_content = format_artifacts_to_string(all_artifacts) |
| | |
| | logger.info("Tool messages", context=docs_content) |
| | |
| | conversation_messages = [ |
| | message |
| | for message in state["messages"] |
| | if message.type in ("human", "system") |
| | or (message.type == "ai" and not message.tool_calls) |
| | ] |
| | structured_response = chain.invoke({ |
| | "question": conversation_messages[-1].content, |
| | "chat_history": conversation_messages, |
| | "context": docs_content, |
| | }, config={"callbacks": [langfuse_handler]}) |
| | |
| | if structured_response.sources: |
| | formatted_answer = embed_references(structured_response) |
| | main_answer = {"role": "assistant", "content": formatted_answer} |
| | citations = f"{structured_response.sources}" |
| | else: |
| | main_answer = {"role": "assistant", "content": structured_response.answer} |
| | citations = [] |
| |
|
| | return { |
| | "messages": main_answer, |
| | "llm-answer": structured_response.answer, |
| | "sources": citations |
| | } |
| |
|
| | def trigger_ai_message_with_tool_call(state: MessagesState) -> AIMessage: |
| | """ |
| | Takes the last user message from the state and returns an AIMessage |
| | with example tool_calls populated. |
| | |
| | Args: |
| | state (dict): A dictionary with a 'messages' key containing a list of LangChain messages. |
| | |
| | Returns: |
| | AIMessage: An AIMessage with tool_calls based on the last user message. |
| | """ |
| | |
| | |
| | user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)] |
| | |
| | if not user_messages: |
| | raise ValueError("No user messages found in the previous messages.") |
| | |
| | last_user_msg = user_messages[-1] |
| |
|
| | tool_call = ToolCall( |
| | name="retrieve", |
| | args={"query": last_user_msg.content}, |
| | id="tool_call_1" |
| | ) |
| |
|
| | |
| | ai_message = AIMessage( |
| | content="Calling the retrieve function...", |
| | tool_calls=[tool_call] |
| | ) |
| |
|
| | return {"messages": [ai_message]} |