faiz0983's picture
Update app.py
ee4f5d4 verified
raw
history blame
5.33 kB
import os
import gradio as gr
# LangChain Core
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.retrievers import EnsembleRetriever
# Providers
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
# Community
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import (
PyPDFLoader,
TextLoader,
Docx2txtLoader
)
from langchain_community.retrievers import BM25Retriever
# Text Splitters
from langchain_text_splitters import RecursiveCharacterTextSplitter
# --------------------------------------------------
# 1. API KEY
# --------------------------------------------------
GROQ_API_KEY = os.getenv("GROQ_API")
STRICT_PROMPT_TEMPLATE = """You are a strict document-based assistant.
Use ONLY the information provided in the context.
RULES:
1. Do not use outside knowledge.
2. If the answer is not present, say:
"I'm sorry, but the provided documents do not contain information to answer this question."
Context:
{context}
Question: {question}
Answer:
"""
STRICT_PROMPT = PromptTemplate(
template=STRICT_PROMPT_TEMPLATE,
input_variables=["context", "question"]
)
# --------------------------------------------------
# 2. FILE LOADER
# --------------------------------------------------
def load_any(path: str):
p = path.lower()
if p.endswith(".pdf"):
return PyPDFLoader(path).load()
if p.endswith(".txt"):
return TextLoader(path, encoding="utf-8").load()
if p.endswith(".docx"):
return Docx2txtLoader(path).load()
return []
# --------------------------------------------------
# 3. PROCESS FILES / BUILD CHAIN
# --------------------------------------------------
def process_files(files, response_length):
if not files or not GROQ_API_KEY:
return None, "⚠️ Missing documents or GROQ_API key."
try:
docs = []
for f in files:
docs.extend(load_any(f.name))
splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=100
)
chunks = splitter.split_documents(docs)
# --- Hybrid Retrieval ---
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
faiss_db = FAISS.from_documents(chunks, embeddings)
faiss_retriever = faiss_db.as_retriever(search_kwargs={"k": 3})
bm25_retriever = BM25Retriever.from_documents(chunks)
bm25_retriever.k = 3
retriever = EnsembleRetriever(
retrievers=[faiss_retriever, bm25_retriever],
weights=[0.5, 0.5]
)
llm = ChatGroq(
groq_api_key=GROQ_API_KEY,
model="llama-3.3-70b-versatile",
temperature=0,
max_tokens=int(response_length)
)
memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True,
output_key="answer"
)
chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
combine_docs_chain_kwargs={"prompt": STRICT_PROMPT},
memory=memory,
return_source_documents=True,
output_key="answer"
)
return chain, f"✅ Chatbot ready (max {response_length} tokens)"
except Exception as e:
return None, f"❌ Error: {str(e)}"
# --------------------------------------------------
# 4. CHAT FUNCTION
# --------------------------------------------------
def chat_function(message, history, chain):
if chain is None:
return "⚠️ Please build the chatbot first."
result = chain.invoke({
"question": message,
"chat_history": history
})
answer = result["answer"]
sources = {
os.path.basename(doc.metadata.get("source", "unknown"))
for doc in result.get("source_documents", [])
}
if sources:
answer += "\n\n---\n**Sources:** " + ", ".join(sources)
return answer
# --------------------------------------------------
# 5. GRADIO UI
# --------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🛡️ Strict Hybrid Multi-RAG (Groq + FAISS + BM25)")
chain_state = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
file_count="multiple",
label="Upload Documents"
)
len_slider = gr.Slider(
100, 4000, value=1000, step=100,
label="Max Answer Tokens"
)
build_btn = gr.Button(
"Build Chatbot",
variant="primary"
)
status = gr.Textbox(
label="Status",
interactive=False
)
with gr.Column(scale=2):
gr.ChatInterface(
fn=chat_function,
additional_inputs=[chain_state]
)
build_btn.click(
process_files,
inputs=[file_input, len_slider],
outputs=[chain_state, status]
)
if __name__ == "__main__":
demo.launch()