Dakhoo commited on
Commit
f3f7425
·
1 Parent(s): 12e0acb

fixed uploader type

Browse files
Files changed (4) hide show
  1. .flake8 +1 -0
  2. .gitignore +163 -0
  3. .pre-commit-config.yaml +59 -0
  4. app.py +331 -2
.flake8 ADDED
@@ -0,0 +1 @@
 
 
1
+ ignore = E501
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tempdir/*
2
+ hf_model/*
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/psf/black
3
+ rev: 23.3.0
4
+ hooks:
5
+ - id: black
6
+
7
+ - repo: https://github.com/pycqa/isort
8
+ rev: 5.12.0
9
+ hooks:
10
+ - id: isort
11
+ args: ["--profile", "black"]
12
+
13
+ - repo: https://github.com/pycqa/flake8
14
+ rev: 6.0.0
15
+ hooks:
16
+ - id: flake8
17
+ exclude: .*/tests|^sandbox
18
+ additional_dependencies: [flake8-docstrings]
19
+ args:
20
+ [
21
+ "--max-line-length=88",
22
+ "--extend-ignore=E203,W503",
23
+ "--docstring-convention",
24
+ "google",
25
+ ]
26
+
27
+ - repo: https://github.com/pre-commit/pre-commit-hooks
28
+ rev: v4.4.0
29
+ hooks:
30
+ - id: requirements-txt-fixer
31
+ files: .*/requirements.*\.txt$
32
+ - id: check-json
33
+ exclude: '^data/.*'
34
+ - id: check-yaml
35
+ exclude: '^applications/.*/charts/.*\.yaml$'
36
+ - id: check-added-large-files
37
+ - id: check-merge-conflict
38
+
39
+ - repo: https://github.com/pre-commit/mirrors-mypy
40
+ rev: v1.3.0
41
+ hooks:
42
+ - id: mypy
43
+ args: [--ignore-missing-imports, --disallow-untyped-defs, --install-types, --non-interactive]
44
+ exclude: .*/tests|^sandbox
45
+
46
+ - repo: local
47
+ hooks:
48
+ - id: hadolint
49
+ name: hadolint
50
+ entry: hadolint/hadolint:v2.12.1-beta hadolint --ignore DL3008 --no-color
51
+ language: docker_image
52
+ types: [file, dockerfile]
53
+
54
+
55
+ - repo: https://github.com/sqlfluff/sqlfluff
56
+ rev: 2.1.1
57
+ hooks:
58
+ - id: sqlfluff-lint
59
+ - id: sqlfluff-fix
app.py CHANGED
@@ -1,4 +1,333 @@
 
 
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
1
+ """This is a public module. It should have a docstring."""
2
+ import itertools
3
+ import os
4
+ import random
5
+ from typing import Any, List, Tuple
6
+
7
+ import openai
8
  import streamlit as st
