Spaces:
Sleeping
Sleeping
File size: 5,524 Bytes
49dfb24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | # app.py
import os
import glob
import tempfile
from typing import List
import streamlit as st
# LangChain / loaders / vectorstore / embeddings / LLM
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA
st.set_page_config(page_title="RAG Papers Chat (Groq)", layout="wide")
# -----------------------
# Load custom CSS
# -----------------------
def load_css(path="style.css"):
if os.path.exists(path):
with open(path) as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
load_css()
# -----------------------
# Sidebar / settings
# -----------------------
st.sidebar.title("βοΈ Settings")
chunk_size = st.sidebar.number_input("Chunk size", min_value=256, max_value=5000, value=1000, step=100)
chunk_overlap = st.sidebar.number_input("Chunk overlap", min_value=0, max_value=1000, value=200, step=50)
top_k = st.sidebar.slider("Top-k chunks to retrieve", min_value=1, max_value=10, value=3)
model_choice = st.sidebar.selectbox(
"Groq model",
options=["llama-3.1-8b-instant", "llama-3.1-8b-8192", "mixtral-3b-16384"],
index=0
)
st.sidebar.markdown("π Your **Groq API key** must be set as a secret (`GROQ_API_KEY`) in Hugging Face Settings.")
# -----------------------
# Utility functions
# -----------------------
@st.cache_data(show_spinner=False)
def load_and_split_pdfs(file_paths: List[str], chunk_size: int, chunk_overlap: int):
"""Load PDFs and return list of split documents (LangChain docs)."""
all_docs = []
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
for path in file_paths:
loader = PyPDFLoader(path)
loaded = loader.load()
splitted = splitter.split_documents(loaded)
all_docs.extend(splitted)
return all_docs
@st.cache_resource(show_spinner=False)
def build_vectorstore(docs):
"""Create HuggingFace embeddings + FAISS vectorstore."""
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(docs, embeddings)
return vectorstore
def initialize_llm(model_name: str):
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
st.error("π¨ No `GROQ_API_KEY` found. Please add it in Hugging Face Space β Settings β Secrets.")
st.stop()
return ChatGroq(model=model_name, api_key=api_key, temperature=0)
# -----------------------
# Main UI
# -----------------------
st.title("π RAG Chat for Research Papers β Streamlit (Groq)")
st.write("Upload multiple PDFs and ask questions. Answers will include deduplicated file sources.")
uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
process_btn = st.button("Process uploaded PDFs")
if process_btn:
if not uploaded_files:
st.warning("Please upload one or more PDF files first.")
else:
tmp_paths = []
for f in uploaded_files:
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
tmp.write(f.read())
tmp.flush()
tmp_paths.append(tmp.name)
st.success("β
PDFs saved. Processing...")
with st.spinner("Splitting into chunks..."):
docs = load_and_split_pdfs(tmp_paths, chunk_size, chunk_overlap)
st.write(f"β
Created {len(docs)} chunks.")
with st.spinner("Building FAISS vectorstore..."):
vectorstore = build_vectorstore(docs)
st.session_state["vectorstore"] = vectorstore
st.session_state["processed"] = True
st.success("β
Vectorstore ready! Ask questions below.")
# -----------------------
# Chat section
# -----------------------
st.markdown("---")
st.subheader("π¬ Chat with your papers")
if "processed" not in st.session_state:
st.info("Process PDFs first to build the index.")
else:
if "llm" not in st.session_state:
st.session_state["llm"] = initialize_llm(model_choice)
if "qa_chain" not in st.session_state:
retriever = st.session_state["vectorstore"].as_retriever(search_kwargs={"k": top_k})
st.session_state["qa_chain"] = RetrievalQA.from_chain_type(
llm=st.session_state["llm"],
retriever=retriever,
chain_type="stuff",
return_source_documents=True,
)
if "history" not in st.session_state:
st.session_state["history"] = []
query = st.text_input("Enter your question")
ask = st.button("Ask")
if ask and query.strip():
with st.spinner("Thinking..."):
result = st.session_state["qa_chain"]({"query": query})
answer = result.get("result", "")
source_docs = result.get("source_documents", [])
unique_sources = list({doc.metadata.get("source", "unknown") for doc in source_docs})
sources_text = "\n".join([f"- {os.path.basename(s)}" for s in unique_sources])
full_answer = f"{answer}\n\nπ **Sources:**\n{sources_text}"
st.session_state["history"].append((query, full_answer))
st.markdown("### π Conversation History")
for user_msg, bot_msg in reversed(st.session_state["history"]):
st.markdown(f"**You:** {user_msg}")
st.markdown(f"**Bot:** {bot_msg}")
|