File size: 5,089 Bytes
0827021
 
b3de77b
0827021
b3de77b
 
 
 
 
0827021
 
 
 
 
b3de77b
0827021
 
b3de77b
0827021
 
 
 
 
 
 
 
 
 
 
 
 
b3de77b
0827021
 
 
 
b3de77b
0827021
 
 
 
 
b3de77b
0827021
 
 
 
2f81d82
 
 
 
 
 
b74ff81
 
 
 
 
 
 
 
 
 
2f81d82
 
 
 
7024dee
 
 
b74ff81
 
 
 
 
 
 
7024dee
 
 
 
0827021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7024dee
 
 
0827021
 
 
 
 
 
 
 
b74ff81
0827021
 
 
 
2f81d82
 
b3de77b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import atexit
import glob
import os
import shutil

import gradio as gr
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

from src.config import COLLECTION_NAME, DOCS_DIR, EMBEDDING_MODEL_NAME, MILVUS_DB_PATH
from src.data_loader import load_data
from src.embedding_generator import (
    generate_document_embeddings,
    generate_query_embeddings,
)
from src.rag_pipeline import answer_question
from src.vector_store import (
    create_collection_if_not_exists,
    get_milvus_client,
    insert_data,
    search,
)

# Initialize models and clients
embedding_model = HuggingFaceEmbedding(
    model_name=EMBEDDING_MODEL_NAME,
    trust_remote_code=True,
    cache_folder=".hf_cache",
)

milvus_client = get_milvus_client(MILVUS_DB_PATH)


# --- Cleanup Function ---
def cleanup_documents():
    """Remove all files from the documents directory."""
    print("Cleaning up uploaded documents...")
    files = glob.glob(os.path.join(DOCS_DIR, "*"))
    for f in files:
        if os.path.isfile(f):
            os.remove(f)
    print("Cleanup complete.")


# Register the cleanup function to run on exit
atexit.register(cleanup_documents)


def reset_collection_if_no_docs():
    """Drop existing collection on startup if there are no documents on disk."""
    try:
        os.makedirs(DOCS_DIR, exist_ok=True)
        files = glob.glob(os.path.join(DOCS_DIR, "*"))
        has_docs = any(os.path.isfile(f) for f in files)
        if not has_docs and milvus_client:
            # Avoid blocking on exit; wrap Milvus operations defensively
            try:
                if milvus_client.has_collection(COLLECTION_NAME):
                    milvus_client.drop_collection(COLLECTION_NAME)
                    print(
                        f"No documents found. Dropped existing collection {COLLECTION_NAME}."
                    )
            except Exception as inner_e:
                print(f"Skip dropping collection on startup due to error: {inner_e}")
    except Exception as e:
        print(f"Error resetting collection on startup: {e}")


def reset_index():
    """Reset the index."""
    try:
        if milvus_client:
            try:
                if milvus_client.has_collection(COLLECTION_NAME):
                    milvus_client.drop_collection(COLLECTION_NAME)
                    print(f"Dropped collection {COLLECTION_NAME}.")
            except Exception as inner_e:
                print(f"Skip dropping collection due to error: {inner_e}")
    except Exception as e:
        print(f"Error dropping collection during cleanup: {e}")


def index_documents(file_list):
    """Index documents from a list of files."""
    if not file_list:
        return "No files to index."

    os.makedirs(DOCS_DIR, exist_ok=True)

    # Move uploaded files to the documents directory
    for file in file_list:
        shutil.copy(file.name, os.path.join(DOCS_DIR, os.path.basename(file.name)))

    docs = load_data(DOCS_DIR)
    documents = [doc.text for doc in docs]

    if not documents:
        return "No documents found in the uploaded files."

    binary_embeddings = generate_document_embeddings(documents, embedding_model)
    if not binary_embeddings:
        return "Could not generate embeddings for the documents."

    dim = len(binary_embeddings[0]) * 8

    create_collection_if_not_exists(milvus_client, COLLECTION_NAME, dim)

    data_to_insert = [
        {"context": context, "binary_vector": binary_embedding}
        for context, binary_embedding in zip(documents, binary_embeddings)
    ]
    insert_data(milvus_client, COLLECTION_NAME, data_to_insert)

    return f"Successfully indexed {len(documents)} documents."


def chat_interface(message, history):
    """Chat interface for the RAG pipeline."""
    query_embedding = generate_query_embeddings(message, embedding_model)
    if not query_embedding:
        return "Sorry, I could not process your query."

    contexts = search(milvus_client, COLLECTION_NAME, query_embedding)
    if not contexts:
        return "I couldn't find any relevant information in the documents."

    answer = answer_question(message, contexts)
    return answer


with gr.Blocks() as demo:
    gr.Markdown("## RAG with Binary Quantization")

    with gr.Tab("Upload & Index"):
        file_input = gr.File(file_count="multiple", label="Upload Documents")
        index_button = gr.Button("Update Index")
        index_status = gr.Textbox(label="Indexing Status")

        reset_index_button = gr.Button("Reset Index")
        reset_index_status = gr.Textbox(label="Resetting Index Status")

    with gr.Tab("Chat"):
        gr.ChatInterface(chat_interface)

        index_button.click(
            fn=index_documents,
            inputs=[file_input],
            outputs=[index_status],
        )
        reset_index_button.click(fn=reset_index, inputs=[])

if __name__ == "__main__":
    # Ensure the documents directory exists from the start
    os.makedirs(DOCS_DIR, exist_ok=True)
    # Reset collection state if there are no documents at startup
    reset_collection_if_no_docs()
    demo.launch()