Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| # LangChain Core | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain.prompts import PromptTemplate | |
| from langchain.retrievers import EnsembleRetriever | |
| # Providers | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| # Community | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.document_loaders import ( | |
| PyPDFLoader, | |
| TextLoader, | |
| Docx2txtLoader | |
| ) | |
| from langchain_community.retrievers import BM25Retriever | |
| # Text Splitters | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # -------------------------------------------------- | |
| # 1. API KEY | |
| # -------------------------------------------------- | |
| GROQ_API_KEY = os.getenv("GROQ_API") | |
| STRICT_PROMPT_TEMPLATE = """You are a strict document-based assistant. | |
| Use ONLY the information provided in the context. | |
| RULES: | |
| 1. Do not use outside knowledge. | |
| 2. If the answer is not present, say: | |
| "I'm sorry, but the provided documents do not contain information to answer this question." | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer: | |
| """ | |
| STRICT_PROMPT = PromptTemplate( | |
| template=STRICT_PROMPT_TEMPLATE, | |
| input_variables=["context", "question"] | |
| ) | |
| # -------------------------------------------------- | |
| # 2. FILE LOADER | |
| # -------------------------------------------------- | |
| def load_any(path: str): | |
| p = path.lower() | |
| if p.endswith(".pdf"): | |
| return PyPDFLoader(path).load() | |
| if p.endswith(".txt"): | |
| return TextLoader(path, encoding="utf-8").load() | |
| if p.endswith(".docx"): | |
| return Docx2txtLoader(path).load() | |
| return [] | |
| # -------------------------------------------------- | |
| # 3. PROCESS FILES / BUILD CHAIN | |
| # -------------------------------------------------- | |
| def process_files(files, response_length): | |
| if not files or not GROQ_API_KEY: | |
| return None, "⚠️ Missing documents or GROQ_API key." | |
| try: | |
| docs = [] | |
| for f in files: | |
| docs.extend(load_any(f.name)) | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=800, | |
| chunk_overlap=100 | |
| ) | |
| chunks = splitter.split_documents(docs) | |
| # --- Hybrid Retrieval --- | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| faiss_db = FAISS.from_documents(chunks, embeddings) | |
| faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3}) | |
| bm25_retriever = BM25Retriever.from_documents(chunks) | |
| bm25_retriever.k = 3 | |
| retriever = EnsembleRetriever( | |
| retrievers=[faiss_retriever, bm25_retriever], | |
| weights=[0.5, 0.5] | |
| ) | |
| llm = ChatGroq( | |
| groq_api_key=GROQ_API_KEY, | |
| model="llama-3.3-70b-versatile", | |
| temperature=0, | |
| max_tokens=int(response_length) | |
| ) | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True, | |
| output_key="answer" | |
| ) | |
| chain = ConversationalRetrievalChain.from_llm( | |
| llm=llm, | |
| retriever=retriever, | |
| combine_docs_chain_kwargs={"prompt": STRICT_PROMPT}, | |
| memory=memory, | |
| return_source_documents=True, | |
| output_key="answer" | |
| ) | |
| return chain, f"✅ Chatbot ready (max {response_length} tokens)" | |
| except Exception as e: | |
| return None, f"❌ Error: {str(e)}" | |
| # -------------------------------------------------- | |
| # 4. CHAT FUNCTION | |
| # -------------------------------------------------- | |
| def chat_function(message, history, chain): | |
| if chain is None: | |
| return "⚠️ Please build the chatbot first." | |
| result = chain.invoke({ | |
| "question": message, | |
| "chat_history": history | |
| }) | |
| answer = result["answer"] | |
| sources = { | |
| os.path.basename(doc.metadata.get("source", "unknown")) | |
| for doc in result.get("source_documents", []) | |
| } | |
| if sources: | |
| answer += "\n\n---\n**Sources:** " + ", ".join(sources) | |
| return answer | |
| # -------------------------------------------------- | |
| # 5. GRADIO UI | |
| # -------------------------------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🛡️ Strict Hybrid Multi-RAG (Groq + FAISS + BM25)") | |
| chain_state = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| file_count="multiple", | |
| label="Upload Documents" | |
| ) | |
| len_slider = gr.Slider( | |
| 100, 4000, value=1000, step=100, | |
| label="Max Answer Tokens" | |
| ) | |
| build_btn = gr.Button( | |
| "Build Chatbot", | |
| variant="primary" | |
| ) | |
| status = gr.Textbox( | |
| label="Status", | |
| interactive=False | |
| ) | |
| with gr.Column(scale=2): | |
| gr.ChatInterface( | |
| fn=chat_function, | |
| additional_inputs=[chain_state] | |
| ) | |
| build_btn.click( | |
| process_files, | |
| inputs=[file_input, len_slider], | |
| outputs=[chain_state, status] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |