File size: 18,255 Bytes
f513b53
 
7678f2a
f513b53
c06d586
afc3005
 
7c6674a
f513b53
7678f2a
7c6674a
 
afc3005
 
c06d586
afc3005
f513b53
7c6674a
7678f2a
f513b53
 
7c6674a
c06d586
7678f2a
7c6674a
7678f2a
c06d586
7678f2a
7c6674a
7678f2a
 
f513b53
c06d586
7678f2a
7c6674a
7678f2a
7c6674a
 
 
 
 
afc3005
7678f2a
7c6674a
7678f2a
 
 
c06d586
 
7c6674a
c06d586
 
 
7c6674a
c06d586
 
 
7c6674a
afc3005
7c6674a
 
 
afc3005
 
7c6674a
 
 
 
 
 
afc3005
7c6674a
 
 
 
 
 
 
 
 
 
afc3005
7c6674a
 
 
 
 
 
 
 
 
afc3005
7c6674a
 
 
afc3005
7c6674a
 
 
 
afc3005
 
7c6674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afc3005
7c6674a
 
 
afc3005
7c6674a
 
 
 
 
574210c
afc3005
7678f2a
7c6674a
7678f2a
7c6674a
 
 
c06d586
 
afc3005
 
 
 
 
 
c06d586
 
 
 
 
7c6674a
 
c06d586
 
 
 
 
 
7c6674a
 
 
c06d586
7c6674a
 
 
 
afc3005
 
 
7c6674a
7678f2a
 
c06d586
7c6674a
 
afc3005
7c6674a
 
 
afc3005
7c6674a
 
 
 
 
 
afc3005
c06d586
7678f2a
7c6674a
7678f2a
c06d586
 
afc3005
7c6674a
 
 
 
 
 
 
afc3005
7c6674a
 
 
 
afc3005
7c6674a
 
 
 
 
afc3005
 
 
 
 
 
7c6674a
 
 
574210c
afc3005
7c6674a
 
 
 
 
 
574210c
7c6674a
 
 
 
 
afc3005
 
 
 
7c6674a
 
 
 
 
afc3005
 
 
 
 
 
 
 
 
7c6674a
 
afc3005
7c6674a
 
 
 
 
 
 
 
 
 
afc3005
 
 
 
 
 
 
 
 
 
 
 
 
7c6674a
 
 
 
afc3005
7c6674a
 
 
 
 
 
 
 
 
 
 
 
 
afc3005
7c6674a
 
 
 
 
afc3005
 
 
 
 
 
 
 
 
 
 
 
 
 
7c6674a
 
 
 
 
 
 
afc3005
7c6674a
afc3005
 
 
7c6674a
 
 
 
 
 
 
 
afc3005
7c6674a
afc3005
 
 
7c6674a
 
 
 
 
 
 
 
 
 
 
 
 
afc3005
 
 
 
 
7c6674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
import streamlit as st
import logging
import os
from io import BytesIO
import pdfplumber
from PIL import Image
import pytesseract
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from rouge_score import rouge_scorer
import re
import time

# Setup logging for Spaces
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Lazy load models
@st.cache_resource(ttl=1800)
def load_embeddings_model():
    logger.info("Loading embeddings model")
    try:
        return SentenceTransformer("all-MiniLM-L12-v2")
    except Exception as e:
        logger.error(f"Embeddings load error: {str(e)}")
        st.error(f"Embedding model error: {str(e)}")
        return None

@st.cache_resource(ttl=1800)
def load_qa_pipeline():
    logger.info("Loading QA pipeline")
    try:
        dataset = load_and_prepare_dataset()
        if dataset:
            fine_tuned_pipeline = fine_tune_qa_model(dataset)
            if fine_tuned_pipeline:
                return fine_tuned_pipeline
        return pipeline("text2text-generation", model="google/flan-t5-base", max_length=300)
    except Exception as e:
        logger.error(f"QA model load error: {str(e)}")
        st.error(f"QA model error: {str(e)}")
        return None

@st.cache_resource(ttl=1800)
def load_summary_pipeline():
    logger.info("Loading summary pipeline")
    try:
        return pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", max_length=150)
    except Exception as e:
        logger.error(f"Summary model load error: {str(e)}")
        st.error(f"Summary model error: {str(e)}")
        return None

