Files changed (1) hide show
  1. app.py +161 -67
app.py CHANGED
@@ -9,8 +9,9 @@ from langchain_community.vectorstores import FAISS
9
  from langchain_huggingface import HuggingFacePipeline
10
  from langchain_classic.prompts import PromptTemplate
11
  from langchain_classic.chains import RetrievalQA
12
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
13
  from huggingface_hub import login
 
14
 
15
 
16
  # --- Page Config & Styling ---
@@ -52,6 +53,49 @@ st.markdown("""
52
  [data-testid="stSidebar"] {
53
  padding-bottom: 50px;
54
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  </style>
56
  """, unsafe_allow_html=True)
57
 
@@ -59,18 +103,25 @@ st.markdown("""
59
  if 'qa_chain' not in st.session_state: st.session_state.qa_chain = None
60
  if 'messages' not in st.session_state: st.session_state.messages = []
61
  if 'processing_done' not in st.session_state: st.session_state.processing_done = False
 
 
 
62
 
63
  # --- Authentication (Secrets Only) ---
64
  hf_token = os.environ.get("HF_TOKEN")
65
 
66
- # --- Model Loading (Cached & CPU Optimized) ---
67
 
68
  @st.cache_resource
69
  def load_embedding_model():
70
  """Load the embedding model once to save time."""
71
  try:
72
  # Using a lightweight, fast embedding model
73
- embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
 
 
 
 
74
  return embeddings
75
  except Exception as e:
76
  st.error(f"Error loading embedding model: {e}")
@@ -78,73 +129,93 @@ def load_embedding_model():
78
 
79
  @st.cache_resource
80
  def load_llm_model(token):
81
- """Load the Gemma LLM once."""
82
  try:
83
  login(token=token)
84
- model_id = "google/gemma-2-2b-it"
85
 
86
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
87
 
88
- # Load model to CPU (float32 is safe for CPU stability)
89
  model = AutoModelForCausalLM.from_pretrained(
90
  model_id,
91
  device_map="cpu",
92
- torch_dtype=torch.float32,
 
93
  token=token
94
  )
95
-
96
- pipe = pipeline(
97
- "text-generation",
98
- model=model,
99
- tokenizer=tokenizer,
100
- max_new_tokens=512,
101
- temperature=0.1,
102
- repetition_penalty=1.1,
103
- return_full_text=False
104
- )
105
- return pipe
106
  except Exception as e:
107
- return None
 
108
 
109
- # --- PDF Processing ---
110
- def process_document(uploaded_file, model_pipeline, embedding_model):
111
  try:
112
  # Save temp file
113
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
114
  tmp.write(uploaded_file.getvalue())
115
  tmp_path = tmp.name
116
 
117
- # Load & Split
118
  loader = PyPDFLoader(tmp_path)
119
  docs = loader.load()
120
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
 
 
 
 
 
121
  chunks = splitter.split_documents(docs)
122
 
123
  # Vector Store (FAISS is faster for in-memory)
124
  vector_store = FAISS.from_documents(chunks, embedding_model)
125
 
126
- # Chain Setup
127
- llm = HuggingFacePipeline(pipeline=model_pipeline)
128
-
129
- template = """<start_of_turn>user
130
- Answer the question based strictly on the context below. Keep answers concise.
131
- Context: {context}
132
- Question: {question}<end_of_turn>
133
- <start_of_turn>model
134
- """
135
- prompt = PromptTemplate(template=template, input_variables=["context", "question"])
136
 
137
- qa_chain = RetrievalQA.from_chain_type(
138
- llm=llm,
139
- retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
140
- chain_type_kwargs={"prompt": prompt},
141
- return_source_documents=True
142
- )
143
- return qa_chain
144
  except Exception as e:
145
  st.error(f"Error processing PDF: {e}")
146
  return None
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # --- Main Layout ---
149
 
150
  # 1. Sidebar Configuration
@@ -154,7 +225,7 @@ with st.sidebar:
154
 
155
  if not hf_token:
156
  st.error("🚨 **HF_TOKEN missing!**")
157
- st.info("Go to Space Settings -> Repository Secrets and add your Hugging Face Access Token as `HF_TOKEN`.")
158
  st.stop()
159
  else:
160
  st.success("βœ… Huggingface Active")
@@ -166,17 +237,20 @@ with st.sidebar:
166
  process_btn = st.button("πŸš€ Process Document", type="primary", use_container_width=True)
167
 
168
  if process_btn:
169
- with st.spinner("🧠 Analyzing PDF"):
170
  # Load models (cached)
171
- llm_pipeline = load_llm_model(hf_token)
172
  embed_model = load_embedding_model()
173
 
174
- if llm_pipeline and embed_model:
175
- qa_chain = process_document(uploaded_file, llm_pipeline, embed_model)
176
- if qa_chain:
177
- st.session_state.qa_chain = qa_chain
 
 
178
  st.session_state.processing_done = True
179
- st.success("Done! You can now chat.")
 
180
  else:
181
  st.error("Failed to process document.")
182
  else:
@@ -184,13 +258,14 @@ with st.sidebar:
184
 
185
  if st.session_state.processing_done:
186
  st.markdown("---")
 
 
187
  if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
188
  st.session_state.messages = []
189
  st.rerun()
190
 
191
  # 2. Main Chat Area
192
  st.title("πŸ“—πŸ’¬ DocTalk - Chat With PDF")
193
- #st.caption("Powered by Google Gemma-2-2B-IT")
194
 
195
  if st.session_state.processing_done:
196
  # Display History
@@ -205,29 +280,48 @@ if st.session_state.processing_done:
205
  st.markdown(user_input)
206
 
207
  with st.chat_message("assistant"):
208
- with st.spinner("Thinking..."):
209
- try:
210
- response = st.session_state.qa_chain.invoke({"query": user_input})
211
- answer = response['result']
212
-
213
- st.markdown(answer)
214
- st.session_state.messages.append({"role": "assistant", "content": answer})
215
-
216
- # Optional: Show sources
217
- with st.expander("πŸ”Ž View Source Context"):
218
- for doc in response['source_documents']:
219
- st.caption(f"Page {doc.metadata.get('page', '?')}: {doc.page_content[:200]}...")
220
-
221
- except Exception as e:
222
- st.error(f"An error occurred: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  else:
224
  # Empty State
225
  st.info("πŸ‘‹ **Welcome!** Please upload a PDF in the sidebar to begin chatting.")
226
  st.markdown("""
227
  **How it works:**
228
- 1. Upload a PDF document.
229
- 2. Click 'Process Document'.
230
- 3. Ask questions and get answers based strictly on your file.
231
  """)
232
 
233
  # --- Footer ---
 
9
  from langchain_huggingface import HuggingFacePipeline
10
  from langchain_classic.prompts import PromptTemplate
11
  from langchain_classic.chains import RetrievalQA
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextIteratorStreamer
13
  from huggingface_hub import login
14
+ from threading import Thread
15
 
16
 
17
  # --- Page Config & Styling ---
 
53
  [data-testid="stSidebar"] {
54
  padding-bottom: 50px;
55
  }
56
+ /* Responsive Design */
57
+ @media (max-width: 768px) {
58
+ /* Make sidebar collapsible on mobile */
59
+ [data-testid="stSidebar"] {
60
+ width: 100% !important;
61
+ }
62
+
63
+ /* Adjust chat input for mobile */
64
+ .stChatInput {
65
+ font-size: 16px !important;
66
+ }
67
+
68
+ /* Better spacing on mobile */
69
+ .block-container {
70
+ padding: 1rem !important;
71
+ }
72
+
73
+ /* Footer text smaller on mobile */
74
+ .footer {
75
+ font-size: 12px;
76
+ padding: 8px;
77
+ }
78
+ }
79
+ @media (max-width: 480px) {
80
+ /* Extra small devices */
81
+ h1 {
82
+ font-size: 1.5rem !important;
83
+ }
84
+
85
+ .stButton button {
86
+ font-size: 14px !important;
87
+ }
88
+ }
89
+ /* Touch-friendly buttons */
90
+ .stButton button {
91
+ min-height: 44px;
92
+ padding: 0.5rem 1rem;
93
+ }
94
+ /* Better chat message display on mobile */
95
+ [data-testid="stChatMessage"] {
96
+ max-width: 100%;
97
+ padding: 0.5rem;
98
+ }
99
  </style>
100
  """, unsafe_allow_html=True)
