File size: 5,373 Bytes
627ec3c
 
 
bf0eea7
627ec3c
 
 
 
 
 
 
920b409
 
627ec3c
 
920b409
aedb86c
599b085
627ec3c
 
 
 
 
 
7e508e0
77a5434
 
 
627ec3c
 
aa85c3a
 
5fea2a2
 
627ec3c
0db0599
 
 
 
 
 
920b409
0db0599
 
 
920b409
0db0599
 
 
627ec3c
0db0599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627ec3c
 
 
 
 
72c6414
627ec3c
 
 
 
 
 
 
 
 
 
0db0599
 
 
 
 
 
aa85c3a
627ec3c
 
 
 
 
 
 
 
 
 
 
 
 
aa85c3a
 
 
 
aedb86c
 
aa85c3a
627ec3c
 
 
 
 
 
 
 
aa85c3a
627ec3c
 
 
fbe3391
aa85c3a
 
fbe3391
 
a664189
aa85c3a
fbe3391
aa85c3a
8a8026b
aa85c3a
 
a664189
aa85c3a
bf0eea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import structlog
from langchain.chat_models import init_chat_model
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
)
from langchain_core.runnables import RunnableParallel
from langgraph.graph import MessagesState
from langfuse import Langfuse

from pydantic import BaseModel

from tools.langfuse_client import get_langfuse_handler, get_langfuse_client
from conversation.citation_utils import CitedAnswer, format_artifacts_to_string, embed_references

from config import app_settings

logger = structlog.get_logger(__name__)

llm = init_chat_model(
    app_settings.llm_model,
    model_provider=app_settings.model_provider,
    region_name=app_settings.llm_region,
    aws_access_key_id=app_settings.aws_access_key_id,
    aws_secret_access_key=app_settings.aws_secret_access_key,
)

structured_llm = llm.with_structured_output(CitedAnswer)

# Get Langfuse Callback handler
langfuse_handler = get_langfuse_handler()

def get_rag_prompt_from_langfuse(
        prompt_name: str,
        prompt_label: str
) -> ChatPromptTemplate:
    """
    Get the prompt for the RAG-system via the Langfuse prompt management system and convert it to a Langchain prompt.

    Args:
        prompt_name (str): The name of the Langfuse prompt.
        prompt_label (str): The label of the Langfuse prompt.

    Returns:
        ChatPromptTemplate: Prompt template for chat model to use.
    """

    # Get Langfuse client
    langfuse = get_langfuse_client()

    # Get current production version of prompt via Langfuse
    langfuse_prompt = langfuse.get_prompt(prompt_name, label=prompt_label)

    # Print loaded Langfuse prompt into logs
    logger.info("This is the loaded prompt from Langfuse: %s", langfuse_prompt.prompt)

    # Convert Langfuse prompt to Langchain prompt
    langchain_prompt = ChatPromptTemplate.from_messages(
        langfuse_prompt.get_langchain_prompt(),
    )
    langchain_prompt.metadata = {"langfuse_prompt": langfuse_prompt}

    return langchain_prompt

# User input
class ChatHistory(BaseModel):
    chat_history: list[AIMessage | HumanMessage]
    question: str
    context: str


_inputs = RunnableParallel(
    {
        "question": lambda x: x["question"],
        "chat_history": lambda x: x["chat_history"],
        "context": lambda x: x["context"]
    }
).with_types(input_type=ChatHistory)

# Get current production version of RAG prompt via Langfuse
langchain_prompt = get_rag_prompt_from_langfuse(
        prompt_name="answer-question-with-context-and-msg-history-copy",
        prompt_label="production" 
)

chain = _inputs | langchain_prompt | structured_llm


def generate(state: MessagesState):
    """Generate answer."""
    # Get generated ToolMessages
    recent_tool_messages = []
    for message in reversed(state["messages"]):
        if message.type == "tool":
            recent_tool_messages.append(message)
        else:
            break
    tool_messages = recent_tool_messages[::-1]
    # Format into prompt
    all_artifacts = []
    for message in tool_messages:
        if message.artifact:
            all_artifacts.extend(message.artifact)
            
    docs_content = format_artifacts_to_string(all_artifacts)
    
    logger.info("Tool messages", context=docs_content)
    
    conversation_messages = [
        message
        for message in state["messages"]
        if message.type in ("human", "system")
        or (message.type == "ai" and not message.tool_calls)
    ]
    structured_response = chain.invoke({
        "question": conversation_messages[-1].content,
        "chat_history": conversation_messages,
        "context": docs_content,
    }, config={"callbacks": [langfuse_handler]}) 
    
    if structured_response.sources:
        formatted_answer = embed_references(structured_response)
        main_answer = {"role": "assistant", "content": formatted_answer}
        citations = f"{structured_response.sources}"
    else:
        main_answer = {"role": "assistant", "content": structured_response.answer}
        citations = []

    return {
        "messages": main_answer,
        "llm-answer": structured_response.answer,
        "sources": citations
    }

def trigger_ai_message_with_tool_call(state: MessagesState) -> AIMessage:
    """
    Takes the last user message from the state and returns an AIMessage
    with example tool_calls populated.
    
    Args:
        state (dict): A dictionary with a 'messages' key containing a list of LangChain messages.

    Returns:
        AIMessage: An AIMessage with tool_calls based on the last user message.
    """
    
    # Filter for user messages
    user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
    
    if not user_messages:
        raise ValueError("No user messages found in the previous messages.")
    
    last_user_msg = user_messages[-1]

    tool_call = ToolCall(
        name="retrieve",
        args={"query": last_user_msg.content},
        id="tool_call_1"
    )

    # Construct the AIMessage with tool_calls
    ai_message = AIMessage(
        content="Calling the retrieve function...",
        tool_calls=[tool_call]
    )

    return {"messages": [ai_message]}