Ultronprime commited on
Commit
bb01969
Β·
verified Β·
1 Parent(s): 09c1ee0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -34
app.py CHANGED
@@ -45,7 +45,7 @@ os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777)
45
  LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
46
  os.makedirs(LOG_DIR, exist_ok=True, mode=0o777)
47
 
48
- # Set Hugging Face cache directory to PERSISTENT_PATH
49
  os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface")
50
  os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777)
51
 
@@ -64,7 +64,9 @@ def initialize_model():
64
  global model
65
  try:
66
  if model is None:
67
- model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=os.path.join(PERSISTENT_PATH, "models"))
 
 
68
  logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
69
  return True
70
  except requests.exceptions.ConnectionError as e:
@@ -78,6 +80,7 @@ def initialize_model():
78
  def handle_gpu_operation(func):
79
  try:
80
  start_time = datetime.now()
 
81
  with autocast(device_type='cuda', dtype=torch.float16):
82
  result = func()
83
  end_time = datetime.now()
@@ -121,7 +124,7 @@ def process_files(files):
121
 
122
  valid_files = [f for f in files if f.name.lower().endswith('.txt')]
123
  if not valid_files:
124
- return "No .txt files found in upload. Please ensure you upload .txt files.", "", ""
125
 
126
  all_chunks = []
127
  processed_files = 0
@@ -133,6 +136,7 @@ def process_files(files):
133
  detected_encoding = from_bytes(content).best().encoding
134
  decoded_content = content.decode(detected_encoding, errors='ignore')
135
 
 
136
  chunks = [decoded_content[i:i+CHUNK_SIZE] for i in range(0, len(decoded_content), CHUNK_SIZE)]
137
  all_chunks.extend(chunks)
138
  processed_files += 1
@@ -141,7 +145,7 @@ def process_files(files):
141
  logger.error(f"Error processing file {file.name}: {str(e)}")
142
 
143
  if not all_chunks:
144
- return "No valid content found in the uploaded .txt files.", "", ""
145
 
146
  # Generate embeddings in batches
147
  all_embeddings = []
@@ -156,7 +160,6 @@ def process_files(files):
156
  # Save results to OUTPUTS_DIR
157
  embeddings_path = os.path.join(OUTPUTS_DIR, "embeddings.npy")
158
  np.save(embeddings_path, np.array(all_embeddings))
159
-
160
  chunks_path = os.path.join(OUTPUTS_DIR, "chunks.txt")
161
  with open(chunks_path, "w", encoding="utf-8") as f:
162
  for chunk in all_chunks:
@@ -179,19 +182,16 @@ def semantic_search(query, top_k=5):
179
  return "Model not initialized. Please process files first."
180
 
181
  try:
182
- # Load saved embeddings from OUTPUTS_DIR
183
- stored_embeddings = np.load(os.path.join(OUTPUTS_DIR, "embeddings.npy"))
184
-
185
- # Load stored chunks from OUTPUTS_DIR
186
- with open(os.path.join(OUTPUTS_DIR, "chunks.txt"), "r", encoding="utf-8") as f:
187
  chunks = f.read().split("\n===CHUNK_SEPARATOR===\n")
188
  chunks = [c for c in chunks if c.strip()]
189
 
190
  # Get query embedding
191
- if model:
192
- query_embedding = model.encode([query])[0]
193
- else:
194
- return "Model not initialized. Please process files first."
195
 
196
  # Calculate similarities
