File size: 15,606 Bytes
f513b53
 
7678f2a
f513b53
c06d586
7c6674a
f513b53
7678f2a
7c6674a
 
c06d586
f513b53
7c6674a
7678f2a
f513b53
 
7c6674a
c06d586
7678f2a
7c6674a
7678f2a
c06d586
7678f2a
7c6674a
7678f2a
 
f513b53
c06d586
7678f2a
7c6674a
7678f2a
7c6674a
 
 
 
 
c06d586
7678f2a
7c6674a
7678f2a
 
 
c06d586
 
7c6674a
c06d586
 
 
7c6674a
c06d586
 
 
7c6674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574210c
7c6674a
7678f2a
7c6674a
7678f2a
7c6674a
 
 
c06d586
 
 
 
 
 
 
7c6674a
 
c06d586
 
 
 
 
 
7c6674a
 
 
c06d586
7c6674a
 
 
 
 
c06d586
 
7c6674a
7678f2a
 
c06d586
7c6674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c06d586
7678f2a
7c6674a
7678f2a
c06d586
 
7c6674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574210c
7c6674a
 
 
 
 
 
 
574210c
7c6674a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
import streamlit as st
import logging
import os
from io import BytesIO
import pdfplumber
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
from datasets import load_dataset
import re

# Setup logging for Spaces
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Lazy load models
@st.cache_resource(ttl=1800)
def load_embeddings_model():
    logger.info("Loading embeddings model")
    try:
        return SentenceTransformer("all-MiniLM-L12-v2")
    except Exception as e:
        logger.error(f"Embeddings load error: {str(e)}")
        st.error(f"Embedding model error: {str(e)}")
        return None

@st.cache_resource(ttl=1800)
def load_qa_pipeline():
    logger.info("Loading QA pipeline")
    try:
        dataset = load_and_prepare_dataset()
        if dataset:
            fine_tuned_pipeline = fine_tune_qa_model(dataset)
            if fine_tuned_pipeline:
                return fine_tuned_pipeline
        return pipeline("text2text-generation", model="google/flan-t5-small", max_length=300)
    except Exception as e:
        logger.error(f"QA model load error: {str(e)}")
        st.error(f"QA model error: {str(e)}")
        return None

@st.cache_resource(ttl=1800)
def load_summary_pipeline():
    logger.info("Loading summary pipeline")
    try:
        return pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", max_length=150)
    except Exception as e:
        logger.error(f"Summary model load error: {str(e)}")
        st.error(f"Summary model error: {str(e)}")
        return None

# Load and prepare dataset (e.g., SQuAD)
@st.cache_resource(ttl=3600)
def load_and_prepare_dataset(dataset_name="squad", max_samples=1000):
    logger.info(f"Loading dataset: {dataset_name}")
    try:
        dataset = load_dataset(dataset_name, split="train")
        dataset = dataset.shuffle(seed=42).select(range(max_samples))
        
        def preprocess(examples):
            inputs = [f"question: {q} context: {c}" for q, c in zip(examples['question'], examples['context'])]
            targets = examples['answers']['text']
            return {'input_text': inputs, 'target_text': [t[0] if t else "" for t in targets]}
        
        dataset = dataset.map(preprocess, batched=True)
        return dataset
    except Exception as e:
        logger.error(f"Dataset load error: {str(e)}")
        return None

# Fine-tune QA model
@st.cache_resource(ttl=3600)
def fine_tune_qa_model(dataset):
    logger.info("Starting fine-tuning")
    try:
        model_name = "google/flan-t5-small"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        
        def tokenize_function(examples):
            model_inputs = tokenizer(examples['input_text'], max_length=512, truncation=True, padding="max_length")
            labels = tokenizer(examples['target_text'], max_length=128, truncation=True, padding="max_length")
            model_inputs["labels"] = labels["input_ids"]
            return model_inputs
        
        tokenized_dataset = dataset.map(tokenize_function, batched=True)
        
        training_args = TrainingArguments(
            output_dir="./fine_tuned_model",
            num_train_epochs=1,
            per_device_train_batch_size=4,
            save_steps=500,
            logging_steps=100,
            evaluation_strategy="no",
            learning_rate=5e-5,
            fp16=False,
        )
        
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset,
        )
        trainer.train()
        
        model.save_pretrained("./fine_tuned_model")
        tokenizer.save_pretrained("./fine_tuned_model")
        logger.info("Fine-tuning complete")
        return pipeline("text2text-generation", model="./fine_tuned_model", tokenizer="./fine_tuned_model", max_length=300)
    except Exception as e:
        logger.error(f"Fine-tuning error: {str(e)}")
        return None

# Augment vector store with dataset
def augment_vector_store(vector_store, dataset_name="squad", max_samples=500):
    logger.info(f"Augmenting vector store with dataset: {dataset_name}")
    try:
        dataset = load_dataset(dataset_name, split="train").select(range(max_samples))
        chunks = [f"Context: {c}\nAnswer: {a['text'][0]}" for c, a in zip(dataset['context'], dataset['answers'])]
        embeddings_model = load_embeddings_model()
        if embeddings_model and vector_store:
            embeddings = embeddings_model.encode(chunks)
            vector_store.add_embeddings(zip(chunks, embeddings))
        return vector_store
    except Exception as e:
        logger.error(f"Vector store augmentation error: {str(e)}")
        return vector_store

