aaporosh commited on
Commit
56d0815
·
verified ·
1 Parent(s): 058a20c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +413 -399
app.py CHANGED
@@ -2,410 +2,424 @@ import streamlit as st
2
  import logging
3
  import os
4
  from io import BytesIO
5
- import pdfplumber
6
- from PIL import Image
7
- import pytesseract
8
- from langchain.text_splitter import CharacterTextSplitter
9
- from langchain_community.vectorstores import FAISS
10
- from sentence_transformers import SentenceTransformer
11
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
12
- from datasets import load_dataset
13
- from rank_bm25 import BM25Okapi
14
- from rouge_score import rouge_scorer
15
  import re
16
  import time
 
 
 
17
 
18
- # Setup logging for Spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
- logger = logging.getLogger(__name__)
21
-
22
- # Lazy load models
23
- @st.cache_resource(ttl=1800)
24
- def load_embeddings_model():
25
- logger.info("Loading embeddings model")
26
- try:
27
- return SentenceTransformer("all-MiniLM-L6-v2")
28
- except Exception as e:
29
- logger.error(f"Embeddings load error: {str(e)}")
30
- st.error(f"Embedding model error: {str(e)}")
31
- return None
32
-
33
- @st.cache_resource(ttl=1800)
34
- def load_qa_pipeline():
35
- logger.info("Loading QA pipeline")
36
- try:
37
- dataset = load_and_prepare_dataset()
38
- if dataset:
39
- fine_tuned_pipeline = fine_tune_qa_model(dataset)
40
- if fine_tuned_pipeline:
41
- return fine_tuned_pipeline
42
- return pipeline("text2text-generation", model="google/flan-t5-small", max_length=300)
43
- except Exception as e:
44
- logger.error(f"QA model load error: {str(e)}")
45
- st.error(f"QA model error: {str(e)}")
46
- return None
47
-
48
- @st.cache_resource(ttl=1800)
49
- def load_summary_pipeline():
50
- logger.info("Loading summary pipeline")
51
- try:
52
- return pipeline("summarization", model="facebook/bart-large-cnn", max_length=250)
53
- except Exception as e:
54
- logger.error(f"Summary model load error: {str(e)}")
55
- st.error(f"Summary model error: {str(e)}")
56
- return None
57
-
58
- # Load and prepare dataset (e.g., SQuAD)
59
- @st.cache_data(ttl=3600)
60
- def load_and_prepare_dataset(dataset_name="squad", max_samples=1000):
61
- logger.info(f"Loading dataset: {dataset_name}")
62
- try:
63
- dataset = load_dataset(dataset_name, split="train[:80%]")
64
- dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
65
-
66
- def preprocess(examples):
67
- inputs = [f"question: {q} context: {c}" for q, c in zip(examples['question'], examples['context'])]
68
- targets = examples['answers']['text']
69
- return {'input_text': inputs, 'target_text': [t[0] if t else "" for t in targets]}
70
-
71
- dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
72
- return dataset
73
- except Exception as e:
74
- logger.error(f"Dataset load error: {str(e)}")
75
- return None
76
-
77
- # Fine-tune QA model
78
- @st.cache_resource(ttl=3600)
79
- def fine_tune_qa_model(dataset):
80
- logger.info("Starting fine-tuning")
81
- try:
82
- model_name = "google/flan-t5-small"
83
- tokenizer = AutoTokenizer.from_pretrained(model_name)
84
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
85
-
86
- def tokenize_function(examples):
87
- model_inputs = tokenizer(examples['input_text'], max_length=512, truncation=True, padding="max_length")
88
- labels = tokenizer(examples['target_text'], max_length=128, truncation=True, padding="max_length")
89
- model_inputs["labels"] = labels["input_ids"]
90
- return model_inputs
91
-
92
- tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['input_text', 'target_text'])
93
-
94
- training_args = TrainingArguments(
95
- output_dir="./fine_tuned_model",
96
- num_train_epochs = 2,
97
- per_device_train_batch_size=4,
98
- save_steps=500,
99
- logging_steps=100,
100
- evaluation_strategy="no",
101
- learning_rate=3e-5,
102
- fp16=False,
103
- )
104
-
105
- trainer = Trainer(
106
- model=model,
107
- args=training_args,
108
- train_dataset=tokenized_dataset,
109
- )
110
- trainer.train()
111
-
112
- model.save_pretrained("./fine_tuned_model")
113
- tokenizer.save_pretrained("./fine_tuned_model")
114
- logger.info("Fine-tuning complete")
115
- return pipeline("text2text-generation", model="./fine_tuned_model", tokenizer="./fine_tuned_model", max_length=300)
116
- except Exception as e:
117
- logger.error(f"Fine-tuning error: {str(e)}")
118
- return None
119
-
120
- # Augment vector store with dataset
121
- def augment_vector_store(vector_store, dataset_name="squad", max_samples=300):
122
- logger.info(f"Augmenting vector store with dataset: {dataset_name}")
123
- try:
124
- dataset = load_dataset(dataset_name, split="train").select(range(min(max_samples, len(dataset))))
125
- chunks = [f"Context: {c}\nAnswer: {a['text'][0]}" for c, a in zip(dataset['context'], dataset['answers'])]
126
- embeddings_model = load_embeddings_model()
127
- if embeddings_model and vector_store:
128
- embeddings = embeddings_model.encode(chunks, batch_size=128, show_progress_bar=False)
129
- vector_store.add_embeddings(zip(chunks, embeddings))
130
- return vector_store
131
- except Exception as e:
132
- logger.error(f"Vector store augmentation error: {str(e)}")
133
- return vector_store
134
-
135
- # Process PDF with enhanced extraction and OCR fallback
136
- def process_pdf(uploaded_file):
137
- logger.info("Processing PDF with enhanced extraction")
138
- try:
139
- text = ""
140
- code_blocks = []
141
- with pdfplumber.open(BytesIO(uploaded_file.getvalue())) as pdf:
142
- for page in pdf.pages[:8]:
143
- extracted = page.extract_text(layout=False)
144
- if not extracted:
145
- try:
146
- img = page.to_image(resolution=150).original
147
- extracted = pytesseract.image_to_string(img, config='--psm 6')
148
- except Exception as ocr_e:
149
- logger.warning(f"OCR failed: {str(ocr_e)}")
150
- if extracted:
151
- # Clean text: remove headers/footers (simple heuristic)
152
- lines = extracted.split("\n")
153
- cleaned_lines = [line for line in lines if not re.match(r'^\s*(Page \d+|.*\d{4}-\d{4}|Copyright.*)\s*$', line, re.I)]
154
- text += "\n".join(cleaned_lines) + "\n"
155
- for char in page.chars:
156
- if 'fontname' in char and 'mono' in char['fontname'].lower():
157
- code_blocks.append(char['text'])
158
- code_text = page.extract_text()
159
- code_matches = re.finditer(r'(^\s{2,}.*?(?:\n\s{2,}.*?)*)', code_text, re.MULTILINE)
160
- for match in code_matches:
161
- code_blocks.append(match.group().strip())
162
- tables = page.extract_tables()
163
- if tables:
164
- for table in tables:
165
- text += "\n".join([" | ".join(map(str, row)) for row in table if row]) + "\n"
166
- for obj in page.extract_words():
167
- if obj.get('size', 0) > 12:
168
- text += f"\n{obj['text']}\n"
169
-
170
- code_text = "\n".join(code_blocks).strip()
171
- if not text:
172
- raise ValueError("No text extracted from PDF")
173
-
174
- text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=250, chunk_overlap=40, keep_separator=True)
175
- text_chunks = text_splitter.split_text(text)[:25]
176
- code_chunks = text_splitter.split_text(code_text)[:10] if code_text else []
177
-
178
- embeddings_model = load_embeddings_model()
179
- if not embeddings_model:
180
- return None, None, text, code_text
181
-
182
- text_vector_store = FAISS.from_embeddings(
183
- zip(text_chunks, [embeddings_model.encode(chunk, show_progress_bar=False, batch_size=128) for chunk in text_chunks]),
184
- embeddings_model.encode
185
- ) if text_chunks else None
186
- code_vector_store = FAISS.from_embeddings(
187
- zip(code_chunks, [embeddings_model.encode(chunk, show_progress_bar=False, batch_size=128) for chunk in code_chunks]),
188
- embeddings_model.encode
189
- ) if code_chunks else None
190
-
191
- if text_vector_store:
192
- text_vector_store = augment_vector_store(text_vector_store)
193
-
194
- logger.info("PDF processed successfully")
195
- return text_vector_store, code_vector_store, text, code_text
196
- except Exception as e:
197
- logger.error(f"PDF processing error: {str(e)}")
198
- st.error(f"PDF error: {str(e)}")
199
- return None, None, "", ""
200
-
201
- # Summarize PDF with ROUGE metrics and improved topic focus
202
- def summarize_pdf(text):
203
- logger.info("Generating summary")
204
- try:
205
- summary_pipeline = load_summary_pipeline()
206
- if not summary_pipeline:
207
- return "Summary model unavailable."
208
-
209
- text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=250, chunk_overlap=40)
210
- chunks = text_splitter.split_text(text)
211
-
212
- # Hybrid search for relevant chunks
213
- embeddings_model = load_embeddings_model()
214
- if embeddings_model and chunks:
215
- temp_vector_store = FAISS.from_embeddings(
216
- zip(chunks, [embeddings_model.encode(chunk, show_progress_bar=False) for chunk in chunks]),
217
- embeddings_model.encode
218
- )
219
- bm25 = BM25Okapi([chunk.split() for chunk in chunks])
220
- query = "main topic and key points"
221
- bm25_docs = bm25.get_top_n(query.split(), chunks, n=4)
222
- faiss_docs = temp_vector_store.similarity_search(query, k=4)
223
- selected_chunks = list(set(bm25_docs + [doc.page_content for doc in faiss_docs]))[:4]
224
  else:
225
- selected_chunks = chunks[:4]
226
-
227
- summaries = []
228
- for chunk in selected_chunks:
229
- summary = summary_pipeline(f"Summarize the main topic and key points in detail: {chunk[:250]}", max_length=100, min_length=50, do_sample=False)[0]['summary_text']
230
- summaries.append(summary.strip())
231
-
232
- combined_summary = " ".join(summaries)
233
- if len(combined_summary.split()) > 250:
234
- combined_summary = " ".join(combined_summary.split()[:250])
235
-
236
- word_count = len(combined_summary.split())
237
- scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
238
- scores = scorer.score(text[:500], combined_summary)
239
- logger.info(f"ROUGE scores: {scores}")
240
-
241
- return f"**Main Topic Summary** ({word_count} words):\n{combined_summary}\n\n**ROUGE-1**: {scores['rouge1'].fmeasure:.2f}"
242
- except Exception as e:
243
- logger.error(f"Summary error: {str(e)}")
244
- return f"Oops, something went wrong summarizing: {str(e)}"
245
-
246
- # Answer question with hybrid search
247
- def answer_question(text_vector_store, code_vector_store, query):
248
- logger.info(f"Processing query: {query}")
249
- try:
250
- if not text_vector_store and not code_vector_store:
251
- return "Please upload a PDF first!"
252
-
253
- qa_pipeline = load_qa_pipeline()
254
- if not qa_pipeline:
255
- return "Sorry, the QA model is unavailable right now."
256
-
257
- is_code_query = any(keyword in query.lower() for keyword in ["code", "script", "function", "programming", "give me code", "show code"])
258
- if is_code_query and code_vector_store:
259
- docs = code_vector_store.similarity_search(query, k=3)
260
- code = "\n".join(doc.page_content for doc in docs)
261
- explanation = qa_pipeline(f"Explain this code: {code[:500]}")[0]['generated_text']
262
- return f"**Code**:\n```python\n{code}\n```\n**Explanation**:\n{explanation}"
263
-
264
- vector_store = text_vector_store
265
- if not vector_store:
266
- return "No relevant content found for your query."
267
-
268
- # Hybrid search: FAISS + BM25
269
- text_chunks = [doc.page_content for doc in vector_store.similarity_search(query, k=10)]
270
- bm25 = BM25Okapi([chunk.split() for chunk in text_chunks])
271
- bm25_docs = bm25.get_top_n(query.split(), text_chunks, n=5)
272
- faiss_docs = vector_store.similarity_search(query, k=5)
273
- combined_docs = list(set(bm25_docs + [doc.page_content for doc in faiss_docs]))[:5]
274
- context = "\n".join(combined_docs)
275
-
276
- prompt = f"Use the following PDF content to answer the question accurately and concisely. Avoid speculation and focus on the provided context:\n\n{context}\n\nQuestion: {query}\nAnswer:"
277
- response = qa_pipeline(prompt)[0]['generated_text']
278
- logger.info("Answer generated")
279
- return f"**Answer**:\n{response.strip()}\n\n**Source Context**:\n{context[:500]}..."
280
- except Exception as e:
281
- logger.error(f"Query error: {str(e)}")
282
- return f"Sorry, something went wrong: {str(e)}"
283
-
284
- # Streamlit UI
285
- try:
286
- st.set_page_config(page_title="Smart PDF Q&A", page_icon="📄", layout="wide")
287
- st.markdown("""
288
- <style>
289
- .main { max-width: 900px; margin: 0 auto; padding: 20px; }
290
- .sidebar { background-color: #f8f9fa; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
291
- .chat-container { border: 1px solid #ddd; border-radius: 12px; padding: 15px; height: 60vh; overflow-y: auto; margin-top: 20px; background-color: #fafafa; }
292
- .stChatMessage { border-radius: 12px; padding: 12px; margin: 8px; max-width: 75%; transition: all 0.3s ease; }
293
- .user { background-color: #e6f3ff; align-self: flex-end; border: 1px solid #b3d4fc; }
294
- .assistant { background-color: #f0f0f0; border: 1px solid #ccc; }
295
- .dark .user { background-color: #2a2a72; color: #fff; border: 1px solid #4a4ab2; }
296
- .dark .assistant { background-color: #2e2e2e; color: #fff; border: 1px solid #4a4a4a; }
297
- .stButton>button { background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 8px; font-weight: bold; }
298
- .stButton>button:hover { background-color: #45a049; transform: scale(1.05); }
299
- pre { background-color: #f8f8f8; padding: 12px; border-radius: 8px; overflow-x: auto; }
300
- .header { background: linear-gradient(90deg, #4CAF50, #81C784); color: white; padding: 15px; border-radius: 8px; text-align: center; box-shadow: 0 2px 4px rgba(0,0,0,0.2); }
301
- .progress-bar { background-color: #e0e0e0; border-radius: 5px; height: 10px; }
302
- .progress-fill { background-color: #4CAF50; height: 100%; border-radius: 5px; transition: width 0.5s ease; }
303
- </style>
304
- """, unsafe_allow_html=True)
305
-
306
- st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
307
- st.markdown("Upload a PDF to ask questions, get a ~200-word summary, or extract code with 'give me code'. Optimized for speed and accuracy!")
308
-
309
- # Initialize session state
310
- if "messages" not in st.session_state:
311
- st.session_state.messages = []
312
- if "text_vector_store" not in st.session_state:
313
- st.session_state.text_vector_store = None
314
- if "code_vector_store" not in st.session_state:
315
- st.session_state.code_vector_store = None
316
- if "pdf_text" not in st.session_state:
317
- st.session_state.pdf_text = ""
318
- if "code_text" not in st.session_state:
319
- st.session_state.code_text = ""
320
-
321
- # Sidebar with controls
322
- with st.sidebar:
323
- st.markdown('<div class="sidebar">', unsafe_allow_html=True)
324
- theme = st.radio("Theme", ["Light", "Dark"], index=0)
325
- dataset_name = st.selectbox("Select Dataset for Fine-Tuning", ["squad", "cnn_dailymail", "bigcode/the-stack"], index=0)
326
- if st.button("Fine-Tune Model"):
327
- progress_bar = st.progress(0)
328
- for i in range(100):
329
- time.sleep(0.008)
330
- progress_bar.progress(i + 1)
331
- dataset = load_and_prepare_dataset(dataset_name=dataset_name)
332
- if dataset:
333
- fine_tuned_pipeline = fine_tune_qa_model(dataset)
334
- if fine_tuned_pipeline:
335
- st.success("Model fine-tuned successfully!")
336
- else:
337
- st.error("Fine-tuning failed.")
338
- if st.button("Clear Chat"):
339
- st.session_state.messages = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  st.experimental_rerun()
341
- if st.button("Retry Summarization") and st.session_state.pdf_text:
342
- progress_bar = st.progress(0)
343
- with st.spinner("Retrying summarization..."):
344
- for i in range(100):
345
- time.sleep(0.008)
346
- progress_bar.progress(i + 1)
347
- summary = summarize_pdf(st.session_state.pdf_text)
348
- st.session_state.messages.append({"role": "assistant", "content": summary})
349
- st.markdown(summary, unsafe_allow_html=True)
350
- st.markdown('</div>', unsafe_allow_html=True)
351
-
352
- # PDF upload and processing
353
- uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"])
354
- col1, col2 = st.columns([1, 1])
355
- with col1:
356
- if st.button("Process PDF"):
357
- progress_bar = st.progress(0)
358
- with st.spinner("Processing PDF..."):
359
- for i in range(100):
360
- time.sleep(0.02)
361
- progress_bar.progress(i + 1)
362
- st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file)
363
- if st.session_state.text_vector_store or st.session_state.code_vector_store:
364
- st.success("PDF processed! Ask away or summarize.")
365
- st.session_state.messages = []
366
- else:
367
- st.error("Failed to process PDF.")
368
- with col2:
369
- if st.button("Summarize PDF") and st.session_state.pdf_text:
370
- progress_bar = st.progress(0)
371
- with st.spinner("Summarizing..."):
372
- for i in range(100):
373
- time.sleep(0.008)
374
- progress_bar.progress(i + 1)
375
- summary = summarize_pdf(st.session_state.pdf_text)
376
- st.session_state.messages.append({"role": "assistant", "content": summary})
377
- st.markdown(summary, unsafe_allow_html=True)
378
-
379
- # Chat interface
380
- st.markdown('<div class="chat-container">', unsafe_allow_html=True)
381
- if st.session_state.text_vector_store or st.session_state.code_vector_store:
382
- prompt = st.chat_input("Ask a question (e.g., 'Give me code' or 'What’s the main idea?'):")
383
- if prompt:
384
- st.session_state.messages.append({"role": "user", "content": prompt})
385
  with st.chat_message("user"):