# Load and prepare dataset (e.g., SQuAD)
@st.cache_data(ttl=3600)
def load_and_prepare_dataset(dataset_name="squad", max_samples=1000):
    logger.info(f"Loading dataset: {dataset_name}")
    try:
        dataset = load_dataset(dataset_name, split="train[:80%]")
        dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
        
        def preprocess(examples):
            inputs = [f"question: {q} context: {c}" for q, c in zip(examples['question'], examples['context'])]
            targets = examples['answers']['text']
            return {'input_text': inputs, 'target_text': [t[0] if t else "" for t in targets]}
        
        dataset = dataset.map(preprocess, batched=True, remove_columns=dataset.column_names)
        return dataset
    except Exception as e:
        logger.error(f"Dataset load error: {str(e)}")
        return None

# Fine-tune QA model
@st.cache_resource(ttl=3600)
def fine_tune_qa_model(dataset):
    logger.info("Starting fine-tuning")
    try:
        model_name = "google/flan-t5-base"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        
        def tokenize_function(examples):
            model_inputs = tokenizer(examples['input_text'], max_length=512, truncation=True, padding="max_length")
            labels = tokenizer(examples['target_text'], max_length=128, truncation=True, padding="max_length")
            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
        
        tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=['input_text', 'target_text'])
        
        training_args = TrainingArguments(
            output_dir="./fine_tuned_model",
            num_train_epochs=2,
            per_device_train_batch_size=4,
            save_steps=500,
            logging_steps=100,
            evaluation_strategy="no",
            learning_rate=3e-5,
            fp16=False,  # Set True if GPU available
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset,
        )
        trainer.train()
        
        model.save_pretrained("./fine_tuned_model")
        tokenizer.save_pretrained("./fine_tuned_model")
        logger.info("Fine-tuning complete")
        return pipeline("text2text-generation", model="./fine_tuned_model", tokenizer="./fine_tuned_model", max_length=300)
    except Exception as e:
        logger.error(f"Fine-tuning error: {str(e)}")
        return None

# Augment vector store with dataset
def augment_vector_store(vector_store, dataset_name="squad", max_samples=500):
    logger.info(f"Augmenting vector store with dataset: {dataset_name}")
    try:
        dataset = load_dataset(dataset_name, split="train").select(range(min(max_samples, len(dataset))))
        chunks = [f"Context: {c}\nAnswer: {a['text'][0]}" for c, a in zip(dataset['context'], dataset['answers'])]
        embeddings_model = load_embeddings_model()
        if embeddings_model and vector_store:
            embeddings = embeddings_model.encode(chunks, batch_size=32, show_progress_bar=False)
            vector_store.add_embeddings(zip(chunks, embeddings))
        return vector_store
    except Exception as e:
        logger.error(f"Vector store augmentation error: {str(e)}")
        return vector_store

# Process PDF with enhanced extraction and OCR fallback
def process_pdf(uploaded_file):
    logger.info("Processing PDF with enhanced extraction")
    try:
        text = ""
        code_blocks = []
        with pdfplumber.open(BytesIO(uploaded_file.getvalue())) as pdf:
            for page in pdf.pages[:20]:
                extracted = page.extract_text(layout=False)
                if not extracted:  # OCR fallback for scanned PDFs
                    try:
                        img = page.to_image(resolution=150).original
                        extracted = pytesseract.image_to_string(img, config='--psm 6')
                    except Exception as ocr_e:
                        logger.warning(f"OCR failed: {str(ocr_e)}")
                if extracted:
                    text += extracted + "\n"
                for char in page.chars:
                    if 'fontname' in char and 'mono' in char['fontname'].lower():
                        code_blocks.append(char['text'])
                code_text = page.extract_text()
                code_matches = re.finditer(r'(^\s{2,}.*?(?:\n\s{2,}.*?)*)', code_text, re.MULTILINE)
                for match in code_matches:
                    code_blocks.append(match.group().strip())
                tables = page.extract_tables()
                if tables:
                    for table in tables:
                        text += "\n".join([" | ".join(map(str, row)) for row in table if row]) + "\n"
                for obj in page.extract_words():
                    if obj.get('size', 0) > 12:
                        text += f"\n{obj['text']}\n"

        code_text = "\n".join(code_blocks).strip()
        if not text:
            raise ValueError("No text extracted from PDF")
        
        text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=400, chunk_overlap=80, keep_separator=True)
        text_chunks = text_splitter.split_text(text)[:80]
        code_chunks = text_splitter.split_text(code_text)[:40] if code_text else []
        
        embeddings_model = load_embeddings_model()
        if not embeddings_model:
            return None, None, text, code_text
        
        text_vector_store = FAISS.from_embeddings(
            zip(text_chunks, [embeddings_model.encode(chunk, show_progress_bar=False) for chunk in text_chunks]),
            embeddings_model.encode
        ) if text_chunks else None
        code_vector_store = FAISS.from_embeddings(
            zip(code_chunks, [embeddings_model.encode(chunk, show_progress_bar=False) for chunk in code_chunks]),
            embeddings_model.encode
        ) if code_chunks else None
        
        if text_vector_store:
            text_vector_store = augment_vector_store(text_vector_store)
        
        logger.info("PDF processed successfully")
        return text_vector_store, code_vector_store, text, code_text
    except Exception as e:
        logger.error(f"PDF processing error: {str(e)}")
        st.error(f"PDF error: {str(e)}")
        return None, None, "", ""

