Ultronprime commited on
Commit
9c74ac0
·
verified ·
1 Parent(s): 9c81028

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -43
app.py CHANGED
@@ -2,23 +2,23 @@ import os
2
  import gradio as gr
3
  import logging
4
  import numpy as np
5
- from transformers import AutoModel, AutoTokenizer
6
  from sentence_transformers import SentenceTransformer
7
  import torch
8
  from torch.cuda.amp import autocast
9
  from spaces import GPU
 
10
 
11
- # Constants
12
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
  CACHE_DIR = os.getenv("CACHE_DIR", "/tmp/cache")
14
  PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/tmp/data")
15
- HF_TOKEN = "YOUR_HF_TOKEN" # Replace with your Hugging Face token
16
 
17
- # Create directories
18
  os.makedirs(CACHE_DIR, exist_ok=True)
19
  os.makedirs(PERSISTENT_PATH, exist_ok=True)
20
 
21
- # Logging Setup
22
  LOG_DIR = os.getenv("LOG_DIR", "/data/logs")
23
  os.makedirs(LOG_DIR, exist_ok=True)
24
  LOG_FILE = LOG_DIR + "/app.log"
@@ -47,35 +47,39 @@ def initialize_model():
47
  @GPU()
48
  def generate_embedding(text, focus):
49
  global model
50
- if model is None:
51
- initialize_model()
52
 
53
  try:
54
  with autocast("cuda"):
55
- embedding = model.encode([text])[0].tolist()
56
- return embedding, ""
 
 
57
  except Exception as e:
58
  error_msg = f"Error generating embedding: {str(e)}"
59
  logger.error(error_msg)
60
  return "", error_msg
61
 
62
  @GPU()
63
- def save_embedding(embedding, name):
64
  try:
65
- np.save(f"{PERSISTENT_PATH}/{name}.npy", np.array(embedding))
66
- return f"Embedding saved as {name}.npy"
 
 
67
  except Exception as e:
68
  error_msg = f"Error saving embedding: {str(e)}"
69
  logger.error(error_msg)
70
  return error_msg
71
 
72
  @GPU()
73
- def convert_to_json(embedding, name):
74
  try:
75
- import json
76
- with open(f"{PERSISTENT_PATH}/{name}.json", "w") as f:
77
- json.dump(embedding, f)
78
- return f"Embedding saved as {name}.json"
79
  except Exception as e:
80
  error_msg = f"Error converting to JSON: {str(e)}"
81
  logger.error(error_msg)
@@ -84,23 +88,37 @@ def convert_to_json(embedding, name):
84
  @GPU()
85
  def process_files(files, focus):
86
  global model
87
- if model is None:
88
- initialize_model()
89
 
90
  try:
91
  all_embeddings = []
 
92
  for file in files:
93
- with open(file.name, 'r') as f:
94
- text = f.read()
95
- with autocast("cuda"):
96
- embedding = model.encode([text])[0].tolist()
97
- all_embeddings.append(embedding)
98
- return all_embeddings, ""
 
 
 
 
 
 
 
 
 
 
 
 
99
  except Exception as e:
100
- error_msg = f"Error processing files: {str(e)}"
101
  logger.error(error_msg)
102
  return "", error_msg
103
 
 
104
  def create_gradio_interface():
105
  with gr.Blocks() as demo:
106
  gr.Markdown("## Text Embedding Generator")
@@ -113,41 +131,45 @@ def create_gradio_interface():
113
  file_input = gr.File(label="Upload Files", file_count="multiple")
114
 
115
  generate_button = gr.Button("Generate Embedding")
116
- embedding_output = gr.Textbox(label="Embedding Vector", lines=5)
117
- error_box = gr.Textbox(label="Status/Error Messages")
118
-
119
- save_name_input = gr.Textbox(label="Save Embedding As")
120
- save_button = gr.Button("Save Embedding")
121
- save_status = gr.Textbox(label="Save Status")
122
-
123
- convert_button = gr.Button("Convert to JSON")
124
- convert_status = gr.Textbox(label="Convert Status")
125
- download_button = gr.Button("Download JSON")
126
- download_output = gr.File(label="Download JSON")
 
 
127
 
