File size: 2,978 Bytes
52f7b28
21a58c6
 
 
52f7b28
5b06b42
 
52f7b28
5b06b42
 
 
52f7b28
21a58c6
 
 
 
5b06b42
52f7b28
 
 
 
 
 
 
 
 
 
 
 
21a58c6
5b06b42
52f7b28
 
 
5b06b42
21a58c6
5b06b42
52f7b28
21a58c6
52f7b28
 
21a58c6
 
5b06b42
21a58c6
5b06b42
 
21a58c6
 
 
 
 
5b06b42
 
 
21a58c6
 
52f7b28
 
 
 
 
 
5b06b42
21a58c6
 
5b06b42
 
21a58c6
 
5b06b42
 
21a58c6
 
 
52f7b28
5b06b42
 
 
21a58c6
52f7b28
 
 
 
21a58c6
52f7b28
5b06b42
52f7b28
5b06b42
21a58c6
52f7b28
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import gradio as gr
from transformers import pipeline

# specific imports to fix "ModuleNotFoundError"
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline

# ------------------ LOAD EMBEDDINGS ------------------
# We use a standard efficient embedding model
embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

# ------------------ LOAD VECTOR STORE ------------------
# Check if vectorstore exists to avoid crashing
if not os.path.exists("vectorstore/faiss_index"):
    print("❌ ERROR: 'vectorstore/faiss_index' folder not found.")
    print("   Please run your ingest/indexing script first to create the database.")
    # Create a dummy empty DB just so the app doesn't crash immediately (optional)
    db = FAISS.from_texts(["Empty index"], embeddings)
else:
    db = FAISS.load_local(
        "vectorstore/faiss_index",
        embeddings,
        allow_dangerous_deserialization=True
    )

# ------------------ LOAD LLM ------------------
# Using phi-2. 
# WARNING: If the Space crashes with "OOM" (Out of Memory), change this to "google/flan-t5-small"
print("Loading Model...")
text_gen_pipeline = pipeline(
    "text-generation",
    model="microsoft/phi-2",
    max_new_tokens=256,   # Reduced slightly to save memory
    temperature=0.2,
    do_sample=True,
    truncation=True
)

llm = HuggingFacePipeline(pipeline=text_gen_pipeline)

# ------------------ RAG CHAIN ------------------
qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    retriever=db.as_retriever(search_kwargs={"k": 3}),
    chain_type="stuff",
)

# ------------------ CHAT FUNCTION ------------------
def chat(user_message, history):
    if not user_message.strip():
        return history

    try:
        # 'invoke' is the new standard, but 'run' is kept for compatibility with your code
        answer = qa_chain.run(user_message)
    except Exception as e:
        answer = f"Error generating answer: {str(e)}"

    history.append((user_message, answer))
    return history

# ------------------ GRADIO UI ------------------
with gr.Blocks(title="Document RAG Chatbot") as demo:
    gr.Markdown(
        """
        # 📚 Document RAG Chatbot  
        Answers are generated **strictly from the provided documents** using Retrieval-Augmented Generation.
        """
    )

    chatbot = gr.Chatbot(height=400)
    query = gr.Textbox(
        label="Ask a question",
        placeholder="Ask something from the documents..."
    )
    
    with gr.Row():
        submit_btn = gr.Button("Submit", variant="primary")
        clear_btn = gr.Button("Clear Chat")

    # Wire up the buttons
    query.submit(chat, [query, chatbot], chatbot)
    submit_btn.click(chat, [query, chatbot], chatbot)
    clear_btn.click(lambda: [], None, chatbot)

if __name__ == "__main__":
    demo.launch()