File size: 14,672 Bytes
0e40d5c
0446596
 
 
29058d8
0446596
 
 
 
 
430911f
0446596
 
430911f
29058d8
0446596
 
 
 
 
 
 
 
 
 
 
0e40d5c
0446596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430911f
0446596
 
 
 
 
 
 
 
 
0e40d5c
873b7fb
0446596
 
 
29058d8
 
 
 
 
 
 
 
 
 
 
670bd2c
29058d8
 
 
 
 
 
 
0446596
 
 
 
 
29058d8
 
 
 
 
 
0446596
29058d8
 
0446596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29058d8
 
 
 
 
0446596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29058d8
0446596
29058d8
 
0446596
 
29058d8
0446596
 
 
 
29058d8
0446596
 
 
 
 
 
 
 
2c0aa0c
0446596
 
 
 
 
 
29058d8
 
 
 
 
 
 
 
 
 
 
 
0446596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29058d8
 
 
 
 
 
 
 
 
 
 
 
0446596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29058d8
0446596
 
 
 
 
 
 
 
 
 
29058d8
 
 
 
 
 
 
 
 
 
 
 
 
0446596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29058d8
0446596
29058d8
0446596
 
 
 
 
29058d8
 
 
 
 
 
 
 
 
 
 
 
0446596
 
 
 
 
 
 
 
 
 
 
29058d8
 
 
 
 
 
 
0446596
29058d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0446596
29058d8
 
 
 
 
0446596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29058d8
0446596
 
 
 
 
 
 
 
 
29058d8
0446596
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
import streamlit as st
import os
import tempfile
from typing import List, Optional
import pickle

# Core libraries
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS

# Document loaders
from langchain.document_loaders import PyPDFLoader

# Configure Streamlit page
st.set_page_config(
    page_title="PDF RAG System",
    page_icon="πŸ“š",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS for better styling
st.markdown("""
<style>
    .main-header {
        font-size: 2.5rem;
        color: #1f77b4;
        text-align: center;
        margin-bottom: 2rem;
    }
    .sidebar-header {
        font-size: 1.5rem;
        color: #ff7f0e;
        margin-bottom: 1rem;
    }
    .success-message {
        padding: 1rem;
        background-color: #d4edda;
        border: 1px solid #c3e6cb;
        border-radius: 0.5rem;
        color: #155724;
        margin: 1rem 0;
    }
    .error-message {
        padding: 1rem;
        background-color: #f8d7da;
        border: 1px solid #f5c6cb;
        border-radius: 0.5rem;
        color: #721c24;
        margin: 1rem 0;
    }
    .source-box {
        background-color: #f8f9fa;
        border-left: 4px solid #007bff;
        padding: 1rem;
        margin: 0.5rem 0;
        border-radius: 0 0.5rem 0.5rem 0;
    }
</style>
""", unsafe_allow_html=True)

# Initialize session state
if 'qa_chain' not in st.session_state:
    st.session_state.qa_chain = None
if 'vectorstore' not in st.session_state:
    st.session_state.vectorstore = None
if 'documents_processed' not in st.session_state:
    st.session_state.documents_processed = False
if 'chat_history' not in st.session_state:
    st.session_state.chat_history = []

@st.cache_resource
def setup_llm(model_name="google/flan-t5-small"):
    """Setup the language model for text generation"""
    with st.spinner("πŸ€– Loading language model..."):
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
            
            pipe = pipeline(
                "text2text-generation",
                model=model,
                tokenizer=tokenizer,
                max_new_tokens=300,
                temperature=0.3,
                do_sample=True,
                device=-1
            )
            
            llm = HuggingFacePipeline(pipeline=pipe)
            return llm
        except Exception as e:
            st.error(f"Error loading model: {e}")
            return None

@st.cache_resource
def setup_embeddings(model_name="all-MiniLM-L6-v2"):
    """Setup the embedding model for vector generation"""
    with st.spinner("πŸ”’ Loading embedding model..."):
        try:
            embeddings = HuggingFaceEmbeddings(model_name=model_name)
            return embeddings
        except Exception as e:
            st.error(f"Error loading embeddings: {e}")
            return None

def process_uploaded_files(uploaded_files, embeddings):
    """Process uploaded PDF files and create FAISS vector store"""
    if not uploaded_files:
        return None, []
    
    documents = []
    
    # Process each uploaded file
    for uploaded_file in uploaded_files:
        try:
            # Create temporary file
            with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
                tmp_file.write(uploaded_file.read())
                tmp_file_path = tmp_file.name
            
            # Load PDF
            loader = PyPDFLoader(tmp_file_path)
            docs = loader.load()
            
            # Add file name to metadata
            for doc in docs:
                doc.metadata['source_file'] = uploaded_file.name
            
            documents.extend(docs)
            
            # Clean up temporary file
            os.unlink(tmp_file_path)
            
            st.success(f"βœ… Processed: {uploaded_file.name} ({len(docs)} pages)")
            
        except Exception as e:
            st.error(f"❌ Error processing {uploaded_file.name}: {e}")
    
    if not documents:
        return None, []
    
    # Split documents into chunks
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len,
        separators=["\n\n", "\n", " ", ""]
    )
    
    text_chunks = text_splitter.split_documents(documents)
    
    # Add metadata to chunks
    for i, text in enumerate(text_chunks):
        text.metadata.update({
            "chunk_id": i,
            "chunk_size": len(text.page_content)
        })
    
    st.info(f"βœ‚οΈ Created {len(text_chunks)} text chunks")
    
    # Create FAISS vector store
    try:
        vectorstore = FAISS.from_documents(text_chunks, embeddings)
        st.success(f"βœ… Successfully created vector database with {len(text_chunks)} chunks!")
        return vectorstore, text_chunks
    except Exception as e:
        st.error(f"❌ Error creating vector database: {e}")
        return None, []

