kunalchamoli commited on
Commit
4fc9940
·
verified ·
1 Parent(s): 9082fb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -53
app.py CHANGED
@@ -1,13 +1,16 @@
1
  import gradio as gr
2
  import os
3
 
4
-
5
  import string
6
  import random
7
  import requests
8
  from bs4 import BeautifulSoup
9
  from datetime import datetime
10
 
 
 
 
 
11
 
12
  from langchain.document_loaders import PyPDFLoader
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -30,11 +33,7 @@ import accelerate
30
 
31
 
32
  # default_persist_directory = './chroma_HF/'
33
- list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1", \
34
- "HuggingFaceH4/zephyr-7b-beta", "NousResearch/Llama-2-7b-chat-hf", \
35
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \
36
- "google/flan-t5-xxl"
37
- ]
38
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
39
 
40
  # Load PDF document and create doc splits
@@ -43,6 +42,7 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
43
  pages = []
44
  for loader in loaders:
45
  pages.extend(loader.load())
 
46
  text_splitter = RecursiveCharacterTextSplitter(
47
  chunk_size = chunk_size,
48
  chunk_overlap = chunk_overlap)
@@ -50,10 +50,8 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
50
  return doc_splits
51
 
52
  def convert_github_url_to_raw(url):
53
- # Ensure the URL is a GitHub blob URL
54
- if "github.com" in url and "/blob/" in url:
55
- raw_url = url.replace("github.com", "raw.githubusercontent.com").replace("/blob", "")
56
- response = requests.get(raw_url)
57
  html_content = response.text
58
  # Step 2: Find the GitHub Icon and Extract the Link
59
  soup = BeautifulSoup(html_content, "html.parser")
@@ -61,24 +59,46 @@ def convert_github_url_to_raw(url):
61
  for a in soup.find_all('a', href=True):
62
  if "github.com" in a['href']: # Assuming the GitHub link contains "github.com"
63
  github_icon_link = a['href']
 
64
  break
65
- markdown_url = convert_github_url_to_raw(github_icon_link)
66
- response = requests.get(markdown_url)
67
- return response
68
- else:
 
 
69
  return ''
70
 
71
  def load_url(list_url_path, chunk_size, chunk_overlap):
72
- texts = [convert_github_url_to_raw(x) for x in list_url_path]
 
 
73
  pages = []
74
- for text in texts:
75
- pages.extend(text)
 
76
  text_splitter = RecursiveCharacterTextSplitter(
77
  chunk_size = chunk_size,
78
  chunk_overlap = chunk_overlap)
79
  doc_splits = text_splitter.split_documents(pages)
 
80
  return doc_splits
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # Create vector database
83
  def create_db(splits, collection_name):
84
  embedding = HuggingFaceEmbeddings()
@@ -107,17 +127,11 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
107
  progress(0.5, desc="Initializing HF Hub...")
108
  # Use of trust_remote_code as model_kwargs
109
  # URL: https://github.com/langchain-ai/langchain/issues/6080
110
- if llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
111
- llm = HuggingFaceHub(
112
- repo_id=llm_model,
113
- model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
114
- )
115
- else:
116
- llm = HuggingFaceHub(
117
- repo_id=llm_model,
118
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
119
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
120
- )
121
 
122
  progress(0.75, desc="Defining buffer memory...")
