Julia Ostheimer commited on
Commit
bf0eea7
·
1 Parent(s): a959e84

Move trigger_ai_message_with_tool_call to conversation/generate.py

Browse files
Files changed (2) hide show
  1. app.py +1 -37
  2. conversation/generate.py +36 -1
app.py CHANGED
@@ -5,8 +5,6 @@ from langchain.chat_models import init_chat_model
5
  from langchain_core.tools import tool
6
  from langchain_aws import BedrockEmbeddings
7
  from langchain_qdrant import QdrantVectorStore
8
- from langchain.schema import AIMessage, HumanMessage
9
- from langchain_core.messages.tool import ToolCall
10
 
11
  from langgraph.checkpoint.memory import MemorySaver
12
  from langgraph.graph import MessagesState, StateGraph, END
@@ -19,7 +17,7 @@ from qdrant_client.http.models import Distance, VectorParams, SparseVectorParams
19
 
20
  import logging_config as _
21
  # from conversation.main import graph
22
- from conversation.generate import generate
23
  from conversation.source_history import prettify_source_history, build_source_history_object
24
  from ingestion.main import ingest_document
25
  from tools.langfuse_client import get_langfuse_handler
@@ -73,40 +71,6 @@ def retrieve(query: str):
73
  )
74
  return serialized, retrieved_docs
75
 
76
- def trigger_ai_message_with_tool_call(state: MessagesState) -> AIMessage:
77
- """
78
- Takes the last user message from the state and returns an AIMessage
79
- with example tool_calls populated.
80
-
81
- Args:
82
- state (dict): A dictionary with a 'messages' key containing a list of LangChain messages.
83
-
84
- Returns:
85
- AIMessage: An AIMessage with tool_calls based on the last user message.
86
- """
87
-
88
- # Filter for user messages
89
- user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
90
-
91
- if not user_messages:
92
- raise ValueError("No user messages found in the previous messages.")
93
-
94
- last_user_msg = user_messages[-1]
95
-
96
- tool_call = ToolCall(
97
- name="retrieve",
98
- args={"query": last_user_msg.content},
99
- id="tool_call_1"
100
- )
101
-
102
- # Construct the AIMessage with tool_calls
103
- ai_message = AIMessage(
104
- content="Calling the retrieve function...",
105
- tool_calls=[tool_call]
106
- )
107
-
108
- return {"messages": [ai_message]}
109
-
110
 
111
  graph_builder = StateGraph(MessagesState)
112
  memory = MemorySaver()
 
5
  from langchain_core.tools import tool
6
  from langchain_aws import BedrockEmbeddings
7
  from langchain_qdrant import QdrantVectorStore
 
 
8
 
9
  from langgraph.checkpoint.memory import MemorySaver
10
  from langgraph.graph import MessagesState, StateGraph, END
 
17
 
18
  import logging_config as _
19
  # from conversation.main import graph
20
+ from conversation.generate import generate, trigger_ai_message_with_tool_call
21
  from conversation.source_history import prettify_source_history, build_source_history_object
22
  from ingestion.main import ingest_document
23
  from tools.langfuse_client import get_langfuse_handler
 
71
  )
72
  return serialized, retrieved_docs
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  graph_builder = StateGraph(MessagesState)
76
  memory = MemorySaver()
conversation/generate.py CHANGED
@@ -1,6 +1,7 @@
1
  import structlog
2
  from langchain.chat_models import init_chat_model
3
  from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
 
4
  from langchain_core.output_parsers import StrOutputParser
5
  from langchain_core.prompts import (
6
  ChatPromptTemplate,
@@ -132,4 +133,38 @@ def generate(state: MessagesState):
132
  "messages": main_answer,
133
  "llm-answer": structured_response.answer,
134
  "sources": citations
135
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import structlog
2
  from langchain.chat_models import init_chat_model
3
  from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
4
+ from langchain_core.messages.tool import ToolCall
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.prompts import (
7
  ChatPromptTemplate,
 
133
  "messages": main_answer,
134
  "llm-answer": structured_response.answer,
135
  "sources": citations
136
+ }
137
+
138
+ def trigger_ai_message_with_tool_call(state: MessagesState) -> AIMessage:
139
+ """
140
+ Takes the last user message from the state and returns an AIMessage
141
+ with example tool_calls populated.
142
+
143
+ Args:
144
+ state (dict): A dictionary with a 'messages' key containing a list of LangChain messages.
145
+
146
+ Returns:
147
+ AIMessage: An AIMessage with tool_calls based on the last user message.
148
+ """
149
+
150
+ # Filter for user messages
151
+ user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
152
+
153
+ if not user_messages:
154
+ raise ValueError("No user messages found in the previous messages.")
155
+
156
+ last_user_msg = user_messages[-1]
157
+
158
+ tool_call = ToolCall(
159
+ name="retrieve",
160
+ args={"query": last_user_msg.content},
161
+ id="tool_call_1"
162
+ )
163
+
164
+ # Construct the AIMessage with tool_calls
165
+ ai_message = AIMessage(
166
+ content="Calling the retrieve function...",
167
+ tool_calls=[tool_call]
168
+ )
169
+
170
+ return {"messages": [ai_message]}