101
 
 
103
  if 'qa_chain' not in st.session_state: st.session_state.qa_chain = None
104
  if 'messages' not in st.session_state: st.session_state.messages = []
105
  if 'processing_done' not in st.session_state: st.session_state.processing_done = False
106
+ if 'vector_store' not in st.session_state: st.session_state.vector_store = None
107
+ if 'model' not in st.session_state: st.session_state.model = None
108
+ if 'tokenizer' not in st.session_state: st.session_state.tokenizer = None
109
 
110
  # --- Authentication (Secrets Only) ---
111
  hf_token = os.environ.get("HF_TOKEN")
112
 
113
+ # --- Model Loading (Cached & Optimized) ---
114
 
115
  @st.cache_resource
116
  def load_embedding_model():
117
  """Load the embedding model once to save time."""
118
  try:
119
  # Using a lightweight, fast embedding model
120
+ embeddings = HuggingFaceEmbeddings(
121
+ model_name="all-MiniLM-L6-v2",
122
+ model_kwargs={'device': 'cpu'},
123
+ encode_kwargs={'normalize_embeddings': True}
124
+ )
125
  return embeddings
126
  except Exception as e:
127
  st.error(f"Error loading embedding model: {e}")
 
129
 
130
  @st.cache_resource
131
  def load_llm_model(token):