128
  process_button = gr.Button("Process Files")
129
- process_output = gr.Textbox(label="Processed Files", lines=5)
 
 
130
 
131
  generate_button.click(
132
  generate_embedding,
133
  inputs=[text_input, focus_input],
134
- outputs=[embedding_output, error_box]
135
  )
136
 
137
  save_button.click(
138
  save_embedding,
139
- inputs=[embedding_output, save_name_input],
140
  outputs=[save_status]
141
  )
142
 
143
  convert_button.click(
144
  convert_to_json,
145
- inputs=[embedding_output, save_name_input],
146
  outputs=[convert_status]
147
  )
148
 
149
  download_button.click(
150
- lambda name: f"{PERSISTENT_PATH}/{name}.json",
151
  inputs=[save_name_input],
152
  outputs=[download_output]
153
  )
@@ -155,7 +177,7 @@ def create_gradio_interface():
155
  process_button.click(
156
  process_files,
157
  inputs=[file_input, focus_input],
158
- outputs=[process_output, error_box]
159
  )
160
 
161
  return demo
 
2
  import gradio as gr
3
  import logging
4
  import numpy as np
 
5
  from sentence_transformers import SentenceTransformer
6
  import torch
7
  from torch.cuda.amp import autocast
8
  from spaces import GPU
9
+ import json # Import json for direct JSON output in UI
10
 
11
+ # Constants (Keep your HF token secure - use environment variables if possible for real deployments)
12
  EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
  CACHE_DIR = os.getenv("CACHE_DIR", "/tmp/cache")
14
  PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/tmp/data")
15
+ HF_TOKEN = "YOUR_HF_TOKEN" # REMEMBER TO REPLACE THIS - BEST TO USE ENVIRONMENT VARIABLE
16
 
17
+ # Create directories (still useful to try, even if /tmp/ is ephemeral)
18
  os.makedirs(CACHE_DIR, exist_ok=True)
19
  os.makedirs(PERSISTENT_PATH, exist_ok=True)
20
 
21
+ # Logging Setup (keep logging - it's helpful for debugging)
22
  LOG_DIR = os.getenv("LOG_DIR", "/data/logs")
23
  os.makedirs(LOG_DIR, exist_ok=True)
24
  LOG_FILE = LOG_DIR + "/app.log"
 
47
  @GPU()
48
  def generate_embedding(text, focus):
49
  global model
50
+ if model is None:
51
+ initialize_model()
52
 
53
  try:
54
  with autocast("cuda"):
55
+ embedding_vector = model.encode([text])[0].tolist() # Get embedding as list
56
+ # Convert embedding to JSON string for direct display in UI
57
+ embedding_json_str = json.dumps(embedding_vector)
58
+ return embedding_json_str, "" # Return JSON string to UI
59
  except Exception as e:
60
  error_msg = f"Error generating embedding: {str(e)}"
61
  logger.error(error_msg)
62
  return "", error_msg
63
 
64
  @GPU()
65
+ def save_embedding(embedding_json, name): # Expect JSON string as input from UI
66
  try:
67
+ embedding = json.loads(embedding_json) # Parse JSON string back to list
68
+ filepath = f"{PERSISTENT_PATH}/{name}.npy" # Construct full filepath
69
+ np.save(filepath, np.array(embedding))
70
+ return f"Embedding saved to: {filepath}" # Return filepath in status
71
  except Exception as e:
72
  error_msg = f"Error saving embedding: {str(e)}"
73
  logger.error(error_msg)
74
  return error_msg
75
 
76
  @GPU()
77
+ def convert_to_json(embedding_json, name): # Expect JSON string as input
78
  try:
