riteshraut commited on
Commit
2e541fd
·
1 Parent(s): 08eb411
Files changed (1) hide show
  1. app.py +54 -157
app.py CHANGED
@@ -3,39 +3,29 @@ import nltk
3
  from functools import wraps
4
 
5
  # ============================ NLTK MONKEY-PATCH (MUST BE FIRST) ============================
6
- # This is the definitive fix. We are intercepting the nltk.download function
7
- # to force it to always use the correct, writable directory.
8
-
9
  print("Applying NLTK monkey-patch...")
10
  NLTK_DATA_DIR = '/tmp/nltk_data'
11
  os.environ['NLTK_DATA'] = NLTK_DATA_DIR
12
  os.makedirs(NLTK_DATA_DIR, exist_ok=True)
13
 
14
- # Store the original download function
15
  _original_nltk_download = nltk.download
16
 
17
- # Create a new, patched download function
18
  @wraps(_original_nltk_download)
19
  def _patched_nltk_download(info_or_id, download_dir=None, **kwargs):
20
- # If the download_dir is not specified (which is the case in the faulty
21
- # 'unstructured' call), force it to our writable directory.
22
  if download_dir is None:
23
  download_dir = NLTK_DATA_DIR
24
-
25
  print(f"Patched NLTK download called for '{info_or_id}', ensuring download_dir='{download_dir}'")
26
  return _original_nltk_download(info_or_id, download_dir=download_dir, **kwargs)
27
 
28
- # Replace the original function with our patched version
29
  nltk.download = _patched_nltk_download
30
  print("NLTK monkey-patch applied successfully.")
31
  # ========================================================================================
32
 
33
-
34
- # Now that the patch is active, we can proceed with imports and initial downloads.
35
  print("Running initial NLTK downloads...")
36
  nltk.download('punkt')
37
  nltk.download('stopwords')
38
- nltk.download('averaged_perceptron_tagger_eng')
39
  print("Initial NLTK downloads complete.")
40
 
41
  import time
@@ -44,22 +34,21 @@ from flask import Flask, request, render_template, session, jsonify, Response, s
44
  from werkzeug.utils import secure_filename
45
  from rag_processor import create_rag_chain
46
 
47
- # ============================ ADDITIONS START ============================
48
  from gtts import gTTS
49
  import io
50
  import re
51
- # ============================ ADDITIONS END ==============================
52
 
53
- # NOW import the rest of the modules that might use NLTK
54
- # Document Loaders
 
55
  from langchain_community.document_loaders import (
56
  TextLoader,
57
- UnstructuredPDFLoader,
58
  Docx2txtLoader,
59
- UnstructuredImageLoader,
60
  )
61
 
62
- # Text Splitter, Embeddings, Retrievers
63
  from langchain.text_splitter import RecursiveCharacterTextSplitter
64
  from langchain_huggingface import HuggingFaceEmbeddings
65
  from langchain_community.vectorstores import FAISS
@@ -69,37 +58,29 @@ from langchain_community.chat_message_histories import ChatMessageHistory
69
 
70
  # --- Basic Flask App Setup ---
71
  app = Flask(__name__)
72
- # A secret key is needed for session management
73
- app.config['SECRET_KEY'] = os.urandom(24)
74
- # Configure the upload folder
75
  app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
76
- # Ensure the upload folder exists
77
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
78
 
79
  # --- In-memory Storage & Global Model Loading ---
80
  rag_chains = {}
81
  message_histories = {}
82
 
83
- # Load the embedding model once when the application starts for efficiency.
84
  print("Loading embedding model...")
85
  EMBEDDING_MODEL = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
86
  print("Embedding model loaded successfully.")
87
 
88
- # A dictionary to map file extensions to their corresponding loader classes
 
89
  LOADER_MAPPING = {
90
  ".txt": TextLoader,
91
- ".pdf": UnstructuredPDFLoader,
92
  ".docx": Docx2txtLoader,
93
- ".jpeg": UnstructuredImageLoader,
94
- ".jpg": UnstructuredImageLoader,
95
- ".png": UnstructuredImageLoader,
96
  }
