MiguelCosta commited on
Commit
0032a40
·
1 Parent(s): 5ec5d7b

add gpt4, agent

Browse files
Files changed (2) hide show
  1. Dockerfile +3 -1
  2. app_one.py +143 -37
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.9
2
  RUN useradd -m -u 1000 user
3
  #USER user
4
  ENV HOME=/home/user \
@@ -13,6 +13,8 @@ RUN pip install qdrant-client
13
  RUN pip install langchain
14
  RUN pip install langchain-community
15
  RUN pip install langchain-openai
 
 
16
  COPY . .
17
 
18
  RUN chown -R user:user $HOME/app/Qdrant_db
 
1
+ FROM python:3.11
2
  RUN useradd -m -u 1000 user
3
  #USER user
4
  ENV HOME=/home/user \
 
13
  RUN pip install langchain
14
  RUN pip install langchain-community
15
  RUN pip install langchain-openai
16
+ RUN pip install duckduckgo-search==5.3.0b4
17
+ RUN pip install langgraph
18
  COPY . .
19
 
20
  RUN chown -R user:user $HOME/app/Qdrant_db
app_one.py CHANGED
@@ -1,14 +1,10 @@
1
-
2
-
3
-
4
-
5
- #from langchain.chat_models import ChatOpenAI
6
- #from langchain_community.chat_models import ChatOpenAI
7
  from langchain_openai import ChatOpenAI
8
  from langchain.prompts import ChatPromptTemplate
9
  from langchain.schema import StrOutputParser
10
  from langchain.schema.runnable import Runnable
 
11
  from langchain.schema.runnable.config import RunnableConfig
 
12
 
13
  from langchain_community.vectorstores import Qdrant
14
  from qdrant_client import QdrantClient, models
@@ -16,66 +12,176 @@ from langchain_openai.embeddings import OpenAIEmbeddings
16
 
17
  from langchain.retrievers import MultiQueryRetriever
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  from operator import itemgetter
20
 
21
  import chainlit as cl
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- model = ChatOpenAI(model="gpt-3.5-turbo", streaming=True)
25
 
26
 
 
 
 
 
27
  client = QdrantClient(path="Qdrant_db")
28
  embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
29
 
 
30
  collection_name = "AML_act"
31
  qdrant = Qdrant(client, collection_name, embedding_model)
32
 
33
  qdrant_retriever = qdrant.as_retriever()
34
  advanced_retriever = MultiQueryRetriever.from_llm(retriever=qdrant_retriever, llm=model)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  @cl.on_chat_start
39
  async def on_chat_start():
40
 
41
- RAG_PROMPT = """
42
-
43
- CONTEXT:
44
- {context}
45
-
46
- QUERY:
47
- {question}
48
-
49
- Answer the query above using the context provided. If you don't know the answer responde with: I don't know
50
- """
51
-
52
- rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
53
 
54
- runnable = (
55
- {"context": itemgetter("question") | advanced_retriever, "question": itemgetter("question")} | rag_prompt | model | StrOutputParser()
56
- )
57
-
58
-
59
  cl.user_session.set("runnable", runnable)
60
 
61
 
62
  @cl.on_message
63
  async def on_message(message: cl.Message):
64
 
65
- runnable = cl.user_session.get("runnable") # type: Runnable
66
- msg = cl.Message(content="")
67
-
68
  print("Query content----------", message.content)
69
 
70
- for chunk in await cl.make_async(runnable.stream)(
71
- {"question": message.content},
72
- config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
73
- ):
74
- await msg.stream_token(chunk)
75
 
76
- print("Answer content----------", msg.content)
77
-
78
- await msg.send()
79
-
80
- print("Answer content----------", msg.content)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain_openai import ChatOpenAI
2
  from langchain.prompts import ChatPromptTemplate
3
  from langchain.schema import StrOutputParser
4
  from langchain.schema.runnable import Runnable
5
+ from langchain.schema.runnable import RunnablePassthrough
6
  from langchain.schema.runnable.config import RunnableConfig
7
+ from langchain_core.messages import HumanMessage
8
 
9
  from langchain_community.vectorstores import Qdrant
10
  from qdrant_client import QdrantClient, models
 
12
 
13
  from langchain.retrievers import MultiQueryRetriever
14
 
15
+ # FROM THE LOADER
16
+ from langchain_community.tools.ddg_search import DuckDuckGoSearchRun
17
+ from langchain.tools.retriever import create_retriever_tool
18
+
19
+ from langchain_core.utils.function_calling import convert_to_openai_function
20
+ from langgraph.prebuilt import ToolExecutor
21
+
22
+ from typing import TypedDict, Annotated
23
+ from langgraph.graph.message import add_messages
24
+
25
+ from langgraph.prebuilt import ToolInvocation
26
+ import json
27
+ from langchain_core.messages import FunctionMessage
28
+
29
+ from langchain_core.messages import BaseMessage
30
+
31
+ from langgraph.graph import StateGraph, END
32
+
33
  from operator import itemgetter
34
 
35
  import chainlit as cl
36
 