9
+ from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
10
+ from langchain.agents.agent_toolkits import create_retriever_tool
11
+ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
12
+ AgentTokenBufferMemory,
13
+ )
14
+ from langchain.callbacks import StreamlitCallbackHandler
15
+ from langchain.chains import QAGenerationChain
16
+ from langchain.chat_models import ChatOpenAI
17
+ from langchain.document_loaders import PyPDFLoader
18
+ from langchain.embeddings import HuggingFaceEmbeddings
19
+ from langchain.prompts import MessagesPlaceholder
20
+ from langchain.schema import AIMessage, HumanMessage, SystemMessage
21
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
22
+ from langchain.vectorstores import FAISS
23
+
24
+ st.set_page_config(page_title="PDF QA", page_icon="📚")
25
+
26
+ starter_message = "Ask me anything about the Doc!"
27
+
28
+
29
+ @st.cache_resource
30
+ def create_prompt(openai_api_key: str) -> Tuple[SystemMessage, ChatOpenAI]:
31
+ """Create prompt."""
32
+ try:
33
+ # Make your OpenAI API request here
34
+ llm = ChatOpenAI(
35
+ temperature=0,
36
+ model_name="gpt-3.5-turbo",
37
+ streaming=True,
38
+ openai_api_key=openai_api_key,
39
+ )
40
+ except openai.error.AuthenticationError as e:
41
+ # Handle timeout error, e.g. retry or log
42
+ print(f"Please check your API key and try again. : {e}")
43
+ pass
44
+
45
+ message = SystemMessage(
46
+ content=(
47
+ "You are a helpful chatbot who is tasked with answering questions about context given through uploaded documents." # noqa: E501 comment
48
+ "Unless otherwise explicitly stated, it is probably fair to assume that questions are about the context given." # noqa: E501 comment
49
+ "If there is any ambiguity, you probably assume they are about that." # noqa: E501 comment
50
+ )
51
+ )
52
+
53
+ prompt = OpenAIFunctionsAgent.create_prompt(
54
+ system_message=message,
55
+ extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
56
+ )
57
+
58
+ return prompt, llm
59
+
60
+
61
+ @st.cache_data
62
+ def save_file_locally(file: Any) -> str:
63
+ """Save uploaded files locally."""
64
+ doc_path = os.path.join("tempdir", file.name)
65
+ with open(doc_path, "wb") as f:
66
+ f.write(file.getbuffer())
67
+
68
+ return doc_path
69
+
70
+
71
+ @st.cache_data
72
+ def load_docs(files: List[Any], url: bool = False) -> str:
73
+ """Load and process the uploaded PDF files."""
74
+ if not url:
75
+ st.info("`Reading doc ...`")
76
+ documents = []
77
+ for file in files:
78
+ doc_path = save_file_locally(file)
79
+ pages = PyPDFLoader(doc_path)
80
+ documents.extend(pages.load())
81
+
82
+ return ",".join([doc.page_content for doc in documents])
83
+
84
+
85
+ @st.cache_data
86
+ def gen_embeddings() -> HuggingFaceEmbeddings:
87
+ """Generate embeddings for given model."""
88
+ embeddings = HuggingFaceEmbeddings(
89
+ cache_folder="hf_model"
90
+ ) # https://github.com/UKPLab/sentence-transformers/issues/1828
91
+ return embeddings
92
+
93
+
94
+ @st.cache_resource
95
+ def process_corpus(corpus: str, chunk_size: int = 1000, overlap: int = 50) -> List:
96
+ """Process text for Semantic Search."""
97
+ text_splitter = RecursiveCharacterTextSplitter(
98
+ chunk_size=chunk_size, chunk_overlap=overlap
99
+ )
100
+
101
+ texts = text_splitter.split_text(corpus)
102
+
103
+ # Display the number of text chunks
104
+ num_chunks = len(texts)
105
+ st.write(f"Number of text chunks: {num_chunks}")
106
+
107
+ # select embedding model
108
+ embeddings = gen_embeddings()
109
+
110
+ # create vectorstore
111
+ vectorstore = FAISS.from_texts(texts, embeddings).as_retriever(
112
+ search_kwargs={"k": 4}
113
+ )
114
+
115
+ # create retriever tool
116
+ tool = create_retriever_tool(
117
+ vectorstore,
118
+ "search_docs",
119
+ "Searches and returns documents using the context provided as a source, relevant to the user input question.", # noqa: E501 comment
120
+ )
121
+
122
+ tools = [tool]
123
+ return tools
124
+
125
+
126
+ @st.cache_data
127
+ def generate_agent_executer(text: str) -> List[AgentExecutor]:
128
+ """Generate the memory functionality."""
129
+ tools = process_corpus(text)
130
+
131
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
132
+ # Synthwave
133
+
134
+ agent_executor = AgentExecutor(
135
+ agent=agent,
136
+ tools=tools,
137
+ verbose=True,
138
+ return_intermediate_steps=True,
139
+ )
140
+ return agent_executor
141
+
142
+
143
+ @st.cache_data
144
+ def generate_eval(raw_text: str, N: int, chunk: int) -> List:
145
+ """Generate the focusing functionality."""
146
+ # Generate N questions from context of chunk chars
147
+ # IN: text, N questions, chunk size to draw question from in the doc
148
+ # OUT: eval set as JSON list
149
+ # raw_text = ','.join(raw_text)
150
+ update = st.empty()
151
+ ques_update = st.empty()
152
+ update.info("`Generating sample questions ...`")
153
+ n = len(raw_text)
154
+ starting_indices = [random.randint(0, n - chunk) for _ in range(N)]
155
+ sub_sequences = [raw_text[i : i + chunk] for i in starting_indices]
156
+ chain = QAGenerationChain.from_llm(llm)
157
+ eval_set = []
158
+ for i, b in enumerate(sub_sequences):
159
+ try:
160
+ qa = chain.run(b)
161
+ eval_set.append(qa)
162
+ ques_update.info(f"Creating Question: {i+1}")
163
+ except ValueError:
164
+ st.warning(f"Error in generating Question: {i+1}...", icon="⚠️")
165
+ continue
166
+
167
+ eval_set_full = list(itertools.chain.from_iterable(eval_set))
168
+
169
+ update.empty()
170
+ ques_update.empty()
171
+
172
+ return eval_set_full
173
+
174
+
175
+ @st.cache_resource()
176
+ def gen_side_bar_qa(text: str) -> None:
177
+ """Generate responses from query."""
178
+ if text:
179
+ # Check if there are no generated question-answer pairs in the session state
180
+ if "eval_set" not in st.session_state:
181
+ # Use the generate_eval function to generate question-answer pairs
182
+ num_eval_questions = 5 # Number of question-answer pairs to generate
183
+ st.session_state.eval_set = generate_eval(text, num_eval_questions, 3000)
184
+
185
+ # Display the question-answer pairs in the sidebar with smaller text
186
+ for i, qa_pair in enumerate(st.session_state.eval_set):
187
+ st.sidebar.markdown(
188
+ f"""
189
+ <div class="css-card">
190
+ <span class="card-tag">Question {i + 1}</span>
191
+ <p style="font-size: 12px;">{qa_pair['question']}</p>
192
+ <p style="font-size: 12px;">{qa_pair['answer']}</p>
193
+ </div>
194
+ """,
195
+ unsafe_allow_html=True,
196
+ )
197
+ st.write("Ready to answer your questions.")
198
+
199
+
200
+ # Add custom CSS
201
+ st.markdown(
202
+ """
203
+ <style>
204
+ #MainMenu {visibility: hidden;
205
+ # }
206
+ footer {visibility: hidden;
207
+ }
208
+ .css-card {
209
+ border-radius: 0px;
210
+ padding: 30px 10px 10px 10px;
211
+ background-color: black;
212
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
213
+ margin-bottom: 10px;
214
+ font-family: "IBM Plex Sans", sans-serif;
215
+ }
216
+ .card-tag {
217
+ border-radius: 0px;
218
+ padding: 1px 5px 1px 5px;
219
+ margin-bottom: 10px;
220
+ position: absolute;
221
+ left: 0px;
222
+ top: 0px;
223
+ font-size: 0.6rem;
224
+ font-family: "IBM Plex Sans", sans-serif;
225
+ color: white;
226
+ background-color: green;
227
+ }
228
+ .css-zt5igj {left:0;
229
+ }
230
+ span.css-10trblm {margin-left:0;
231
+ }
232
+ div.css-1kyxreq {margin-top: -40px;
233
+ }
234
+ </style>
235
+ """,
236
+ unsafe_allow_html=True,
237
+ )
238
+
239
+ st.write(
240
+ """
241
+ <div style="display: flex; align-items: center; margin-left: 0;">
242
+ <h1 style="display: inline-block;">PDF GPT</h1>
243
+ <sup style="margin-left:5px;font-size:small; color: green;">beta</sup>
244
+ </div>
245
+ """,
246
+ unsafe_allow_html=True,
247
+ )
248
+
249
+ # Build sidebar
250
+ with st.sidebar:
251
+ openai_api_key = st.text_input(
252
+ "OpenAI API Key", key="api_key_openai", type="password"
253
+ )
254
+ if openai_api_key and openai_api_key.startswith("sk-"):
255
+ prompt, llm = create_prompt(openai_api_key)
256
+ memory = AgentTokenBufferMemory(llm=llm)
257
+ "[here OpenAI API key](https://platform.openai.com/account/api-keys)"
258
+ else:
259
+ st.info("Please add your correct OpenAI API key in the sidebar.")
260
+
261
+ # If there's no OpenAI API key, show a message and stop the app for rendering further
262
+ if not openai_api_key:
263
+ st.info("Please add your OpenAI API key in the sidebar.")
264
+ st.stop()
265
+
266
+ # Use RecursiveCharacterTextSplitter as the default and only text splitter
267
+ splitter_type = "RecursiveCharacterTextSplitter"
268
+
269
+ uploaded_files = st.file_uploader(
270
+ "Upload a PDF Document", type=["pdf"], accept_multiple_files=True
271
+ )
272
+
273
+ if uploaded_files:
274
+ # Check if last_uploaded_files is not in session_state or
275
+ # if uploaded_files are different from last_uploaded_files
276
+ if (
277
+ "last_uploaded_files" not in st.session_state
278
+ or st.session_state.last_uploaded_files != uploaded_files
279
+ ):
280
+ st.session_state.last_uploaded_files = uploaded_files
281
+ if "eval_set" in st.session_state:
282
+ del st.session_state["eval_set"]
283
+
284
+ # Load and process the uploaded PDF or TXT files.
285
+ raw_pdf_text = load_docs(uploaded_files)
286
+ st.success("Documents uploaded and processed.")
287
+
288
+ # # Question and answering
289
+ # user_question = st.text_input("Enter your question:")
290
+
291
+ # embeddings = gen_embeddings()
292
+ # gen_side_bar_qa(raw_pdf_text)
293
+
294
+ # memory, agent_executor = generate_memory_agent_executre(raw_pdf_text)
295
+ agent_executor = generate_agent_executer(raw_pdf_text)
296
+
297
+ if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
298
+ st.session_state["messages"] = [AIMessage(content=starter_message)]
299
+
300
+ for msg in st.session_state.messages:
301
+ if isinstance(msg, AIMessage):
302
+ st.chat_message("assistant").write(msg.content)
303
+ elif isinstance(msg, HumanMessage):
304
+ st.chat_message("user").write(msg.content)
305
+ memory.chat_memory.add_message(msg)
306
+
307
+ if user_question := st.chat_input(placeholder=starter_message):
308
+ st.chat_message("user").write(user_question)
309
+
310
+ with st.chat_message("assistant"):
311
+ st_callback = StreamlitCallbackHandler(
312
+ st.container(),
313
+ expand_new_thoughts=True,
314
+ collapse_completed_thoughts=True,
315
+ thought_labeler=None,
316
+ )
317
+
318
+ response = agent_executor(
319
+ {"input": user_question, "history": st.session_state.messages},
320
+ callbacks=[st_callback],
321
+ include_run_info=True,
322
+ )
323
+ st.session_state.messages.append(AIMessage(content=response["output"]))
324
+
325
+ st.write(response["output"])
326
+
327
+ memory.save_context({"input": user_question}, response)
328
+
329
+ st.session_state["messages"] = memory.buffer
330
+
331
+ run_id = response["__run"].run_id
332
 
333
+ col_blank, col_text, col1, col2 = st.columns([10, 2, 1, 1])