197
  similarities = np.dot(stored_embeddings, query_embedding) / (
@@ -200,8 +200,6 @@ def semantic_search(query, top_k=5):
200
 
201
  # Get top results
202
  top_indices = np.argsort(similarities)[-top_k:][::-1]
203
-
204
- # Format results
205
  results = []
206
  for idx in top_indices:
207
  results.append(f"""
@@ -209,9 +207,7 @@ Similarity: {similarities[idx]:.3f}
209
  Content: {chunks[idx]}
210
  -------------------
211
  """)
212
-
213
  return "\n".join(results)
214
-
215
  except Exception as e:
216
  logger.error(f"Search error: {str(e)}")
217
  return f"Search error occurred: {str(e)}"
@@ -223,11 +219,12 @@ def search_and_format(query, num_results):
223
 
224
  def browse_outputs():
225
  try:
 
226
  webbrowser.open(f"file://{OUTPUTS_DIR}")
227
- return "Opened outputs directory"
228
  except Exception as e:
229
  logger.error(f"Error opening file browser: {str(e)}")
230
- return "Error opening file browser"
231
 
232
  def download_results():
233
  required_files = ["embeddings.npy", "chunks.txt"]
@@ -235,15 +232,15 @@ def download_results():
235
  if missing:
236
  logger.error(f"Missing files: {missing}")
237
  return None
238
-
239
  try:
240
  zip_path = os.path.join(OUTPUTS_DIR, "results.zip")
241
  with zipfile.ZipFile(zip_path, 'w') as zipf:
242
  for file in required_files:
243
- zipf.write(os.path.join(OUTPUTS_DIR, file), file)
 
244
  return zip_path
245
  except Exception as e:
246
- logger.error(f"Error creating download: {str(e)}")
247
  return None
248
 
249
  def create_gradio_interface():
@@ -262,12 +259,18 @@ def create_gradio_interface():
262
  process_button = gr.Button("Generate Embeddings")
263
  output_text = gr.Textbox(label="Status")
264
 
 
 
 
 
 
 
265
  with gr.Tab("Search"):
266
  query_input = gr.Textbox(
267
  label="Enter your search query",
268
  placeholder="Enter text to search through your documents..."
269
  )
270
- top_k = gr.Slider(
271
  minimum=1,
272
  maximum=20,
273
  value=5,
@@ -280,10 +283,9 @@ def create_gradio_interface():
280
  lines=10,
281
  show_copy_button=True
282
  )
283
-
284
  search_button.click(
285
  fn=search_and_format,
286
- inputs=[query_input, top_k],
287
  outputs=results_output
288
  )
289
 
@@ -297,17 +299,11 @@ def create_gradio_interface():
297
  browse_button = gr.Button("πŸ“ Browse Outputs")
298
  browse_button.click(
299
  fn=browse_outputs,
300
- outputs=None
301
  )
302
 
303
- process_button.click(
304
- process_files,
305
- inputs=[file_input],
306
- outputs=[output_text, error_box, error_box]
307
- )
308
-
309
  return demo
310
 
311
  if __name__ == "__main__":
312
  demo = create_gradio_interface()
313
- demo.launch(server_name="0.0.0.0")
 
45
  LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
46
  os.makedirs(LOG_DIR, exist_ok=True, mode=0o777)
47
 
48
+ # Set Hugging Face cache directory to persistent storage
49
  os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface")
50
  os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777)
51
 
 
64
  global model
65
  try:
66
  if model is None:
67
+ model_cache = os.path.join(PERSISTENT_PATH, "models")
68
+ os.makedirs(model_cache, exist_ok=True, mode=0o777)
69
+ model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache)
70
  logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
71
  return True
72
  except requests.exceptions.ConnectionError as e:
 
80
  def handle_gpu_operation(func):
81
  try:
82
  start_time = datetime.now()
83
+ # Updated autocast usage as per deprecation notice
84
  with autocast(device_type='cuda', dtype=torch.float16):
85
  result = func()
86
  end_time = datetime.now()
 
124
 
125
  valid_files = [f for f in files if f.name.lower().endswith('.txt')]
126
  if not valid_files:
127
+ return "No .txt files found. Please upload valid .txt files.", "", ""
128
 
129
  all_chunks = []
130
  processed_files = 0
 
136
  detected_encoding = from_bytes(content).best().encoding
137
  decoded_content = content.decode(detected_encoding, errors='ignore')
138
 
139
+ # Split content into chunks
140
  chunks = [decoded_content[i:i+CHUNK_SIZE] for i in range(0, len(decoded_content), CHUNK_SIZE)]
141
  all_chunks.extend(chunks)
142
  processed_files += 1
 
145
  logger.error(f"Error processing file {file.name}: {str(e)}")
146
 
147
  if not all_chunks:
148
+ return "No valid content found in the uploaded files.", "", ""
149
 
