Ultronprime commited on
Commit
d4cee85
·
verified ·
1 Parent(s): 1b5d6e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -71
app.py CHANGED
@@ -8,14 +8,9 @@ from dataclasses import dataclass
8
  from datetime import datetime
9
  from pathlib import Path
10
  import gc
11
-
12
- import torch
13
- from torch.cuda.amp import autocast
14
- from transformers import AutoModel, AutoTokenizer
15
- from sentence_transformers import SentenceTransformer
16
- from charset_normalizer import from_bytes
17
- import numpy as np
18
- import requests
19
 
20
  # Custom Exception Class
21
  class GPUQuotaExceededError(Exception):
@@ -26,19 +21,22 @@ EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
26
  CHUNK_SIZE = 500
27
  BATCH_SIZE = 32
28
  CACHE_DIR = os.getenv("CACHE_DIR", "/tmp/cache")
29
- PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data")
30
 
31
- # Create directories
32
- os.makedirs(CACHE_DIR, exist_ok=True)
33
  os.makedirs(PERSISTENT_PATH, exist_ok=True)
 
 
 
 
34
 
35
  # Logging Setup
36
- LOG_DIR = os.getenv("LOG_DIR", "/data/logs")
37
  os.makedirs(LOG_DIR, exist_ok=True)
38
- LOG_FILE = Path(LOG_DIR) / "app.log"
39
 
40
  logging.basicConfig(
41
- filename=str(LOG_FILE),
42
  level=logging.INFO,
43
  format="%(asctime)s - %(levelname)s - %(message)s",
44
  )
@@ -137,10 +135,12 @@ def process_files(files):
137
  embeddings = handle_gpu_operation(lambda: get_model().encode(batch))
138
  all_embeddings.extend(embeddings)
139
 
140
- # Save results
141
- np.save(f"{PERSISTENT_PATH}/embeddings.npy", np.array(all_embeddings))
 
142
 
143
- with open(f"{PERSISTENT_PATH}/chunks.txt", "w", encoding="utf-8") as f:
 
144
  for chunk in all_chunks:
145
  f.write(chunk + "\n===CHUNK_SEPARATOR===\n")
146
 
@@ -162,16 +162,16 @@ def semantic_search(query, top_k=5):
162
  return "Model initialization failed. Please try again."
163
 
164
  try:
165
- # Load saved embeddings
166
- stored_embeddings = np.load(f"{PERSISTENT_PATH}/embeddings.npy")
167
 
168
- # Load stored chunks
169
- with open(f"{PERSISTENT_PATH}/chunks.txt", "r", encoding="utf-8") as f:
170
  chunks = f.read().split("\n===CHUNK_SEPARATOR===\n")
171
  chunks = [c for c in chunks if c.strip()] # Remove empty chunks
172
 
173
  # Get query embedding
174
- query_embedding = handle_gpu_operation(lambda: get_model().encode([query]))[0] # Use get_model() to get the model
175
 
176
  # Calculate similarities
177
  similarities = np.dot(stored_embeddings, query_embedding) / (
@@ -201,40 +201,33 @@ def search_and_format(query, num_results):
201
  return "Please enter a search query"
202
  return semantic_search(query, top_k=num_results)
203
 
204
- def download_results(text):
205
- if not text:
206
- return None
207
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
208
- filename = f"search_results_{timestamp}.txt"
209
- with open(filename, "w", encoding="utf-8") as f:
210
- f.write(text)
211
- return filename
212
-
213
- @spaces.GPU
214
- def safe_generate_embedding(text):
215
- global model
216
- if model is None: # Check if model is initialized
217
- initialize_model() # Initialize only if needed and within GPU context
218
-
219
  try:
220
- embedding = handle_gpu_operation(
221
- lambda: get_model().encode([text])[0].tolist() # Use get_model() to get the model
222
- )
223
- return embedding, "", False
224
- except GPUQuotaExceededError as e:
225
- error_msg = str(e)
226
- logger.error(error_msg)
227
- return "", error_msg, True
228
  except Exception as e:
229
- error_msg = f"Error generating embedding: {str(e)}"
230
- logger.error(error_msg)
231
- return "", error_msg, True
232
 
233
- def download_embeddings():
234
- embeddings_path = f"{PERSISTENT_PATH}/embeddings.npy"
235
- if not os.path.exists(embeddings_path):
236
- return None
237
- return embeddings_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  def create_gradio_interface():
240
  with gr.Blocks() as demo:
@@ -270,7 +263,6 @@ def create_gradio_interface():
270
  lines=10,
271
  show_copy_button=True
272
  )
273
- download_button = gr.Button("⬇️ Download Results")
274
 
275
  search_button.click(
276
  fn=search_and_format,
@@ -278,27 +270,26 @@ def create_gradio_interface():
278
  outputs=results_output
279
  )
280
 
281
- download_button.click(
282
- fn=download_results,
283
- inputs=[results_output],
284
- outputs=[gr.File(label="Download Search Results")]
 
285
  )
286
 
287
- with gr.Tab("Inspect Embeddings"):
288
- embed_input = gr.Textbox(label="Enter Text for Embedding")
289
- embed_button = gr.Button("Generate Embedding")
290
- embed_output = gr.Textbox(label="Embedding Vector", lines=5)
291
-
292
- embed_button.click(
293
- safe_generate_embedding,
294
- inputs=[embed_input],
295
- outputs=[embed_output, error_box, error_box]
296
  )