97
 
98
  def get_session_history(session_id: str) -> ChatMessageHistory:
99
- """
100
- Retrieves the chat history for a given session ID. If it doesn't exist,
101
- a new history object is created.
102
- """
103
  if session_id not in message_histories:
104
  message_histories[session_id] = ChatMessageHistory()
105
  return message_histories[session_id]
@@ -111,23 +92,18 @@ def index():
111
 
112
  @app.route('/upload', methods=['POST'])
113
  def upload_files():
114
- """Handles multiple file uploads, processing, and RAG chain creation."""
115
- # Ensure NLTK is still configured correctly
116
- if 'NLTK_DATA' not in os.environ:
117
- os.environ['NLTK_DATA'] = '/tmp/nltk_data'
118
-
119
  files = request.files.getlist('file')
120
 
121
  if not files or all(f.filename == '' for f in files):
122
  return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
123
-
124
  all_docs = []
125
  all_filenames = []
126
  failed_files = []
127
-
128
  try:
129
- print(f"Processing {len(files)} files...")
130
-
131
  for file in files:
132
  if file and file.filename:
133
  filename = secure_filename(file.filename)
@@ -136,159 +112,90 @@ def upload_files():
136
 
137
  try:
138
  file.save(filepath)
139
- print(f"Saved file: {filename} at {filepath}")
140
 
141
  file_extension = os.path.splitext(filename)[1].lower()
142
- if file_extension not in LOADER_MAPPING:
143
- print(f"Skipping unsupported file type: {filename}")
144
- failed_files.append(f"{filename} (unsupported format)")
145
- continue
146
-
147
- loader_class = LOADER_MAPPING[file_extension]
148
- loader_kwargs = {}
149
 
150
- if file_extension in [".jpeg", ".jpg", ".png"]:
151
- loader_kwargs['mode'] = 'single'
152
- elif file_extension == ".pdf":
153
- loader_kwargs['strategy'] = 'hi_res'
154
- loader_kwargs['languages'] = ['eng']
155
-
156
- print(f"Loading {filename} with {loader_class.__name__}...")
157
- loader = loader_class(filepath, **loader_kwargs)
158
- loaded_docs = loader.load()
159
-
160
- # Check if documents were actually loaded
161
- if loaded_docs:
162
- print(f"Successfully loaded {len(loaded_docs)} documents from {filename}")
163
- # Check if the documents have content
164
- for doc in loaded_docs:
165
- if hasattr(doc, 'page_content') and doc.page_content:
166
- print(f"Document content preview (first 100 chars): {doc.page_content[:100]}")
167
- else:
168
- print(f"Warning: Document from {filename} has no content")
169
- all_docs.extend(loaded_docs)
170
- else:
171
- print(f"Warning: No documents loaded from {filename}")
172
- failed_files.append(f"{filename} (no content extracted)")
173
 
 
 
 
 
 
 
 
 
 
 
 
174
  except Exception as e:
175
  print(f"Error processing file {filename}: {e}")
176
- failed_files.append(f"{filename} (processing error: {str(e)})")
177
  continue
178
 
179
  if not all_docs:
180
- error_msg = "No processable content was extracted from the uploaded files."
181
- if failed_files:
182
- error_msg += f" Failed files: {', '.join(failed_files)}"
183
- print(error_msg)
184
  return jsonify({'status': 'error', 'message': error_msg}), 400
185
 
186
- # --- Process all documents together ---
187
- print(f"Total documents loaded: {len(all_docs)}")
188
 
189
- # Split documents into chunks
190
- text_splitter = RecursiveCharacterTextSplitter(
191
- chunk_size=1000,
192
- chunk_overlap=200,
193
- length_function=len,
194
- separators=["\n\n", "\n", " ", ""]
195
- )
196
  splits = text_splitter.split_documents(all_docs)
197
 
