Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import glob | |
| import tempfile | |
| from typing import List | |
| import streamlit as st | |
| # LangChain / loaders / vectorstore / embeddings / LLM | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain_groq import ChatGroq | |
| from langchain.chains import RetrievalQA | |
| st.set_page_config(page_title="RAG Papers Chat (Groq)", layout="wide") | |
| # ----------------------- | |
| # Load custom CSS | |
| # ----------------------- | |
| def load_css(path="style.css"): | |
| if os.path.exists(path): | |
| with open(path) as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| load_css() | |
| # ----------------------- | |
| # Sidebar / settings | |
| # ----------------------- | |
| st.sidebar.title("βοΈ Settings") | |
| chunk_size = st.sidebar.number_input("Chunk size", min_value=256, max_value=5000, value=1000, step=100) | |
| chunk_overlap = st.sidebar.number_input("Chunk overlap", min_value=0, max_value=1000, value=200, step=50) | |
| top_k = st.sidebar.slider("Top-k chunks to retrieve", min_value=1, max_value=10, value=3) | |
| model_choice = st.sidebar.selectbox( | |
| "Groq model", | |
| options=["llama-3.1-8b-instant", "llama-3.1-8b-8192", "mixtral-3b-16384"], | |
| index=0 | |
| ) | |
| st.sidebar.markdown("π Your **Groq API key** must be set as a secret (`GROQ_API_KEY`) in Hugging Face Settings.") | |
| # ----------------------- | |
| # Utility functions | |
| # ----------------------- | |
| def load_and_split_pdfs(file_paths: List[str], chunk_size: int, chunk_overlap: int): | |
| """Load PDFs and return list of split documents (LangChain docs).""" | |
| all_docs = [] | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| for path in file_paths: | |
| loader = PyPDFLoader(path) | |
| loaded = loader.load() | |
| splitted = splitter.split_documents(loaded) | |
| all_docs.extend(splitted) | |
| return all_docs | |
| def build_vectorstore(docs): | |
| """Create HuggingFace embeddings + FAISS vectorstore.""" | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| vectorstore = FAISS.from_documents(docs, embeddings) | |
| return vectorstore | |
| def initialize_llm(model_name: str): | |
| api_key = os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| st.error("π¨ No `GROQ_API_KEY` found. Please add it in Hugging Face Space β Settings β Secrets.") | |
| st.stop() | |
| return ChatGroq(model=model_name, api_key=api_key, temperature=0) | |
| # ----------------------- | |
| # Main UI | |
| # ----------------------- | |
| st.title("π RAG Chat for Research Papers β Streamlit (Groq)") | |
| st.write("Upload multiple PDFs and ask questions. Answers will include deduplicated file sources.") | |
| uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True) | |
| process_btn = st.button("Process uploaded PDFs") | |
| if process_btn: | |
| if not uploaded_files: | |
| st.warning("Please upload one or more PDF files first.") | |
| else: | |
| tmp_paths = [] | |
| for f in uploaded_files: | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") | |
| tmp.write(f.read()) | |
| tmp.flush() | |
| tmp_paths.append(tmp.name) | |
| st.success("β PDFs saved. Processing...") | |
| with st.spinner("Splitting into chunks..."): | |
| docs = load_and_split_pdfs(tmp_paths, chunk_size, chunk_overlap) | |
| st.write(f"β Created {len(docs)} chunks.") | |
| with st.spinner("Building FAISS vectorstore..."): | |
| vectorstore = build_vectorstore(docs) | |
| st.session_state["vectorstore"] = vectorstore | |
| st.session_state["processed"] = True | |
| st.success("β Vectorstore ready! Ask questions below.") | |
| # ----------------------- | |
| # Chat section | |
| # ----------------------- | |
| st.markdown("---") | |
| st.subheader("π¬ Chat with your papers") | |
| if "processed" not in st.session_state: | |
| st.info("Process PDFs first to build the index.") | |
| else: | |
| if "llm" not in st.session_state: | |
| st.session_state["llm"] = initialize_llm(model_choice) | |
| if "qa_chain" not in st.session_state: | |
| retriever = st.session_state["vectorstore"].as_retriever(search_kwargs={"k": top_k}) | |
| st.session_state["qa_chain"] = RetrievalQA.from_chain_type( | |
| llm=st.session_state["llm"], | |
| retriever=retriever, | |
| chain_type="stuff", | |
| return_source_documents=True, | |
| ) | |
| if "history" not in st.session_state: | |
| st.session_state["history"] = [] | |
| query = st.text_input("Enter your question") | |
| ask = st.button("Ask") | |
| if ask and query.strip(): | |
| with st.spinner("Thinking..."): | |
| result = st.session_state["qa_chain"]({"query": query}) | |
| answer = result.get("result", "") | |
| source_docs = result.get("source_documents", []) | |
| unique_sources = list({doc.metadata.get("source", "unknown") for doc in source_docs}) | |
| sources_text = "\n".join([f"- {os.path.basename(s)}" for s in unique_sources]) | |
| full_answer = f"{answer}\n\nπ **Sources:**\n{sources_text}" | |
| st.session_state["history"].append((query, full_answer)) | |
| st.markdown("### π Conversation History") | |
| for user_msg, bot_msg in reversed(st.session_state["history"]): | |
| st.markdown(f"**You:** {user_msg}") | |
| st.markdown(f"**Bot:** {bot_msg}") | |