Spaces:
Sleeping
Sleeping
File size: 3,402 Bytes
0827021 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import gradio as gr
import os
import atexit
import glob
import shutil
from src.config import (
DOCS_DIR,
COLLECTION_NAME,
EMBEDDING_MODEL_NAME,
MILVUS_DB_PATH,
)
from src.data_loader import load_data
from src.embedding_generator import (
generate_document_embeddings,
generate_query_embeddings,
)
from src.vector_store import (
get_milvus_client,
create_collection_if_not_exists,
insert_data,
search,
)
from src.rag_pipeline import answer_question
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
# Initialize models and clients
embedding_model = HuggingFaceEmbedding(
model_name=EMBEDDING_MODEL_NAME,
trust_remote_code=True,
cache_folder=".hf_cache",
)
milvus_client = get_milvus_client(MILVUS_DB_PATH)
# --- Cleanup Function ---
def cleanup_documents():
"""Remove all files from the documents directory."""
print("Cleaning up uploaded documents...")
files = glob.glob(os.path.join(DOCS_DIR, '*'))
for f in files:
if os.path.isfile(f):
os.remove(f)
print("Cleanup complete.")
# Register the cleanup function to run on exit
atexit.register(cleanup_documents)
def index_documents(file_list):
"""Index documents from a list of files."""
if not file_list:
return "No files to index."
os.makedirs(DOCS_DIR, exist_ok=True)
# Move uploaded files to the documents directory
for file in file_list:
shutil.copy(file.name, os.path.join(DOCS_DIR, os.path.basename(file.name)))
docs = load_data(DOCS_DIR)
documents = [doc.text for doc in docs]
if not documents:
return "No documents found in the uploaded files."
binary_embeddings = generate_document_embeddings(documents, embedding_model)
if not binary_embeddings:
return "Could not generate embeddings for the documents."
dim = len(binary_embeddings[0]) * 8
create_collection_if_not_exists(milvus_client, COLLECTION_NAME, dim)
data_to_insert = [
{"context": context, "binary_vector": binary_embedding}
for context, binary_embedding in zip(documents, binary_embeddings)
]
insert_data(milvus_client, COLLECTION_NAME, data_to_insert)
return f"Successfully indexed {len(documents)} documents."
def chat_interface(message, history):
"""Chat interface for the RAG pipeline."""
query_embedding = generate_query_embeddings(message, embedding_model)
if not query_embedding:
return "Sorry, I could not process your query."
contexts = search(milvus_client, COLLECTION_NAME, query_embedding)
if not contexts:
return "I couldn't find any relevant information in the documents."
answer = answer_question(message, contexts)
return answer
with gr.Blocks() as demo:
gr.Markdown("## RAG with Binary Quantization")
with gr.Tab("Upload & Index"):
file_input = gr.File(file_count="multiple", label="Upload Documents")
index_button = gr.Button("Update Index")
index_status = gr.Textbox(label="Indexing Status")
with gr.Tab("Chat"):
gr.ChatInterface(chat_interface)
index_button.click(
fn=index_documents,
inputs=[file_input],
outputs=[index_status],
)
if __name__ == "__main__":
# Ensure the documents directory exists from the start
os.makedirs(DOCS_DIR, exist_ok=True)
demo.launch() |