Le Ngoc Anh commited on
Commit
4a9e510
·
1 Parent(s): 260c00b

commit name

Browse files
Files changed (5) hide show
  1. .gitattributes +0 -35
  2. README.md +0 -13
  3. app.py +198 -0
  4. condaenv.khtzjjyc.requirements.txt +12 -0
  5. requirements.txt +12 -0
.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: Test
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: streamlit
7
- sdk_version: 1.44.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain_community.callbacks import StreamlitCallbackHandler
3
+ from langchain_core.runnables import RunnableConfig
4
+
5
+ from src import CFG
6
+ from src.query_expansion import build_multiple_queries_expansion_chain
7
+ from src.retrieval_qa import (
8
+ build_retrieval_qa,
9
+ build_base_retriever,
10
+ build_rerank_retriever,
11
+ build_compression_retriever,
12
+ )
13
+ from src.vectordb import build_vectordb, load_faiss, load_chroma
14
+ from streamlit_app.pdf_display import get_doc_highlighted, display_pdf
15
+ from streamlit_app.utils import load_base_embeddings, load_llm, load_reranker
16
+
17
+ st.set_page_config(page_title="Retrieval QA", layout="wide")
18
+
19
+ LLM = load_llm()
20
+ BASE_EMBEDDINGS = load_base_embeddings()
21
+ RERANKER = load_reranker()
22
+
23
+
24
+ @st.cache_resource
25
+ def load_vectordb():
26
+ if CFG.VECTORDB_TYPE == "faiss":
27
+ return load_faiss(BASE_EMBEDDINGS)
28
+ if CFG.VECTORDB_TYPE == "chroma":
29
+ return load_chroma(BASE_EMBEDDINGS)
30
+ raise NotImplementedError
31
+
32
+
33
+ @st.cache_resource
34
+ def load_retriever(_vectordb, retrieval_mode):
35
+ if retrieval_mode == "Base":
36
+ return build_base_retriever(_vectordb)
37
+ if retrieval_mode == "Rerank":
38
+ return build_rerank_retriever(_vectordb, RERANKER)
39
+ if retrieval_mode == "Contextual compression":
40
+ return build_compression_retriever(_vectordb, BASE_EMBEDDINGS)
41
+ raise NotImplementedError
42
+
43
+
44
+ def init_sess_state():
45
+ if "uploaded_filename" not in st.session_state:
46
+ st.session_state["uploaded_filename"] = ""
47
+
48
+ if "last_form" not in st.session_state:
49
+ st.session_state["last_form"] = list()
50
+
51
+ if "last_query" not in st.session_state:
52
+ st.session_state["last_query"] = ""
53
+
54
+ if "last_response" not in st.session_state:
55
+ st.session_state["last_response"] = dict()
56
+
57
+ if "last_related" not in st.session_state:
58
+ st.session_state["last_related"] = list()
59
+
60
+
61
+ def doc_qa():
62
+ init_sess_state()
63
+
64
+ with st.sidebar:
65
+ st.header("RAG with quantized LLM")
66
+ st.info(f"LLM: `{CFG.LLM_PATH}`")
67
+ st.info(f"Embeddings: `{CFG.EMBEDDINGS_PATH}`")
68
+ st.info(f"Reranker: `{CFG.RERANKER_PATH}`")
69
+
70
+ uploaded_file = st.file_uploader(
71
+ "Upload a PDF and build VectorDB", type=["pdf"]
72
+ )
73
+ if st.button("Build VectorDB"):
74
+ if uploaded_file is None:
75
+ st.error("No PDF uploaded")
76
+ else:
77
+ uploaded_filename = f"./data/{uploaded_file.name}"
78
+ with open(uploaded_filename, "wb") as f:
79
+ f.write(uploaded_file.getvalue())
80
+ with st.spinner("Building VectorDB..."):
81
+ build_vectordb(uploaded_filename)
82
+ st.session_state.uploaded_filename = uploaded_filename
83
+
84
+ if st.session_state.uploaded_filename != "":
85
+ st.info(f"Current document: {st.session_state.uploaded_filename}")
86
+
87
+ try:
88
+ with st.status("Load VectorDB", expanded=False) as status:
89
+ st.write("Loading VectorDB ...")
90
+ vectordb = load_vectordb()
91
+ status.update(
92
+ label="Loading complete!", state="complete", expanded=False
93
+ )
94
+
95
+ st.success("Reading from existing VectorDB")
96
+ except Exception as e:
97
+ st.error(f"No existing VectorDB found: {e}")
98
+
99
+ c0, c1 = st.columns(2)
100
+
101
+ with c0.form("qa_form"):
102
+ user_query = st.text_area("Your query")
103
+ with st.expander("Settings"):
104
+ mode = st.radio(
105
+ "Mode",
106
+ ["Retrieval only", "Retrieval QA"],
107
+ index=1,
108
+ help="""Retrieval only will output extracts related to your query immediately, \
109
+ while Retrieval QA will output an answer to your query and will take a while on CPU.""",
110
+ )
111
+ retrieval_mode = st.radio(
112
+ "Retrieval method",
113
+ ["Base", "Rerank", "Contextual compression"],
114
+ index=1,
115
+ )
116
+
117
+ submitted = st.form_submit_button("Query")
118
+ if submitted:
119
+ if user_query == "":
120
+ st.error("Please enter a query.")
121
+
122
+ if user_query != "" and (
123
+ st.session_state.last_query != user_query
124
+ ):
125
+ st.session_state.last_query = user_query
126
+
127
+ if mode == "Retrieval only":
128
+ retriever = load_retriever(vectordb, retrieval_mode)
129
+ with c0:
130
+ with st.spinner("Retrieving ..."):
131
+ relevant_docs = retriever.get_relevant_documents(user_query)
132
+
133
+ st.session_state.last_response = {
134
+ "query": user_query,
135
+ "source_documents": relevant_docs,
136
+ }
137
+
138
+ chain = build_multiple_queries_expansion_chain(LLM)
139
+ res = chain.invoke(user_query)
140
+ st.session_state.last_related = [
141
+ x.strip() for x in res.split("\n") if x.strip()
142
+ ]
143
+ else:
144
+ retriever = load_retriever(db, retrieval_mode)
145
+ retrieval_qa = build_retrieval_qa(LLM, retriever)
146
+
147
+ st_callback = StreamlitCallbackHandler(
148
+ parent_container=c0.container(),
149
+ expand_new_thoughts=True,
150
+ collapse_completed_thoughts=True,
151
+ )
152
+ st.session_state.last_response = retrieval_qa.invoke(
153
+ user_query, config=RunnableConfig(callbacks=[st_callback])
154
+ )
155
+ st_callback._complete_current_thought()
156
+
157
+ if st.session_state.last_response:
158
+ with c0:
159
+ st.warning(f"##### {st.session_state.last_query}")
160
+ if st.session_state.last_response.get("result") is not None:
161
+ st.success(st.session_state.last_response["result"])
162
+
163
+ if st.session_state.last_related:
164
+ st.write("#### Related")
165
+ for r in st.session_state.last_related:
166
+ st.write(f"```\n{r}\n```")
167
+
168
+ with c1:
169
+ st.write("#### Sources")
170
+ for row in st.session_state.last_response["source_documents"]:
171
+ st.write("**Page {}**".format(row.metadata["page"] + 1))
172
+ st.info(row.page_content.replace("$", r"\$"))
173
+
174
+ # Display PDF
175
+ st.write("---")
176
+ _display_pdf_from_docs(st.session_state.last_response["source_documents"])
177
+
178
+
179
+ def _display_pdf_from_docs(source_documents):
180
+ n = len(source_documents)
181
+ i = st.radio(
182
+ "View in PDF", list(range(n)), format_func=lambda x: f"Extract {x + 1}"
183
+ )
184
+ row = source_documents[i]
185
+ try:
186
+ extracted_doc, page_nums = get_doc_highlighted(
187
+ row.metadata["source"], row.page_content
188
+ )
189
+ if extracted_doc is None:
190
+ st.error("No page found")
191
+ else:
192
+ display_pdf(extracted_doc, page_nums[0] + 1)
193
+ except Exception as e:
194
+ st.error(e)
195
+
196
+
197
+ if __name__ == "__main__":
198
+ doc_qa()
condaenv.khtzjjyc.requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chromadb==0.4.22
2
+ ctransformers==0.2.27
3
+ faiss-cpu==1.7.4
4
+ fastapi==0.104.1
5
+ langchain==0.1.3
6
+ PyMuPDF==1.23.8
7
+ python-box==7.1.1
8
+ rank-bm25==0.2.2
9
+ sentence-transformers==2.2.2
10
+ simsimd==3.3.0
11
+ streamlit==1.30.0
12
+ umap-learn==0.5.5
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chromadb==0.4.22
2
+ ctransformers==0.2.27
3
+ faiss-cpu==1.7.4
4
+ fastapi==0.104.1
5
+ langchain==0.1.3
6
+ PyMuPDF==1.23.8
7
+ python-box==7.1.1
8
+ rank-bm25==0.2.2
9
+ sentence-transformers==2.2.2
10
+ simsimd==3.3.0
11
+ streamlit==1.30.0
12
+ umap-learn==0.5.5