Files changed (1) hide show
  1. app.py +226 -84
app.py CHANGED
@@ -6,11 +6,9 @@ from langchain_community.document_loaders import PyPDFLoader
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
  from langchain_community.vectorstores import FAISS
9
- from langchain_huggingface import HuggingFacePipeline
10
- from langchain_classic.prompts import PromptTemplate
11
- from langchain_classic.chains import RetrievalQA
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
13
  from huggingface_hub import login
 
14
 
15
 
16
  # --- Page Config & Styling ---
@@ -52,25 +50,91 @@ st.markdown("""
52
  [data-testid="stSidebar"] {
53
  padding-bottom: 50px;
54
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  </style>
56
  """, unsafe_allow_html=True)
57
 
58
  # --- Session State Management ---
59
- if 'qa_chain' not in st.session_state: st.session_state.qa_chain = None
60
- if 'messages' not in st.session_state: st.session_state.messages = []
61
- if 'processing_done' not in st.session_state: st.session_state.processing_done = False
 
 
 
 
 
 
 
62
 
63
  # --- Authentication (Secrets Only) ---
64
  hf_token = os.environ.get("HF_TOKEN")
65
 
66
- # --- Model Loading (Cached & CPU Optimized) ---
67
 
68
  @st.cache_resource
69
  def load_embedding_model():
70
  """Load the embedding model once to save time."""
71
  try:
72
- # Using a lightweight, fast embedding model
73
- embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
 
 
 
74
  return embeddings
75
  except Exception as e:
76
  st.error(f"Error loading embedding model: {e}")
@@ -78,161 +142,239 @@ def load_embedding_model():
78
 
79
  @st.cache_resource
80
  def load_llm_model(token):
81
- """Load the Gemma LLM once."""
82
  try:
83
  login(token=token)
84
- model_id = "google/gemma-2-2b-it"
85
 
86
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
87
 
88
- # Load model to CPU (float32 is safe for CPU stability)
89
  model = AutoModelForCausalLM.from_pretrained(
90
  model_id,
91
  device_map="cpu",
92
- torch_dtype=torch.float32,
 
93
  token=token
94
  )
95
-
96
- pipe = pipeline(
97
- "text-generation",
98
- model=model,
99
- tokenizer=tokenizer,
100
- max_new_tokens=512,
101
- temperature=0.1,
102
- repetition_penalty=1.1,
103
- return_full_text=False
104
- )
105
- return pipe
106
  except Exception as e:
107
- return None
 
108
 
109
- # --- PDF Processing ---
110
- def process_document(uploaded_file, model_pipeline, embedding_model):
 
111
  try:
112
  # Save temp file
113
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
114
  tmp.write(uploaded_file.getvalue())
115
  tmp_path = tmp.name
116
 
117
- # Load & Split
118
  loader = PyPDFLoader(tmp_path)
119
  docs = loader.load()
120
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
 
 
 
 
 
 
121
  chunks = splitter.split_documents(docs)
122
 
123
- # Vector Store (FAISS is faster for in-memory)
124
  vector_store = FAISS.from_documents(chunks, embedding_model)
125
 
126
- # Chain Setup
127
- llm = HuggingFacePipeline(pipeline=model_pipeline)
128
-
129
- template = """<start_of_turn>user
130
- Answer the question based strictly on the context below. Keep answers concise.
131
- Context: {context}
132
- Question: {question}<end_of_turn>
133
- <start_of_turn>model
134
- """
135
- prompt = PromptTemplate(template=template, input_variables=["context", "question"])
136
 
137
- qa_chain = RetrievalQA.from_chain_type(
138
- llm=llm,
139
- retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
140
- chain_type_kwargs={"prompt": prompt},
141
- return_source_documents=True
142
- )
143
- return qa_chain
144
  except Exception as e:
145
  st.error(f"Error processing PDF: {e}")
146
  return None
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # --- Main Layout ---
149
 
150
  # 1. Sidebar Configuration
151
  with st.sidebar:
