JaMussCraft commited on
Commit
cdcd010
·
verified ·
1 Parent(s): fc854ff

Uploaded RAG UI gradio app

Browse files
Files changed (1) hide show
  1. app.py +556 -58
app.py CHANGED
@@ -1,64 +1,562 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if __name__ == "__main__":
64
  demo.launch()
 
1
  import gradio as gr
2
+ import csv
3
+ import random
4
+ import os
5
+ import shutil
6
+ import json
7
+ from llama_index.embeddings.openai import OpenAIEmbedding
8
+ from llama_index.core import (
9
+ VectorStoreIndex,
10
+ SimpleDirectoryReader,
11
+ StorageContext,
12
+ load_index_from_storage,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  )
14
+ from llama_index.core.settings import Settings
15
+ import faiss
16
+ import numpy as np
17
+ from llama_index.vector_stores.faiss import FaissVectorStore
18
+ from llama_index.core.node_parser import SimpleNodeParser, SentenceSplitter
19
+ from llama_index.core.schema import Document
20
+ from llama_index.core.schema import IndexNode
21
+ from llama_index.core import ServiceContext
22
+ from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine
23
+ from llama_index.embeddings.huggingface.base import HuggingFaceEmbedding
24
+ from llama_index.llms.openai import OpenAI
25
+ from transformers import BitsAndBytesConfig
26
+ from llama_index.core.prompts import PromptTemplate
27
+ import torch
28
+ import pandas as pd
29
+ import fitz
30
+ from transformers import pipeline
31
+ from sklearn.metrics.pairwise import cosine_similarity
32
+ from sklearn.feature_extraction.text import TfidfVectorizer
33
 
34
 