123
  memory = ConversationBufferMemory(
@@ -144,30 +158,47 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
144
  # Initialize database
145
  def initialize_database(list_file_obj, input_urls, chunk_size, chunk_overlap, progress=gr.Progress()):
146
  # Create list of documents (when valid)
147
- list_file_path = [x.name for x in list_file_obj if x is not None]
148
- list_url = [x for x in input_urls if x is not None]
 
 
 
 
 
 
 
149
 
150
  # Create collection_name for vector database
151
  progress(0.1, desc="Creating collection...")
152
- # collection_name = Path(list_file_path[0]).stem
153
-
154
- # # Fix potential issues from naming convention
155
- # collection_name = collection_name.replace(" ","-")
156
- # collection_name = collection_name[:50]
157
  res = ''.join(random.choices(string.ascii_letters, k=10))
158
  collection_name = f"HuggingFace101_{res}"
159
  print('Collection name: ', collection_name)
160
  progress(0.25, desc="Loading document...")
161
 
162
  # Load document and create splits
163
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
164
- print(type(doc_splits))
 
 
 
 
 
 
 
 
 
165
 
 
 
 
 
 
 
166
  # Create or load vector database
167
  progress(0.5, desc="Generating vector database...")
168
 
169
  # global vector_db
170
- vector_db = create_db(doc_splits, collection_name)
171
  progress(0.9, desc="Done!")
172
  return vector_db, collection_name, "Complete!"
173
 
@@ -196,8 +227,12 @@ def conversation(qa_chain, message, history):
196
  response_source1 = response_sources[0].page_content.strip()
197
  response_source2 = response_sources[1].page_content.strip()
198
  # Langchain sources are zero-based
199
- response_source1_page = response_sources[0].metadata["page"] + 1
200
- response_source2_page = response_sources[1].metadata["page"] + 1
 
 
 
 
201
  # print ('chat response: ', response_answer)
202
  # print('DB source', response_sources)
203
 
@@ -215,7 +250,6 @@ def upload_file(file_obj):
215
  # initialize_database(file_path, progress)
216
  return list_file_path
217
 
218
-
219
  def demo():
220
  with gr.Blocks(theme="base") as demo:
221
  vector_db = gr.State()
@@ -231,15 +265,38 @@ def demo():
231
  """)
232
  with gr.Tab("Step 1 - Document pre-processing"):
233
  with gr.Row():
234
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
235
- input_url = gr.Textbox(label="Or enter a URL", placeholder="https://example.com")
 
 
 
 
 
 
 
236
  with gr.Row():
237
- db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
 
 
 
 
238
  with gr.Accordion("Advanced options - Document text splitter", open=False):
239
  with gr.Row():
240
- slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
 
 
 
 
 
 
241
  with gr.Row():
242
- slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
 
 
 
 
 
 
243
  with gr.Row():
244
  db_progress = gr.Textbox(label="Vector database initialization", value="None")
245
  with gr.Row():
@@ -247,15 +304,36 @@ def demo():
247
 
248
  with gr.Tab("Step 2 - QA chain initialization"):
249
  with gr.Row():
250
- llm_btn = gr.Radio(list_llm_simple, \
251
- label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
 
 
 
252
  with gr.Accordion("Advanced options - LLM model", open=False):
253
  with gr.Row():
254
- slider_temperature = gr.Slider(minimum = 0.0, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
 
 
 
 
 
 
255
  with gr.Row():
256
- slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
 
 
 
 
 
 
257
  with gr.Row():
258
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
 
 
 
 
 
 
259
  with gr.Row():
260
  llm_progress = gr.Textbox(value="None",label="QA chain initialization")
261
  with gr.Row():
@@ -281,6 +359,7 @@ def demo():
281
  db_btn.click(initialize_database, \
282
  inputs=[document, input_url, slider_chunk_size, slider_chunk_overlap], \
283
  outputs=[vector_db, collection_name, db_progress])
 
284
  qachain_btn.click(initialize_LLM, \
285
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
286
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0], \
@@ -291,11 +370,11 @@ def demo():
291
  # Chatbot events
292
  msg.submit(conversation, \
293
  inputs=[qa_chain, msg, chatbot], \
294
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], \
295
  queue=False)
296
  submit_btn.click(conversation, \
297
  inputs=[qa_chain, msg, chatbot], \
298
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page], \
299
  queue=False)
300
  clear_btn.click(lambda:[None,"",0,"",0], \
301
  inputs=None, \
 
1
  import gradio as gr
2
  import os
3
 
 
4
  import string
5
  import random
6
  import requests
7
  from bs4 import BeautifulSoup
8
  from datetime import datetime
9
 
10
+ import wget
11
+ from langchain_community.document_loaders import UnstructuredMarkdownLoader
12
+ from langchain_community.document_loaders import UnstructuredURLLoader
13
+
14
 
15
  from langchain.document_loaders import PyPDFLoader
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
33
 
34
 
35
  # default_persist_directory = './chroma_HF/'
36
+ list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"]
 
 
 
 
37
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
38
 
39
  # Load PDF document and create doc splits
 
42
  pages = []
43
  for loader in loaders:
44
  pages.extend(loader.load())
45
+ print(pages)
46
  text_splitter = RecursiveCharacterTextSplitter(
47
  chunk_size = chunk_size,
48
  chunk_overlap = chunk_overlap)
 
50
  return doc_splits
51
 
52
  def convert_github_url_to_raw(url):
53
+ try:
54
+ response = requests.get(url)
 
 
55
  html_content = response.text
56
  # Step 2: Find the GitHub Icon and Extract the Link
57
  soup = BeautifulSoup(html_content, "html.parser")
 
59
  for a in soup.find_all('a', href=True):
60
  if "github.com" in a['href']: # Assuming the GitHub link contains "github.com"
61
  github_icon_link = a['href']
62
+ print(github_icon_link)
63
  break
64
+ raw_url = github_icon_link.replace("github.com", "raw.githubusercontent.com").replace("/blob", "")
65
+ # final_response = requests.get(raw_url)
66
+ # content = final_response.text
67
+ return raw_url
68
+ except Exception as e:
69
+ print(e)
70
  return ''
71
 
72
  def load_url(list_url_path, chunk_size, chunk_overlap):
73
+ urls = [convert_github_url_to_raw(x) for x in list_url_path]
74
+ files = [wget.download(x) for x in urls]
75
+ loaders = [UnstructuredMarkdownLoader(f'./{x}') for x in files]
76
  pages = []
77
+ for loader in loaders:
78
+ pages.extend(loader.load())
79
+ print(pages)
80
  text_splitter = RecursiveCharacterTextSplitter(
81
  chunk_size = chunk_size,
82
  chunk_overlap = chunk_overlap)
83
  doc_splits = text_splitter.split_documents(pages)
84
+ _ = [os.remove(f'./{x}') for x in files]
85
  return doc_splits
86
 
87
+ # def load_url(list_url_path, chunk_size, chunk_overlap):
88
+ # texts = [convert_github_url_to_raw(x) for x in list_url_path]
89
+ # pages = []
90
+ # for text in texts:
91
+ # pages.append(text)
92
+ # print(f'length of pages is {len(pages)}')
93
+ # text_splitter = RecursiveCharacterTextSplitter(
94
+ # chunk_size = chunk_size,
95
+ # chunk_overlap = chunk_overlap)
96
+ # total_doc_splits = []
97
+ # docs_ = text_splitter.create_documents(pages)
98
+ # print(f"lenth of docs is {len(docs_)}")
99
+ # return docs_
100
+
101
+
102
  # Create vector database
103
  def create_db(splits, collection_name):
104
  embedding = HuggingFaceEmbeddings()
 
127
  progress(0.5, desc="Initializing HF Hub...")
128
  # Use of trust_remote_code as model_kwargs
129
  # URL: https://github.com/langchain-ai/langchain/issues/6080
130
+ llm = HuggingFaceHub(
131
+ repo_id=llm_model,
132
+ # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
133
+ model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
134
+ )
 
 
 
 
 
 
135
 
136
  progress(0.75, desc="Defining buffer memory...")
137
  memory = ConversationBufferMemory(
 
158
  # Initialize database
159
  def initialize_database(list_file_obj, input_urls, chunk_size, chunk_overlap, progress=gr.Progress()):
160
  # Create list of documents (when valid)
161
+ try:
162
+ list_file_path = [x.name for x in list_file_obj if x is not None]
163
+ # print(f'file paths are {list_file_path}')
164
+ except:
165
+ list_file_path = None
166
+ try:
167
+ list_url = [url.strip() for url in input_urls.split(',') if url.strip()]
168
+ except:
169
+ list_url = None
170
 
171
  # Create collection_name for vector database
172
  progress(0.1, desc="Creating collection...")
 
 
 
 
 
173
  res = ''.join(random.choices(string.ascii_letters, k=10))
174
  collection_name = f"HuggingFace101_{res}"
175
  print('Collection name: ', collection_name)
176
  progress(0.25, desc="Loading document...")
177
 
178
  # Load document and create splits
179
+ if list_file_path is not None:
180
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
181
+ else:
182
+ doc_splits = []
183
+ if list_url is not None:
184
+ url_splits = load_url(list_url, chunk_size, chunk_overlap)
185
+ else:
186
+ url_splits = []
187
+
188
+ # pdf_data_type = type(doc_splits)
189
+ # url_data_type = type(url_splits)
190
 
191
+ # print(pdf_data_type)
192
+ # print(url_data_type)
193
+ total_splits = []
194
+ total_splits.extend(doc_splits)
195
+ total_splits.extend(url_splits)
196
+ print(total_splits[0].metadata.keys())
197
  # Create or load vector database
198
  progress(0.5, desc="Generating vector database...")
199
 
200
  # global vector_db
201
+ vector_db = create_db(total_splits, collection_name)
202
  progress(0.9, desc="Done!")
203
  return vector_db, collection_name, "Complete!"
204
 
 
227
  response_source1 = response_sources[0].page_content.strip()
228
  response_source2 = response_sources[1].page_content.strip()
229
  # Langchain sources are zero-based
230
+ try:
231
+ response_source1_page = response_sources[0].metadata["page"] + 1
232
+ response_source2_page = response_sources[1].metadata["page"] + 1
233
+ except:
234
+ response_source1_page = response_sources[0].metadata['source']
235
+ response_source2_page = response_sources[1].metadata['source']
236
  # print ('chat response: ', response_answer)
237
  # print('DB source', response_sources)
238
 
 
250
  # initialize_database(file_path, progress)
251
  return list_file_path
252
 
 
253
  def demo():
254
  with gr.Blocks(theme="base") as demo:
255
  vector_db = gr.State()
 
265
  """)
