maclenn77 commited on
Commit
16d3fdf
·
unverified ·
1 Parent(s): ada0a19

Add a langchain agent (#10)

Browse files
Files changed (4) hide show
  1. app.py +30 -36
  2. src/agent.py +43 -0
  3. src/chroma_client.py +6 -5
  4. src/search.py +18 -0
app.py CHANGED
@@ -6,10 +6,15 @@ import fitz
6
  import streamlit as st
7
  import openai
8
  from dotenv import load_dotenv
 
 
9
  from src.chroma_client import ChromaDB
10
  import src.gui_messages as gm
11
  from src import settings
12
 
 
 
 
13
  load_dotenv()
14
 
15
 
@@ -44,27 +49,14 @@ with st.sidebar:
44
  chroma_db = ChromaDB(openai.api_key)
45
  openai_client, collection = settings.build(chroma_db)
46
 
47
- # Query ChromaDb
48
- query = st.text_input(
49
- "Query ChromaDb", value="", placeholder="Enter query", label_visibility="collapsed"
50
- )
51
- if st.button("Search"):
52
- results = collection.query(
53
- query_texts=[query],
54
- n_results=3,
55
- )
56
-
57
- for idx, result in enumerate(results["documents"][0]):
58
- st.markdown(
59
- result
60
- + "..."
61
- + "**Source:** "
62
- + results["metadatas"][0][idx]["source"]
63
- + " **Tokens:** "
64
- + str(results["metadatas"][0][idx]["num_tokens"])
65
- )
66
-
67
 
 
 
 
 
68
  pdf = st.file_uploader("Upload a file", type="pdf")
69
 
70
  if pdf is not None:
@@ -75,19 +67,11 @@ if pdf is not None:
75
  st.write(text[0:300] + "...")
76
  if st.button("Save chunks"):
77
  with st.spinner("Saving chunks..."):
78
- chunks = textwrap.wrap(text, 3000)
79
  for idx, chunk in enumerate(chunks):
80
  encoding = tiktoken.get_encoding("cl100k_base")
81
  num_tokens = len(encoding.encode(chunk))
82
- response = (
83
- openai_client.embeddings.create(
84
- input=chunk, model="text-embedding-ada-002"
85
- )
86
- .data[0]
87
- .embedding
88
- )
89
  collection.add(
90
- embeddings=[response],
91
  documents=[chunk],
92
  metadatas=[{"source": pdf.name, "num_tokens": num_tokens}],
93
  ids=[pdf.name + str(idx)],
@@ -95,11 +79,21 @@ if pdf is not None:
95
  else:
96
  st.write("Please upload a file of type: pdf")
97
 
98
- if st.button("Chroma data collection"):
99
- st.write(collection)
 
 
 
 
 
 
 
 
 
100
 
101
- if st.button("Delete Chroma Collection"):
102
- try:
103
- chroma_db.client.delete_collection(collection.name)
104
- except AttributeError:
105
- st.error("Collection erased.")
 
 
6
  import streamlit as st
7
  import openai
8
  from dotenv import load_dotenv
9
+ from langchain.chat_models import ChatOpenAI
10
+ from langchain.callbacks import StreamlitCallbackHandler
11
  from src.chroma_client import ChromaDB
12
  import src.gui_messages as gm
13
  from src import settings
14
 
15
+ from src.agent import PDFExplainer
16
+
17
+
18
  load_dotenv()
19
 
20
 
 
49
  chroma_db = ChromaDB(openai.api_key)
50
  openai_client, collection = settings.build(chroma_db)
51
 
52
+ # Create Agent
53
+ llm = ChatOpenAI(temperature=0.9, model="gpt-3.5-turbo-16k")
54
+ agent = PDFExplainer(llm, chroma_db).agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # Main
57
+ st.title("PDF Explainer")
58
+ st.subheader("Create your knowledge base")
59
+ st.write("Upload PDF files that will help the AI Agent to understand your domain.")
60
  pdf = st.file_uploader("Upload a file", type="pdf")
61
 
62
  if pdf is not None:
 
67
  st.write(text[0:300] + "...")
68
  if st.button("Save chunks"):
69
  with st.spinner("Saving chunks..."):
70
+ chunks = textwrap.wrap(text, 1250)
71
  for idx, chunk in enumerate(chunks):
72
  encoding = tiktoken.get_encoding("cl100k_base")
73
  num_tokens = len(encoding.encode(chunk))
 
 
 
 
 
 
 
74
  collection.add(
 
75
  documents=[chunk],
76
  metadatas=[{"source": pdf.name, "num_tokens": num_tokens}],
77
  ids=[pdf.name + str(idx)],
 
79
  else:
80
  st.write("Please upload a file of type: pdf")
81
 
82
+ st.subheader("Search on your knowledge base")
83
+ # if st.button("Chroma data collection"):
84
+ # st.write(collection)
85
+
86
+ # if st.button("Delete Chroma Collection"):
87
+ # try:
88
+ # chroma_db.client.delete_collection(collection.name)
89
+ # except AttributeError:
90
+ # st.error("Collection erased.")
91
+
92
+ prompt = st.chat_input()
93
 
94
+ if prompt:
95
+ st.chat_message("user").write(prompt)
96
+ with st.chat_message("assistant"):
97
+ st_callback = StreamlitCallbackHandler(st.container())
98
+ response = agent.run(prompt, callbacks=[st_callback])
99
+ st.write(response)
src/agent.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """An Langchain Agent that uses ChromaDB as a query tool"""
2
+ from langchain.agents import AgentType, initialize_agent
3
+ from langchain.tools import Tool
4
+ from src.search import Search
5
+
6
+
7
+ class PDFExplainer:
8
+ """An Agent that uses ChromaDB as a query tool"""
9
+
10
+ def __init__(self, llm, chroma_db):
11
+ """Initialize the Agent"""
12
+ search = Search(chroma_db)
13
+
14
+ self.tools = [
15
+ Tool.from_function(
16
+ func=search.run,
17
+ name="Search DB",
18
+ description="Useful when you need more context about a specific topic.",
19
+ handle_parsing_errors=True,
20
+ )
21
+ ]
22
+
23
+ self.agent = initialize_agent(
24
+ self.tools,
25
+ llm,
26
+ agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
27
+ verbose=True,
28
+ handle_parsing_errors=True,
29
+ )
30
+
31
+ def add_tools(self, tools: list[Tool]):
32
+ """Add tools to the Agent"""
33
+ self.tools.extend(tools)
34
+
35
+ def replace_agent(self, agent: AgentType, llm):
36
+ """Replace the Agent"""
37
+ self.agent = initialize_agent(
38
+ self.tools,
39
+ llm,
40
+ agent=agent,
41
+ verbose=True,
42
+ handle_parsing_errors=True,
43
+ )
src/chroma_client.py CHANGED
@@ -1,6 +1,7 @@
1
  """A client for ChromaDB."""
2
  import chromadb
3
- from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
 
4
  import streamlit as st
5
 
6
 
@@ -24,11 +25,11 @@ class ChromaDB:
24
  def create_collection(self, name):
25
  """Create a Chroma collection."""
26
  try:
27
- embedding_function = OpenAIEmbeddingFunction(
28
- api_key=self.api_key, model_name="text-embedding-ada-002"
29
- )
30
  collection = self.client.get_or_create_collection(
31
- name=name, embedding_function=embedding_function
32
  )
33
  return collection
34
  except AttributeError:
 
1
  """A client for ChromaDB."""
2
  import chromadb
3
+
4
+ # from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
5
  import streamlit as st
6
 
7
 
 
25
  def create_collection(self, name):
26
  """Create a Chroma collection."""
27
  try:
28
+ # embedding_function = OpenAIEmbeddingFunction(
29
+ # api_key=self.api_key, model_name="text-embedding-ada-002"
30
+ # )
31
  collection = self.client.get_or_create_collection(
32
+ name=name # , embedding_function=embedding_function
33
  )
34
  return collection
35
  except AttributeError:
src/search.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Search Tool"""
2
+
3
+
4
+ class Search:
5
+ """Search Tool"""
6
+
7
+ def __init__(self, chroma_db):
8
+ """Initialize the Search Tool"""
9
+ self.chroma_db = chroma_db
10
+
11
+ def run(self, query: str):
12
+ """Run the Agent"""
13
+ collection = self.chroma_db.get_collection("pdf-explainer")
14
+ return collection.query(query_texts=[query], n_results=3)["documents"][0]
15
+
16
+ def collection_name(self):
17
+ """Return the collection name"""
18
+ return self.chroma_db.collection.name