152
- st.title("πŸ€– Configuration")
153
  st.markdown("---")
154
 
155
  if not hf_token:
156
  st.error("🚨 **HF_TOKEN missing!**")
157
- st.info("Go to Space Settings -> Repository Secrets and add your Hugging Face Access Token as `HF_TOKEN`.")
158
  st.stop()
159
  else:
160
- st.success("βœ… Huggingface Active")
161
 
162
  st.subheader("πŸ“„ Document Upload")
163
- uploaded_file = st.file_uploader("Upload your PDF", type="pdf", help="Max file size ~200MB")
164
 
165
  if uploaded_file:
166
  process_btn = st.button("πŸš€ Process Document", type="primary", use_container_width=True)
167
 
168
  if process_btn:
169
- with st.spinner("🧠 Analyzing PDF"):
170
  # Load models (cached)
171
- llm_pipeline = load_llm_model(hf_token)
172
  embed_model = load_embedding_model()
173
 
174
- if llm_pipeline and embed_model:
175
- qa_chain = process_document(uploaded_file, llm_pipeline, embed_model)
176
- if qa_chain:
177
- st.session_state.qa_chain = qa_chain
 
 
178
  st.session_state.processing_done = True
179
- st.success("Done! You can now chat.")
 
180
  else:
181
- st.error("Failed to process document.")
182
  else:
183
- st.error("Failed to load AI models. Check token permissions.")
184
 
185
  if st.session_state.processing_done:
186
  st.markdown("---")
 
 
 
187
  if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
188
  st.session_state.messages = []
189
  st.rerun()
 
 
 
 
 
 
190
 
191
  # 2. Main Chat Area
192
  st.title("πŸ“—πŸ’¬ DocTalk - Chat With PDF")
193
- #st.caption("Powered by Google Gemma-2-2B-IT")
194
 
195
  if st.session_state.processing_done:
196
- # Display History
197
  for msg in st.session_state.messages:
198
  with st.chat_message(msg["role"]):
199
  st.markdown(msg["content"])
200
 
201
  # Chat Input
202
  if user_input := st.chat_input("Ask a question about your document..."):
 
203
  st.session_state.messages.append({"role": "user", "content": user_input})
204
  with st.chat_message("user"):
205
  st.markdown(user_input)
206
 
 
207
  with st.chat_message("assistant"):
208
- with st.spinner("Thinking..."):
209
- try:
210
- response = st.session_state.qa_chain.invoke({"query": user_input})
211
- answer = response['result']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- st.markdown(answer)
214
- st.session_state.messages.append({"role": "assistant", "content": answer})
215
 
216
- # Optional: Show sources
217
- with st.expander("πŸ”Ž View Source Context"):
218
- for doc in response['source_documents']:
219
- st.caption(f"Page {doc.metadata.get('page', '?')}: {doc.page_content[:200]}...")
220
-
221
- except Exception as e:
222
- st.error(f"An error occurred: {e}")
 
 
 
 
 
 
 
223
  else:
224
  # Empty State
225
- st.info("πŸ‘‹ **Welcome!** Please upload a PDF in the sidebar to begin chatting.")
226
- st.markdown("""
227
- **How it works:**
228
- 1. Upload a PDF document.
229
- 2. Click 'Process Document'.
230
- 3. Ask questions and get answers based strictly on your file.
231
- """)
 
 
 
 
 
 
 
 
 
 
232
 
233
  # --- Footer ---
234
  st.markdown("""
235
  <div class="footer">
236
- Made with ❀️ with Streamlit and Gemma model, by Tannu Yadav
237
  </div>
238
  """, unsafe_allow_html=True)
 
6
  from langchain_text_splitters import RecursiveCharacterTextSplitter
7
  from langchain_huggingface import HuggingFaceEmbeddings
8
  from langchain_community.vectorstores import FAISS
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
 
10
  from huggingface_hub import login
11
+ from threading import Thread
12
 
13
 
14
  # --- Page Config & Styling ---
 
50
  [data-testid="stSidebar"] {
51
  padding-bottom: 50px;
52
  }