297
 
298
- download_embeddings_button = gr.Button("⬇️ Download Embeddings")
299
- download_embeddings_button.click(
300
- fn=download_embeddings,
301
- outputs=[gr.File(label="Download Embeddings")]
 
302
  )
303
 
304
  process_button.click(
 
8
  from datetime import datetime
9
  from pathlib import Path
10
  import gc
11
+ import zipfile
12
+ import shutil
13
+ import tempfile
 
 
 
 
 
14
 
15
  # Custom Exception Class
16
  class GPUQuotaExceededError(Exception):
 
21
  CHUNK_SIZE = 500
22
  BATCH_SIZE = 32
23
  CACHE_DIR = os.getenv("CACHE_DIR", "/tmp/cache")
24
+ PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/workspace")
25
 
26
+ # Directories setup
 
27
  os.makedirs(PERSISTENT_PATH, exist_ok=True)
28
+ TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp")
29
+ os.makedirs(TEMP_DIR, exist_ok=True)
30
+ OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs")
31
+ os.makedirs(OUTPUTS_DIR, exist_ok=True)
32
 
33
  # Logging Setup
34
+ LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs"))
35
  os.makedirs(LOG_DIR, exist_ok=True)
36
+ LOG_FILE = os.path.join(LOG_DIR, "app.log")
37
 
38
  logging.basicConfig(
39
+ filename=LOG_FILE,
40
  level=logging.INFO,
41
  format="%(asctime)s - %(levelname)s - %(message)s",
42
  )
 
135
  embeddings = handle_gpu_operation(lambda: get_model().encode(batch))
136
  all_embeddings.extend(embeddings)
137
 
138
+ # Save results to OUTPUTS_DIR
139
+ embeddings_path = os.path.join(OUTPUTS_DIR, "embeddings.npy")
140
+ np.save(embeddings_path, np.array(all_embeddings))
141
 
142
+ chunks_path = os.path.join(OUTPUTS_DIR, "chunks.txt")
143
+ with open(chunks_path, "w", encoding="utf-8") as f:
144
  for chunk in all_chunks:
145
  f.write(chunk + "\n===CHUNK_SEPARATOR===\n")
146
 
 
162
  return "Model initialization failed. Please try again."
163
 
164
  try:
165
+ # Load saved embeddings from OUTPUTS_DIR
166
+ stored_embeddings = np.load(os.path.join(OUTPUTS_DIR, "embeddings.npy"))
167
 
168
+ # Load stored chunks from OUTPUTS_DIR
169
+ with open(os.path.join(OUTPUTS_DIR, "chunks.txt"), "r", encoding="utf-8") as f:
170
  chunks = f.read().split("\n===CHUNK_SEPARATOR===\n")
171
  chunks = [c for c in chunks if c.strip()] # Remove empty chunks
172
 
173
  # Get query embedding
174
+ query_embedding = handle_gpu_operation(lambda: get_model().encode([query]))[0]
175
 
176
  # Calculate similarities
177
  similarities = np.dot(stored_embeddings, query_embedding) / (
 
201
  return "Please enter a search query"
202
  return semantic_search(query, top_k=num_results)
203
 
204
+ def browse_outputs():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  try:
206
+ os.startfile(OUTPUTS_DIR) # For Windows, on Linux use subprocess.run(['xdg-open', OUTPUTS_DIR])
 
 
 
 
 
 
 
207
  except Exception as e:
208
+ logger.error(f"Error opening file browser: {str(e)}")
209
+ return "Error opening file browser"
 
210
 
211
+ def download_results_from_disk():
212
+ try:
213
+ output_files = [
214
+ os.path.join(OUTPUTS_DIR, "embeddings.npy"),
215
+ os.path.join(OUTPUTS_DIR, "chunks.txt")
216
+ ]
217
+
218
+ # Create a temporary zip file
219
+ temp_dir = tempfile.gettempdir()
220
+ zip_path = os.path.join(temp_dir, "results.zip")
221
+
222
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
223
+ for file in output_files:
224
+ if os.path.exists(file):
225
+ zipf.write(file, os.path.basename(file))
226
+
227
+ return zip_path
228
+ except Exception as e:
229
+ logger.error(f"Error creating download: {str(e)}")
230
+ return "Error creating download file"
231
 
232
  def create_gradio_interface():
233
  with gr.Blocks() as demo:
 
263
  lines=10,
264
  show_copy_button=True
265
  )
 
266
 
267
  search_button.click(
268
  fn=search_and_format,
 
270
  outputs=results_output
271
  )
272
 
273
+ # Download Results Button
274
+ download_results_button = gr.Button("⬇️ Download Search Results")
275
+ download_results_button.click(
276
+ fn=download_results_from_disk,
277
+ outputs=[gr.File(label="Download Results")]
278
  )
279
 
280
+ with gr.Tab("_FILES_"):
281
+ # Browse Outputs Button
282
+ browse_button = gr.Button("📁 Browse Outputs", variant="primary")
283
+ browse_button.click(
284
+ fn=browse_outputs,
285
+ outputs=None
 
 
 
286
  )
287
 
288
+ # Download All Results Button
289
+ download_all_button = gr.Button("⬇️ Download All Results", variant="primary")
290
+ download_all_button.click(
291
+ fn=download_results_from_disk,
292
+ outputs=[gr.File(label="Download All Results")]
293
  )
294
 
295
  process_button.click(