RAG_UI / app.py
JaMussCraft's picture
Uploaded RAG UI gradio app
cdcd010 verified
import gradio as gr
import csv
import random
import os
import shutil
import json
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import (
VectorStoreIndex,
SimpleDirectoryReader,
StorageContext,
load_index_from_storage,
)
from llama_index.core.settings import Settings
import faiss
import numpy as np
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.core.node_parser import SimpleNodeParser, SentenceSplitter
from llama_index.core.schema import Document
from llama_index.core.schema import IndexNode
from llama_index.core import ServiceContext
from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine
from llama_index.embeddings.huggingface.base import HuggingFaceEmbedding
from llama_index.llms.openai import OpenAI
from transformers import BitsAndBytesConfig
from llama_index.core.prompts import PromptTemplate
import torch
import pandas as pd
import fitz
from transformers import pipeline
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
llm = OpenAI(temperature=0, model="gpt-4o-mini", max_tokens=512)
Settings.llm = llm
UPLOAD_DIR = "uploaded_files"
STATE_FILE = "uploaded_files_state.json"
PERSIST_DIR = "persisted_indexes"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(PERSIST_DIR, exist_ok=True)
# !!! why???
# torch.set_num_threads(1)
# torch.set_num_interop_threads(1)
def index_gen(file_path, index_name):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# One giant index: insertion example
# if os.path.exists('persisted_indexes/test1.faiss'):
# print("RUNNING TEST!")
# # Load document from file
# documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
# faiss_index = faiss.read_index('persisted_indexes/test1.faiss')
# embed_model = HuggingFaceEmbedding(
# model_name="BAAI/bge-small-en-v1.5"
# )
# Settings.embed_model = embed_model
# vector_store = FaissVectorStore(faiss_index=faiss_index)
# storage_context = StorageContext.from_defaults(
# persist_dir=PERSIST_DIR, vector_store=vector_store
# )
# index = load_index_from_storage(storage_context)
# print(index)
# for doc in documents:
# print('inserting ', doc)
# index.insert(doc)
# index.storage_context.persist(PERSIST_DIR)
# faiss.write_index(faiss_index, 'persisted_indexes/test1.faiss')
# print('insertion and persist complete!')
# return index
try:
# Load document from file
documents = SimpleDirectoryReader(input_files=[file_path]).load_data()
# Initialize embedding model and vector store
embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5", device=device
)
Settings.embed_model = embed_model
embedding_dim = 384 # Ensure this matches the embedding model used
faiss_index = faiss.IndexFlatL2(embedding_dim)
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
print(f"Number of documents to index: {len(documents)}.")
# Parse and index documents
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(documents)
index = VectorStoreIndex(nodes, storage_context=storage_context)
print(f"Number of nodes generated:{len(nodes)}")
# individual index directory
index_directory = os.path.join(PERSIST_DIR, index_name)
os.makedirs(index_directory, exist_ok=True)
index_path = os.path.join(index_directory, f"{index_name}.faiss")
index.storage_context.persist(index_directory)
# index.storage_context.persist(PERSIST_DIR)
faiss.write_index(faiss_index, index_path)
if not os.path.exists(index_path):
raise FileNotFoundError(
f"FAISS index file not created at path: {index_path}"
)
return index_path
except Exception as e:
print(f"Error in index_gen with file {file_path}: {str(e)}")
return None
def save_uploaded_files_state(uploaded_files, indexed_files=None):
try:
state_file_json = {}
state_file_json["uploaded_files"] = list(uploaded_files)
if indexed_files:
state_file_json["indexed_files"] = list(indexed_files)
# else:
# # ??? why
# _, existing_indexed_files = load_uploaded_files_state()
# state_file_json["indexed_files"] = list(existing_indexed_files)
with open(STATE_FILE, "w") as f:
json.dump(state_file_json, f, indent=4)
except IOError as e:
print(f"Error saving uploaded files state: {str(e)}")
def load_uploaded_files_state():
try:
if os.path.exists(STATE_FILE):
with open(STATE_FILE, "r") as f:
state_data = json.load(f)
return set(state_data.get("uploaded_files", set())), set(
state_data.get("indexed_files", set())
)
except (IOError, json.JSONDecodeError) as e:
print(f"Error loading uploaded files state: {str(e)}")
return set(), set()
def save_file(file_path):
try:
file_name = os.path.basename(file_path)
server_save_path = os.path.join(UPLOAD_DIR, file_name)
shutil.copy(file_path, server_save_path)
return server_save_path
except (IOError, shutil.Error) as e:
print(f"Error saving file {file_path}: {str(e)}")
return None
with gr.Blocks() as demo:
gr.Markdown("## 📁 File Management & Chat Assistant")
with gr.Tabs():
# Tab 1: File Management
with gr.Tab("File Management"):
with gr.Row():
with gr.Column(scale=1):
file_upload = gr.File(
label="Upload PDF,JSON or TXT Files",
file_types=[".pdf", ".json", ".txt", "directory"],
file_count="multiple",
interactive=True,
)
file_table = gr.DataFrame(
headers=["Sr. No.", "File Name", "File Size"],
value=[],
interactive=False,
row_count=(4, "dynamic"),
wrap=True,
max_height=1000
)
file_checkbox = gr.CheckboxGroup(
label="Select Files to Index/Delete", choices=[]
)
select_all_button = gr.Button("Select All")
index_button = gr.Button("Index Selected Files")
delete_button = gr.Button("Delete Selected Files")
with gr.Column(scale=3):
message_box = gr.Markdown("")
chatbot = gr.Chatbot(label="LLM", type="messages")
with gr.Row():
chat_input = gr.Textbox(
show_label=False,
placeholder="Type your message here",
scale=8,
)
send_button = gr.Button("Send", scale=1)
# Tab 2: Indexed Files
with gr.Tab("Indexed Files"):
indexed_file_table = gr.DataFrame(
headers=["Indexed File", "Size"],
value=[],
interactive=False,
row_count=(4, "dynamic"),
)
# STATES
uploaded_files_state = gr.State(load_uploaded_files_state())
@delete_button.click(
inputs=[file_checkbox, uploaded_files_state, file_upload],
outputs=[file_table, file_checkbox, uploaded_files_state, indexed_file_table],
)
def delete_files(selected_files, uploaded_files_state, file_upload):
print("deleting files...: ", selected_files, uploaded_files_state, file_upload)
uploaded_files, indexed_files = uploaded_files_state
if not selected_files or not uploaded_files:
return gr.update(), selected_files, (uploaded_files, indexed_files)
# default return
# return [[]], selected_files, uploaded_files_state
# "we" means with extension
selected_file_names_we = [file.split(". ")[1] for file in selected_files]
for file_name_we in selected_file_names_we:
file_path = os.path.join(UPLOAD_DIR, file_name_we)
index_name = file_name_we.split(".")[0]
index_directory = os.path.join(PERSIST_DIR, index_name)
index_path = os.path.join(index_directory, f'{index_name}.faiss')
print(file_name_we, file_path, index_name, index_directory, index_path)
try:
if os.path.exists(file_path):
os.remove(file_path)
uploaded_files.remove(file_path)
else:
gr.Error(f"Could not delete file (File not found): {file_path}", duration=3)
if os.path.exists(index_directory):
shutil.rmtree(index_directory)
indexed_files.remove(index_path)
else:
gr.Error(f"Could not delete index directory (Path not found): {index_directory}", duration=3)
except Exception as e:
gr.Error(f"Error deleting {file_name_we}: {str(e)}", duration=3)
save_uploaded_files_state(uploaded_files, indexed_files)
file_info, checkbox_options = [], []
for idx, file_path in enumerate(uploaded_files, start=1):
file_name = os.path.basename(file_path)
file_size = os.path.getsize(file_path)
file_info.append([idx, file_name, f"{round(file_size / 1024, 2)} KB"])
checkbox_options.append(f"{idx}. {file_name}")
indexed_file_display = [
[
os.path.basename(index_path).split(".")[0],
f"{round(os.path.getsize(index_path) / 1024, 2)} KB",
]
for index_path in indexed_files
]
return (
file_info,
gr.update(choices=checkbox_options, value=[]),
(uploaded_files, indexed_files),
indexed_file_display,
)
@chat_input.submit(
inputs=[chat_input, chatbot, uploaded_files_state],
outputs=[chat_input, chatbot],
)
@send_button.click(
inputs=[chat_input, chatbot, uploaded_files_state],
outputs=[chat_input, chatbot],
)
# Chat function with improved SQuAD matching
def chat_with_bot(user_input, chat_history, uploaded_files_state):
if not user_input:
return user_input, chat_history
_, indexed_files = uploaded_files_state
chat_history.append(
{
"role": "user",
"content": user_input,
}
)
response = "I do not have the answer. Please upload and index relevant files first."
file_with_answer = None
custom_prompt = PromptTemplate(
template=(
"Use the following context to answer the query. Do not use outside knowledge. "
"If the answer is not found in the context, respond with: 'I do not have the answer.'\n"
"Context: {context_str}\n"
"Query: {query_str}\n"
"Answer:"
)
)
if not index_files:
response = "No files have been indexed for answering this question."
try:
for index_path in indexed_files:
print('checking ', index_path)
file_name = os.path.basename(index_path)
index_name = file_name.split(".")[0]
if not os.path.exists(index_path):
print(f"FAISS index not found at {index_path}, skipping...")
continue
storage_context = None
try:
faiss_index = faiss.read_index(index_path)
embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
Settings.embed_model = embed_model
vector_store = FaissVectorStore(faiss_index=faiss_index)
storage_context = StorageContext.from_defaults(
persist_dir=f'{PERSIST_DIR}/{index_name}', vector_store=vector_store
)
except Exception as e:
raise RuntimeError(
f"Failed to load FAISS index at {index_path}: {str(e)}"
)
# print(get_global("embed_model"))
index = load_index_from_storage(storage_context)
print(f"Index loaded with {len(index.docstore.docs)} documents.")
retriever = index.as_retriever(similarity_top_k=10)
query_engine = RetrieverQueryEngine(retriever=retriever)
query_engine.update_prompts(
{"response_synthesizer:text_qa_template": custom_prompt}
)
# Query the index for the user input
query_result = query_engine.query(user_input)
print("query result: ", query_result)
if query_result.response.strip() != "I do not have the answer.":
response = f"{query_result.response} \n\n Source: {file_name}"
# response = f"Answer from indexed file '{file_name}': {query_result.response}"
file_with_answer = file_name
break
else:
response = "I do not have the answer."
except Exception as e:
response = f"Error querying the index: {str(e)}"
print(response)
chat_history.append(
{
"role": "assistant",
"content": response,
}
)
return gr.update(value=""), chat_history
@index_button.click(
inputs=[file_checkbox, uploaded_files_state, indexed_file_table],
outputs=[
file_checkbox,
uploaded_files_state,
indexed_file_table,
select_all_button,
],
)
def index_files(selected_files, uploaded_files_state, indexed_file_table):
uploaded_files, indexed_files = uploaded_files_state
print("indexing files...", selected_files, uploaded_files_state)
if not selected_files or not uploaded_files:
gr.Warning("Please select or upload files for indexing.", duration=3)
return (
selected_files,
uploaded_files_state,
indexed_file_table,
gr.update(),
)
files_to_index = []
for file in selected_files:
file_name_we = file.split(". ")[1]
file_path = os.path.join(UPLOAD_DIR, file_name_we)
index_name = file_name_we.split(".")[0]
index_directory = os.path.join(PERSIST_DIR, index_name)
index_path = os.path.join(index_directory, f'{index_name}.faiss')
if index_path not in indexed_files:
files_to_index.append(file_path)
else:
gr.Info(
f"File '{os.path.basename(file_path)}' is already indexed.",
duration=3,
)
for file_path in files_to_index:
try:
file_name = os.path.basename(file_path)
index_name = file_name.split(".")[0]
index_path = index_gen(file_path, index_name)
gr.Info(f"Successfully indexed: {file_name}", duration=3)
# Save indexed file info for persistence
# index_path = os.path.join(PERSIST_DIR, f"{index_name}.faiss")
indexed_files.add(index_path)
except Exception as e:
gr.Error(f"Error indexing {file_path}: {str(e)}", duration=3)
# Update the state with new indexed files
save_uploaded_files_state(uploaded_files, indexed_files)
# Convert indexed file info to display format
indexed_file_display = [
[
os.path.basename(index_path).split(".")[0],
f"{round(os.path.getsize(index_path) / 1024, 2)} KB",
]
for index_path in indexed_files
]
return (
gr.update(value=[]),
(uploaded_files, indexed_files),
indexed_file_display,
gr.update(value="Select All"),
)
@select_all_button.click(
inputs=[uploaded_files_state, select_all_button, file_checkbox],
outputs=[file_checkbox, select_all_button],
)
def select_all_checkbox(uploaded_files_state, select_all_button, file_checkbox):
uploaded_files, _ = uploaded_files_state
if not uploaded_files:
return file_checkbox, select_all_button
button_value = ""
if select_all_button == "Select All":
button_value = "Unselect All"
else:
button_value = "Select All"
checkbox_options = []
if not file_checkbox:
checkbox_options = [
f"{idx + 1}. {os.path.basename(file)}"
for idx, file in enumerate(uploaded_files)
]
return gr.update(value=checkbox_options), gr.update(value=button_value)
# Load initial state when app starts
@demo.load(
inputs=[uploaded_files_state],
outputs=[file_table, file_checkbox, uploaded_files_state, indexed_file_table],
)
def load_state_on_start(uploaded_files_state):
uploaded_files, indexed_files = load_uploaded_files_state()
print("demo loading...", uploaded_files, indexed_files)
# Populate uploaded files table and checkbox options
file_info = []
checkbox_options = []
for idx, server_file_path in enumerate(uploaded_files, start=1):
file_name = os.path.basename(server_file_path)
file_size = os.path.getsize(server_file_path)
file_info.append([idx, file_name, f"{round(file_size / 1024, 2)} KB"])
checkbox_options.append(f"{idx}. {file_name}")
# Populate indexed files table
indexed_file_display = [
[
os.path.basename(index_path).split(".")[0],
f"{round(os.path.getsize(index_path) / 1024, 2)} KB",
]
for index_path in indexed_files
]
return (
file_info,
gr.update(choices=checkbox_options),
(uploaded_files, indexed_files),
indexed_file_display,
)
@file_upload.upload(
inputs=[file_upload, uploaded_files_state],
outputs=[file_table, file_checkbox, file_upload, uploaded_files_state],
)
def upload_files(file_upload, uploaded_files_state):
uploaded_files, indexed_files = uploaded_files_state
for file_path in file_upload:
server_save_path = save_file(file_path)
if server_save_path:
uploaded_files.add(server_save_path)
save_uploaded_files_state(uploaded_files)
file_info = []
checkbox_options = []
for i, file_path in enumerate(uploaded_files, start=1):
file_name = os.path.basename(file_path)
file_size = os.path.getsize(file_path)
file_info.append([i, file_name, f"{round(file_size / 1024, 2)} KB"])
checkbox_options.append(f"{i}. {file_name}")
gr.Info("Successfully uploaded file(s).", duration=3)
return (
file_info,
gr.update(choices=checkbox_options),
[],
(uploaded_files, indexed_files),
)
if __name__ == "__main__":
demo.launch()