tjwrld commited on
Commit
33ce21d
·
verified ·
1 Parent(s): 9dc9e74

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +179 -0
main.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import fitz # PyMuPDF
3
+ import nltk
4
+ from nltk.tokenize import word_tokenize
5
+ import google.generativeai as genai
6
+ import faiss
7
+ import numpy as np
8
+ import os
9
+
10
+ # Ensure NLTK resources are downloaded
11
+ nltk.download("punkt")
12
+
13
+ # Configure Gemini API (use environment variable or Streamlit secrets for API key)
14
+
15
+ # GEMINI_API_KEY = "" # Replace with your actual API key
16
+ # genai.configure(api_key=GEMINI_API_KEY)
17
+
18
+ genai.configure(api_key=os.environ["AI_API_KEY"])
19
+ gemini_model = genai.GenerativeModel('gemini-1.5-flash')
20
+
21
+ # Function to extract text from the uploaded PDF using PyMuPDF (fitz)
22
+ def extract_text_from_pdf(pdf_file):
23
+ try:
24
+ doc = fitz.open(stream=pdf_file.read(), filetype="pdf")
25
+ text = ""
26
+ for page_num in range(len(doc)):
27
+ page = doc.load_page(page_num)
28
+ text += page.get_text()
29
+ return text
30
+ except Exception as e:
31
+ st.error(f"Error extracting text from PDF: {e}")
32
+ return None
33
+
34
+ # Function to split text into overlapping chunks using NLTK tokenization
35
+ def split_text_into_chunks(text, chunk_size=500, overlap=100):
36
+ try:
37
+ words = word_tokenize(text)
38
+ chunks = []
39
+ for i in range(0, len(words), chunk_size - overlap):
40
+ chunk = " ".join(words[i:i + chunk_size])
41
+ chunks.append(chunk)
42
+ return chunks
43
+ except Exception as e:
44
+ st.error(f"Error splitting text into chunks: {e}")
45
+ return []
46
+
47
+ # Function to generate embeddings for a list of text chunks
48
+ def generate_embeddings(chunks, title="PDF Document"):
49
+ embeddings = []
50
+ for chunk in chunks:
51
+ try:
52
+ embedding = genai.embed_content(
53
+ model="models/embedding-001",
54
+ content=chunk,
55
+ task_type="retrieval_document",
56
+ title=title
57
+ )
58
+ embeddings.append(embedding["embedding"])
59
+ except Exception as e:
60
+ st.error(f"Error generating embedding for chunk: {e}")
61
+ return embeddings
62
+
63
+ # Function to store embeddings in FAISS
64
+ def store_embeddings_in_faiss(embeddings):
65
+ try:
66
+ embeddings_array = np.array(embeddings).astype('float32')
67
+ dimension = embeddings_array.shape[1]
68
+ index = faiss.IndexFlatL2(dimension)
69
+ index.add(embeddings_array)
70
+ return index
71
+ except Exception as e:
72
+ st.error(f"Error storing embeddings in FAISS: {e}")
73
+ return None
74
+
75
+ # Function to retrieve relevant chunks using FAISS
76
+ def retrieve_relevant_chunks(query_embedding, index, chunks, top_k=3):
77
+ try:
78
+ query_embedding = np.array(query_embedding).astype('float32').reshape(1, -1)
79
+ distances, indices = index.search(query_embedding, top_k)
80
+ relevant_chunks = [chunks[i] for i in indices[0]]
81
+ return relevant_chunks
82
+ except Exception as e:
83
+ st.error(f"Error retrieving relevant chunks: {e}")
84
+ return []
85
+
86
+ # Function to generate an answer using Gemini API
87
+ def generate_answer(query, context_chunks):
88
+ try:
89
+ context = "\n".join(context_chunks)
90
+ prompt = f"""
91
+ Context:
92
+ {context}
93
+
94
+ Question:
95
+ {query}
96
+
97
+ Answer the question based on the context provided above.
98
+ """
99
+ response = gemini_model.generate_content(prompt)
100
+ return response.text
101
+ except Exception as e:
102
+ st.error(f"Error generating answer: {e}")
103
+ return "Unable to generate an answer due to an error."
104
+
105
+ # Streamlit UI
106
+ with st.sidebar:
107
+ st.title("Navigation")
108
+ hide_st_style = '''
109
+ <style>
110
+ MainMenu {visibility: hidden;}
111
+ footer {visibility: hidden;}
112
+ header {visibility: hidden;}
113
+ </style>
114
+ '''
115
+ st.markdown(hide_st_style, unsafe_allow_html=True)
116
+ page = st.radio("Options", ["Home", "Privacy Policy"], label_visibility="collapsed")
117
+
118
+ if page == "Home":
119
+ st.title("Gemini RAG Application")
120
+ st.markdown("Upload a PDF document and ask questions to get answers using Google's Gemini API.")
121
+
122
+ pdf_file = st.file_uploader("Choose a PDF file", type="pdf")
123
+
124
+ if pdf_file is not None:
125
+ with st.spinner("Extracting text..."):
126
+ extracted_text = extract_text_from_pdf(pdf_file)
127
+
128
+ if extracted_text:
129
+ with st.spinner("Splitting text into overlapping chunks..."):
130
+ chunks = split_text_into_chunks(extracted_text, chunk_size=500, overlap=100)
131
+
132
+ if chunks:
133
+ with st.status(f"Total chunks: {len(chunks)}"):
134
+ for i, chunk in enumerate(chunks):
135
+ st.subheader(f"Chunk {i + 1}")
136
+ st.text_area(f"Chunk {i + 1} Text", chunk, height=200, key=f"chunk_{i}")
137
+
138
+ with st.spinner("Generating embeddings..."):
139
+ embeddings = generate_embeddings(chunks)
140
+
141
+ if embeddings:
142
+ with st.spinner("Storing embeddings in FAISS..."):
143
+ index = store_embeddings_in_faiss(embeddings)
144
+
145
+ if index:
146
+ st.success("Embeddings have been successfully stored in the FAISS vector database.")
147
+
148
+ query = st.text_input("Enter your question:")
149
+ if query:
150
+ with st.spinner("Generating query embedding..."):
151
+ query_embedding = genai.embed_content(
152
+ model="models/embedding-001",
153
+ content=query,
154
+ task_type="retrieval_query"
155
+ )["embedding"]
156
+
157
+ with st.spinner("Retrieving relevant chunks..."):
158
+ relevant_chunks = retrieve_relevant_chunks(query_embedding, index, chunks, top_k=3)
159
+
160
+ if relevant_chunks:
161
+ with st.status("### Relevant Context Chunks:"):
162
+ for i, chunk in enumerate(relevant_chunks):
163
+ st.subheader(f"Chunk {i + 1}")
164
+ st.text_area(f"Relevant Chunk {i + 1} Text", chunk, height=200, key=f"relevant_chunk_{i}")
165
+
166
+ with st.spinner("Generating answer..."):
167
+ answer = generate_answer(query, relevant_chunks)
168
+ st.write("### Answer:")
169
+ st.write(answer)
170
+ else:
171
+ st.warning("No relevant chunks found.")
172
+ else:
173
+ st.error("Failed to store embeddings in FAISS.")
174
+ else:
175
+ st.error("Failed to generate embeddings.")
176
+ else:
177
+ st.error("No chunks generated from the text.")
178
+ else:
179
+ st.error("No text extracted. The document might be image-based or corrupted.")