53
+ /* Responsive Design */
54
+ @media (max-width: 768px) {
55
+ /* Make sidebar collapsible on mobile */
56
+ [data-testid="stSidebar"] {
57
+ width: 100% !important;
58
+ }
59
+
60
+ /* Adjust chat input for mobile */
61
+ .stChatInput {
62
+ font-size: 16px !important;
63
+ }
64
+
65
+ /* Better spacing on mobile */
66
+ .block-container {
67
+ padding: 1rem !important;
68
+ }
69
+
70
+ /* Footer text smaller on mobile */
71
+ .footer {
72
+ font-size: 12px;
73
+ padding: 8px;
74
+ }
75
+ }
76
+ @media (max-width: 480px) {
77
+ /* Extra small devices */
78
+ h1 {
79
+ font-size: 1.5rem !important;
80
+ }
81
+
82
+ .stButton button {
83
+ font-size: 14px !important;
84
+ }
85
+ }
86
+ /* Touch-friendly buttons */
87
+ .stButton button {
88
+ min-height: 44px;
89
+ padding: 0.5rem 1rem;
90
+ }
91
+ /* Better chat message display on mobile */
92
+ [data-testid="stChatMessage"] {
93
+ max-width: 100%;
94
+ padding: 0.5rem;
95
+ }
96
+ /* Animated typing indicator */
97
+ @keyframes blink {
98
+ 0%, 49% { opacity: 1; }
99
+ 50%, 100% { opacity: 0; }
100
+ }
101
+ @keyframes pulse {
102
+ 0%, 100% { transform: scale(1); opacity: 1; }
103
+ 50% { transform: scale(1.2); opacity: 0.7; }
104
+ }
105
+ @keyframes shimmer {
106
+ 0% { background-position: -100% 0; }
107
+ 100% { background-position: 100% 0; }
108
+ }
109
  </style>