132
+ """Load the Gemma LLM once - returns model and tokenizer for streaming."""
133
  try:
134
  login(token=token)
135
+ model_id = "google/gemma-2-2b-it"
136
 
137
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
138
 
139
+ # Load model to CPU with optimizations
140
  model = AutoModelForCausalLM.from_pretrained(
141
  model_id,
142
  device_map="cpu",
143
+ torch_dtype=torch.float32,
144
+ low_cpu_mem_usage=True,
145
  token=token
146
  )
147
+
148
+ return model, tokenizer
 
 
 
 
 
 
 
 
 
149
  except Exception as e:
150
+ st.error(f"Error loading LLM: {e}")
151
+ return None, None
152
 
153
+ # --- PDF Processing (Optimized) ---
154
+ def process_document(uploaded_file, embedding_model):
155
  try:
156
  # Save temp file
157
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
158
  tmp.write(uploaded_file.getvalue())
159
  tmp_path = tmp.name
160
 
161
+ # Load & Split with optimized parameters
162
  loader = PyPDFLoader(tmp_path)
163
  docs = loader.load()
164
+
165
+ # Larger chunks, less overlap = faster processing
166
+ splitter = RecursiveCharacterTextSplitter(
167
+ chunk_size=1500, # Increased from 1000
168
+ chunk_overlap=50 # Reduced from 100
169
+ )
170
  chunks = splitter.split_documents(docs)
171
 
172
  # Vector Store (FAISS is faster for in-memory)
173
  vector_store = FAISS.from_documents(chunks, embedding_model)
174
 
175
+ # Clean up temp file
176
+ os.unlink(tmp_path)
 
 
 
 
 
 
 
 
177
 
178
+ return vector_store
 
 
 
 
 
 
179
  except Exception as e:
180
  st.error(f"Error processing PDF: {e}")
181
  return None
182
 
183
+ def get_relevant_context(vector_store, question):
184
+ """Retrieve relevant context from vector store."""
185
+ retriever = vector_store.as_retriever(search_kwargs={"k": 2})
186
+ docs = retriever.invoke(question)
187
+ context = "\n\n".join([doc.page_content for doc in docs])
188
+ return context, docs
189
+
190
+ def stream_response(model, tokenizer, prompt):
191
+ """Generate streaming response from the model."""
192
+ # Tokenize input
193
+ inputs = tokenizer(prompt, return_tensors="pt")
194
+
195
+ # Create streamer
196
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
197
+
198
+ # Generation config
199
+ generation_kwargs = dict(
200
+ inputs,
201
+ streamer=streamer,
202
+ max_new_tokens=200,
203
+ temperature=0.2,
204
+ top_p=0.9,
205
+ repetition_penalty=1.15,
206
+ do_sample=True
207
+ )
208
+
209
+ # Start generation in a separate thread
210
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
211
+ thread.start()
212
+
213
+ # Yield tokens as they're generated
214
+ for text in streamer:
215
+ yield text
216
+
217
+ thread.join()
218
+
219
  # --- Main Layout ---