# Process PDF with enhanced extraction
def process_pdf(uploaded_file):
    logger.info("Processing PDF with enhanced extraction")
    try:
        text = ""
        code_blocks = []
        with pdfplumber.open(BytesIO(uploaded_file.getvalue())) as pdf:
            for page in pdf.pages[:20]:
                extracted = page.extract_text(layout=False)
                if extracted:
                    text += extracted + "\n"
                for char in page.chars:
                    if 'fontname' in char and 'mono' in char['fontname'].lower():
                        code_blocks.append(char['text'])
                code_text = page.extract_text()
                code_matches = re.finditer(r'(^\s{2,}.*?(?:\n\s{2,}.*?)*)', code_text, re.MULTILINE)
                for match in code_matches:
                    code_blocks.append(match.group().strip())
                tables = page.extract_tables()
                if tables:
                    for table in tables:
                        text += "\n".join([" | ".join(map(str, row)) for row in table if row]) + "\n"
                for obj in page.extract_words():
                    if obj.get('size', 0) > 12:
                        text += f"\n{obj['text']}\n"

        code_text = "\n".join(code_blocks).strip()
        if not text:
            raise ValueError("No text extracted from PDF")
        
        text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=500, chunk_overlap=100, keep_separator=True)
        text_chunks = text_splitter.split_text(text)[:50]
        code_chunks = text_splitter.split_text(code_text)[:25] if code_text else []
        
        embeddings_model = load_embeddings_model()
        if not embeddings_model:
            return None, None, text, code_text
        
        text_vector_store = FAISS.from_embeddings(
            zip(text_chunks, [embeddings_model.encode(chunk) for chunk in text_chunks]),
            embeddings_model.encode
        ) if text_chunks else None
        code_vector_store = FAISS.from_embeddings(
            zip(code_chunks, [embeddings_model.encode(chunk) for chunk in code_chunks]),
            embeddings_model.encode
        ) if code_chunks else None
        
        # Augment text vector store with dataset
        if text_vector_store:
            text_vector_store = augment_vector_store(text_vector_store)
        
        logger.info("PDF processed successfully with enhanced extraction")
        return text_vector_store, code_vector_store, text, code_text
    except Exception as e:
        logger.error(f"PDF processing error: {str(e)}")
        st.error(f"PDF error: {str(e)}")
        return None, None, "", ""

# Summarize PDF
def summarize_pdf(text):
    logger.info("Generating summary")
    try:
        summary_pipeline = load_summary_pipeline()
        if not summary_pipeline:
            return "Summary model unavailable."
        
        text_splitter = CharacterTextSplitter(separator="\n\n", chunk_size=500, chunk_overlap=50)
        chunks = text_splitter.split_text(text)[:2]
        summaries = []
        
        for chunk in chunks:
            summary = summary_pipeline(chunk[:500], max_length=100, min_length=30, do_sample=False)[0]['summary_text']
            summaries.append(summary.strip())
        
        combined_summary = " ".join(summaries)
        if len(combined_summary.split()) > 150:
            combined_summary = " ".join(combined_summary.split()[:150])
        logger.info("Summary generated")
        return f"Sure, here's a concise summary of the PDF:\n{combined_summary}"
    except Exception as e:
        logger.error(f"Summary error: {str(e)}")
        return f"Oops, something went wrong summarizing: {str(e)}"

# Answer question with improved response
def answer_question(text_vector_store, code_vector_store, query):
    logger.info(f"Processing query: {query}")
    try:
        if not text_vector_store and not code_vector_store:
            return "Please upload a PDF first!"
        
        qa_pipeline = load_qa_pipeline()
        if not qa_pipeline:
            return "Sorry, the QA model is unavailable right now."
        
        is_code_query = any(keyword in query.lower() for keyword in ["code", "script", "function", "programming", "give me code", "show code"])
        if is_code_query and code_vector_store:
            return f"Here's the code from the PDF:\n```python\n{st.session_state.code_text}\n```"
        
        vector_store = text_vector_store
        if not vector_store:
            return "No relevant content found for your query."
        
        docs = vector_store.similarity_search(query, k=5)
        context = "\n".join(doc.page_content for doc in docs)
        prompt = f"Context: {context}\nQuestion: {query}\nProvide a detailed, accurate answer based on the context, prioritizing relevant information. Respond as a helpful assistant:"
        response = qa_pipeline(prompt)[0]['generated_text']
        logger.info("Answer generated")
        return f"Got it! Here's a detailed answer:\n{response.strip()}"
    except Exception as e:
        logger.error(f"Query error: {str(e)}")
        return f"Sorry, something went wrong: {str(e)}"

