jaothan commited on
Commit
aa43c4f
·
verified ·
1 Parent(s): 73913fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -109
app.py CHANGED
@@ -1,109 +1,109 @@
1
- import os
2
-
3
- import streamlit as st
4
- from langchain.chains import RetrievalQA
5
- from PyPDF2 import PdfReader
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain.callbacks.base import BaseCallbackHandler
8
- from langchain.vectorstores.neo4j_vector import Neo4jVector
9
- from streamlit.logger import get_logger
10
- from chains import (
11
- load_embedding_model,
12
- load_llm,
13
- )
14
-
15
- url = os.getenv("NEO4J_URI")
16
- username = os.getenv("NEO4J_USERNAME")
17
- password = os.getenv("NEO4J_PASSWORD")
18
- ollama_base_url = os.getenv("OLLAMA_BASE_URL")
19
- embedding_model_name = os.getenv("EMBEDDING_MODEL", "SentenceTransformer" )
20
- llm_name = os.getenv("LLM", "llama2")
21
- url = os.getenv("NEO4J_URI")
22
-
23
- # Check if the required environment variables are set
24
- if not all([url, username, password,
25
- ollama_base_url]):
26
- st.write("The application requires some information before running.")
27
- with st.form("connection_form"):
28
- url = st.text_input("Enter NEO4J_URI",)
29
- username = st.text_input("Enter NEO4J_USERNAME")
30
- password = st.text_input("Enter NEO4J_PASSWORD", type="password")
31
- ollama_base_url = st.text_input("Enter OLLAMA_BASE_URL")
32
- st.markdown("Only enter the OPENAI_APIKEY to use OpenAI instead of Ollama. Leave blank to use Ollama.")
33
- openai_apikey = st.text_input("Enter OPENAI_API_KEY", type="password")
34
- submit_button = st.form_submit_button("Submit")
35
- if submit_button:
36
- if not all([url, username, password, ]):
37
- st.write("Enter the Neo4j information.")
38
- if not (ollama_base_url or openai_apikey):
39
- st.write("Enter the Ollama URL or OpenAI API Key.")
40
- if openai_apikey:
41
- llm_name = "gpt-3.5"
42
- os.environ['OPENAI_API_KEY'] = openai_apikey
43
-
44
- os.environ["NEO4J_URL"] = url
45
-
46
- logger = get_logger(__name__)
47
-
48
- embeddings, dimension = load_embedding_model(
49
- embedding_model_name, config={"ollama_base_url": ollama_base_url}, logger=logger
50
- )
51
-
52
-
53
- class StreamHandler(BaseCallbackHandler):
54
- def __init__(self, container, initial_text=""):
55
- self.container = container
56
- self.text = initial_text
57
-
58
- def on_llm_new_token(self, token: str, **kwargs) -> None:
59
- self.text += token
60
- self.container.markdown(self.text)
61
-
62
- llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
63
-
64
-
65
- def main():
66
- st.header("📄Chat with your pdf file")
67
-
68
- # upload a your pdf file
69
- pdf = st.file_uploader("Upload your PDF", type="pdf")
70
-
71
- if pdf is not None:
72
- pdf_reader = PdfReader(pdf)
73
-
74
- text = ""
75
- for page in pdf_reader.pages:
76
- text += page.extract_text()
77
-
78
- # langchain_textspliter
79
- text_splitter = RecursiveCharacterTextSplitter(
80
- chunk_size=1000, chunk_overlap=200, length_function=len
81
- )
82
-
83
- chunks = text_splitter.split_text(text=text)
84
-
85
- # Store the chunks part in db (vector)
86
- vectorstore = Neo4jVector.from_texts(
87
- chunks,
88
- url=url,
89
- username=username,
90
- password=password,
91
- embedding=embeddings,
92
- index_name="pdf_bot",
93
- node_label="PdfBotChunk",
94
- pre_delete_collection=True, # Delete existing PDF data
95
- )
96
- qa = RetrievalQA.from_chain_type(
97
- llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever()
98
- )
99
-
100
- # Accept user questions/query
101
- query = st.text_input("Ask questions about your PDF file")
102
-
103
- if query:
104
- stream_handler = StreamHandler(st.empty())
105
- qa.run(query, callbacks=[stream_handler])
106
-
107
-
108
- if __name__ == "__main__":
109
- main()
 