# Summarize PDF with ROUGE metrics
def summarize_pdf(text):
    logger.info("Generating summary")
    try:
        summary_pipeline = load_summary_pipeline()
        if not summary_pipeline:
            return "Summary model unavailable."
        
        text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=400, chunk_overlap=50)
        chunks = text_splitter.split_text(text)[:2]
        summaries = []
        
        for chunk in chunks:
            summary = summary_pipeline(chunk[:400], max_length=100, min_length=30, do_sample=False)[0]['summary_text']
            summaries.append(summary.strip())
        
        combined_summary = " ".join(summaries)
        if len(combined_summary.split()) > 150:
            combined_summary = " ".join(combined_summary.split()[:150])
        
        scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
        scores = scorer.score(text[:400], combined_summary)
        logger.info(f"ROUGE scores: {scores}")
        
        return f"**Summary**:\n{combined_summary}\n\n**ROUGE-1**: {scores['rouge1'].fmeasure:.2f}"
    except Exception as e:
        logger.error(f"Summary error: {str(e)}")
        return f"Oops, something went wrong summarizing: {str(e)}"

# Answer question with hybrid search
def answer_question(text_vector_store, code_vector_store, query):
    logger.info(f"Processing query: {query}")
    try:
        if not text_vector_store and not code_vector_store:
            return "Please upload a PDF first!"
        
        qa_pipeline = load_qa_pipeline()
        if not qa_pipeline:
            return "Sorry, the QA model is unavailable right now."
        
        is_code_query = any(keyword in query.lower() for keyword in ["code", "script", "function", "programming", "give me code", "show code"])
        if is_code_query and code_vector_store:
            docs = code_vector_store.similarity_search(query, k=3)
            code = "\n".join(doc.page_content for doc in docs)
            explanation = qa_pipeline(f"Explain this code: {code[:500]}")[0]['generated_text']
            return f"**Code**:\n```python\n{code}\n```\n**Explanation**:\n{explanation}"
        
        vector_store = text_vector_store
        if not vector_store:
            return "No relevant content found for your query."
        
        # Hybrid search: FAISS + BM25
        text_chunks = [doc.page_content for doc in vector_store.similarity_search(query, k=10)]
        bm25 = BM25Okapi([chunk.split() for chunk in text_chunks])
        bm25_docs = bm25.get_top_n(query.split(), text_chunks, n=5)
        faiss_docs = vector_store.similarity_search(query, k=5)
        combined_docs = list(set(bm25_docs + [doc.page_content for doc in faiss_docs]))[:5]
        context = "\n".join(combined_docs)
        
        prompt = f"Use the following PDF content to answer the question accurately and concisely. Avoid speculation and focus on the provided context:\n\n{context}\n\nQuestion: {query}\nAnswer:"
        response = qa_pipeline(prompt)[0]['generated_text']
        logger.info("Answer generated")
        return f"**Answer**:\n{response.strip()}\n\n**Source Context**:\n{context[:500]}..."
    except Exception as e:
        logger.error(f"Query error: {str(e)}")
        return f"Sorry, something went wrong: {str(e)}"

