Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import atexit | |
| import glob | |
| import shutil | |
| from src.config import ( | |
| DOCS_DIR, | |
| COLLECTION_NAME, | |
| EMBEDDING_MODEL_NAME, | |
| MILVUS_DB_PATH, | |
| ) | |
| from src.data_loader import load_data | |
| from src.embedding_generator import ( | |
| generate_document_embeddings, | |
| generate_query_embeddings, | |
| ) | |
| from src.vector_store import ( | |
| get_milvus_client, | |
| create_collection_if_not_exists, | |
| insert_data, | |
| search, | |
| ) | |
| from src.rag_pipeline import answer_question | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| # Initialize models and clients | |
| embedding_model = HuggingFaceEmbedding( | |
| model_name=EMBEDDING_MODEL_NAME, | |
| trust_remote_code=True, | |
| cache_folder=".hf_cache", | |
| ) | |
| milvus_client = get_milvus_client(MILVUS_DB_PATH) | |
| # --- Cleanup Function --- | |
| def cleanup_documents(): | |
| """Remove all files from the documents directory.""" | |
| print("Cleaning up uploaded documents...") | |
| files = glob.glob(os.path.join(DOCS_DIR, '*')) | |
| for f in files: | |
| if os.path.isfile(f): | |
| os.remove(f) | |
| print("Cleanup complete.") | |
| # Register the cleanup function to run on exit | |
| atexit.register(cleanup_documents) | |
| def index_documents(file_list): | |
| """Index documents from a list of files.""" | |
| if not file_list: | |
| return "No files to index." | |
| os.makedirs(DOCS_DIR, exist_ok=True) | |
| # Move uploaded files to the documents directory | |
| for file in file_list: | |
| shutil.copy(file.name, os.path.join(DOCS_DIR, os.path.basename(file.name))) | |
| docs = load_data(DOCS_DIR) | |
| documents = [doc.text for doc in docs] | |
| if not documents: | |
| return "No documents found in the uploaded files." | |
| binary_embeddings = generate_document_embeddings(documents, embedding_model) | |
| if not binary_embeddings: | |
| return "Could not generate embeddings for the documents." | |
| dim = len(binary_embeddings[0]) * 8 | |
| create_collection_if_not_exists(milvus_client, COLLECTION_NAME, dim) | |
| data_to_insert = [ | |
| {"context": context, "binary_vector": binary_embedding} | |
| for context, binary_embedding in zip(documents, binary_embeddings) | |
| ] | |
| insert_data(milvus_client, COLLECTION_NAME, data_to_insert) | |
| return f"Successfully indexed {len(documents)} documents." | |
| def chat_interface(message, history): | |
| """Chat interface for the RAG pipeline.""" | |
| query_embedding = generate_query_embeddings(message, embedding_model) | |
| if not query_embedding: | |
| return "Sorry, I could not process your query." | |
| contexts = search(milvus_client, COLLECTION_NAME, query_embedding) | |
| if not contexts: | |
| return "I couldn't find any relevant information in the documents." | |
| answer = answer_question(message, contexts) | |
| return answer | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## RAG with Binary Quantization") | |
| with gr.Tab("Upload & Index"): | |
| file_input = gr.File(file_count="multiple", label="Upload Documents") | |
| index_button = gr.Button("Update Index") | |
| index_status = gr.Textbox(label="Indexing Status") | |
| with gr.Tab("Chat"): | |
| gr.ChatInterface(chat_interface) | |
| index_button.click( | |
| fn=index_documents, | |
| inputs=[file_input], | |
| outputs=[index_status], | |
| ) | |
| if __name__ == "__main__": | |
| # Ensure the documents directory exists from the start | |
| os.makedirs(DOCS_DIR, exist_ok=True) | |
| demo.launch() |