File size: 5,524 Bytes
49dfb24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# app.py
import os
import glob
import tempfile
from typing import List
import streamlit as st

# LangChain / loaders / vectorstore / embeddings / LLM
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.chains import RetrievalQA

st.set_page_config(page_title="RAG Papers Chat (Groq)", layout="wide")

# -----------------------
# Load custom CSS
# -----------------------
def load_css(path="style.css"):
    if os.path.exists(path):
        with open(path) as f:
            st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)

load_css()

# -----------------------
# Sidebar / settings
# -----------------------
st.sidebar.title("βš™οΈ Settings")
chunk_size = st.sidebar.number_input("Chunk size", min_value=256, max_value=5000, value=1000, step=100)
chunk_overlap = st.sidebar.number_input("Chunk overlap", min_value=0, max_value=1000, value=200, step=50)
top_k = st.sidebar.slider("Top-k chunks to retrieve", min_value=1, max_value=10, value=3)
model_choice = st.sidebar.selectbox(
    "Groq model",
    options=["llama-3.1-8b-instant", "llama-3.1-8b-8192", "mixtral-3b-16384"],
    index=0
)
st.sidebar.markdown("πŸ”‘ Your **Groq API key** must be set as a secret (`GROQ_API_KEY`) in Hugging Face Settings.")

# -----------------------
# Utility functions
# -----------------------
@st.cache_data(show_spinner=False)
def load_and_split_pdfs(file_paths: List[str], chunk_size: int, chunk_overlap: int):
    """Load PDFs and return list of split documents (LangChain docs)."""
    all_docs = []
    splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    for path in file_paths:
        loader = PyPDFLoader(path)
        loaded = loader.load()
        splitted = splitter.split_documents(loaded)
        all_docs.extend(splitted)
    return all_docs

@st.cache_resource(show_spinner=False)
def build_vectorstore(docs):
    """Create HuggingFace embeddings + FAISS vectorstore."""
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vectorstore = FAISS.from_documents(docs, embeddings)
    return vectorstore

def initialize_llm(model_name: str):
    api_key = os.environ.get("GROQ_API_KEY")
    if not api_key:
        st.error("🚨 No `GROQ_API_KEY` found. Please add it in Hugging Face Space β†’ Settings β†’ Secrets.")
        st.stop()
    return ChatGroq(model=model_name, api_key=api_key, temperature=0)

# -----------------------
# Main UI
# -----------------------
st.title("πŸ“š RAG Chat for Research Papers β€” Streamlit (Groq)")
st.write("Upload multiple PDFs and ask questions. Answers will include deduplicated file sources.")

uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
process_btn = st.button("Process uploaded PDFs")

if process_btn:
    if not uploaded_files:
        st.warning("Please upload one or more PDF files first.")
    else:
        tmp_paths = []
        for f in uploaded_files:
            tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
            tmp.write(f.read())
            tmp.flush()
            tmp_paths.append(tmp.name)

        st.success("βœ… PDFs saved. Processing...")

        with st.spinner("Splitting into chunks..."):
            docs = load_and_split_pdfs(tmp_paths, chunk_size, chunk_overlap)
            st.write(f"βœ… Created {len(docs)} chunks.")

        with st.spinner("Building FAISS vectorstore..."):
            vectorstore = build_vectorstore(docs)

        st.session_state["vectorstore"] = vectorstore
        st.session_state["processed"] = True
        st.success("βœ… Vectorstore ready! Ask questions below.")

# -----------------------
# Chat section
# -----------------------
st.markdown("---")
st.subheader("πŸ’¬ Chat with your papers")

if "processed" not in st.session_state:
    st.info("Process PDFs first to build the index.")
else:
    if "llm" not in st.session_state:
        st.session_state["llm"] = initialize_llm(model_choice)

    if "qa_chain" not in st.session_state:
        retriever = st.session_state["vectorstore"].as_retriever(search_kwargs={"k": top_k})
        st.session_state["qa_chain"] = RetrievalQA.from_chain_type(
            llm=st.session_state["llm"],
            retriever=retriever,
            chain_type="stuff",
            return_source_documents=True,
        )

    if "history" not in st.session_state:
        st.session_state["history"] = []

    query = st.text_input("Enter your question")
    ask = st.button("Ask")

    if ask and query.strip():
        with st.spinner("Thinking..."):
            result = st.session_state["qa_chain"]({"query": query})
            answer = result.get("result", "")
            source_docs = result.get("source_documents", [])

            unique_sources = list({doc.metadata.get("source", "unknown") for doc in source_docs})
            sources_text = "\n".join([f"- {os.path.basename(s)}" for s in unique_sources])

            full_answer = f"{answer}\n\nπŸ“š **Sources:**\n{sources_text}"
            st.session_state["history"].append((query, full_answer))

    st.markdown("### πŸ“œ Conversation History")
    for user_msg, bot_msg in reversed(st.session_state["history"]):
        st.markdown(f"**You:** {user_msg}")
        st.markdown(f"**Bot:** {bot_msg}")