198
- print(f"Documents split into {len(splits)} chunks")
199
-
200
- # Verify that splits have content
201
  if not splits:
202
  return jsonify({
203
  'status': 'error',
204
- 'message': 'Documents were loaded but no text chunks were created. The files might be empty or contain only non-text content.'
205
  }), 400
206
-
207
- # Additional check for empty chunks
208
- non_empty_splits = [s for s in splits if s.page_content and s.page_content.strip()]
209
- if not non_empty_splits:
210
- return jsonify({
211
- 'status': 'error',
212
- 'message': 'All text chunks are empty. Please check if your files contain readable text.'
213
- }), 400
214
-
215
- if len(non_empty_splits) < len(splits):
216
- print(f"Warning: {len(splits) - len(non_empty_splits)} empty chunks were filtered out")
217
- splits = non_empty_splits
218
 
219
- print(f"Creating vector store with {len(splits)} non-empty chunks...")
 
220
 
221
- try:
222
- vectorstore = FAISS.from_documents(documents=splits, embedding=EMBEDDING_MODEL)
223
- print("Vector store created successfully")
224
- except IndexError as e:
225
- print(f"IndexError creating vector store: {e}")
226
- return jsonify({
227
- 'status': 'error',
228
- 'message': 'Failed to create embeddings. The documents might not contain enough text content.'
229
- }), 500
230
- except Exception as e:
231
- print(f"Error creating vector store: {e}")
232
- return jsonify({
233
- 'status': 'error',
234
- 'message': f'Failed to create vector store: {str(e)}'
235
- }), 500
236
-
237
- # Create retrievers
238
- print("Creating BM25 retriever...")
239
  bm25_retriever = BM25Retriever.from_documents(splits)
240
  bm25_retriever.k = 5
241
-
242
- print("Creating FAISS retriever...")
243
  faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
244
 
245
- print("Creating ensemble retriever...")
246
- ensemble_retriever = EnsembleRetriever(
247
- retrievers=[bm25_retriever, faiss_retriever],
248
- weights=[0.5, 0.5]
249
- )
250
 
251
- # Create session and RAG chain
252
  session_id = str(uuid.uuid4())
253
  rag_chains[session_id] = create_rag_chain(ensemble_retriever, get_session_history)
254
- print(f"RAG chain created for session {session_id} with {len(all_filenames)} documents.")
255
-
256
  session['session_id'] = session_id
257
 
258
- # Prepare response
259
- display_filenames = ", ".join(all_filenames)
260
- response_data = {'status': 'success', 'filename': display_filenames}
261
 
 
262
  if failed_files:
263
  response_data['warnings'] = f"Some files could not be processed: {', '.join(failed_files)}"
264
 
265
  return jsonify(response_data)
266
 
267
  except Exception as e:
268
- print(f"Unexpected error creating RAG chain: {e}")
269
  import traceback
270
  traceback.print_exc()
271
- return jsonify({'status': 'error', 'message': f'Failed to process files: {str(e)}'}), 500
272
 
273
  @app.route('/chat', methods=['POST'])
274
  def chat():
275
- """Handles chat messages and streams the response with memory."""
276
  data = request.get_json()
277
  question = data.get('question')
278
  session_id = session.get('session_id')
279
 
280
- if not all([question, session_id]):
281
- return jsonify({'status': 'error', 'message': 'Missing data in request.'}), 400
282
-
283
- if session_id not in rag_chains:
284
- return jsonify({'status': 'error', 'message': 'Session not found. Please upload documents again.'}), 400
285
 
286
  try:
287
  rag_chain = rag_chains[session_id]
288
  config = {"configurable": {"session_id": session_id}}
289
 
290
  def generate():
291
- """A generator function to stream the response."""
292
  for chunk in rag_chain.stream({"question": question, "config": config}):
293
  yield chunk
294
 
@@ -298,29 +205,22 @@ def chat():
298
  print(f"Error during chat invocation: {e}")
