Subayyal commited on
Commit
49dfb24
Β·
verified Β·
1 Parent(s): 25988f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import glob
4
+ import tempfile
5
+ from typing import List
6
+ import streamlit as st
7
+
8
+ # LangChain / loaders / vectorstore / embeddings / LLM
9
+ from langchain_community.document_loaders import PyPDFLoader
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+ from langchain_groq import ChatGroq
14
+ from langchain.chains import RetrievalQA
15
+
16
+ st.set_page_config(page_title="RAG Papers Chat (Groq)", layout="wide")
17
+
18
+ # -----------------------
19
+ # Load custom CSS
20
+ # -----------------------
21
+ def load_css(path="style.css"):
22
+ if os.path.exists(path):
23
+ with open(path) as f:
24
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
25
+
26
+ load_css()
27
+
28
+ # -----------------------
29
+ # Sidebar / settings
30
+ # -----------------------
31
+ st.sidebar.title("βš™οΈ Settings")
32
+ chunk_size = st.sidebar.number_input("Chunk size", min_value=256, max_value=5000, value=1000, step=100)
33
+ chunk_overlap = st.sidebar.number_input("Chunk overlap", min_value=0, max_value=1000, value=200, step=50)
34
+ top_k = st.sidebar.slider("Top-k chunks to retrieve", min_value=1, max_value=10, value=3)
35
+ model_choice = st.sidebar.selectbox(
36
+ "Groq model",
37
+ options=["llama-3.1-8b-instant", "llama-3.1-8b-8192", "mixtral-3b-16384"],
38
+ index=0
39
+ )
40
+ st.sidebar.markdown("πŸ”‘ Your **Groq API key** must be set as a secret (`GROQ_API_KEY`) in Hugging Face Settings.")
41
+
42
+ # -----------------------
43
+ # Utility functions
44
+ # -----------------------
45
+ @st.cache_data(show_spinner=False)
46
+ def load_and_split_pdfs(file_paths: List[str], chunk_size: int, chunk_overlap: int):
47
+ """Load PDFs and return list of split documents (LangChain docs)."""
48
+ all_docs = []
49
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
50
+ for path in file_paths:
51
+ loader = PyPDFLoader(path)
52
+ loaded = loader.load()
53
+ splitted = splitter.split_documents(loaded)
54
+ all_docs.extend(splitted)
55
+ return all_docs
56
+
57
+ @st.cache_resource(show_spinner=False)
58
+ def build_vectorstore(docs):
59
+ """Create HuggingFace embeddings + FAISS vectorstore."""
60
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
61
+ vectorstore = FAISS.from_documents(docs, embeddings)
62
+ return vectorstore
63
+
64
+ def initialize_llm(model_name: str):
65
+ api_key = os.environ.get("GROQ_API_KEY")
66
+ if not api_key:
67
+ st.error("🚨 No `GROQ_API_KEY` found. Please add it in Hugging Face Space β†’ Settings β†’ Secrets.")
68
+ st.stop()
69
+ return ChatGroq(model=model_name, api_key=api_key, temperature=0)
70
+
71
+ # -----------------------
72
+ # Main UI
73
+ # -----------------------
74
+ st.title("πŸ“š RAG Chat for Research Papers β€” Streamlit (Groq)")
75
+ st.write("Upload multiple PDFs and ask questions. Answers will include deduplicated file sources.")
76
+
77
+ uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
78
+ process_btn = st.button("Process uploaded PDFs")
79
+
80
+ if process_btn:
81
+ if not uploaded_files:
82
+ st.warning("Please upload one or more PDF files first.")
83
+ else:
84
+ tmp_paths = []
85
+ for f in uploaded_files:
86
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
87
+ tmp.write(f.read())
88
+ tmp.flush()
89
+ tmp_paths.append(tmp.name)
90
+
91
+ st.success("βœ… PDFs saved. Processing...")
92
+
93
+ with st.spinner("Splitting into chunks..."):
94
+ docs = load_and_split_pdfs(tmp_paths, chunk_size, chunk_overlap)
95
+ st.write(f"βœ… Created {len(docs)} chunks.")
96
+
97
+ with st.spinner("Building FAISS vectorstore..."):
98
+ vectorstore = build_vectorstore(docs)
99
+
100
+ st.session_state["vectorstore"] = vectorstore
101
+ st.session_state["processed"] = True
102
+ st.success("βœ… Vectorstore ready! Ask questions below.")
103
+
104
+ # -----------------------
105
+ # Chat section
106
+ # -----------------------
107
+ st.markdown("---")
108
+ st.subheader("πŸ’¬ Chat with your papers")
109
+
110
+ if "processed" not in st.session_state:
111
+ st.info("Process PDFs first to build the index.")
112
+ else:
113
+ if "llm" not in st.session_state:
114
+ st.session_state["llm"] = initialize_llm(model_choice)
115
+
116
+ if "qa_chain" not in st.session_state:
117
+ retriever = st.session_state["vectorstore"].as_retriever(search_kwargs={"k": top_k})
118
+ st.session_state["qa_chain"] = RetrievalQA.from_chain_type(
119
+ llm=st.session_state["llm"],
120
+ retriever=retriever,
121
+ chain_type="stuff",
122
+ return_source_documents=True,
123
+ )
124
+
125
+ if "history" not in st.session_state:
126
+ st.session_state["history"] = []
127
+
128
+ query = st.text_input("Enter your question")
129
+ ask = st.button("Ask")
130
+
131
+ if ask and query.strip():
132
+ with st.spinner("Thinking..."):
133
+ result = st.session_state["qa_chain"]({"query": query})
134
+ answer = result.get("result", "")
135
+ source_docs = result.get("source_documents", [])
136
+
137
+ unique_sources = list({doc.metadata.get("source", "unknown") for doc in source_docs})
138
+ sources_text = "\n".join([f"- {os.path.basename(s)}" for s in unique_sources])
139
+
140
+ full_answer = f"{answer}\n\nπŸ“š **Sources:**\n{sources_text}"
141
+ st.session_state["history"].append((query, full_answer))
142
+
143
+ st.markdown("### πŸ“œ Conversation History")
144
+ for user_msg, bot_msg in reversed(st.session_state["history"]):
145
+ st.markdown(f"**You:** {user_msg}")
146
+ st.markdown(f"**Bot:** {bot_msg}")