faiz0983's picture
Update app.py
ad21633 verified
raw
history blame
4.7 kB
import os
import gradio as gr
# Classic & Community Imports
from langchain_classic.chains import ConversationalRetrievalChain
from langchain_classic.memory import ConversationBufferMemory
from langchain_groq import ChatGroq
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader, TextLoader, Docx2txtLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.prompts import PromptTemplate
# --- 1. SETUP API & SYSTEM PROMPT ---
# Hugging Face uses os.getenv for secrets
api_key = os.getenv("GROQ_API")
STRICT_PROMPT_TEMPLATE = """You are a strict document-based assistant.
Use the following pieces of context to answer the user's question.
RESTRICTIONS:
1. ONLY use the information provided in the context below.
2. If the answer is not contained within the context, specifically say: "I'm sorry, but the provided documents do not contain information to answer this question."
3. Do NOT use your own outside knowledge.
Context:
{context}
Question: {question}
Helpful Answer:"""
STRICT_PROMPT = PromptTemplate(
template=STRICT_PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
# --- 2. LOADING LOGIC ---
def load_any(path: str):
p = path.lower()
if p.endswith(".pdf"): return PyPDFLoader(path).load()
if p.endswith(".txt"): return TextLoader(path, encoding="utf-8").load()
if p.endswith(".docx"): return Docx2txtLoader(path).load()
return []
# --- 3. HYBRID PROCESSING ---
def process_files(files, response_length):
if not files or not api_key:
return None, "⚠️ Missing files or GROQ_API key in Secrets."
try:
docs = []
for file_obj in files:
docs.extend(load_any(file_obj.name))
splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
chunks = splitter.split_documents(docs)
# Hybrid Retrievers
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
faiss_db = FAISS.from_documents(chunks, embeddings)
faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3})
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = 3
ensemble_retriever = EnsembleRetriever(
retrievers=[faiss_retriever, bm25_retriever],
weights=[0.5, 0.5]
)
llm = ChatGroq(
groq_api_key=api_key,
model="llama-3.3-70b-versatile",
temperature=0,
max_tokens=int(response_length)
)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer"
)
chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=ensemble_retriever,
combine_docs_chain_kwargs={"prompt": STRICT_PROMPT},
memory=memory,
return_source_documents=True,
output_key="answer"
)
return chain, f"✅ Knowledge base built! Max answer length: {response_length} tokens."
except Exception as e:
return None, f"❌ Error: {str(e)}"
# --- 4. CHAT FUNCTION ---
def chat_function(message, history, chain):
if not chain:
return "⚠️ Build the chatbot first."
res = chain.invoke({"question": message})
answer = res["answer"]
sources = list(set([os.path.basename(d.metadata.get("source", "unknown")) for d in res.get("source_documents", [])]))
source_display = "\n\n----- \n**Sources used:** " + ", ".join(sources)
return answer + source_display
# --- 5. UI BUILDING ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🛡️ Strict Hybrid Multi-RAG")
chain_state = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(file_count="multiple", label="1. Upload Documents")
len_slider = gr.Slider(minimum=100, maximum=4000, value=1000, step=100, label="2. Response Length")
build_btn = gr.Button("3. Build Restricted Chatbot", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2):
gr.ChatInterface(fn=chat_function, additional_inputs=[chain_state])
build_btn.click(process_files, inputs=[file_input, len_slider], outputs=[chain_state, status])
if __name__ == "__main__":
demo.launch()