299
  return Response("An error occurred while getting the answer.", status=500, mimetype='text/plain')
300
 
301
- # ============================ ADDITIONS START ============================
302
-
303
  def clean_markdown_for_tts(text: str) -> str:
304
- """Removes markdown formatting for cleaner text-to-speech output."""
305
- # Remove bold (**text**) and italics (*text* or _text_)
306
  text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
307
  text = re.sub(r'\_(.*?)\_', r'\1', text)
308
- # Remove inline code (`code`)
309
  text = re.sub(r'`(.*?)`', r'\1', text)
310
- # Remove headings (e.g., #, ##, ###)
311
  text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
312
- # Remove list item markers (*, -, 1.)
313
  text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
314
  text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
315
- # Remove blockquotes (>)
316
  text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
317
- # Replace multiple newlines with a single space
318
  text = re.sub(r'\n+', ' ', text)
319
  return text.strip()
320
 
321
  @app.route('/tts', methods=['POST'])
322
  def text_to_speech():
323
- """Generates audio from text and returns it as an MP3 stream."""
324
  data = request.get_json()
325
  text = data.get('text')
326
 
@@ -328,9 +228,7 @@ def text_to_speech():
328
  return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
329
 
330
  try:
331
- # Clean the text before sending to gTTS
332
  clean_text = clean_markdown_for_tts(text)
333
-
334
  tts = gTTS(clean_text, lang='en')
335
  mp3_fp = io.BytesIO()
336
  tts.write_to_fp(mp3_fp)
@@ -339,7 +237,6 @@ def text_to_speech():
339
  except Exception as e:
340
  print(f"Error in TTS generation: {e}")
341
  return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
342
- # ============================ ADDITIONS END ==============================
343
 
344
  if __name__ == '__main__':
345
  app.run(debug=True, port=5001)
 
3
  from functools import wraps
4
 
5
  # ============================ NLTK MONKEY-PATCH (MUST BE FIRST) ============================
6
+ # This patch ensures NLTK downloads to a writable directory on platforms like Hugging Face Spaces.
 
 
7
  print("Applying NLTK monkey-patch...")
8
  NLTK_DATA_DIR = '/tmp/nltk_data'
9
  os.environ['NLTK_DATA'] = NLTK_DATA_DIR
10
  os.makedirs(NLTK_DATA_DIR, exist_ok=True)
11
 
 
12
  _original_nltk_download = nltk.download
13
 
 
14
  @wraps(_original_nltk_download)
15
  def _patched_nltk_download(info_or_id, download_dir=None, **kwargs):
 
 
16
  if download_dir is None:
17
  download_dir = NLTK_DATA_DIR
 
18
  print(f"Patched NLTK download called for '{info_or_id}', ensuring download_dir='{download_dir}'")
19
  return _original_nltk_download(info_or_id, download_dir=download_dir, **kwargs)
20
 
 
21
  nltk.download = _patched_nltk_download
22
  print("NLTK monkey-patch applied successfully.")
23
  # ========================================================================================
24
 
25
+ # Now that the patch is active, we can proceed with initial downloads.
 
26
  print("Running initial NLTK downloads...")
27
  nltk.download('punkt')
28
  nltk.download('stopwords')
 
29
  print("Initial NLTK downloads complete.")
30
 
31
  import time
 
34
  from werkzeug.utils import secure_filename
35
  from rag_processor import create_rag_chain
36
 
37
+ # --- Text-to-Speech Additions ---
38
  from gtts import gTTS
39
  import io
40
  import re
 
41
 
42
+ # --- MODIFIED: Lightweight Document Loaders ---
43
+ # We are only importing loaders for text-based files to keep the app lightweight.
44
+ # PyPDFLoader is used for text-based PDFs. Unstructured loaders for images are removed.
45
  from langchain_community.document_loaders import (
46
  TextLoader,
 
47
  Docx2txtLoader,
48
+ PyPDFLoader, # Lightweight PDF loader
49
  )
50
 