220
 
221
  # 1. Sidebar Configuration
 
225
 
226
  if not hf_token:
227
  st.error("🚨 **HF_TOKEN missing!**")
228
+ st.info("Go to Space Settings β†’ Repository Secrets and add your Hugging Face Access Token as `HF_TOKEN`.")
229
  st.stop()
230
  else:
231
  st.success("βœ… Huggingface Active")
 
237
  process_btn = st.button("πŸš€ Process Document", type="primary", use_container_width=True)
238
 
239
  if process_btn:
240
+ with st.spinner("🧠 Analyzing PDF ..."):
241
  # Load models (cached)
242
+ model, tokenizer = load_llm_model(hf_token)
243
  embed_model = load_embedding_model()
244
 
245
+ if model and tokenizer and embed_model:
246
+ vector_store = process_document(uploaded_file, embed_model)
247
+ if vector_store:
248
+ st.session_state.vector_store = vector_store
249
+ st.session_state.model = model
250
+ st.session_state.tokenizer = tokenizer
251
  st.session_state.processing_done = True
252
+ st.success("βœ… Done! You can now chat with streaming responses.")
253
+ st.rerun()
254
  else:
255
  st.error("Failed to process document.")
256
  else:
 
258
 
259
  if st.session_state.processing_done:
260
  st.markdown("---")
261
+ st.info("βœ… Document Processed")
262
+
263
  if st.button("πŸ—‘οΈ Clear Chat History", use_container_width=True):
264
  st.session_state.messages = []
265
  st.rerun()
266
 
267
  # 2. Main Chat Area
268
  st.title("πŸ“—πŸ’¬ DocTalk - Chat With PDF")
 
269
 
270
  if st.session_state.processing_done:
271
  # Display History
 
280
  st.markdown(user_input)
281
 
282
  with st.chat_message("assistant"):
283
+ try:
284
+ # Get relevant context
285
+ context, source_docs = get_relevant_context(st.session_state.vector_store, user_input)
286
+
287
+ # Build prompt
288
+ prompt = f"""<|system|>
289
+ You are a helpful assistant. Answer based only on the context provided. Be concise.</s>
290
+ <|user|>
291
+ Context: {context}
292
+ Question: {user_input}</s>
293
+ <|assistant|>
294
+ """
295
+
296
+ # Stream the response
297
+ response_placeholder = st.empty()
298
+ full_response = ""
299
+
300
+ for chunk in stream_response(st.session_state.model, st.session_state.tokenizer, prompt):
301
+ full_response += chunk
302
+ response_placeholder.markdown(full_response + "β–Œ")
303
+
304
+ # Final update without cursor
305
+ response_placeholder.markdown(full_response)
306
+
307
+ # Save to history
308
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
309
+
310
+ # Optional: Show sources
311
+ with st.expander("πŸ”Ž View Source Context"):
312
+ for i, doc in enumerate(source_docs):
313
+ st.caption(f"**Source {i+1}** (Page {doc.metadata.get('page', '?')}): {doc.page_content[:150]}...")
314
+
315
+ except Exception as e:
316
+ st.error(f"An error occurred: {e}")
317
  else:
318
  # Empty State
319
  st.info("πŸ‘‹ **Welcome!** Please upload a PDF in the sidebar to begin chatting.")
320
  st.markdown("""
321
  **How it works:**
322
+ 1. Upload a PDF document
323
+ 2. Click 'Process Document'
324
+ 3. Ask questions and get **live streaming answers**
325
  """)
326
 
327
  # --- Footer ---