150
  # Generate embeddings in batches
151
  all_embeddings = []
 
160
  # Save results to OUTPUTS_DIR
161
  embeddings_path = os.path.join(OUTPUTS_DIR, "embeddings.npy")
162
  np.save(embeddings_path, np.array(all_embeddings))
 
163
  chunks_path = os.path.join(OUTPUTS_DIR, "chunks.txt")
164
  with open(chunks_path, "w", encoding="utf-8") as f:
165
  for chunk in all_chunks:
 
182
  return "Model not initialized. Please process files first."
183
 
184
  try:
185
+ # Load saved embeddings and chunks from OUTPUTS_DIR
186
+ embeddings_file = os.path.join(OUTPUTS_DIR, "embeddings.npy")
187
+ chunks_file = os.path.join(OUTPUTS_DIR, "chunks.txt")
188
+ stored_embeddings = np.load(embeddings_file)
189
+ with open(chunks_file, "r", encoding="utf-8") as f:
190
  chunks = f.read().split("\n===CHUNK_SEPARATOR===\n")
191
  chunks = [c for c in chunks if c.strip()]
192
 
193
  # Get query embedding
194
+ query_embedding = model.encode([query])[0]
 
 
 
195
 
196
  # Calculate similarities
197
  similarities = np.dot(stored_embeddings, query_embedding) / (
 
200
 
201
  # Get top results
202
  top_indices = np.argsort(similarities)[-top_k:][::-1]
 
 
203
  results = []
204
  for idx in top_indices:
205
  results.append(f"""
 
207
  Content: {chunks[idx]}
208
  -------------------
209
  """)
 
210
  return "\n".join(results)
 
211
  except Exception as e:
212
  logger.error(f"Search error: {str(e)}")
213
  return f"Search error occurred: {str(e)}"
 
219
 
220
  def browse_outputs():
221
  try:
222
+ # Open the outputs directory in a web browser (may work on some systems)
223
  webbrowser.open(f"file://{OUTPUTS_DIR}")
224
+ return "Opened outputs directory."
225
  except Exception as e:
226
  logger.error(f"Error opening file browser: {str(e)}")
227
+ return "Error opening file browser."
228
 
229
  def download_results():
230
  required_files = ["embeddings.npy", "chunks.txt"]
 
232
  if missing:
233
  logger.error(f"Missing files: {missing}")
234
  return None
 
235
  try:
236
  zip_path = os.path.join(OUTPUTS_DIR, "results.zip")
237
  with zipfile.ZipFile(zip_path, 'w') as zipf:
238
  for file in required_files:
239
+ file_path = os.path.join(OUTPUTS_DIR, file)
240
+ zipf.write(file_path, file)
241
  return zip_path
242
  except Exception as e:
243
+ logger.error(f"Error creating download archive: {str(e)}")
244
  return None
245
 
246
  def create_gradio_interface():
 
259
  process_button = gr.Button("Generate Embeddings")
260
  output_text = gr.Textbox(label="Status")
261
 
262
+ process_button.click(
263
+ fn=process_files,
264
+ inputs=[file_input],
265
+ outputs=[output_text, error_box, error_box]
266
+ )
267
+
268
  with gr.Tab("Search"):
269
  query_input = gr.Textbox(
270
  label="Enter your search query",
271
  placeholder="Enter text to search through your documents..."
272
  )
273
+ top_k_slider = gr.Slider(
274
  minimum=1,
275
  maximum=20,
276
  value=5,
 
283
  lines=10,
284
  show_copy_button=True
285
  )
 
286
  search_button.click(
287
  fn=search_and_format,
288
+ inputs=[query_input, top_k_slider],
289
  outputs=results_output
290
  )
291
 
 
299
  browse_button = gr.Button("πŸ“ Browse Outputs")
300
  browse_button.click(
301
  fn=browse_outputs,
302
+ outputs=[gr.Textbox(label="Browse Status")]
303
  )
304
 
 
 
 
 
 
 
305
  return demo
306
 
307
  if __name__ == "__main__":
308
  demo = create_gradio_interface()
309
+ demo.launch(server_name="0.0.0.0")