def create_qa_chain(llm, vectorstore, k=5):
    """Create a question-answering chain with retrieval"""
    if not vectorstore or not llm:
        return None
    
    prompt_template = """Use the following context to answer the question. If you cannot find the answer in the context, say "I cannot find this information in the provided documents."

Context: {context}

Question: {question}

Answer:"""

    PROMPT = PromptTemplate(
        template=prompt_template,
        input_variables=["context", "question"]
    )
    
    try:
        qa_chain = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff",
            retriever=vectorstore.as_retriever(search_kwargs={"k": k}),
            chain_type_kwargs={"prompt": PROMPT},
            return_source_documents=True
        )
        return qa_chain
    except Exception as e:
        st.error(f"Error creating QA chain: {e}")
        return None

def ask_question(qa_chain, question):
    """Ask a question and get an answer with sources"""
    if not qa_chain:
        return None
    
    try:
        result = qa_chain({"query": question})
        
        response = {
            "question": question,
            "answer": result["result"],
            "source_documents": result.get("source_documents", [])
        }
        
        return response
        
    except Exception as e:
        st.error(f"❌ Error processing question: {e}")
        return None

def search_similar_chunks(vectorstore, query, k=5):
    """Search for similar chunks without generating an answer"""
    if not vectorstore:
        return []
    
    try:
        results = vectorstore.similarity_search(query, k=k)
        return results
    except Exception as e:
        st.error(f"Error searching: {e}")
        return []

