# app.py import os import glob import tempfile from typing import List import streamlit as st # LangChain / loaders / vectorstore / embeddings / LLM from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain.embeddings import HuggingFaceEmbeddings from langchain_groq import ChatGroq from langchain.chains import RetrievalQA st.set_page_config(page_title="RAG Papers Chat (Groq)", layout="wide") # ----------------------- # Load custom CSS # ----------------------- def load_css(path="style.css"): if os.path.exists(path): with open(path) as f: st.markdown(f"", unsafe_allow_html=True) load_css() # ----------------------- # Sidebar / settings # ----------------------- st.sidebar.title("āš™ļø Settings") chunk_size = st.sidebar.number_input("Chunk size", min_value=256, max_value=5000, value=1000, step=100) chunk_overlap = st.sidebar.number_input("Chunk overlap", min_value=0, max_value=1000, value=200, step=50) top_k = st.sidebar.slider("Top-k chunks to retrieve", min_value=1, max_value=10, value=3) model_choice = st.sidebar.selectbox( "Groq model", options=["llama-3.1-8b-instant", "llama-3.1-8b-8192", "mixtral-3b-16384"], index=0 ) st.sidebar.markdown("šŸ”‘ Your **Groq API key** must be set as a secret (`GROQ_API_KEY`) in Hugging Face Settings.") # ----------------------- # Utility functions # ----------------------- @st.cache_data(show_spinner=False) def load_and_split_pdfs(file_paths: List[str], chunk_size: int, chunk_overlap: int): """Load PDFs and return list of split documents (LangChain docs).""" all_docs = [] splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) for path in file_paths: loader = PyPDFLoader(path) loaded = loader.load() splitted = splitter.split_documents(loaded) all_docs.extend(splitted) return all_docs @st.cache_resource(show_spinner=False) def build_vectorstore(docs): """Create HuggingFace embeddings + FAISS vectorstore.""" embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") vectorstore = FAISS.from_documents(docs, embeddings) return vectorstore def initialize_llm(model_name: str): api_key = os.environ.get("GROQ_API_KEY") if not api_key: st.error("🚨 No `GROQ_API_KEY` found. Please add it in Hugging Face Space → Settings → Secrets.") st.stop() return ChatGroq(model=model_name, api_key=api_key, temperature=0) # ----------------------- # Main UI # ----------------------- st.title("šŸ“š RAG Chat for Research Papers — Streamlit (Groq)") st.write("Upload multiple PDFs and ask questions. Answers will include deduplicated file sources.") uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) process_btn = st.button("Process uploaded PDFs") if process_btn: if not uploaded_files: st.warning("Please upload one or more PDF files first.") else: tmp_paths = [] for f in uploaded_files: tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") tmp.write(f.read()) tmp.flush() tmp_paths.append(tmp.name) st.success("āœ… PDFs saved. Processing...") with st.spinner("Splitting into chunks..."): docs = load_and_split_pdfs(tmp_paths, chunk_size, chunk_overlap) st.write(f"āœ… Created {len(docs)} chunks.") with st.spinner("Building FAISS vectorstore..."): vectorstore = build_vectorstore(docs) st.session_state["vectorstore"] = vectorstore st.session_state["processed"] = True st.success("āœ… Vectorstore ready! Ask questions below.") # ----------------------- # Chat section # ----------------------- st.markdown("---") st.subheader("šŸ’¬ Chat with your papers") if "processed" not in st.session_state: st.info("Process PDFs first to build the index.") else: if "llm" not in st.session_state: st.session_state["llm"] = initialize_llm(model_choice) if "qa_chain" not in st.session_state: retriever = st.session_state["vectorstore"].as_retriever(search_kwargs={"k": top_k}) st.session_state["qa_chain"] = RetrievalQA.from_chain_type( llm=st.session_state["llm"], retriever=retriever, chain_type="stuff", return_source_documents=True, ) if "history" not in st.session_state: st.session_state["history"] = [] query = st.text_input("Enter your question") ask = st.button("Ask") if ask and query.strip(): with st.spinner("Thinking..."): result = st.session_state["qa_chain"]({"query": query}) answer = result.get("result", "") source_docs = result.get("source_documents", []) unique_sources = list({doc.metadata.get("source", "unknown") for doc in source_docs}) sources_text = "\n".join([f"- {os.path.basename(s)}" for s in unique_sources]) full_answer = f"{answer}\n\nšŸ“š **Sources:**\n{sources_text}" st.session_state["history"].append((query, full_answer)) st.markdown("### šŸ“œ Conversation History") for user_msg, bot_msg in reversed(st.session_state["history"]): st.markdown(f"**You:** {user_msg}") st.markdown(f"**Bot:** {bot_msg}")