File size: 4,687 Bytes
f768714
 
e90b7cb
ad21633
e90b7cb
caee110
 
ee4f5d4
 
 
 
ad21633
 
ee4f5d4
 
 
 
 
 
 
 
41bbc28
ad21633
ee4f5d4
 
e90b7cb
ee4f5d4
ad21633
e90b7cb
 
 
ad21633
e7c8f2f
e90b7cb
 
ee4f5d4
ad21633
 
 
 
 
ee4f5d4
 
e90b7cb
ad21633
 
f768714
e90b7cb
f768714
b2a3594
 
ee4f5d4
b2a3594
ee4f5d4
b2a3594
ee4f5d4
f768714
 
e90b7cb
ad21633
e90b7cb
 
 
 
 
 
 
 
28c2469
e90b7cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28c2469
e90b7cb
 
 
28c2469
e90b7cb
f768714
ee4f5d4
e7c8f2f
ee4f5d4
e90b7cb
ee4f5d4
 
 
 
e90b7cb
ee4f5d4
 
e90b7cb
 
 
 
ee4f5d4
 
 
 
 
 
 
e90b7cb
 
de4638e
ee4f5d4
f768714
ee4f5d4
f768714
 
e90b7cb
 
e7c8f2f
e90b7cb
ee4f5d4
f768714
ee4f5d4
 
 
 
f768714
e7c8f2f
ee4f5d4
e7c8f2f
ee4f5d4
 
f768714
 
e90b7cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import gradio as gr
import traceback

# ---------------- LangChain (STABLE 0.1.x) ----------------
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.retrievers import EnsembleRetriever

# Providers
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings

# Community
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import (
    PyPDFLoader,
    TextLoader,
    Docx2txtLoader
)
from langchain_community.retrievers import BM25Retriever

from langchain_text_splitters import RecursiveCharacterTextSplitter

# ---------------- CONFIG ----------------
GROQ_API_KEY = os.getenv("GROQ_API")

STRICT_PROMPT = PromptTemplate(
    template="""
You are a strict document-based assistant.

Rules:
1. ONLY use the provided context.
2. If the answer is not in the context, say:
"I'm sorry, but the provided documents do not contain information to answer this question."

Context:
{context}

Question: {question}

Answer:
""",
    input_variables=["context", "question"]
)

# ---------------- FILE LOADER ----------------
def load_any(path: str):
    p = path.lower()
    if p.endswith(".pdf"):
        return PyPDFLoader(path).load()
    if p.endswith(".txt"):
        return TextLoader(path, encoding="utf-8").load()
    if p.endswith(".docx"):
        return Docx2txtLoader(path).load()
    return []

# ---------------- BUILD CHAIN ----------------
def process_files(files, response_length):
    if not files:
        return None, "❌ No files uploaded"
    if not GROQ_API_KEY:
        return None, "❌ GROQ_API secret not set"

    try:
        docs = []
        for f in files:
            docs.extend(load_any(str(f)))  # 🔥 THIS IS THE FIX

        splitter = RecursiveCharacterTextSplitter(
            chunk_size=800,
            chunk_overlap=100
        )
        chunks = splitter.split_documents(docs)

        embeddings = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2"
        )

        faiss_db = FAISS.from_documents(chunks, embeddings)
        faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3})

        bm25 = BM25Retriever.from_documents(chunks)
        bm25.k = 3

        retriever = EnsembleRetriever(
            retrievers=[faiss_retriever, bm25],
            weights=[0.5, 0.5]
        )

        llm = ChatGroq(
            groq_api_key=GROQ_API_KEY,
            model="llama-3.3-70b-versatile",
            temperature=0,
            max_tokens=int(response_length)
        )

        memory = ConversationBufferMemory(
            memory_key="chat_history",
            return_messages=True,
            output_key="answer"
        )

        chain = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=retriever,
            memory=memory,
            combine_docs_chain_kwargs={"prompt": STRICT_PROMPT},
            return_source_documents=True,
            output_key="answer"
        )

        return chain, "✅ Chatbot built successfully"

    except Exception as e:
        import traceback
        traceback.print_exc()
        return None, f"❌ {repr(e)}"


# ---------------- CHAT ----------------
def chat_function(message, history, chain):
    if chain is None:
        return "⚠️ Build the chatbot first"

    result = chain.invoke({
        "question": message,
        "chat_history": history
    })

    answer = result["answer"]

    sources = {
        os.path.basename(
            d.metadata.get("source", d.metadata.get("file_path", "unknown"))
        )
        for d in result.get("source_documents", [])
    }

    if sources:
        answer += "\n\n---\n**Sources:** " + ", ".join(sources)

    return answer

# ---------------- UI ----------------
with gr.Blocks() as demo:
    gr.Markdown("Multi-RAG Chatbot")

    chain_state = gr.State(None)

    with gr.Row():
        with gr.Column(scale=1):
            files = gr.File(file_count="multiple", label="Upload Documents")
            tokens = gr.Slider(100, 4000, value=1000, step=100, label="Max Tokens")
            build = gr.Button("Build Chatbot", variant="primary")
            status = gr.Textbox(label="Status", interactive=False)

        with gr.Column(scale=2):
            gr.ChatInterface(
                fn=chat_function,
                additional_inputs=[chain_state]
            )

    build.click(
        process_files,
        inputs=[files, tokens],
        outputs=[chain_state, status]
    )

if __name__ == "__main__":
    demo.launch(theme=gr.themes.Soft())