Ultronprime commited on
Commit
4be0978
·
verified ·
1 Parent(s): c211fcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -11
app.py CHANGED
@@ -4,7 +4,7 @@ import logging
4
  import numpy as np
5
  from sentence_transformers import SentenceTransformer
6
  import torch
7
- from torch.amp import autocast # Corrected import for autocast
8
  from spaces import GPU
9
  import json # Import json for direct JSON output in UI
10
 
@@ -32,26 +32,34 @@ logger = logging.getLogger(__name__)
32
 
33
  # Model initialization
34
  model = None
 
35
 
36
  def initialize_model():
37
- global model
38
  try:
39
  if model is None:
40
  model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=CACHE_DIR, use_auth_token=HF_TOKEN)
41
  logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
42
- return True
 
 
43
  except Exception as e:
44
- logger.error(f"Model initialization failed: {str(e)}")
45
- return False
 
 
 
46
 
47
  @GPU()
48
  def generate_embedding(text, focus):
49
- global model
50
  if model is None:
51
- initialize_model()
 
 
52
 
53
  try:
54
- with torch.amp.autocast('cuda'): # Corrected autocast usage
55
  embedding_vector = model.encode([text])[0].tolist() # Get embedding as list
56
  # Convert embedding to JSON string for direct display in UI
57
  embedding_json_str = json.dumps(embedding_vector)
@@ -87,9 +95,11 @@ def convert_to_json(embedding_json, name): # Expect JSON string as input
87
 
88
  @GPU()
89
  def process_files(files, focus):
90
- global model
91
  if model is None:
92
- initialize_model()
 
 
93
 
94
  try:
95
  all_embeddings = []
@@ -98,7 +108,7 @@ def process_files(files, focus):
98
  try:
99
  with open(file.name, 'r') as f:
100
  text = f.read()
101
- with torch.amp.autocast('cuda'): # Corrected autocast usage
102
  embedding = model.encode([text])[0].tolist()
103
  all_embeddings.append(embedding)
104
  file_statuses.append(f"File '{file.name}' processed successfully.")
@@ -123,6 +133,8 @@ def create_gradio_interface():
123
  with gr.Blocks() as demo:
124
  gr.Markdown("## Text Embedding Generator")
125
 
 
 
126
  with gr.Row():
127
  text_input = gr.Textbox(label="Enter Text")
128
  focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)")
@@ -149,6 +161,11 @@ def create_gradio_interface():
149
  process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output
150
  process_status = gr.Textbox(label="File Processing Status") # Status for file processing
151
 
 
 
 
 
 
152
 
153
  generate_button.click(
154
  generate_embedding,
@@ -183,5 +200,10 @@ def create_gradio_interface():
183
  return demo
184
 
185
  if __name__ == "__main__":
 
 
 
 
 
186
  demo = create_gradio_interface()
187
  demo.launch(server_name="0.0.0.0")
 
4
  import numpy as np
5
  from sentence_transformers import SentenceTransformer
6
  import torch
7
+ from torch.amp import autocast
8
  from spaces import GPU
9
  import json # Import json for direct JSON output in UI
10
 
 
32
 
33
  # Model initialization
34
  model = None
35
+ model_initialization_error = "" # Global variable to store initialization error
36
 
37
  def initialize_model():
38
+ global model, model_initialization_error
39
  try:
40
  if model is None:
41
  model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=CACHE_DIR, use_auth_token=HF_TOKEN)
42
  logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
43
+ model_initialization_error = "" # Clear any previous error
44
+ return True, "" # Return success and no error message
45
+ return True, "" # Already initialized, return success and no error
46
  except Exception as e:
47
+ error_msg = f"Model initialization failed: {str(e)}"
48
+ logger.error(error_msg)
49
+ model_initialization_error = error_msg # Store error message
50
+ return False, error_msg # Return failure and error message
51
+
52
 
53
  @GPU()
54
  def generate_embedding(text, focus):
55
+ global model, model_initialization_error
56
  if model is None:
57
+ success, error_message = initialize_model() # Call initialize_model and get status
58
+ if not success:
59
+ return "", error_message # Return initialization error to UI
60
 
61
  try:
62
+ with torch.amp.autocast('cuda'):
63
  embedding_vector = model.encode([text])[0].tolist() # Get embedding as list
64
  # Convert embedding to JSON string for direct display in UI
65
  embedding_json_str = json.dumps(embedding_vector)
 
95
 
96
  @GPU()
97
  def process_files(files, focus):
98
+ global model, model_initialization_error
99
  if model is None:
100
+ success, error_message = initialize_model() # Call initialize_model and get status
101
+ if not success:
102
+ return "", error_message # Return initialization error to UI
103
 
104
  try:
105
  all_embeddings = []
 
108
  try:
109
  with open(file.name, 'r') as f:
110
  text = f.read()
111
+ with torch.amp.autocast('cuda'):
112
  embedding = model.encode([text])[0].tolist()
113
  all_embeddings.append(embedding)
114
  file_statuses.append(f"File '{file.name}' processed successfully.")
 
133
  with gr.Blocks() as demo:
134
  gr.Markdown("## Text Embedding Generator")
135
 
136
+ initialization_status_box = gr.Textbox(label="Initialization Status", value=model_initialization_error, visible=False) # Hidden box to hold init error
137
+
138
  with gr.Row():
139
  text_input = gr.Textbox(label="Enter Text")
140
  focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)")
 
161
  process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output
162
  process_status = gr.Textbox(label="File Processing Status") # Status for file processing
163
 
164
+ demo.load( # Call initialize_model on app load
165
+ lambda: ("", model_initialization_error), # Dummy output for other components, error for initialization_status_box
166
+ outputs=[status_box, initialization_status_box] # status_box for general messages, init status for hidden box
167
+ )
168
+
169
 
170
  generate_button.click(
171
  generate_embedding,
 
200
  return demo
201
 
202
  if __name__ == "__main__":
203
+ # Explicitly initialize the model at app startup and check for errors
204
+ initialization_success, initialization_error_message = initialize_model()
205
+ if not initialization_success:
206
+ print(f"App startup failed due to model initialization error:\n{initialization_error_message}") # Print to console for startup errors
207
+
208
  demo = create_gradio_interface()
209
  demo.launch(server_name="0.0.0.0")