# Streamlit UI
try:
    st.set_page_config(page_title="Smart PDF Q&A", page_icon="📄", layout="wide")
    st.markdown("""
        <style>
        .main { max-width: 900px; margin: 0 auto; padding: 20px; }
        .sidebar { background-color: #f8f9fa; padding: 10px; border-radius: 5px; }
        .chat-container { border: 1px solid #ddd; border-radius: 10px; padding: 10px; height: 60vh; overflow-y: auto; margin-top: 20px; }
        .stChatMessage { border-radius: 10px; padding: 10px; margin: 5px; max-width: 70%; }
        .user { background-color: #e6f3ff; align-self: flex-end; }
        .assistant { background-color: #f0f0f0; }
        .dark .user { background-color: #2a2a72; color: #fff; }
        .dark .assistant { background-color: #2e2e2e; color: #fff; }
        .stButton>button { background-color: #4CAF50; color: white; border: none; padding: 8px 16px; border-radius: 5px; }
        .stButton>button:hover { background-color: #45a049; }
        pre { background-color: #f8f8f8; padding: 10px; border-radius: 5px; overflow-x: auto; }
        .header { background: linear-gradient(90deg, #4CAF50, #81C784); color: white; padding: 10px; border-radius: 5px; text-align: center; }
        </style>
    """, unsafe_allow_html=True)

    st.markdown('<div class="header"><h1>Smart PDF Q&A</h1></div>', unsafe_allow_html=True)
    st.markdown("Upload a PDF to ask questions, summarize (~150 words), or extract code with 'give me code'. Fast and friendly responses!")

    # Initialize session state
    if "messages" not in st.session_state:
        st.session_state.messages = []
    if "text_vector_store" not in st.session_state:
        st.session_state.text_vector_store = None
    if "code_vector_store" not in st.session_state:
        st.session_state.code_vector_store = None
    if "pdf_text" not in st.session_state:
        st.session_state.pdf_text = ""
    if "code_text" not in st.session_state:
        st.session_state.code_text = ""

    # Sidebar with toggle and dataset options
    with st.sidebar:
        st.markdown('<div class="sidebar">', unsafe_allow_html=True)
        theme = st.radio("Theme", ["Light", "Dark"], index=0)
        dataset_name = st.selectbox("Select Dataset for Fine-Tuning", ["squad", "cnn_dailymail", "bigcode/the-stack"], index=0)
        if st.button("Fine-Tune Model"):
            with st.spinner("Fine-tuning model..."):
                dataset = load_and_prepare_dataset(dataset_name=dataset_name)
                if dataset:
                    fine_tuned_pipeline = fine_tune_qa_model(dataset)
                    if fine_tuned_pipeline:
                        st.success("Model fine-tuned successfully!")
                    else:
                        st.error("Fine-tuning failed.")
        st.markdown('</div>', unsafe_allow_html=True)

    # PDF upload and processing
    uploaded_file = st.file_uploader("Upload a PDF", type=["pdf"])
    col1, col2 = st.columns([1, 1])
    with col1:
        if st.button("Process PDF"):
            with st.spinner("Processing PDF..."):
                st.session_state.text_vector_store, st.session_state.code_vector_store, st.session_state.pdf_text, st.session_state.code_text = process_pdf(uploaded_file)
                if st.session_state.text_vector_store or st.session_state.code_vector_store:
                    st.success("PDF processed! Ask away or summarize.")
                    st.session_state.messages = []
                else:
                    st.error("Failed to process PDF.")
    with col2:
        if st.button("Summarize PDF") and st.session_state.pdf_text:
            with st.spinner("Summarizing..."):
                summary = summarize_pdf(st.session_state.pdf_text)
                st.session_state.messages.append({"role": "assistant", "content": summary})
                st.markdown(summary, unsafe_allow_html=True)

    # Chat interface
    st.markdown('<div class="chat-container">', unsafe_allow_html=True)
    if st.session_state.text_vector_store or st.session_state.code_vector_store:
        prompt = st.chat_input("Ask a question (e.g., 'Give me code' or 'What’s the main idea?'):")
        if prompt:
            st.session_state.messages.append({"role": "user", "content": prompt})
            with st.chat_message("user"):
                st.markdown(prompt)
            with st.chat_message("assistant"):
                with st.spinner('<div class="spinner">⏳</div>'):
                    answer = answer_question(st.session_state.text_vector_store, st.session_state.code_vector_store, prompt)
                st.markdown(answer, unsafe_allow_html=True)
            st.session_state.messages.append({"role": "assistant", "content": answer})

    # Display chat history
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"], unsafe_allow_html=True)

    st.markdown('</div>', unsafe_allow_html=True)

    # Download chat history
    if st.session_state.messages:
        chat_text = "\n".join(f"{m['role'].capitalize()}: {m['content']}" for m in st.session_state.messages)
        st.download_button("Download Chat History", chat_text, "chat_history.txt")

except Exception as e:
    logger.error(f"App initialization failed: {str(e)}")
    st.error(f"App failed to start: {str(e)}. Check Spaces logs or contact support.")