79
+ filepath = f"{PERSISTENT_PATH}/{name}.json" # Construct full filepath
80
+ with open(filepath, "w") as f:
81
+ f.write(embedding_json) # Directly write the JSON string
82
+ return f"Embedding saved as JSON to: {filepath}" # Return filepath in status
83
  except Exception as e:
84
  error_msg = f"Error converting to JSON: {str(e)}"
85
  logger.error(error_msg)
 
88
  @GPU()
89
  def process_files(files, focus):
90
  global model
91
+ if model is None:
92
+ initialize_model()
93
 
94
  try:
95
  all_embeddings = []
96
+ file_statuses = [] # To track status for each file
97
  for file in files:
98
+ try:
99
+ with open(file.name, 'r') as f:
100
+ text = f.read()
101
+ with autocast("cuda"):
102
+ embedding = model.encode([text])[0].tolist()
103
+ all_embeddings.append(embedding)
104
+ file_statuses.append(f"File '{file.name}' processed successfully.")
105
+ except Exception as file_e:
106
+ error_msg = f"Error processing file '{file.name}': {str(file_e)}"
107
+ logger.error(error_msg)
108
+ file_statuses.append(error_msg)
109
+
110
+ # Prepare status message for all files
111
+ status_message = "\n".join(file_statuses)
112
+ # Convert embeddings to JSON string for UI display (for demonstration - might be too long for large files)
113
+ all_embeddings_json = json.dumps(all_embeddings)
114
+
115
+ return all_embeddings_json, status_message # Return JSON string and status message
116
  except Exception as e:
117
+ error_msg = f"Error in process_files function: {str(e)}"
118
  logger.error(error_msg)
119
  return "", error_msg
120
 
121
+
122
  def create_gradio_interface():
123
  with gr.Blocks() as demo:
124
  gr.Markdown("## Text Embedding Generator")
 
131
  file_input = gr.File(label="Upload Files", file_count="multiple")
132
 
133
  generate_button = gr.Button("Generate Embedding")
134
+ embedding_output = gr.Textbox(label="Embedding Vector (JSON)", lines=5) # Label changed to JSON
135
+ status_box = gr.Textbox(label="Status/Messages") # Renamed error_box to status_box
136
+
137
+ with gr.Accordion("Save and Download Options", open=False): # Accordion for save/download options
138
+ save_name_input = gr.Textbox(label="Save Embedding As (Name without extension)")
139
+ with gr.Row():
140
+ save_button = gr.Button("Save as .npy")
141
+ convert_button = gr.Button("Save as .json")
142
+ with gr.Row():
143
+ save_status = gr.Textbox(label="Save Status")
144
+ convert_status = gr.Textbox(label="Convert Status")
145
+ download_button = gr.Button("Download JSON")
146
+ download_output = gr.File(label="Download JSON File")
147
 
148
  process_button = gr.Button("Process Files")
149
+ process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output
150
+ process_status = gr.Textbox(label="File Processing Status") # Status for file processing
151
+
152
 
153
  generate_button.click(
154
  generate_embedding,
155
  inputs=[text_input, focus_input],
156
+ outputs=[embedding_output, status_box] # Renamed error_box to status_box
157
  )
158
 
159
  save_button.click(
160
  save_embedding,
161
+ inputs=[embedding_output, save_name_input], # Input is now embedding_output (JSON string)
162
  outputs=[save_status]
163
  )
164
 
165
  convert_button.click(
166
  convert_to_json,
167
+ inputs=[embedding_output, save_name_input], # Input is embedding_output (JSON string)
168
  outputs=[convert_status]
169
  )
170
 
171
  download_button.click(
172
+ lambda name: f"{PERSISTENT_PATH}/{name}.json" if name else None, # Handle empty name
173
  inputs=[save_name_input],
174
  outputs=[download_output]
175
  )
 
177
  process_button.click(
178
  process_files,
179
  inputs=[file_input, focus_input],
180
+ outputs=[process_output, process_status] # outputs for process_files
181
  )
182
 
183
  return demo