51
+ # --- Standard LangChain Components ---
52
  from langchain.text_splitter import RecursiveCharacterTextSplitter
53
  from langchain_huggingface import HuggingFaceEmbeddings
54
  from langchain_community.vectorstores import FAISS
 
58
 
59
  # --- Basic Flask App Setup ---
60
  app = Flask(__name__)
61
+ app.config['SECRET_KEY'] = os.urandom(24)
 
 
62
  app.config['UPLOAD_FOLDER'] = '/tmp/uploads'
 
63
  os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
64
 
65
  # --- In-memory Storage & Global Model Loading ---
66
  rag_chains = {}
67
  message_histories = {}
68
 
69
+ # The 'all-MiniLM-L6-v2' model is already a great lightweight choice. No changes needed here.
70
  print("Loading embedding model...")
71
  EMBEDDING_MODEL = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
72
  print("Embedding model loaded successfully.")
73
 
74
+ # --- MODIFIED: Lightweight Loader Mapping ---
75
+ # This mapping now only includes loaders for text-based files.
76
  LOADER_MAPPING = {
77
  ".txt": TextLoader,
78
+ ".pdf": PyPDFLoader,
79
  ".docx": Docx2txtLoader,
 
 
 
80
  }
81
 
82
  def get_session_history(session_id: str) -> ChatMessageHistory:
83
+ """Retrieves or creates a chat history for a given session ID."""
 
 
 
84
  if session_id not in message_histories:
85
  message_histories[session_id] = ChatMessageHistory()
86
  return message_histories[session_id]
 
92
 
93
  @app.route('/upload', methods=['POST'])
94
  def upload_files():
95
+ """Handles file uploads using a lightweight, text-only processing strategy."""
 
 
 
 
96
  files = request.files.getlist('file')
97
 
98
  if not files or all(f.filename == '' for f in files):
99
  return jsonify({'status': 'error', 'message': 'No selected files.'}), 400
100
+
101
  all_docs = []
102
  all_filenames = []
103
  failed_files = []
104
+
105
  try:
106
+ print(f"Processing {len(files)} files with a lightweight strategy...")
 
107
  for file in files:
108
  if file and file.filename:
109
  filename = secure_filename(file.filename)
 
112
 
113
  try:
114
  file.save(filepath)
115
+ print(f"Saved file: {filename}")
116
 
117
  file_extension = os.path.splitext(filename)[1].lower()
 
 
 
 
 
 
 
118
 
119
+ # --- REVISED: Simplified Loading Logic ---
120
+ if file_extension in LOADER_MAPPING:
121
+ loader_class = LOADER_MAPPING[file_extension]
122
+ print(f"Loading {filename} with {loader_class.__name__}...")
123
+ loader = loader_class(filepath)
124
+ loaded_docs = loader.load()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # Crucial Check: Ensure content was actually extracted.
127
+ # This is important for scanned PDFs, where PyPDFLoader will produce no text.
128
+ if loaded_docs and any(doc.page_content.strip() for doc in loaded_docs):
129
+ all_docs.extend(loaded_docs)
130
+ else:
131
+ print(f"Warning: No text content found in {filename}. It might be empty or image-based.")
132
+ failed_files.append(f"{filename} (no text found)")
133
+ else:
134
+ print(f"Skipping unsupported file type: {filename}")
135
+ failed_files.append(f"{filename} (unsupported format)")
136
+
137
  except Exception as e:
138
  print(f"Error processing file {filename}: {e}")
139
+ failed_files.append(f"{filename} (processing error)")
140
  continue
141
 
142
  if not all_docs:
143
+ error_msg = "No processable text content was extracted from the uploaded files. Please ensure files are not empty, corrupted, or image-based."
 
 
 
144
  return jsonify({'status': 'error', 'message': error_msg}), 400
145
 
146
+ # --- Process all documents together (No changes from here on) ---
147
+ print(f"Total documents with text loaded: {len(all_docs)}")
148
 