266
  with gr.Tab("Step 1 - Document pre-processing"):
267
  with gr.Row():
268
+ document = gr.Files(height=100,
269
+ file_count="multiple",
270
+ file_types=["pdf"],
271
+ interactive=True,
272
+ label="Upload your PDF documents (single or multiple)")
273
+ input_url = gr.Textbox(label="Or Enter a URL",
274
+ value="https://huggingface.co/blog/segmoe",
275
+ placeholder="Enter URLs separated by commas"
276
+ )
277
  with gr.Row():
278
+ db_btn = gr.Radio(["ChromaDB"],
279
+ label="Vector database type",
280
+ value = "ChromaDB",
281
+ type="index",
282
+ info="Choose your vector database")
283
  with gr.Accordion("Advanced options - Document text splitter", open=False):
284
  with gr.Row():
285
+ slider_chunk_size = gr.Slider(minimum = 100,
286
+ maximum = 1000,
287
+ value=600,
288
+ step=20,
289
+ label="Chunk size",
290
+ info="Chunk size",
291
+ interactive=True)
292
  with gr.Row():
293
+ slider_chunk_overlap = gr.Slider(minimum = 10,
294
+ maximum = 200,
295
+ value=40,
296
+ step=10,
297
+ label="Chunk overlap",
298
+ info="Chunk overlap",
299
+ interactive=True)
300
  with gr.Row():
301
  db_progress = gr.Textbox(label="Vector database initialization", value="None")
302
  with gr.Row():
 
304
 
305
  with gr.Tab("Step 2 - QA chain initialization"):
306
  with gr.Row():
307
+ llm_btn = gr.Radio(list_llm_simple,
308
+ label="LLM models",
309
+ value = list_llm_simple[0],
310
+ type="index",
311
+ info="Choose your LLM model")
312
  with gr.Accordion("Advanced options - LLM model", open=False):
313
  with gr.Row():
314
+ slider_temperature = gr.Slider(minimum = 0.0,
315
+ maximum = 1.0,
316
+ value=0.7,
317
+ step=0.1,
318
+ label="Temperature",
319
+ info="Model temperature",
320
+ interactive=True)
321
  with gr.Row():
322
+ slider_maxtokens = gr.Slider(minimum = 224,
323
+ maximum = 4096,
324
+ value=1024,
325
+ step=32,
326
+ label="Max Tokens",
327
+ info="Model max tokens",
328
+ interactive=True)
329
  with gr.Row():
330
+ slider_topk = gr.Slider(minimum = 1,
331
+ maximum = 10,
332
+ value=3,
333
+ step=1,
334
+ label="top-k samples",
335
+ info="Model top-k samples",
336
+ interactive=True)
337
  with gr.Row():
338
  llm_progress = gr.Textbox(value="None",label="QA chain initialization")
339
  with gr.Row():
 
359
  db_btn.click(initialize_database, \
360
  inputs=[document, input_url, slider_chunk_size, slider_chunk_overlap], \
361
  outputs=[vector_db, collection_name, db_progress])
362
+
363
  qachain_btn.click(initialize_LLM, \
364
  inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
365
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0], \
 
370
  # Chatbot events
371
  msg.submit(conversation, \
372
  inputs=[qa_chain, msg, chatbot], \
373
+ outputs=[qa_chain, msg, chatbot], \
374
  queue=False)
375
  submit_btn.click(conversation, \
376
  inputs=[qa_chain, msg, chatbot], \
377
+ outputs=[qa_chain, msg, chatbot], \
378
  queue=False)
379
  clear_btn.click(lambda:[None,"",0,"",0], \
380
  inputs=None, \