1
+ import os
2
+ !pip install sentence-transformers
3
+ import streamlit as st
4
+ from langchain.chains import RetrievalQA
5
+ from PyPDF2 import PdfReader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.callbacks.base import BaseCallbackHandler
8
+ from langchain.vectorstores.neo4j_vector import Neo4jVector
9
+ from streamlit.logger import get_logger
10
+ from chains import (
11
+ load_embedding_model,
12
+ load_llm,
13
+ )
14
+
15
+ url = os.getenv("NEO4J_URI")
16
+ username = os.getenv("NEO4J_USERNAME")
17
+ password = os.getenv("NEO4J_PASSWORD")
18
+ ollama_base_url = os.getenv("OLLAMA_BASE_URL")
19
+ embedding_model_name = os.getenv("EMBEDDING_MODEL", "SentenceTransformer" )
20
+ llm_name = os.getenv("LLM", "llama2")
21
+ url = os.getenv("NEO4J_URI")
22
+
23
+ # Check if the required environment variables are set
24
+ if not all([url, username, password,
25
+ ollama_base_url]):
26
+ st.write("The application requires some information before running.")
27
+ with st.form("connection_form"):
28
+ url = st.text_input("Enter NEO4J_URI",)
29
+ username = st.text_input("Enter NEO4J_USERNAME")
30
+ password = st.text_input("Enter NEO4J_PASSWORD", type="password")
31
+ ollama_base_url = st.text_input("Enter OLLAMA_BASE_URL")
32
+ st.markdown("Only enter the OPENAI_APIKEY to use OpenAI instead of Ollama. Leave blank to use Ollama.")
33
+ openai_apikey = st.text_input("Enter OPENAI_API_KEY", type="password")
34
+ submit_button = st.form_submit_button("Submit")
35
+ if submit_button:
36
+ if not all([url, username, password, ]):
37
+ st.write("Enter the Neo4j information.")
38
+ if not (ollama_base_url or openai_apikey):
39
+ st.write("Enter the Ollama URL or OpenAI API Key.")
40
+ if openai_apikey:
41
+ llm_name = "gpt-3.5"
42
+ os.environ['OPENAI_API_KEY'] = openai_apikey
43
+
44
+ os.environ["NEO4J_URL"] = url
45
+
46
+ logger = get_logger(__name__)
47
+
48
+ embeddings, dimension = load_embedding_model(
49
+ embedding_model_name, config={"ollama_base_url": ollama_base_url}, logger=logger
50
+ )
51
+
52
+
53
+ class StreamHandler(BaseCallbackHandler):
54
+ def __init__(self, container, initial_text=""):
55
+ self.container = container
56
+ self.text = initial_text
57
+
58
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
59
+ self.text += token
60
+ self.container.markdown(self.text)
61
+
62
+ llm = load_llm(llm_name, logger=logger, config={"ollama_base_url": ollama_base_url})
63
+
64
+
65
+ def main():
66
+ st.header("📄Chat with your pdf file")
67
+
68
+ # upload a your pdf file
69
+ pdf = st.file_uploader("Upload your PDF", type="pdf")
70
+
71
+ if pdf is not None:
72
+ pdf_reader = PdfReader(pdf)
73
+
74
+ text = ""
75
+ for page in pdf_reader.pages:
76
+ text += page.extract_text()
77
+
78
+ # langchain_textspliter
79
+ text_splitter = RecursiveCharacterTextSplitter(
80
+ chunk_size=1000, chunk_overlap=200, length_function=len
81
+ )
82
+
83
+ chunks = text_splitter.split_text(text=text)
84
+
85
+ # Store the chunks part in db (vector)
86
+ vectorstore = Neo4jVector.from_texts(
87
+ chunks,
88
+ url=url,
89
+ username=username,
90
+ password=password,
91
+ embedding=embeddings,
92
+ index_name="pdf_bot",
93
+ node_label="PdfBotChunk",
94
+ pre_delete_collection=True, # Delete existing PDF data
95
+ )
96
+ qa = RetrievalQA.from_chain_type(
97
+ llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever()
98
+ )
99
+
100
+ # Accept user questions/query
101
+ query = st.text_input("Ask questions about your PDF file")
102
+
103
+ if query:
104
+ stream_handler = StreamHandler(st.empty())
105
+ qa.run(query, callbacks=[stream_handler])
106
+
107
+
108
+ if __name__ == "__main__":
109
+ main()