35
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
36
+
37
+ # os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
38
+ llm = OpenAI(temperature=0, model="gpt-4o-mini", max_tokens=512)
39
+ Settings.llm = llm
40
+
41
+ UPLOAD_DIR = "uploaded_files"
42
+ STATE_FILE = "uploaded_files_state.json"
43
+ PERSIST_DIR = "persisted_indexes"
44
+
45
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
46
+ os.makedirs(PERSIST_DIR, exist_ok=True)
47
+
48
+ # !!! why???
49
+ # torch.set_num_threads(1)
50
+ # torch.set_num_interop_threads(1)
51
+
52
+
53
+ def index_gen(file_path, index_name):
54
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
55
+
56
+ # One giant index: insertion example
57
+ # if os.path.exists('persisted_indexes/test1.faiss'):
58
+ # print("RUNNING TEST!")
59
+ # # Load document from file
60
+ # documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
61
+
62
+ # faiss_index = faiss.read_index('persisted_indexes/test1.faiss')
63
+ # embed_model = HuggingFaceEmbedding(
64
+ # model_name="BAAI/bge-small-en-v1.5"
65
+ # )
66
+ # Settings.embed_model = embed_model
67
+
68
+ # vector_store = FaissVectorStore(faiss_index=faiss_index)
69
+ # storage_context = StorageContext.from_defaults(
70
+ # persist_dir=PERSIST_DIR, vector_store=vector_store
71
+ # )
72
+
73
+ # index = load_index_from_storage(storage_context)
74
+ # print(index)
75
+ # for doc in documents:
76
+ # print('inserting ', doc)
77
+ # index.insert(doc)
78
+ # index.storage_context.persist(PERSIST_DIR)
79
+ # faiss.write_index(faiss_index, 'persisted_indexes/test1.faiss')
80
+ # print('insertion and persist complete!')
81
+ # return index
82
+
83
+
84
+ try:
85
+ # Load document from file
86
+ documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
87
+
88
+ # Initialize embedding model and vector store
89
+ embed_model = HuggingFaceEmbedding(
90
+ model_name="BAAI/bge-small-en-v1.5", device=device
91
+ )
92
+ Settings.embed_model = embed_model
93
+ embedding_dim = 384 # Ensure this matches the embedding model used
94
+
95
+ faiss_index = faiss.IndexFlatL2(embedding_dim)
96
+ vector_store = FaissVectorStore(faiss_index=faiss_index)
97
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
98
+
99
+ print(f"Number of documents to index: {len(documents)}.")
100
+
101
+ # Parse and index documents
102
+ parser = SentenceSplitter()
103
+ nodes = parser.get_nodes_from_documents(documents)
104
+ index = VectorStoreIndex(nodes, storage_context=storage_context)
105
+ print(f"Number of nodes generated:{len(nodes)}")
106
+
107
+ # individual index directory
108
+ index_directory = os.path.join(PERSIST_DIR, index_name)
109
+ os.makedirs(index_directory, exist_ok=True)
110
+ index_path = os.path.join(index_directory, f"{index_name}.faiss")
111
+
112
+
113
+ index.storage_context.persist(index_directory)
114
+ # index.storage_context.persist(PERSIST_DIR)
115
+ faiss.write_index(faiss_index, index_path)
116
+
117
+ if not os.path.exists(index_path):
118
+ raise FileNotFoundError(
119
+ f"FAISS index file not created at path: {index_path}"
120
+ )
121
+
122
+ return index_path
123
+
124
+ except Exception as e:
125
+ print(f"Error in index_gen with file {file_path}: {str(e)}")
126
+ return None
127
+
128
+
129
+ def save_uploaded_files_state(uploaded_files, indexed_files=None):
130
+ try:
131
+ state_file_json = {}
132
+ state_file_json["uploaded_files"] = list(uploaded_files)
133
+
134
+ if indexed_files:
135
+ state_file_json["indexed_files"] = list(indexed_files)
136
+
137
+ # else:
138
+ # # ??? why
139
+ # _, existing_indexed_files = load_uploaded_files_state()
140
+ # state_file_json["indexed_files"] = list(existing_indexed_files)
141
+
142
+ with open(STATE_FILE, "w") as f:
143
+ json.dump(state_file_json, f, indent=4)
144
+
145
+ except IOError as e:
146
+ print(f"Error saving uploaded files state: {str(e)}")
147
+
148
+
149
+ def load_uploaded_files_state():
150
+ try:
151
+ if os.path.exists(STATE_FILE):
152
+ with open(STATE_FILE, "r") as f:
153
+ state_data = json.load(f)
154
+ return set(state_data.get("uploaded_files", set())), set(
155
+ state_data.get("indexed_files", set())
156
+ )
157
+
158
+ except (IOError, json.JSONDecodeError) as e:
159
+ print(f"Error loading uploaded files state: {str(e)}")
160
+
161
+ return set(), set()
162
+
163
+
164
+ def save_file(file_path):
165
+ try:
166
+ file_name = os.path.basename(file_path)
167
+ server_save_path = os.path.join(UPLOAD_DIR, file_name)
168
+ shutil.copy(file_path, server_save_path)
169
+ return server_save_path
170
+
171
+ except (IOError, shutil.Error) as e:
172
+ print(f"Error saving file {file_path}: {str(e)}")
173
+ return None
174
+
175
+
176
+ with gr.Blocks() as demo:
177
+ gr.Markdown("## 📁 File Management & Chat Assistant")
178
+
179
+ with gr.Tabs():
180
+ # Tab 1: File Management
181
+ with gr.Tab("File Management"):
182
+ with gr.Row():
183
+ with gr.Column(scale=1):
184
+ file_upload = gr.File(
185
+ label="Upload PDF,JSON or TXT Files",
186
+ file_types=[".pdf", ".json", ".txt", "directory"],
187
+ file_count="multiple",
188
+ interactive=True,
189
+ )
190
+ file_table = gr.DataFrame(
191
+ headers=["Sr. No.", "File Name", "File Size"],
192
+ value=[],
193
+ interactive=False,
194
+ row_count=(4, "dynamic"),
195
+ wrap=True,
196
+ max_height=1000
197
+ )
198
+ file_checkbox = gr.CheckboxGroup(
199
+ label="Select Files to Index/Delete", choices=[]
200
+ )
201
+ select_all_button = gr.Button("Select All")
202
+ index_button = gr.Button("Index Selected Files")
203
+ delete_button = gr.Button("Delete Selected Files")
204
+
205
+ with gr.Column(scale=3):
206
+ message_box = gr.Markdown("")
207
+ chatbot = gr.Chatbot(label="LLM", type="messages")
208
+
209
+ with gr.Row():
210
+ chat_input = gr.Textbox(
211
+ show_label=False,
212
+ placeholder="Type your message here",
213
+ scale=8,
214
+ )
215
+ send_button = gr.Button("Send", scale=1)
216
+
217
+ # Tab 2: Indexed Files
218
+ with gr.Tab("Indexed Files"):
219
+ indexed_file_table = gr.DataFrame(
220
+ headers=["Indexed File", "Size"],
221
+ value=[],
222
+ interactive=False,
223
+ row_count=(4, "dynamic"),
224
+ )
225
+
226
+ # STATES
227
+ uploaded_files_state = gr.State(load_uploaded_files_state())
228
+
229
+ @delete_button.click(
230
+ inputs=[file_checkbox, uploaded_files_state, file_upload],
231
+ outputs=[file_table, file_checkbox, uploaded_files_state, indexed_file_table],
232
+ )
233
+ def delete_files(selected_files, uploaded_files_state, file_upload):
234
+ print("deleting files...: ", selected_files, uploaded_files_state, file_upload)
235
+
236
+ uploaded_files, indexed_files = uploaded_files_state
237
+
238
+ if not selected_files or not uploaded_files:
239
+ return gr.update(), selected_files, (uploaded_files, indexed_files)
240
+
241
+ # default return
242
+ # return [[]], selected_files, uploaded_files_state
243
+
244
+ # "we" means with extension
245
+ selected_file_names_we = [file.split(". ")[1] for file in selected_files]
246
+
247
+ for file_name_we in selected_file_names_we:
248
+ file_path = os.path.join(UPLOAD_DIR, file_name_we)
249
+ index_name = file_name_we.split(".")[0]
250
+ index_directory = os.path.join(PERSIST_DIR, index_name)
251
+ index_path = os.path.join(index_directory, f'{index_name}.faiss')
252
+ print(file_name_we, file_path, index_name, index_directory, index_path)
253
+
254
+ try:
255
+ if os.path.exists(file_path):
256
+ os.remove(file_path)
257
+ uploaded_files.remove(file_path)
258
+
259
+ else:
260
+ gr.Error(f"Could not delete file (File not found): {file_path}", duration=3)
261
+
262
+ if os.path.exists(index_directory):
263
+ shutil.rmtree(index_directory)
264
+ indexed_files.remove(index_path)
265
+
266
+ else:
267
+ gr.Error(f"Could not delete index directory (Path not found): {index_directory}", duration=3)
268
+
269
+ except Exception as e:
270
+ gr.Error(f"Error deleting {file_name_we}: {str(e)}", duration=3)
271
+
272
+ save_uploaded_files_state(uploaded_files, indexed_files)
273
+
274
+ file_info, checkbox_options = [], []
275
+ for idx, file_path in enumerate(uploaded_files, start=1):
276
+ file_name = os.path.basename(file_path)
277
+ file_size = os.path.getsize(file_path)
278
+ file_info.append([idx, file_name, f"{round(file_size / 1024, 2)} KB"])
279
+ checkbox_options.append(f"{idx}. {file_name}")
280
+
281
+ indexed_file_display = [
282
+ [
283
+ os.path.basename(index_path).split(".")[0],
284
+ f"{round(os.path.getsize(index_path) / 1024, 2)} KB",
285
+ ]
286
+ for index_path in indexed_files
287
+ ]
288
+
289
+ return (
290
+ file_info,
291
+ gr.update(choices=checkbox_options, value=[]),
292
+ (uploaded_files, indexed_files),
293
+ indexed_file_display,
294
+ )
295
+
296
+ @chat_input.submit(
297
+ inputs=[chat_input, chatbot, uploaded_files_state],
298
+ outputs=[chat_input, chatbot],
299
+ )
300
+ @send_button.click(
301
+ inputs=[chat_input, chatbot, uploaded_files_state],
302
+ outputs=[chat_input, chatbot],
303
+ )
304
+ # Chat function with improved SQuAD matching
305
+ def chat_with_bot(user_input, chat_history, uploaded_files_state):
306
+ if not user_input:
307
+ return user_input, chat_history
308
+
309
+ _, indexed_files = uploaded_files_state
310
+
311
+ chat_history.append(
312
+ {
313
+ "role": "user",
314
+ "content": user_input,
315
+ }
316
+ )
317
+
318
+ response = "I do not have the answer. Please upload and index relevant files first."
319
+ file_with_answer = None
320
+ custom_prompt = PromptTemplate(
321
+ template=(
322
+ "Use the following context to answer the query. Do not use outside knowledge. "
323
+ "If the answer is not found in the context, respond with: 'I do not have the answer.'\n"
324
+ "Context: {context_str}\n"
325
+ "Query: {query_str}\n"
326
+ "Answer:"
327
+ )
328
+ )
329
+
330
+ if not index_files:
331
+ response = "No files have been indexed for answering this question."
332
+
333
+ try:
334
+ for index_path in indexed_files:
335
+ print('checking ', index_path)
336
+ file_name = os.path.basename(index_path)
337
+ index_name = file_name.split(".")[0]
338
+
339
+ if not os.path.exists(index_path):
340
+ print(f"FAISS index not found at {index_path}, skipping...")
341
+ continue
342
+
343
+ storage_context = None
344
+ try:
345
+ faiss_index = faiss.read_index(index_path)
346
+ embed_model = HuggingFaceEmbedding(
347
+ model_name="BAAI/bge-small-en-v1.5"
348
+ )
349
+ Settings.embed_model = embed_model
350
+
351
+ vector_store = FaissVectorStore(faiss_index=faiss_index)
352
+ storage_context = StorageContext.from_defaults(
353
+ persist_dir=f'{PERSIST_DIR}/{index_name}', vector_store=vector_store
354
+ )
355
+
356
+ except Exception as e:
357
+ raise RuntimeError(
358
+ f"Failed to load FAISS index at {index_path}: {str(e)}"
359
+ )
360
+
361
+ # print(get_global("embed_model"))
362
+
363
+ index = load_index_from_storage(storage_context)
364
+ print(f"Index loaded with {len(index.docstore.docs)} documents.")
365
+
366
+ retriever = index.as_retriever(similarity_top_k=10)
367
+ query_engine = RetrieverQueryEngine(retriever=retriever)
368
+ query_engine.update_prompts(
369
+ {"response_synthesizer:text_qa_template": custom_prompt}
370
+ )
371
+
372
+ # Query the index for the user input
373
+ query_result = query_engine.query(user_input)
374
+ print("query result: ", query_result)
375
+
376
+ if query_result.response.strip() != "I do not have the answer.":
377
+ response = f"{query_result.response} \n\n Source: {file_name}"
378
+ # response = f"Answer from indexed file '{file_name}': {query_result.response}"
379
+ file_with_answer = file_name
380
+ break
381
+
382
+ else:
383
+ response = "I do not have the answer."
384
+
385
+ except Exception as e:
386
+ response = f"Error querying the index: {str(e)}"
387
+ print(response)
388
+
389
+ chat_history.append(
390
+ {
391
+ "role": "assistant",
392
+ "content": response,
393
+ }
394
+ )
395
+
396
+ return gr.update(value=""), chat_history
397
+
398
+ @index_button.click(
399
+ inputs=[file_checkbox, uploaded_files_state, indexed_file_table],
400
+ outputs=[
401
+ file_checkbox,
402
+ uploaded_files_state,
403
+ indexed_file_table,
404
+ select_all_button,
405
+ ],
406
+ )
407
+ def index_files(selected_files, uploaded_files_state, indexed_file_table):
408
+ uploaded_files, indexed_files = uploaded_files_state
409
+ print("indexing files...", selected_files, uploaded_files_state)
410
+
411
+ if not selected_files or not uploaded_files:
412
+ gr.Warning("Please select or upload files for indexing.", duration=3)
413
+ return (
414
+ selected_files,
415
+ uploaded_files_state,
416
+ indexed_file_table,
417
+ gr.update(),
418
+ )
419
+
420
+
421
+ files_to_index = []
422
+ for file in selected_files:
423
+ file_name_we = file.split(". ")[1]
424
+ file_path = os.path.join(UPLOAD_DIR, file_name_we)
425
+ index_name = file_name_we.split(".")[0]
426
+ index_directory = os.path.join(PERSIST_DIR, index_name)
427
+ index_path = os.path.join(index_directory, f'{index_name}.faiss')
428
+
429
+ if index_path not in indexed_files:
430
+ files_to_index.append(file_path)
431
+ else:
432
+ gr.Info(
433
+ f"File '{os.path.basename(file_path)}' is already indexed.",
434
+ duration=3,
435
+ )
436
+
437
+ for file_path in files_to_index:
438
+ try:
439
+ file_name = os.path.basename(file_path)
440
+ index_name = file_name.split(".")[0]
441
+ index_path = index_gen(file_path, index_name)
442
+ gr.Info(f"Successfully indexed: {file_name}", duration=3)
443
+
444
+ # Save indexed file info for persistence
445
+ # index_path = os.path.join(PERSIST_DIR, f"{index_name}.faiss")
446
+ indexed_files.add(index_path)
447
+
448
+ except Exception as e:
449
+ gr.Error(f"Error indexing {file_path}: {str(e)}", duration=3)
450
+
451
+ # Update the state with new indexed files
452
+ save_uploaded_files_state(uploaded_files, indexed_files)
453
+
454
+ # Convert indexed file info to display format
455
+ indexed_file_display = [
456
+ [
457
+ os.path.basename(index_path).split(".")[0],
458
+ f"{round(os.path.getsize(index_path) / 1024, 2)} KB",
459
+ ]
460
+ for index_path in indexed_files
461
+ ]
462
+
463
+ return (
464
+ gr.update(value=[]),
465
+ (uploaded_files, indexed_files),
466
+ indexed_file_display,
467
+ gr.update(value="Select All"),
468
+ )
469
+
470
+ @select_all_button.click(
471
+ inputs=[uploaded_files_state, select_all_button, file_checkbox],
472
+ outputs=[file_checkbox, select_all_button],
473
+ )
474
+ def select_all_checkbox(uploaded_files_state, select_all_button, file_checkbox):
475
+ uploaded_files, _ = uploaded_files_state
476
+
477
+ if not uploaded_files:
478
+ return file_checkbox, select_all_button
479
+
480
+ button_value = ""
481
+ if select_all_button == "Select All":
482
+ button_value = "Unselect All"
483
+ else:
484
+ button_value = "Select All"
485
+
486
+ checkbox_options = []
487
+ if not file_checkbox:
488
+ checkbox_options = [
489
+ f"{idx + 1}. {os.path.basename(file)}"
490
+ for idx, file in enumerate(uploaded_files)
491
+ ]
492
+
493
+ return gr.update(value=checkbox_options), gr.update(value=button_value)
494
+
495
+ # Load initial state when app starts
496
+ @demo.load(
497
+ inputs=[uploaded_files_state],
498
+ outputs=[file_table, file_checkbox, uploaded_files_state, indexed_file_table],
499
+ )
500
+ def load_state_on_start(uploaded_files_state):
501
+ uploaded_files, indexed_files = load_uploaded_files_state()
502
+
503
+ print("demo loading...", uploaded_files, indexed_files)
504
+
505
+ # Populate uploaded files table and checkbox options
506
+ file_info = []
507
+ checkbox_options = []
508
+ for idx, server_file_path in enumerate(uploaded_files, start=1):
509
+ file_name = os.path.basename(server_file_path)
510
+ file_size = os.path.getsize(server_file_path)
511
+ file_info.append([idx, file_name, f"{round(file_size / 1024, 2)} KB"])
512
+ checkbox_options.append(f"{idx}. {file_name}")
513
+
514
+ # Populate indexed files table
515
+ indexed_file_display = [
516
+ [
517
+ os.path.basename(index_path).split(".")[0],
518
+ f"{round(os.path.getsize(index_path) / 1024, 2)} KB",
519
+ ]
520
+ for index_path in indexed_files
521
+ ]
522
+
523
+ return (
524
+ file_info,
525
+ gr.update(choices=checkbox_options),
526
+ (uploaded_files, indexed_files),
527
+ indexed_file_display,
528
+ )
529
+
530
+ @file_upload.upload(
531
+ inputs=[file_upload, uploaded_files_state],
532
+ outputs=[file_table, file_checkbox, file_upload, uploaded_files_state],
533
+ )
534
+ def upload_files(file_upload, uploaded_files_state):
535
+ uploaded_files, indexed_files = uploaded_files_state
536
+
537
+ for file_path in file_upload:
538
+ server_save_path = save_file(file_path)
539
+ if server_save_path:
540
+ uploaded_files.add(server_save_path)
541
+
542
+ save_uploaded_files_state(uploaded_files)
543
+
544
+ file_info = []
545
+ checkbox_options = []
546
+ for i, file_path in enumerate(uploaded_files, start=1):
547
+ file_name = os.path.basename(file_path)
548
+ file_size = os.path.getsize(file_path)
549
+ file_info.append([i, file_name, f"{round(file_size / 1024, 2)} KB"])
550
+ checkbox_options.append(f"{i}. {file_name}")
551
+
552
+ gr.Info("Successfully uploaded file(s).", duration=3)
553
+
554
+ return (
555
+ file_info,
556
+ gr.update(choices=checkbox_options),
557
+ [],
558
+ (uploaded_files, indexed_files),
559
+ )
560
+
561
  if __name__ == "__main__":
562
  demo.launch()