Spaces:
Sleeping
Sleeping
File size: 4,703 Bytes
f768714 ad21633 e64e86b ad21633 41bbc28 ad21633 41bbc28 ad21633 f768714 ad21633 f768714 ad21633 41bbc28 ad21633 f768714 ad21633 f768714 ad21633 f768714 41bbc28 ad21633 41bbc28 f768714 41bbc28 ad21633 f768714 41bbc28 ad21633 f768714 41bbc28 f768714 41bbc28 f768714 ad21633 f768714 41bbc28 ad21633 f768714 ad21633 f768714 41bbc28 f768714 41bbc28 f768714 41bbc28 ad21633 41bbc28 ad21633 41bbc28 ad21633 f768714 ad21633 41bbc28 f768714 41bbc28 f768714 ad21633 f768714 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import os
import gradio as gr
# Classic & Community Imports
from langchain_classic.chains import ConversationalRetrievalChain
from langchain_classic.memory import ConversationBufferMemory
from langchain_groq import ChatGroq
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.prompts import PromptTemplate
# --- 1. SETUP API & SYSTEM PROMPT ---
# Hugging Face uses os.getenv for secrets
api_key = os.getenv("GROQ_API")
STRICT_PROMPT_TEMPLATE = """You are a strict document-based assistant.
Use the following pieces of context to answer the user's question.
RESTRICTIONS:
1. ONLY use the information provided in the context below.
2. If the answer is not contained within the context, specifically say: "I'm sorry, but the provided documents do not contain information to answer this question."
3. Do NOT use your own outside knowledge.
Context:
{context}
Question: {question}
Helpful Answer:"""
STRICT_PROMPT = PromptTemplate(
template=STRICT_PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
# --- 2. LOADING LOGIC ---
def load_any(path: str):
p = path.lower()
if p.endswith(".pdf"): return PyPDFLoader(path).load()
if p.endswith(".txt"): return TextLoader(path, encoding="utf-8").load()
if p.endswith(".docx"): return Docx2txtLoader(path).load()
return []
# --- 3. HYBRID PROCESSING ---
def process_files(files, response_length):
if not files or not api_key:
return None, "⚠️ Missing files or GROQ_API key in Secrets."
try:
docs = []
for file_obj in files:
docs.extend(load_any(file_obj.name))
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
chunks = splitter.split_documents(docs)
# Hybrid Retrievers
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
faiss_db = FAISS.from_documents(chunks, embeddings)
faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3})
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = 3
ensemble_retriever = EnsembleRetriever(
retrievers=[faiss_retriever, bm25_retriever],
weights=[0.5, 0.5]
)
llm = ChatGroq(
groq_api_key=api_key,
model="llama-3.3-70b-versatile",
temperature=0,
max_tokens=int(response_length)
)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer"
)
chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=ensemble_retriever,
combine_docs_chain_kwargs={"prompt": STRICT_PROMPT},
memory=memory,
return_source_documents=True,
output_key="answer"
)
return chain, f"✅ Knowledge base built! Max answer length: {response_length} tokens."
except Exception as e:
return None, f"❌ Error: {str(e)}"
# --- 4. CHAT FUNCTION ---
def chat_function(message, history, chain):
if not chain:
return "⚠️ Build the chatbot first."
res = chain.invoke({"question": message})
answer = res["answer"]
sources = list(set([os.path.basename(d.metadata.get("source", "unknown")) for d in res.get("source_documents", [])]))
source_display = "\n\n----- \n**Sources used:** " + ", ".join(sources)
return answer + source_display
# --- 5. UI BUILDING ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🛡️ Strict Hybrid Multi-RAG")
chain_state = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(file_count="multiple", label="1. Upload Documents")
len_slider = gr.Slider(minimum=100, maximum=4000, value=1000, step=100, label="2. Response Length")
build_btn = gr.Button("3. Build Restricted Chatbot", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2):
gr.ChatInterface(fn=chat_function, additional_inputs=[chain_state])
build_btn.click(process_files, inputs=[file_input, len_slider], outputs=[chain_state, status])
if __name__ == "__main__":
demo.launch() |