Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| # Classic & Community Imports | |
| from langchain_classic.chains import ConversationalRetrievalChain | |
| from langchain_classic.memory import ConversationBufferMemory | |
| from langchain_groq import ChatGroq | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain.prompts import PromptTemplate | |
| # --- 1. SETUP API & SYSTEM PROMPT --- | |
| # Hugging Face uses os.getenv for secrets | |
| api_key = os.getenv("GROQ_API") | |
| STRICT_PROMPT_TEMPLATE = """You are a strict document-based assistant. | |
| Use the following pieces of context to answer the user's question. | |
| RESTRICTIONS: | |
| 1. ONLY use the information provided in the context below. | |
| 2. If the answer is not contained within the context, specifically say: "I'm sorry, but the provided documents do not contain information to answer this question." | |
| 3. Do NOT use your own outside knowledge. | |
| Context: | |
| {context} | |
| Question: {question} | |
| Helpful Answer:""" | |
| STRICT_PROMPT = PromptTemplate( | |
| template=STRICT_PROMPT_TEMPLATE, | |
| input_variables=["context", "question"] | |
| ) | |
| # --- 2. LOADING LOGIC --- | |
| def load_any(path: str): | |
| p = path.lower() | |
| if p.endswith(".pdf"): return PyPDFLoader(path).load() | |
| if p.endswith(".txt"): return TextLoader(path, encoding="utf-8").load() | |
| if p.endswith(".docx"): return Docx2txtLoader(path).load() | |
| return [] | |
| # --- 3. HYBRID PROCESSING --- | |
| def process_files(files, response_length): | |
| if not files or not api_key: | |
| return None, "⚠️ Missing files or GROQ_API key in Secrets." | |
| try: | |
| docs = [] | |
| for file_obj in files: | |
| docs.extend(load_any(file_obj.name)) | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100) | |
| chunks = splitter.split_documents(docs) | |
| # Hybrid Retrievers | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| faiss_db = FAISS.from_documents(chunks, embeddings) | |
| faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3}) | |
| bm25_retriever = BM25Retriever.from_documents(chunks) | |
| bm25_retriever.k = 3 | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[faiss_retriever, bm25_retriever], | |
| weights=[0.5, 0.5] | |
| ) | |
| llm = ChatGroq( | |
| groq_api_key=api_key, | |
| model="llama-3.3-70b-versatile", | |
| temperature=0, | |
| max_tokens=int(response_length) | |
| ) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True, | |
| output_key="answer" | |
| ) | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=ensemble_retriever, | |
| combine_docs_chain_kwargs={"prompt": STRICT_PROMPT}, | |
| memory=memory, | |
| return_source_documents=True, | |
| output_key="answer" | |
| ) | |
| return chain, f"✅ Knowledge base built! Max answer length: {response_length} tokens." | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| # --- 4. CHAT FUNCTION --- | |
| def chat_function(message, history, chain): | |
| if not chain: | |
| return "⚠️ Build the chatbot first." | |
| res = chain.invoke({"question": message}) | |
| answer = res["answer"] | |
| sources = list(set([os.path.basename(d.metadata.get("source", "unknown")) for d in res.get("source_documents", [])])) | |
| source_display = "\n\n----- \n**Sources used:** " + ", ".join(sources) | |
| return answer + source_display | |
| # --- 5. UI BUILDING --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🛡️ Strict Hybrid Multi-RAG") | |
| chain_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File(file_count="multiple", label="1. Upload Documents") | |
| len_slider = gr.Slider(minimum=100, maximum=4000, value=1000, step=100, label="2. Response Length") | |
| build_btn = gr.Button("3. Build Restricted Chatbot", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=2): | |
| gr.ChatInterface(fn=chat_function, additional_inputs=[chain_state]) | |
| build_btn.click(process_files, inputs=[file_input, len_slider], outputs=[chain_state, status]) | |
| if __name__ == "__main__": | |
| demo.launch() |