149
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
 
 
 
 
 
 
150
  splits = text_splitter.split_documents(all_docs)
151
 
 
 
 
152
  if not splits:
153
  return jsonify({
154
  'status': 'error',
155
+ 'message': 'Loaded documents but could not create text chunks. Check file content.'
156
  }), 400
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
+ print(f"Documents split into {len(splits)} chunks. Creating vector store...")
159
+ vectorstore = FAISS.from_documents(documents=splits, embedding=EMBEDDING_MODEL)
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  bm25_retriever = BM25Retriever.from_documents(splits)
162
  bm25_retriever.k = 5
 
 
163
  faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
164
 
165
+ ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])
 
 
 
 
166
 
 
167
  session_id = str(uuid.uuid4())
168
  rag_chains[session_id] = create_rag_chain(ensemble_retriever, get_session_history)
 
 
169
  session['session_id'] = session_id
170
 
171
+ print(f"RAG chain created for session {session_id}.")
 
 
172
 
173
+ response_data = {'status': 'success', 'filename': ", ".join(all_filenames)}
174
  if failed_files:
175
  response_data['warnings'] = f"Some files could not be processed: {', '.join(failed_files)}"
176
 
177
  return jsonify(response_data)
178
 
179
  except Exception as e:
 
180
  import traceback
181
  traceback.print_exc()
182
+ return jsonify({'status': 'error', 'message': f'An unexpected error occurred: {str(e)}'}), 500
183
 
184
  @app.route('/chat', methods=['POST'])
185
  def chat():
186
+ """Handles chat messages and streams the response."""
187
  data = request.get_json()
188
  question = data.get('question')
189
  session_id = session.get('session_id')
190
 
191
+ if not all([question, session_id]) or session_id not in rag_chains:
192
+ return jsonify({'status': 'error', 'message': 'Session not found or invalid. Please upload documents again.'}), 400
 
 
 
193
 
194
  try:
195
  rag_chain = rag_chains[session_id]
196
  config = {"configurable": {"session_id": session_id}}
197
 
198
  def generate():
 
199
  for chunk in rag_chain.stream({"question": question, "config": config}):
200
  yield chunk
201
 
 
205
  print(f"Error during chat invocation: {e}")
206
  return Response("An error occurred while getting the answer.", status=500, mimetype='text/plain')
207
 
208
+ # ============================ Text-to-Speech Functions ============================
 
209
  def clean_markdown_for_tts(text: str) -> str:
210
+ """Removes markdown for cleaner text-to-speech output."""
 
211
  text = re.sub(r'\*(\*?)(.*?)\1\*', r'\2', text)
212
  text = re.sub(r'\_(.*?)\_', r'\1', text)
 
213
  text = re.sub(r'`(.*?)`', r'\1', text)
 
214
  text = re.sub(r'^\s*#{1,6}\s+', '', text, flags=re.MULTILINE)
 
215
  text = re.sub(r'^\s*[\*\-]\s+', '', text, flags=re.MULTILINE)
216
  text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)
 
217
  text = re.sub(r'^\s*>\s?', '', text, flags=re.MULTILINE)
 
218
  text = re.sub(r'\n+', ' ', text)
219
  return text.strip()
220
 
221
  @app.route('/tts', methods=['POST'])
222
  def text_to_speech():
223
+ """Generates audio from text."""
224
  data = request.get_json()
225
  text = data.get('text')
226
 
 
228
  return jsonify({'status': 'error', 'message': 'No text provided.'}), 400
229
 
230
  try:
 
231
  clean_text = clean_markdown_for_tts(text)
 
232
  tts = gTTS(clean_text, lang='en')
233
  mp3_fp = io.BytesIO()
234
  tts.write_to_fp(mp3_fp)
 
237
  except Exception as e:
238
  print(f"Error in TTS generation: {e}")
239
  return jsonify({'status': 'error', 'message': 'Failed to generate audio.'}), 500
 
240
 
241
  if __name__ == '__main__':
242
  app.run(debug=True, port=5001)