File size: 4,703 Bytes
f768714
 
ad21633
 
e64e86b
 
ad21633
 
 
 
 
41bbc28
 
ad21633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41bbc28
ad21633
 
 
 
f768714
ad21633
f768714
 
 
 
 
 
 
ad21633
 
41bbc28
ad21633
f768714
 
 
 
 
 
ad21633
f768714
 
ad21633
f768714
41bbc28
 
ad21633
41bbc28
 
f768714
41bbc28
 
ad21633
f768714
41bbc28
ad21633
 
 
 
 
 
 
f768714
41bbc28
 
f768714
 
41bbc28
f768714
 
ad21633
 
f768714
 
 
 
41bbc28
ad21633
f768714
 
 
 
ad21633
f768714
 
41bbc28
f768714
41bbc28
 
f768714
41bbc28
ad21633
41bbc28
ad21633
41bbc28
ad21633
 
 
f768714
 
 
 
ad21633
 
 
41bbc28
f768714
 
41bbc28
f768714
ad21633
f768714
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import gradio as gr

# Classic & Community Imports
from langchain_classic.chains import ConversationalRetrievalChain
from langchain_classic.memory import ConversationBufferMemory
from langchain_groq import ChatGroq
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.prompts import PromptTemplate

# --- 1. SETUP API & SYSTEM PROMPT ---
# Hugging Face uses os.getenv for secrets
api_key = os.getenv("GROQ_API")

STRICT_PROMPT_TEMPLATE = """You are a strict document-based assistant. 
Use the following pieces of context to answer the user's question. 

RESTRICTIONS:
1. ONLY use the information provided in the context below.
2. If the answer is not contained within the context, specifically say: "I'm sorry, but the provided documents do not contain information to answer this question."
3. Do NOT use your own outside knowledge.

Context:
{context}

Question: {question}
Helpful Answer:"""

STRICT_PROMPT = PromptTemplate(
    template=STRICT_PROMPT_TEMPLATE, 
    input_variables=["context", "question"]
)

# --- 2. LOADING LOGIC ---
def load_any(path: str):
    p = path.lower()
    if p.endswith(".pdf"): return PyPDFLoader(path).load()
    if p.endswith(".txt"): return TextLoader(path, encoding="utf-8").load()
    if p.endswith(".docx"): return Docx2txtLoader(path).load()
    return []

# --- 3. HYBRID PROCESSING ---
def process_files(files, response_length):
    if not files or not api_key:
        return None, "⚠️ Missing files or GROQ_API key in Secrets."

    try:
        docs = []
        for file_obj in files:
            docs.extend(load_any(file_obj.name))
        
        splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
        chunks = splitter.split_documents(docs)

        # Hybrid Retrievers
        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        faiss_db = FAISS.from_documents(chunks, embeddings)
        faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3})
        
        bm25_retriever = BM25Retriever.from_documents(chunks)
        bm25_retriever.k = 3

        ensemble_retriever = EnsembleRetriever(
            retrievers=[faiss_retriever, bm25_retriever],
            weights=[0.5, 0.5]
        )

        llm = ChatGroq(
            groq_api_key=api_key, 
            model="llama-3.3-70b-versatile", 
            temperature=0,
            max_tokens=int(response_length)
        )
        
        memory = ConversationBufferMemory(
            memory_key="chat_history", 
            return_messages=True, 
            output_key="answer"
        )
        
        chain = ConversationalRetrievalChain.from_llm(
            llm=llm,
            retriever=ensemble_retriever,
            combine_docs_chain_kwargs={"prompt": STRICT_PROMPT},
            memory=memory,
            return_source_documents=True,
            output_key="answer"
        )
        
        return chain, f"✅ Knowledge base built! Max answer length: {response_length} tokens."

    except Exception as e:
        return None, f"❌ Error: {str(e)}"

# --- 4. CHAT FUNCTION ---
def chat_function(message, history, chain):
    if not chain:
        return "⚠️ Build the chatbot first."
    
    res = chain.invoke({"question": message})
    answer = res["answer"]
    
    sources = list(set([os.path.basename(d.metadata.get("source", "unknown")) for d in res.get("source_documents", [])]))
    source_display = "\n\n----- \n**Sources used:** " + ", ".join(sources)
    
    return answer + source_display

# --- 5. UI BUILDING ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🛡️ Strict Hybrid Multi-RAG")
    chain_state = gr.State(None)
    
    with gr.Row():
        with gr.Column(scale=1):
            file_input = gr.File(file_count="multiple", label="1. Upload Documents")
            len_slider = gr.Slider(minimum=100, maximum=4000, value=1000, step=100, label="2. Response Length")
            build_btn = gr.Button("3. Build Restricted Chatbot", variant="primary")
            status = gr.Textbox(label="Status", interactive=False)
        
        with gr.Column(scale=2):
            gr.ChatInterface(fn=chat_function, additional_inputs=[chain_state])

    build_btn.click(process_files, inputs=[file_input, len_slider], outputs=[chain_state, status])

if __name__ == "__main__":
    demo.launch()