386
- st.markdown(prompt)
 
387
  with st.chat_message("assistant"):
388
- progress_bar = st.progress(0)
389
- with st.spinner('<div class="spinner">⏳ Processing...</div>'):
390
- for i in range(100):
391
- time.sleep(0.004)
392
- progress_bar.progress(i + 1)
393
- answer = answer_question(st.session_state.text_vector_store, st.session_state.code_vector_store, prompt)
394
- st.markdown(answer, unsafe_allow_html=True)
395
- st.session_state.messages.append({"role": "assistant", "content": answer})
396
-
397
- # Display chat history
398
- for message in st.session_state.messages:
399
- with st.chat_message(message["role"]):
400
- st.markdown(message["content"], unsafe_allow_html=True)
401
-
402
- st.markdown('</div>', unsafe_allow_html=True)
403
-
404
- # Download chat history
405
- if st.session_state.messages:
406
- chat_text = "\n".join(f"{m['role'].capitalize()}: {m['content']}" for m in st.session_state.messages)
407
- st.download_button("Download Chat History", chat_text, "chat_history.txt")
408
-
409
- except Exception as e:
410
- logger.error(f"App initialization failed: {str(e)}")
411
- st.error(f"App failed to start: {str(e)}. Check Spaces logs or contact support.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import logging
3
  import os
4
  from io import BytesIO
 
 
 
 
 
 
 
 
 
 
5
  import re
6
  import time
7
+ from typing import List, Tuple, Optional
8
+
9
+ import pdfplumber
10
 
11
+ # Optional OCR (guarded)
12
+ try:
13
+ import pytesseract
14
+ OCR_AVAILABLE = True
15
+ except Exception:
16
+ OCR_AVAILABLE = False
17
+
18
+ from rank_bm25 import BM25Okapi
19
+
20
+ # Embeddings + Vector store
21
+ from sentence_transformers import SentenceTransformer
22
+ import numpy as np
23
+
24
+ try:
25
+ import faiss # direct FAISS for speed and control
26
+ FAISS_OK = True
27
+ except Exception:
28
+ FAISS_OK = False
29
+
30
+ # Lightweight HF pipelines
31
+ from transformers import pipeline
32
+
33
+ # ----------------------------
34
+ # App & Logging Setup
35
+ # ----------------------------
36
+ st.set_page_config(page_title="Smart PDF Chat & Summarizer", page_icon="📄", layout="wide")
37
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
38
+ logger = logging.getLogger("smart_pdf")
39
+
40
+ # ----------------------------
41
+ # Caching: models & utilities
42
+ # ----------------------------
43
+ @st.cache_resource(show_spinner=False)
44
+ def get_embedder(name: str = "sentence-transformers/all-MiniLM-L6-v2"):
45
+ return SentenceTransformer(name)
46
+
47
+ @st.cache_resource(show_spinner=False)
48
+ def get_qa_pipeline():
49
+ # Small, fast instruction model
50
+ return pipeline(
51
+ "text2text-generation",
52
+ model="google/flan-t5-small",
53
+ device=-1,
54
+ max_length=220
55
+ )
56
+
57
+ @st.cache_resource(show_spinner=False)
58
+ def get_summarizer():
59
+ # DistilBART is much faster than bart-large-cnn
60
+ return pipeline(
61
+ "summarization",
62
+ model="sshleifer/distilbart-cnn-12-6",
63
+ device=-1,
64
+ max_length=220,
65
+ min_length=80,
66
+ do_sample=False,
67
+ )
68
+
69
+ # ----------------------------
70
+ # PDF processing
71
+ # ----------------------------
72
+
73
+ def _looks_like_code(line: str) -> bool:
74
+ if len(line.strip()) == 0:
75
+ return False
76
+ # Heuristics for code-y lines
77
+ code_tokens = [
78
+ r"\b(def|class|import|from|return|if|elif|else|for|while|try|except|finally|with)\b",
79
+ r"[{}`;<>]|::|=>|#|//|/\*|\*/",
80
+ r"\(|\)|\[|\]|\{|\}",
81
+ ]
82
+ matches = sum(bool(re.search(p, line)) for p in code_tokens)
83
+ indent = len(line) - len(line.lstrip())
84
+ return matches >= 1 or indent >= 4
85
+
86
+
87
+ def extract_text_and_code_from_pdf(file_bytes: bytes, ocr_fallback: bool = True, max_pages: int = 50) -> Tuple[str, List[str]]:
88
+ """Return (plain_text, code_blocks[]) from a PDF with simple OCR fallback."""
89
+ text_parts: List[str] = []
90
+ code_lines: List[str] = []
91
+
92
+ with pdfplumber.open(BytesIO(file_bytes)) as pdf:
93
+ pages = pdf.pages[:max_pages]
94
+ for page in pages:
95
+ # 1) Try text extraction
96
+ extracted = page.extract_text(x_tolerance=1.5, y_tolerance=1.0) or ""
97
+
98
+ # 2) OCR fallback if page empty and OCR available
99
+ if not extracted.strip() and ocr_fallback and OCR_AVAILABLE:
100
+ try:
101
+ img = page.to_image(resolution=180).original
102
+ extracted = pytesseract.image_to_string(img, config='--psm 6') or ""
103
+ except Exception as e:
104
+ logger.warning(f"OCR failed on a page: {e}")
105
+
106
+ # 3) Clean and collect
107
+ if extracted:
108
+ # Remove common headers/footers by simple rules
109
+ lines = [ln for ln in extracted.splitlines() if not re.match(r"^(Page\s*\d+|Copyright.*)$", ln, flags=re.I)]
110
+ text_parts.append("\n".join(lines))
111
+
112
+ # Code detection: fenced blocks first
113
+ fenced = re.findall(r"```[\w-]*\n([\s\S]*?)```", extracted, flags=re.M)
114
+ for blk in fenced:
115
+ blk = blk.strip()
116
+ if blk:
117
+ code_lines.append(blk)
118
+
119
+ # Otherwise, line-wise heuristic
120
+ for ln in lines:
121
+ if _looks_like_code(ln):
122
+ code_lines.append(ln)
123
+
124
+ # 4) Tables -> pipe-separated rows
125
+ try:
126
+ tables = page.extract_tables() or []
127
+ for tb in tables:
128
+ for row in tb:
129
+ if row and any(str(c).strip() for c in row):
130
+ text_parts.append(" | ".join(str(c).strip() for c in row))
131
+ except Exception:
132
+ pass
133
+
134
+ full_text = "\n\n".join(tp for tp in text_parts if tp.strip())
135
+
136
+ # Merge adjacent code lines into blocks
137
+ code_blocks: List[str] = []
138
+ if code_lines:
139
+ current: List[str] = []
140
+ for ln in code_lines:
141
+ if ln.strip():
142
+ current.append(ln)
143
+ else:
144
+ if current:
145
+ code_blocks.append("\n".join(current))
146
+ current = []
147
+ if current:
148
+ code_blocks.append("\n".join(current))
149
+
150
+ # Deduplicate & trim giant blocks
151
+ seen = set()
152
+ unique_blocks = []
153
+ for blk in code_blocks:
154
+ key = blk.strip()
155
+ if key and key not in seen:
156
+ seen.add(key)
157
+ # cap extreme long blocks for UI; still allow download of full
158
+ unique_blocks.append(blk[:8000])
159
+
160
+ return full_text, unique_blocks
161
+
162
+ # ----------------------------
163
+ # Chunking & Indexing
164
+ # ----------------------------
165
+
166
+ def chunk_text(text: str, chunk_size: int = 700, chunk_overlap: int = 120) -> List[str]:
167
+ text = re.sub(r"\n{3,}", "\n\n", text).strip()
168
+ paras = [p.strip() for p in re.split(r"\n\s*\n", text) if p.strip()]
169
+ chunks: List[str] = []
170
+ buf: str = ""
171
+ for para in paras:
172
+ if not buf:
173
+ buf = para
174
+ elif len(buf) + len(para) + 1 <= chunk_size:
175
+ buf += "\n" + para
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  else:
177
+ chunks.append(buf)
178
+ # overlap
179
+ overlap = buf[-chunk_overlap:] if chunk_overlap > 0 else ""
180
+ buf = (overlap + "\n" + para).strip()
181
+ if buf:
182
+ chunks.append(buf)
183
+ return chunks
184
+
185
+ @st.cache_resource(show_spinner=False)
186
+ def build_indexes(chunks: List[str]):
187
+ embedder = get_embedder()
188
+ matrix = embedder.encode(chunks, show_progress_bar=False, batch_size=64, normalize_embeddings=True)
189
+ matrix = np.asarray(matrix).astype('float32')
190
+
191
+ bm25 = BM25Okapi([c.split() for c in chunks])
192
+
193
+ if FAISS_OK:
194
+ index = faiss.IndexFlatIP(matrix.shape[1])
195
+ index.add(matrix)
196
+ return {
197
+ "chunks": chunks,
198
+ "embeddings": matrix,
199
+ "faiss": index,
200
+ "bm25": bm25,
201
+ }
202
+ else:
203
+ # Fallback: cosine via numpy (slower but OK for small docs)
204
+ return {
205
+ "chunks": chunks,
206
+ "embeddings": matrix,
207
+ "faiss": None,
208
+ "bm25": bm25,
209
+ }
210
+
211
+ # ----------------------------
212
+ # Retrieval + QA
213
+ # ----------------------------
214
+
215
+ def retrieve(topk: int, query: str, idx):
216
+ chunks = idx["chunks"]
217
+ embeddings = idx["embeddings"]
218
+ bm25 = idx["bm25"]
219
+
220
+ # BM25
221
+ bm25_docs = bm25.get_top_n(query.split(), chunks, n=min(topk, len(chunks)))
222
+
223
+ # FAISS / cosine
224
+ embedder = get_embedder()
225
+ qv = embedder.encode([query], normalize_embeddings=True)[0].astype('float32')
226
+
227
+ if idx["faiss"] is not None:
228
+ D, I = idx["faiss"].search(np.array([qv]), min(topk, len(chunks)))
229
+ faiss_docs = [chunks[i] for i in I[0]]
230
+ else:
231
+ # cosine with numpy
232
+ sims = embeddings @ qv
233
+ order = np.argsort(-sims)[:topk]
234
+ faiss_docs = [chunks[i] for i in order]
235
+
236
+ # Merge uniques with preference to BM25 then FAISS
237
+ merged: List[str] = []
238
+ seen = set()
239
+ for c in bm25_docs + faiss_docs:
240
+ if c not in seen:
241
+ merged.append(c)
242
+ seen.add(c)
243
+ if len(merged) >= topk:
244
+ break
245
+ return merged
246
+
247
+
248
+ def rag_answer(query: str, idx, max_ctx_chars: int = 3000) -> str:
249
+ ctx_chunks = retrieve(6, query, idx)
250
+ # Concatenate up to a char budget
251
+ ctx = "\n\n".join(ctx_chunks)
252
+ if len(ctx) > max_ctx_chars:
253
+ ctx = ctx[:max_ctx_chars]
254
+ qa = get_qa_pipeline()
255
+ prompt = (
256
+ "Answer the question using ONLY the provided context. "
257
+ "If the answer is not in the context, say 'I couldn't find that in the PDF.'\n\n"
258
+ f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:"
259
+ )
260
+ out = qa(prompt)[0]["generated_text"].strip()
261
+ return out
262
+
263
+
264
+ def summarize_text(full_text: str) -> str:
265
+ summarizer = get_summarizer()
266
+ # Summarize in parts for long docs
267
+ chunks = chunk_text(full_text, chunk_size=1200, chunk_overlap=150)
268
+ partials = []
269
+ for ch in chunks[:8]: # cap to keep it snappy on CPU
270
+ partials.append(summarizer(ch)[0]["summary_text"].strip())
271
+ # Final stitch summary
272
+ stitched = " ".join(partials)
273
+ if len(stitched) > 2000:
274
+ stitched = summarizer(stitched[:3000])[0]["summary_text"].strip()
275
+ return stitched
276
+
277
+ # ----------------------------
278
+ # UI
279
+ # ----------------------------
280
+
281
+ st.markdown(
282
+ """
283
+ <style>
284
+ .app-header {background: linear-gradient(90deg,#10b981,#22c55e); color: white; padding: 16px; border-radius: 14px; text-align:center; box-shadow: 0 6px 20px rgba(16,185,129,.25)}
285
+ .card {border:1px solid #e5e7eb; border-radius: 14px; padding: 16px; background: #fff}
286
+ .muted {color:#6b7280}
287
+ .kbd {background:#f3f4f6; border:1px solid #e5e7eb; border-radius:6px; padding:2px 6px; font-family: ui-monospace, SFMono-Regular, Menlo, Monaco}
288
+ </style>
289
+ """,
290
+ unsafe_allow_html=True,
291
+ )
292
+
293
+ st.markdown('<div class="app-header"><h1>📄 Smart PDF Chat & Summarizer</h1><p class="muted">Fast answers, focused summaries, and automatic code extraction</p></div>', unsafe_allow_html=True)
294
+
295
+ # Session state
296
+ if "idx" not in st.session_state:
297
+ st.session_state.idx = None
298
+ if "pdf_text" not in st.session_state:
299
+ st.session_state.pdf_text = ""
300
+ if "code_blocks" not in st.session_state:
301
+ st.session_state.code_blocks = []
302
+
303
+ # Sidebar
304
+ with st.sidebar:
305
+ st.subheader("Upload & Options")
306
+ file = st.file_uploader("Upload a PDF", type=["pdf"], help="Max ~50 pages for speed. Uses OCR fallback if needed.")
307
+ max_pages = st.slider("Max pages to parse", 5, 100, 50, help="Lower = faster")
308
+ do_ocr = st.toggle("Enable OCR fallback (slower)", value=False)
309
+ chunk_size = st.slider("Chunk size", 300, 1400, 700, step=50)
310
+ overlap = st.slider("Chunk overlap", 0, 300, 120, step=10)
311
+
312
+ colA, colB = st.columns(2)
313
+ with colA:
314
+ if st.button("⚙️ Build Index", use_container_width=True, type="primary"):
315
+ if not file:
316
+ st.warning("Please upload a PDF first.")
317
+ else:
318
+ with st.spinner("Reading & indexing PDF…"):
319
+ data = file.read()
320
+ text, code_blocks = extract_text_and_code_from_pdf(data, ocr_fallback=do_ocr, max_pages=max_pages)
321
+ st.session_state.pdf_text = text
322
+ st.session_state.code_blocks = code_blocks
323
+
324
+ if not text.strip():
325
+ st.error("Couldn't extract any text from the PDF.")
326
+ else:
327
+ chunks = chunk_text(text, chunk_size=chunk_size, chunk_overlap=overlap)
328
+ st.session_state.idx = build_indexes(chunks)
329
+ st.success(f"Indexed {len(chunks)} chunks. Ready!")
330
+ with colB:
331
+ if st.button("🧹 Clear", use_container_width=True):
332
+ st.session_state.idx = None
333
+ st.session_state.pdf_text = ""
334
+ st.session_state.code_blocks = []
335
  st.experimental_rerun()
336
+
337
+ if st.session_state.code_blocks:
338
+ st.caption("Detected code blocks. You can copy or download from the Summary tab.")
339
+
340
+ # Main area — two sections exactly: Chat & Summary
341
+ chat_tab, summary_tab = st.tabs(["💬 Chat", "📝 Summary (with Code)"])
342
+
343
+ with chat_tab:
344
+ st.markdown("<div class='card'>Ask questions about your PDF. Retrieval-augmented answers use only the document context.</div>", unsafe_allow_html=True)
345
+
346
+ if st.session_state.idx is None:
347
+ st.info("Upload a PDF and click **Build Index** in the sidebar.")
348
+ else:
349
+ user_q = st.chat_input("Ask anything about the PDF…")
350
+ if "chat" not in st.session_state:
351
+ st.session_state.chat = []
352
+
353
+ # Render history
354
+ for role, content in st.session_state.get("chat", []):
355
+ with st.chat_message(role):
356
+ st.markdown(content)
357
+
358
+ if user_q:
359
+ st.session_state.chat.append(("user", user_q))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  with st.chat_message("user"):
361
+ st.markdown(user_q)
362
+
363
  with st.chat_message("assistant"):
364
+ with st.spinner("Thinking…"):
365
+ try:
366
+ ans = rag_answer(user_q, st.session_state.idx)
367
+ except Exception as e:
368
+ ans = f"Sorry, I hit an error while answering: {e}"
369
+ st.markdown(ans)
370
+ st.session_state.chat.append(("assistant", ans))
371
+
372
+ with summary_tab:
373
+ st.markdown("<div class='card'>One-click concise summary of the entire document, plus extracted programming code if detected.</div>", unsafe_allow_html=True)
374
+
375
+ col1, col2 = st.columns([1,1])
376
+ with col1:
377
+ if st.button("🔎 Summarize PDF", type="primary", use_container_width=True):
378
+ if not st.session_state.pdf_text.strip():
379
+ st.warning("No parsed text yet. Upload & Build Index first.")
380
+ else:
381
+ with st.spinner("Summarizing…"):
382
+ try:
383
+ sm = summarize_text(st.session_state.pdf_text)
384
+ st.session_state.summary = sm
385
+ st.success("Summary generated.")
386
+ except Exception as e:
387
+ st.error(f"Summarization failed: {e}")
388
+ with col2:
389
+ if st.session_state.pdf_text:
390
+ st.download_button(
391
+ "⬇️ Download raw extracted text",
392
+ st.session_state.pdf_text,
393
+ file_name="extracted_text.txt",
394
+ use_container_width=True,
395
+ )
396
+
397
+ if st.session_state.get("summary"):
398
+ st.subheader("Summary")
399
+ st.write(st.session_state.summary)
400
+
401
+ st.divider()
402
+
403
+ st.subheader("Extracted Code")
404
+ if st.session_state.code_blocks:
405
+ for i, blk in enumerate(st.session_state.code_blocks, start=1):
406
+ with st.expander(f"Code block #{i}"):
407
+ st.code(blk, language=None)
408
+ st.download_button(
409
+ f"Download code #{i}",
410
+ blk,
411
+ file_name=f"code_block_{i}.txt",
412
+ key=f"dl_{i}",
413
+ )
414
+ all_code = "\n\n\n".join(st.session_state.code_blocks)
415
+ st.download_button("⬇️ Download all code", all_code, file_name="all_code.txt")
416
+ else:
417
+ st.caption("No code-like content detected yet.")
418
+
419
+ # Footer tips
420
+ st.markdown(
421
+ """
422
+ <div class="muted" style="margin-top:24px">⚡ Tips for faster responses: use smaller PDFs, lower the "Max pages" and "Chunk size" in the sidebar, and keep OCR off unless needed.</div>
423
+ """,
424
+ unsafe_allow_html=True,
425
+ )