110
  """, unsafe_allow_html=True)
111
 
112
  # --- Session State Management ---
113
+ if 'messages' not in st.session_state:
114
+ st.session_state.messages = []
115
+ if 'processing_done' not in st.session_state:
116
+ st.session_state.processing_done = False
117
+ if 'vector_store' not in st.session_state:
118
+ st.session_state.vector_store = None
119
+ if 'model' not in st.session_state:
120
+ st.session_state.model = None
121
+ if 'tokenizer' not in st.session_state:
122
+ st.session_state.tokenizer = None
123
 
124
  # --- Authentication (Secrets Only) ---
125
  hf_token = os.environ.get("HF_TOKEN")
126
 
127
+ # --- Model Loading (Cached & Optimized) ---
128
 
129
  @st.cache_resource
130
  def load_embedding_model():
131
  """Load the embedding model once to save time."""
132
  try:
133
+ embeddings = HuggingFaceEmbeddings(
134
+ model_name="all-MiniLM-L6-v2",
135
+ model_kwargs={'device': 'cpu'},
136
+ encode_kwargs={'normalize_embeddings': True}
137
+ )
138
  return embeddings
139
  except Exception as e:
140
  st.error(f"Error loading embedding model: {e}")
 
142
 
143
  @st.cache_resource
144
  def load_llm_model(token):
145
+ """Load the Gemma LLM once - returns model and tokenizer for streaming."""
146
  try:
147
  login(token=token)
148
+ model_id = "google/gemma-2-2b-it"
149
 
150
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
151
 
152
+ # Load model to CPU with optimizations
153
  model = AutoModelForCausalLM.from_pretrained(
154
  model_id,
155
  device_map="cpu",
156
+ torch_dtype=torch.float32,
157
+ low_cpu_mem_usage=True,
158
  token=token
159
  )
160
+
161
+ return model, tokenizer
 
 
 
 
 
 
 
 
 
162
  except Exception as e:
163
+ st.error(f"Error loading LLM: {e}")
164
+ return None, None
165
 
166
+ # --- PDF Processing (Optimized for better accuracy) ---
167
+ def process_document(uploaded_file, embedding_model):
168
+ """Process PDF and create vector store."""
169
  try:
170
  # Save temp file
171
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
172
  tmp.write(uploaded_file.getvalue())
173
  tmp_path = tmp.name
174
 
175
+ # Load & Split with balanced parameters for accuracy
176
  loader = PyPDFLoader(tmp_path)
177
  docs = loader.load()
178
+
179
+ # Balanced chunking for better accuracy
180
+ splitter = RecursiveCharacterTextSplitter(
181
+ chunk_size=1000,
182
+ chunk_overlap=100,
183
+ separators=["\n\n", "\n", " ", ""]
184
+ )
185
  chunks = splitter.split_documents(docs)
186
 
187
+ # Vector Store
188
  vector_store = FAISS.from_documents(chunks, embedding_model)
189
 
190
+ # Clean up temp file
191
+ os.unlink(tmp_path)
 
 
 
 
 
 
 
 
192
 
193
+ return vector_store
 
 
 
 
 
 
194
  except Exception as e:
195
  st.error(f"Error processing PDF: {e}")
196
  return None
197
 
198
+ def get_relevant_context(vector_store, question):
199
+ """Retrieve relevant context from vector store."""
200
+ try:
201
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
202
+ docs = retriever.invoke(question)
203
+ context = "\n\n".join([doc.page_content for doc in docs])
204
+ return context, docs
205
+ except Exception as e:
206
+ st.error(f"Error retrieving context: {e}")
207
+ return "", []
208
+
209
+ def stream_response(model, tokenizer, prompt):
210
+ """Generate streaming response from the model."""
211
+ try:
212
+ # Tokenize input
213
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
214
+
215
+ # Create streamer
216
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
217
+
218
+ # Generation config optimized for Gemma
219
+ generation_kwargs = dict(
220
+ inputs,
221
+ streamer=streamer,
222
+ max_new_tokens=512,
223
+ temperature=0.3,
224
+ top_p=0.95,
225
+ repetition_penalty=1.1,
226
+ do_sample=True,
227
+ pad_token_id=tokenizer.eos_token_id
228
+ )
229
+
230
+ # Start generation in a separate thread
231
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
232
+ thread.start()
233
+
234
+ # Yield tokens as they're generated
235
+ for text in streamer:
236
+ yield text
237
+
238
+ thread.join()
239
+ except Exception as e:
240
+ yield f"Error generating response: {e}"
241
+
242
  # --- Main Layout ---
243
 
244
  # 1. Sidebar Configuration
245
  with st.sidebar:
246
+ st.title("Configuration")
247
  st.markdown("---")
248
 
249
  if not hf_token:
250
  st.error("🚨 **HF_TOKEN missing!**")
251
+ st.info("Go to Space Settings β†’ Repository Secrets and add your Hugging Face Access Token as `HF_TOKEN`.")
252
  st.stop()
253
  else:
254
+ st.success("βœ… Hugging Face Connected")
255
 
256
  st.subheader("πŸ“„ Document Upload")
257
+ uploaded_file = st.file_uploader("Upload your PDF", type="pdf", help="Upload a PDF document to chat with")
258
 
259
  if uploaded_file:
260
  process_btn = st.button("πŸš€ Process Document", type="primary", use_container_width=True)
261
 
262
  if process_btn:
263
+ with st.spinner("🧠 Analyzing PDF document..."):
264
  # Load models (cached)
265
+ model, tokenizer = load_llm_model(hf_token)
266
  embed_model = load_embedding_model()
267
 
268
+ if model and tokenizer and embed_model:
269
+ vector_store = process_document(uploaded_file, embed_model)
270
+ if vector_store:
271
+ st.session_state.vector_store = vector_store
272
+ st.session_state.model = model
273
+ st.session_state.tokenizer = tokenizer
274
  st.session_state.processing_done = True
275
+ st.success("βœ… Document processed! Start chatting below.")
276
+ st.rerun()
277
  else:
278
+ st.error("❌ Failed to process document. Please try again.")
279
  else:
280
+ st.error("❌ Failed to load AI models. Check your token permissions.")
281
 
282
  if st.session_state.processing_done:
283
  st.markdown("---")
284
+ st.success("βœ… Start Chatting")
285
+ st.info(f"πŸ“„ **{uploaded_file.name if uploaded_file else 'Document'}** loaded")
286
+
287
  if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
288
  st.session_state.messages = []
289
  st.rerun()
290
+
291
+ if st.button("πŸ”„ Upload New Document", use_container_width=True):
292
+ st.session_state.processing_done = False
293
+ st.session_state.vector_store = None
294
+ st.session_state.messages = []
295
+ st.rerun()
296
 
297
  # 2. Main Chat Area
298
  st.title("πŸ“—πŸ’¬ DocTalk - Chat With PDF")
 
299
 
300
  if st.session_state.processing_done:
301
+ # Display Chat History
302
  for msg in st.session_state.messages:
303
  with st.chat_message(msg["role"]):
304
  st.markdown(msg["content"])
305
 
306
  # Chat Input
307
  if user_input := st.chat_input("Ask a question about your document..."):
308
+ # Add user message
309
  st.session_state.messages.append({"role": "user", "content": user_input})
310
  with st.chat_message("user"):
311
  st.markdown(user_input)
312
 
313
+ # Generate assistant response
314
  with st.chat_message("assistant"):
315
+ try:
316
+ # Get relevant context
317
+ context, source_docs = get_relevant_context(st.session_state.vector_store, user_input)
318
+
319
+ if not context:
320
+ st.warning("⚠️ Could not find relevant information in the document.")
321
+ else:
322
+ # Build prompt for Gemma
323
+ prompt = f"""<start_of_turn>user
324
+ Answer the question based strictly on the context below. Be concise and accurate.
325
+ Context: {context}
326
+ Question: {user_input}<end_of_turn>
327
+ <start_of_turn>model
328
+ """
329
+
330
+ # Stream the response
331
+ response_placeholder = st.empty()
332
+ full_response = ""
333
+
334
+ for chunk in stream_response(st.session_state.model, st.session_state.tokenizer, prompt):
335
+ full_response += chunk
336
+ response_placeholder.markdown(full_response + " <span style='animation: blink 1s infinite; color: #00d4ff; font-weight: bold;'>✍</span>", unsafe_allow_html=True)
337
 
338
+ # Final update without cursor
339
+ response_placeholder.markdown(full_response)
340
 
341
+ # Save to history
342
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
343
+
344
+ # Show sources
345
+ if source_docs:
346
+ with st.expander("πŸ”Ž View Source Context"):
347
+ for i, doc in enumerate(source_docs):
348
+ st.markdown(f"**Source {i+1}** (Page {doc.metadata.get('page', 'Unknown')})")
349
+ st.caption(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content)
350
+ st.markdown("---")
351
+
352
+ except Exception as e:
353
+ st.error(f"❌ An error occurred: {e}")
354
+ st.info("Please try asking your question again or upload a new document.")
355
  else:
356
  # Empty State
357
+ st.info("πŸ‘‹ **Welcome to DocTalk!** Upload a PDF document in the sidebar to begin chatting.")
358
+
359
+ col1, col2, col3 = st.columns(3)
360
+
361
+ with col1:
362
+ st.markdown("### πŸ“€ Upload")
363
+ st.markdown("Upload your PDF document using the sidebar")
364
+
365
+ with col2:
366
+ st.markdown("### πŸ”„ Process")
367
+ st.markdown("Click 'Process Document' to analyze it")
368
+
369
+ with col3:
370
+ st.markdown("### πŸ’¬ Chat")
371
+ st.markdown("Ask questions and get instant answers")
372
+
373
+ st.markdown("---")
374
 
375
  # --- Footer ---
376
  st.markdown("""
377
  <div class="footer">
378
+ Made with ❀️ using Streamlit and Gemma model, by Tannu Yadav
379
  </div>
380
  """, unsafe_allow_html=True)