# Streamlit UI
try:
    st.set_page_config(page_title="Smart PDF Q&A", page_icon="📄", layout="wide")
    st.markdown("""
        <style>
        .main { max-width: 900px; margin: 0 auto; padding: 20px; }
        .sidebar { background-color: #f8f9fa; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
        .chat-container { border: 1px solid #ddd; border-radius: 12px; padding: 15px; height: 60vh; overflow-y: auto; margin-top: 20px; background-color: #fafafa; }
        .stChatMessage { border-radius: 12px; padding: 12px; margin: 8px; max-width: 75%; transition: all 0.3s ease; }
        .user { background-color: #e6f3ff; align-self: flex-end; border: 1px solid #b3d4fc; }
        .assistant { background-color: #f0f0f0; border: 1px solid #ccc; }
        .dark .user { background-color: #2a2a72; color: #fff; border: 1px solid #4a4ab2; }
        .dark .assistant { background-color: #2e2e2e; color: #fff; border: 1px solid #4a4a4a; }
        .stButton>button { background-color: #4CAF50; color: white; border: none; padding: 10px 20px; border-radius: 8px; font-weight: bold; }
        .stButton>button:hover { background-color: #45a049; transform: scale(1.05); }
        pre { background-color: #f8f8f8; padding: 12px; border-radius: 8px; overflow-x: auto; }
        .header { background: linear-gradient(90deg, #4CAF50, #81C784); color: white; padding: 15px; border-radius: 8px; text-align: center; box-shadow: 0 2px 4px rgba(0,0,0,0.2); }
        .progress-bar { background-color: #e0e0e0; border-radius: 5px; height: 10px; }
        .progress-fill { background-color: #4CAF50; height: 100%; border-radius: 5px; transition: width 0.5s ease; }
        </style>
    """, unsafe_allow_html=True)

    st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
    st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'. Fast, accurate, and smooth!")

    # Initialize session state
    if "messages" not in st.session_state:
        st.session_state.messages = []
    if "text_vector_store" not in st.session_state:
        st.session_state.text_vector_store = None
    if "code_vector_store" not in st.session_state:
        st.session_state.code_vector_store = None
    if "pdf_text" not in st.session_state:
        st.session_state.pdf_text = ""
    if "code_text" not in st.session_state:
        st.session_state.code_text = ""

    # Sidebar with controls
    with st.sidebar:
        st.markdown('<div class="sidebar">', unsafe_allow_html=True)
        theme = st.radio("Theme", ["Light", "Dark"], index=0)
        dataset_name = st.selectbox("Select Dataset for Fine-Tuning", ["squad", "cnn_dailymail", "bigcode/the-stack"], index=0)
        if st.button("Fine-Tune Model"):
            progress_bar = st.progress(0)
            for i in range(100):
                time.sleep(0.02)
                progress_bar.progress(i + 1)
            dataset = load_and_prepare_dataset(dataset_name=dataset_name)
            if dataset:
                fine_tuned_pipeline = fine_tune_qa_model(dataset)
                if fine_tuned_pipeline:
                    st.success("Model fine-tuned successfully!")
                else:
                    st.error("Fine-tuning failed.")
        if st.button("Clear Chat"):
            st.session_state.messages = []
            st.experimental_rerun()
        st.markdown('</div>', unsafe_allow_html=True)

    # PDF upload and processing
    uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"])
    col1, col2 = st.columns([1, 1])
    with col1:
        if st.button("Process PDF"):
            progress_bar = st.progress(0)
            with st.spinner("Processing PDF..."):
                for i in range(100):
                    time.sleep(0.05)
                    progress_bar.progress(i + 1)
                st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file)
                if st.session_state.text_vector_store or st.session_state.code_vector_store:
                    st.success("PDF processed! Ask away or summarize.")
                    st.session_state.messages = []
                else:
                    st.error("Failed to process PDF.")
    with col2:
        if st.button("Summarize PDF") and st.session_state.pdf_text:
            progress_bar = st.progress(0)
            with st.spinner("Summarizing..."):
                for i in range(100):
                    time.sleep(0.02)
                    progress_bar.progress(i + 1)
                summary = summarize_pdf(st.session_state.pdf_text)
                st.session_state.messages.append({"role": "assistant", "content": summary})
                st.markdown(summary, unsafe_allow_html=True)

    # Chat interface
    st.markdown('<div class="chat-container">', unsafe_allow_html=True)
    if st.session_state.text_vector_store or st.session_state.code_vector_store:
        prompt = st.chat_input("Ask a question (e.g., 'Give me code' or 'What’s the main idea?'):")
        if prompt:
            st.session_state.messages.append({"role": "user", "content": prompt})
            with st.chat_message("user"):
                st.markdown(prompt)
            with st.chat_message("assistant"):
                progress_bar = st.progress(0)
                with st.spinner('<div class="spinner">⏳ Processing...</div>'):
                    for i in range(100):
                        time.sleep(0.01)
                        progress_bar.progress(i + 1)
                    answer = answer_question(st.session_state.text_vector_store, st.session_state.code_vector_store, prompt)
                st.markdown(answer, unsafe_allow_html=True)
            st.session_state.messages.append({"role": "assistant", "content": answer})

    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"], unsafe_allow_html=True)

    st.markdown('</div>', unsafe_allow_html=True)

    # Download chat history
    if st.session_state.messages:
        chat_text = "\n".join(f"{m['role'].capitalize()}: {m['content']}" for m in st.session_state.messages)
        st.download_button("Download Chat History", chat_text, "chat_history.txt")

except Exception as e:
    logger.error(f"App initialization failed: {str(e)}")
    st.error(f"App failed to start: {str(e)}. Check Spaces logs or contact support.")