# Main App Interface
def main():
    st.markdown('<h1 class="main-header">πŸ“š PDF RAG System</h1>', unsafe_allow_html=True)
    st.markdown("Upload PDF documents and ask questions about their content using AI-powered retrieval!")
    
    # Sidebar for configuration
    with st.sidebar:
        st.markdown('<h2 class="sidebar-header">βš™οΈ Configuration</h2>', unsafe_allow_html=True)
        
        # Model configuration
        st.subheader("πŸ€– Model Settings")
        llm_model = st.selectbox(
            "Language Model",
            ["google/flan-t5-small", "google/flan-t5-base"],
            help="Choose the language model (smaller models are faster)"
        )
        
        embedding_model = st.selectbox(
            "Embedding Model",
            ["all-MiniLM-L6-v2", "sentence-transformers/all-mpnet-base-v2"],
            help="Choose the embedding model"
        )
        
        retrieval_k = st.slider(
            "Number of chunks to retrieve",
            min_value=1,
            max_value=10,
            value=5,
            help="How many relevant chunks to use for answering questions"
        )
        
        st.subheader("πŸ’Ύ Vector Store")
        st.info("Using FAISS (local vector storage)")
        
        # Option to save/load vector store
        if st.session_state.vectorstore:
            if st.button("πŸ’Ύ Save Vector Store"):
                try:
                    # Save vector store to session state or file
                    st.session_state.vectorstore.save_local("faiss_index")
                    st.success("Vector store saved!")
                except Exception as e:
                    st.error(f"Error saving: {e}")
    
    # Main content area
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.subheader("πŸ“ Upload Documents")
        uploaded_files = st.file_uploader(
            "Choose PDF files",
            type=['pdf'],
            accept_multiple_files=True,
            help="Upload one or more PDF files to analyze"
        )
        
        if st.button("πŸš€ Process Documents", type="primary"):
            if not uploaded_files:
                st.warning("Please upload at least one PDF file.")
            else:
                with st.spinner("Processing documents..."):
                    # Setup models
                    llm = setup_llm(llm_model)
                    embeddings = setup_embeddings(embedding_model)
                    
                    if llm and embeddings:
                        # Process files
                        vectorstore, text_chunks = process_uploaded_files(uploaded_files, embeddings)
                        
                        if vectorstore:
                            # Create QA chain
                            qa_chain = create_qa_chain(llm, vectorstore, k=retrieval_k)
                            
                            if qa_chain:
                                # Store in session state
                                st.session_state.qa_chain = qa_chain
                                st.session_state.vectorstore = vectorstore
                                st.session_state.documents_processed = True
                                
                                st.balloons()
                                st.success("πŸŽ‰ Documents processed successfully! You can now ask questions.")
                            else:
                                st.error("Failed to create QA chain.")
                    else:
                        st.error("Failed to load models.")
    
    with col2:
        st.subheader("πŸ’¬ Ask Questions")
        
        if st.session_state.documents_processed:
            question = st.text_input(
                "Your question:",
                placeholder="What are the main topics discussed in the documents?",
                help="Ask any question about your uploaded documents"
            )
            
            col2a, col2b = st.columns([1, 1])
            
            with col2a:
                if st.button("πŸ” Get Answer"):
                    if question:
                        with st.spinner("Searching for answer..."):
                            result = ask_question(st.session_state.qa_chain, question)
                            
                            if result:
                                # Add to chat history
                                st.session_state.chat_history.append({
                                    "question": question,
                                    "answer": result["answer"],
                                    "sources": result["source_documents"]
                                })
                                
                                # Display answer
                                st.subheader("πŸ’‘ Answer:")
                                st.write(result["answer"])
                                
                                # Display sources
                                if result["source_documents"]:
                                    st.subheader("πŸ“š Sources:")
                                    for i, doc in enumerate(result["source_documents"][:3]):
                                        with st.expander(f"Source {i+1}: {doc.metadata.get('source_file', 'Unknown')}"):
                                            st.write(doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content)
                    else:
                        st.warning("Please enter a question.")
            
            with col2b:
                if st.button("πŸ” Search Similar"):
                    if question:
                        with st.spinner("Searching for similar content..."):
                            results = search_similar_chunks(st.session_state.vectorstore, question, k=5)
                            
                            if results:
                                st.subheader("πŸ” Similar Content:")
                                for i, doc in enumerate(results):
                                    with st.expander(f"Match {i+1}: {doc.metadata.get('source_file', 'Unknown')}"):
                                        st.write(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content)
        else:
            st.info("πŸ‘† Please upload and process documents first to start asking questions.")
    
    # Chat History
    if st.session_state.chat_history:
        st.subheader("πŸ“ Chat History")
        
        for i, chat in enumerate(reversed(st.session_state.chat_history[-5:])):  # Show last 5
            with st.expander(f"Q: {chat['question'][:50]}..."):
                st.write("**Question:**", chat['question'])
                st.write("**Answer:**", chat['answer'])
                
                if chat['sources']:
                    st.write("**Sources:**")
                    for j, doc in enumerate(chat['sources'][:2]):  # Show top 2 sources
                        st.write(f"{j+1}. {doc.metadata.get('source_file', 'Unknown')}")
    
    # Clear session button
    if st.session_state.documents_processed:
        if st.button("πŸ—‘οΈ Clear Session"):
            st.session_state.qa_chain = None
            st.session_state.vectorstore = None
            st.session_state.documents_processed = False
            st.session_state.chat_history = []
            st.success("Session cleared! You can upload new documents.")
            st.rerun()

if __name__ == "__main__":
    main()