37
+ import os
38
+ import getpass
39
+ from uuid import uuid4
40
+
41
+
42
+ #os.environ["LANGCHAIN_TRACING_V2"] = "true"
43
+ #os.environ["LANGCHAIN_PROJECT"] = f"AML-au - {uuid4().hex[0:8]}"
44
+ #os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("LangSmith_API_Key: ")
45
+
46
+ # PROMPTS IN USE
47
+ duckduckgo_description = "Useful for when you need to answer questions about aml."
48
+
49
+ aml_act_retriever_description = "Searches and returns excerpts from the aml act."
50
+
51
+ agent_prompt = " Only conduct DuckDuckGo searches when asked about Anti Money Laundering (aml). "
52
+ # If the question is not about aml answer with: I don't know.
53
+
54
 
 
55
 
56
 
57
+ # Model used for the MultiQueryRetriever set with "some" temperature
58
+ model = ChatOpenAI(model="gpt-3.5-turbo", streaming=True) #temperature=0.7
59
+
60
+ # Create Qdrant vectorstore as a retreiver
61
  client = QdrantClient(path="Qdrant_db")
62
  embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
63
 
64
+ # Load collection from disk
65
  collection_name = "AML_act"
66
  qdrant = Qdrant(client, collection_name, embedding_model)
67
 
68
  qdrant_retriever = qdrant.as_retriever()
69
  advanced_retriever = MultiQueryRetriever.from_llm(retriever=qdrant_retriever, llm=model)
70
 
71
+ retreiver_tool = create_retriever_tool(
72
+ advanced_retriever,
73
+ "search_aml_act_retriever",
74
+ aml_act_retriever_description,
75
+ )
76
+
77
+ tool_belt = [DuckDuckGoSearchRun(description=duckduckgo_description +
78
+ "Input should be a search query."), retreiver_tool]
79
+
80
+ tool_executor = ToolExecutor(tool_belt)
81
+
82
+ model_aml = ChatOpenAI(model="gpt-4", temperature=0)
83
+
84
+ functions = [convert_to_openai_function(t) for t in tool_belt]
85
+ model_aml = model_aml.bind_functions(functions)
86
+
87
+ print(functions)
88
+
89
+ # BUILD THE GRAPH
90
+ class AgentState(TypedDict):
91
+ messages: Annotated[list, add_messages]
92
+
93
+ def call_model(state):
94
+ messages = state["messages"]
95
+ response = model_aml.invoke(messages)
96
+ return {"messages" : [response]}
97
+
98
+ def call_tool(state):
99
+ last_message = state["messages"][-1]
100
+
101
+ action = ToolInvocation(
102
+ tool=last_message.additional_kwargs["function_call"]["name"],
103
+ tool_input=json.loads(
104
+ last_message.additional_kwargs["function_call"]["arguments"]
105
+ )
106
+ )
107
+
108
+ response = tool_executor.invoke(action)
109
+
110
+ function_message = FunctionMessage(content=str(response), name=action.tool)
111
+
112
+ return {"messages" : [function_message]}
113
+
114
+ workflow = StateGraph(AgentState)
115
+ workflow.add_node("agent", call_model)
116
+ workflow.add_node("action", call_tool)
117
+ workflow.set_entry_point("agent")
118
+
119
+ def should_continue(state):
120
+ last_message = state["messages"][-1]
121
+
122
+ if "function_call" not in last_message.additional_kwargs:
123
+ return "end"
124
+
125
+ return "continue"
126
+
127
+ workflow.add_conditional_edges(
128
+ "agent",
129
+ should_continue,
130
+ {
131
+ "continue" : "action",
132
+ "end" : END
133
+ }
134
+ )
135
+
136
+ workflow.add_edge("action", "agent")
137
+
138
+ app = workflow.compile()
139
+
140
+ # Aux print, useful for debuging
141
+ def print_messages(messages):
142
+ next_is_tool = False
143
+ initial_query = True
144
+ for message in messages["messages"]:
145
+ if "function_call" in message.additional_kwargs:
146
+ print()
147
+ print(f'Tool Call - Name: {message.additional_kwargs["function_call"]["name"]} + Query: {message.additional_kwargs["function_call"]["arguments"]}')
148
+ next_is_tool = True
149
+ continue
150
+ if next_is_tool:
151
+ print(f"Tool Response: {message.content}")
152
+ next_is_tool = False
153
+ continue
154
+ if initial_query:
155
+ print(f"Initial Query: {message.content}")
156
+ print()
157
+ initial_query = False
158
+ continue
159
+ print()
160
+ print(f"Agent Response: {message.content}")
161
 
162
 
163
  @cl.on_chat_start
164
  async def on_chat_start():
165
 
166
+ runnable = app #| output_parser # | StrOutputParser()
 
 
 
 
 
 
 
 
 
 
 
167
 
 
 
 
 
 
168
  cl.user_session.set("runnable", runnable)
169
 
170
 
171
  @cl.on_message
172
  async def on_message(message: cl.Message):
173
 
 
 
 
174
  print("Query content----------", message.content)
175
 
176
+ input_message = HumanMessage(content=(agent_prompt + message.content))
177
+
178
+ response = app.invoke({"messages": [input_message]})
 
 
179
 
180
+ await cl.Message(
181
+ content=response["messages"][-1].content).send()
 
 
 
182
 
183
+ print_messages(response)
184
+
185
+ print("Answer content----------", response["messages"][-1].content)
186
+
187
+