RAG-BASED-APP / app.py
Subayyal's picture
Create app.py
49dfb24 verified
# app.py
import os
import glob
import tempfile
from typing import List
import streamlit as st
# LangChain / loaders / vectorstore / embeddings / LLM
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
st.set_page_config(page_title="RAG Papers Chat (Groq)", layout="wide")
# -----------------------
# Load custom CSS
# -----------------------
def load_css(path="style.css"):
if os.path.exists(path):
with open(path) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
load_css()
# -----------------------
# Sidebar / settings
# -----------------------
st.sidebar.title("βš™οΈ Settings")
chunk_size = st.sidebar.number_input("Chunk size", min_value=256, max_value=5000, value=1000, step=100)
chunk_overlap = st.sidebar.number_input("Chunk overlap", min_value=0, max_value=1000, value=200, step=50)
top_k = st.sidebar.slider("Top-k chunks to retrieve", min_value=1, max_value=10, value=3)
model_choice = st.sidebar.selectbox(
"Groq model",
options=["llama-3.1-8b-instant", "llama-3.1-8b-8192", "mixtral-3b-16384"],
index=0
)
st.sidebar.markdown("πŸ”‘ Your **Groq API key** must be set as a secret (`GROQ_API_KEY`) in Hugging Face Settings.")
# -----------------------
# Utility functions
# -----------------------
@st.cache_data(show_spinner=False)
def load_and_split_pdfs(file_paths: List[str], chunk_size: int, chunk_overlap: int):
"""Load PDFs and return list of split documents (LangChain docs)."""
all_docs = []
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
for path in file_paths:
loader = PyPDFLoader(path)
loaded = loader.load()
splitted = splitter.split_documents(loaded)
all_docs.extend(splitted)
return all_docs
@st.cache_resource(show_spinner=False)
def build_vectorstore(docs):
"""Create HuggingFace embeddings + FAISS vectorstore."""
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(docs, embeddings)
return vectorstore
def initialize_llm(model_name: str):
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
st.error("🚨 No `GROQ_API_KEY` found. Please add it in Hugging Face Space β†’ Settings β†’ Secrets.")
st.stop()
return ChatGroq(model=model_name, api_key=api_key, temperature=0)
# -----------------------
# Main UI
# -----------------------
st.title("πŸ“š RAG Chat for Research Papers β€” Streamlit (Groq)")
st.write("Upload multiple PDFs and ask questions. Answers will include deduplicated file sources.")
uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
process_btn = st.button("Process uploaded PDFs")
if process_btn:
if not uploaded_files:
st.warning("Please upload one or more PDF files first.")
else:
tmp_paths = []
for f in uploaded_files:
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
tmp.write(f.read())
tmp.flush()
tmp_paths.append(tmp.name)
st.success("βœ… PDFs saved. Processing...")
with st.spinner("Splitting into chunks..."):
docs = load_and_split_pdfs(tmp_paths, chunk_size, chunk_overlap)
st.write(f"βœ… Created {len(docs)} chunks.")
with st.spinner("Building FAISS vectorstore..."):
vectorstore = build_vectorstore(docs)
st.session_state["vectorstore"] = vectorstore
st.session_state["processed"] = True
st.success("βœ… Vectorstore ready! Ask questions below.")
# -----------------------
# Chat section
# -----------------------
st.markdown("---")
st.subheader("πŸ’¬ Chat with your papers")
if "processed" not in st.session_state:
st.info("Process PDFs first to build the index.")
else:
if "llm" not in st.session_state:
st.session_state["llm"] = initialize_llm(model_choice)
if "qa_chain" not in st.session_state:
retriever = st.session_state["vectorstore"].as_retriever(search_kwargs={"k": top_k})
st.session_state["qa_chain"] = RetrievalQA.from_chain_type(
llm=st.session_state["llm"],
retriever=retriever,
chain_type="stuff",
return_source_documents=True,
)
if "history" not in st.session_state:
st.session_state["history"] = []
query = st.text_input("Enter your question")
ask = st.button("Ask")
if ask and query.strip():
with st.spinner("Thinking..."):
result = st.session_state["qa_chain"]({"query": query})
answer = result.get("result", "")
source_docs = result.get("source_documents", [])
unique_sources = list({doc.metadata.get("source", "unknown") for doc in source_docs})
sources_text = "\n".join([f"- {os.path.basename(s)}" for s in unique_sources])
full_answer = f"{answer}\n\nπŸ“š **Sources:**\n{sources_text}"
st.session_state["history"].append((query, full_answer))
st.markdown("### πŸ“œ Conversation History")
for user_msg, bot_msg in reversed(st.session_state["history"]):
st.markdown(f"**You:** {user_msg}")
st.markdown(f"**Bot:** {bot_msg}")