avimittal30 commited on
Commit
a7aaec4
·
1 Parent(s): 793774f

pushing files

Browse files
Files changed (4) hide show
  1. app.py +144 -0
  2. data.py +92 -0
  3. helper.py +115 -0
  4. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ import pickle
5
+ import faiss
6
+ import numpy as np
7
+ from helper import extract_text_from_pdf, chunk_text, embedding_function, embedding_model, query_llm_with_context
8
+ import logging
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Set page configuration
15
+ st.set_page_config(
16
+ page_title="PDF RAG System",
17
+ page_icon="📚",
18
+ layout="wide"
19
+ )
20
+
21
+ # Title and description
22
+ st.title("📚 PDF RAG System")
23
+ st.markdown("""
24
+ This application allows you to upload a PDF file, ask questions about its content, and get AI-generated answers based on the document.
25
+ """)
26
+
27
+ # File upload section
28
+ st.header("1. Upload PDF")
29
+ uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
30
+
31
+ # Initialize session state variables
32
+ if 'pdf_processed' not in st.session_state:
33
+ st.session_state.pdf_processed = False
34
+ if 'index' not in st.session_state:
35
+ st.session_state.index = None
36
+ if 'chunks' not in st.session_state:
37
+ st.session_state.chunks = None
38
+ if 'pdf_path' not in st.session_state:
39
+ st.session_state.pdf_path = None
40
+
41
+ # Process the uploaded PDF
42
+ if uploaded_file is not None and not st.session_state.pdf_processed:
43
+ with st.spinner("Processing PDF..."):
44
+ # Create a temporary file to save the uploaded PDF
45
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
46
+ tmp_file.write(uploaded_file.getvalue())
47
+ st.session_state.pdf_path = tmp_file.name
48
+
49
+ # Extract text from PDF
50
+ pdf_text = extract_text_from_pdf(st.session_state.pdf_path)
51
+
52
+ # Chunk the text
53
+ chunks = chunk_text(pdf_text, chunk_size=1000, chunk_overlap=100)
54
+ st.session_state.chunks = chunks
55
+
56
+ # Create embeddings
57
+ embeddings = embedding_function(chunks)
58
+
59
+ # Convert embeddings to numpy array if they aren't already
60
+ if not isinstance(embeddings, np.ndarray):
61
+ embeddings = np.array(embeddings).astype('float32')
62
+
63
+ # Get the dimension of the embeddings
64
+ dimension = embeddings.shape[1]
65
+
66
+ # Initialize FAISS index
67
+ index = faiss.IndexFlatL2(dimension)
68
+
69
+ # Add vectors to the index
70
+ index.add(embeddings)
71
+
72
+ # Save the index and chunks
73
+ faiss.write_index(index, "./faiss_index")
74
+ with open("./document_chunks.pkl", 'wb') as f:
75
+ pickle.dump(chunks, f)
76
+
77
+ # Update session state
78
+ st.session_state.index = index
79
+ st.session_state.pdf_processed = True
80
+
81
+ st.success(f"PDF processed successfully! {len(chunks)} chunks created.")
82
+
83
+ # Query section
84
+ st.header("2. Ask a Question")
85
+ query = st.text_input("Enter your question about the PDF content:")
86
+
87
+ # Add a button to submit the query
88
+ if st.button("Get Answer") and query and st.session_state.pdf_processed:
89
+ with st.spinner("Retrieving relevant information and generating answer..."):
90
+ try:
91
+ # Generate embedding for the query
92
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True).astype('float32')
93
+
94
+ # Search the index
95
+ n_results = 5
96
+ distances, indices = st.session_state.index.search(query_embedding, n_results)
97
+
98
+ # Get the documents
99
+ documents = [st.session_state.chunks[i] for i in indices[0]]
100
+
101
+ # Convert distances to similarity scores (L2 distance: lower is better)
102
+ # Normalize distances to [0, 1] range where 1 is most similar
103
+ max_distance = np.max(distances)
104
+ similarity_scores = [1 - (dist / max_distance) for dist in distances[0]]
105
+
106
+ # Create context from retrieved documents
107
+ context = (documents, similarity_scores)
108
+
109
+ # Query the LLM with context
110
+ answer = query_llm_with_context(query, context, top_n=3)
111
+
112
+ # Display the answer
113
+ st.header("3. Answer")
114
+ st.write(answer)
115
+
116
+ # Display the retrieved documents
117
+ with st.expander("View Retrieved Documents"):
118
+ for i, (doc, score) in enumerate(zip(documents, similarity_scores)):
119
+ st.markdown(f"**Document {i+1}** (Relevance: {score:.4f})")
120
+ st.text(doc[:500] + "..." if len(doc) > 500 else doc)
121
+ st.markdown("---")
122
+
123
+ except Exception as e:
124
+ st.error(f"An error occurred: {str(e)}")
125
+ logger.exception("Error during query processing")
126
+
127
+ # Add a reset button
128
+ if st.button("Reset and Upload New PDF"):
129
+ # Clean up temporary files
130
+ if st.session_state.pdf_path and os.path.exists(st.session_state.pdf_path):
131
+ os.unlink(st.session_state.pdf_path)
132
+
133
+ # Reset session state
134
+ st.session_state.pdf_processed = False
135
+ st.session_state.index = None
136
+ st.session_state.chunks = None
137
+ st.session_state.pdf_path = None
138
+
139
+ # Reload the page
140
+ st.experimental_rerun()
141
+
142
+ # Footer
143
+ st.markdown("---")
144
+ st.markdown("Built with Streamlit, FAISS, and Ollama")
data.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from helper import extract_text_from_pdf, chunk_text, embedding_function, embedding_model, generate_hypothetical_answer, query_llm_with_context
2
+ import numpy as np
3
+ import faiss
4
+ import pickle
5
+ import os
6
+ import logging
7
+ from helper import query_llm_with_context
8
+ logging.basicConfig(level=logging.INFO)
9
+
10
+ # Path for storing the FAISS index and document chunks
11
+ index_path = "./faiss_index"
12
+ chunks_path = "./document_chunks.pkl"
13
+
14
+ pdf_path = 'C:\Git Projects\AnnualReport_rag\IBM.pdf'
15
+
16
+ print('Extracting text from pdf...')
17
+ pdf_text = extract_text_from_pdf(pdf_path)
18
+
19
+ print('Chunking pdf...')
20
+ chunks = chunk_text(pdf_text, chunk_size=1000, chunk_overlap=100)
21
+
22
+ print('Embedding chunks...')
23
+ embeddings = embedding_function(chunks)
24
+
25
+ print(f"Embeddings type: {type(embeddings)}")
26
+ print(f"First embedding type: {type(embeddings[0])}")
27
+ print(f"First embedding shape or length: {len(embeddings[0]) if hasattr(embeddings[0], '__len__') else 'unknown'}")
28
+
29
+ # Convert embeddings to numpy array if they aren't already
30
+ if not isinstance(embeddings, np.ndarray):
31
+ print("Converting embeddings to numpy array...")
32
+ embeddings = np.array(embeddings).astype('float32')
33
+
34
+ # Get the dimension of the embeddings
35
+ dimension = embeddings.shape[1]
36
+ print(f"Embedding dimension: {dimension}")
37
+
38
+ # Initialize FAISS index
39
+ print('Initializing FAISS index...')
40
+ index = faiss.IndexFlatL2(dimension) # L2 distance for similarity search
41
+
42
+ # Add vectors to the index
43
+ print('Adding vectors to FAISS index...')
44
+ index.add(embeddings)
45
+
46
+ # Save the index
47
+ print('Saving FAISS index...')
48
+ faiss.write_index(index, index_path)
49
+
50
+ # Save the document chunks for retrieval
51
+ print('Saving document chunks...')
52
+ with open(chunks_path, 'wb') as f:
53
+ pickle.dump(chunks, f)
54
+
55
+ print(f"Total vectors in index: {index.ntotal}")
56
+
57
+
58
+ def retrieve_documents(query, n_results=5):
59
+ # Generate embedding for the query
60
+ query_embedding = embedding_model.encode([query], convert_to_numpy=True).astype('float32')
61
+
62
+ # Search the index
63
+ distances, indices = index.search(query_embedding, n_results)
64
+
65
+ # Get the documents
66
+ documents = [chunks[i] for i in indices[0]]
67
+
68
+ # Convert distances to similarity scores (L2 distance: lower is better)
69
+ # Normalize distances to [0, 1] range where 1 is most similar
70
+ max_distance = np.max(distances)
71
+ similarity_scores = [1 - (dist / max_distance) for dist in distances[0]]
72
+
73
+ return documents, similarity_scores
74
+
75
+
76
+ # Test the retrieval
77
+ query="how has the profitability of the company been in last five years"
78
+ print('Retrieving documents...')
79
+ general_docs, general_scores = retrieve_documents(query, n_results=15)
80
+ print(f"Number of docs returned for general query: {len(general_docs)}")
81
+
82
+ # Print the results
83
+ # for i, (doc, score) in enumerate(zip(general_docs, general_scores)):
84
+ # print(f"\nResult {i+1} (Score: {score:.4f}):")
85
+ # print(f"{doc[:200]}...")
86
+
87
+ new_query=query+generate_hypothetical_answer(query)
88
+ combined_context=retrieve_documents(new_query, n_results=15)
89
+
90
+ answer = query_llm_with_context(query, combined_context, top_n=3)
91
+
92
+ print('final_response:{answer}')
helper.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from pypdf import PdfReader
4
+ import requests
5
+ import json
6
+
7
+
8
+ def extract_text_from_pdf(pdf_path):
9
+ reader = PdfReader(pdf_path)
10
+ text = ""
11
+ for page in reader.pages:
12
+ text += page.extract_text() + "\n"
13
+ return text.strip()
14
+
15
+ def chunk_text(text, chunk_size=500, chunk_overlap=100):
16
+ splitter = RecursiveCharacterTextSplitter(
17
+ chunk_size=chunk_size,
18
+ chunk_overlap=chunk_overlap, # Overlap to preserve context
19
+ separators=["\n\n", "\n", " ", ""], # Prioritize logical breaks
20
+ )
21
+ return splitter.split_text(text)
22
+
23
+ embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
24
+
25
+ def embedding_function(texts):
26
+ return embedding_model.encode(texts, convert_to_numpy=True).tolist()
27
+
28
+
29
+
30
+ def generate_hypothetical_answer(query):
31
+ import requests
32
+ import json
33
+
34
+ # Ollama API endpoint (default is localhost:11434)
35
+ ollama_url = "http://localhost:11434/api/generate"
36
+
37
+ # Prepare the prompt
38
+ prompt = f"Generate a plausible answer to the question:\n\n{query}\n\nAnswer:"
39
+
40
+ # Prepare the request payload
41
+ payload = {
42
+ "model": "llama2", # or any other model you have pulled in Ollama
43
+ "prompt": prompt,
44
+ "stream": False
45
+ }
46
+
47
+ try:
48
+ # Make the API request to Ollama
49
+ response = requests.post(ollama_url, json=payload)
50
+ response.raise_for_status() # Raise an exception for HTTP errors
51
+
52
+ # Parse the response
53
+ result = response.json()
54
+
55
+ # Extract the generated text
56
+ generated_text = result.get("response", "")
57
+ return generated_text.strip()
58
+
59
+ except Exception as e:
60
+ print(f"Error generating hypothetical answer: {e}")
61
+ return "Failed to generate a hypothetical answer."
62
+
63
+
64
+
65
+
66
+ def query_llm_with_context(query,context,top_n=3):
67
+ # Get documents sorted by similarity
68
+ sorted_docs, sorted_scores = context
69
+
70
+ # Use only the top N documents
71
+ top_docs = sorted_docs[:top_n]
72
+
73
+ # Create a context string by joining the top documents
74
+ context = "\n\n===Document Boundary===\n\n".join(top_docs)
75
+
76
+ # Create a prompt with the context and query
77
+ prompt = f"""
78
+ Context information is below.
79
+ ---------------------
80
+ {context}
81
+ ---------------------
82
+
83
+ Given the context information and not prior knowledge, answer the following query:
84
+ Query: {query}
85
+ """
86
+
87
+ # Call Ollama API instead of OpenAI
88
+ ollama_url = "http://localhost:11434/api/generate"
89
+
90
+ # Prepare the request payload
91
+ payload = {
92
+ "model": "llama2", # or any other model you have pulled in Ollama
93
+ "prompt": prompt,
94
+ "stream": False
95
+ }
96
+
97
+ try:
98
+ # Make the API request to Ollama
99
+ response = requests.post(ollama_url, json=payload)
100
+ response.raise_for_status() # Raise an exception for HTTP errors
101
+
102
+ # Parse the response
103
+ result = response.json()
104
+
105
+ # Extract the generated text
106
+ generated_text = result.get("response", "")
107
+ return generated_text.strip()
108
+
109
+ except Exception as e:
110
+ print(f"Error querying LLM with context: {e}")
111
+ return "Failed to generate an answer with the provided context."
112
+
113
+
114
+
115
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ sentence-transformers
2
+ chromadb
3
+ pypdf
4
+ langchain
5
+ openai
6
+ faiss-cpu