Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import logging
|
|
| 4 |
import numpy as np
|
| 5 |
from sentence_transformers import SentenceTransformer
|
| 6 |
import torch
|
| 7 |
-
from torch.amp import autocast
|
| 8 |
from spaces import GPU
|
| 9 |
import json # Import json for direct JSON output in UI
|
| 10 |
|
|
@@ -32,26 +32,34 @@ logger = logging.getLogger(__name__)
|
|
| 32 |
|
| 33 |
# Model initialization
|
| 34 |
model = None
|
|
|
|
| 35 |
|
| 36 |
def initialize_model():
|
| 37 |
-
global model
|
| 38 |
try:
|
| 39 |
if model is None:
|
| 40 |
model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=CACHE_DIR, use_auth_token=HF_TOKEN)
|
| 41 |
logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
except Exception as e:
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
@GPU()
|
| 48 |
def generate_embedding(text, focus):
|
| 49 |
-
global model
|
| 50 |
if model is None:
|
| 51 |
-
initialize_model()
|
|
|
|
|
|
|
| 52 |
|
| 53 |
try:
|
| 54 |
-
with torch.amp.autocast('cuda'):
|
| 55 |
embedding_vector = model.encode([text])[0].tolist() # Get embedding as list
|
| 56 |
# Convert embedding to JSON string for direct display in UI
|
| 57 |
embedding_json_str = json.dumps(embedding_vector)
|
|
@@ -87,9 +95,11 @@ def convert_to_json(embedding_json, name): # Expect JSON string as input
|
|
| 87 |
|
| 88 |
@GPU()
|
| 89 |
def process_files(files, focus):
|
| 90 |
-
global model
|
| 91 |
if model is None:
|
| 92 |
-
initialize_model()
|
|
|
|
|
|
|
| 93 |
|
| 94 |
try:
|
| 95 |
all_embeddings = []
|
|
@@ -98,7 +108,7 @@ def process_files(files, focus):
|
|
| 98 |
try:
|
| 99 |
with open(file.name, 'r') as f:
|
| 100 |
text = f.read()
|
| 101 |
-
with torch.amp.autocast('cuda'):
|
| 102 |
embedding = model.encode([text])[0].tolist()
|
| 103 |
all_embeddings.append(embedding)
|
| 104 |
file_statuses.append(f"File '{file.name}' processed successfully.")
|
|
@@ -123,6 +133,8 @@ def create_gradio_interface():
|
|
| 123 |
with gr.Blocks() as demo:
|
| 124 |
gr.Markdown("## Text Embedding Generator")
|
| 125 |
|
|
|
|
|
|
|
| 126 |
with gr.Row():
|
| 127 |
text_input = gr.Textbox(label="Enter Text")
|
| 128 |
focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)")
|
|
@@ -149,6 +161,11 @@ def create_gradio_interface():
|
|
| 149 |
process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output
|
| 150 |
process_status = gr.Textbox(label="File Processing Status") # Status for file processing
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
generate_button.click(
|
| 154 |
generate_embedding,
|
|
@@ -183,5 +200,10 @@ def create_gradio_interface():
|
|
| 183 |
return demo
|
| 184 |
|
| 185 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
demo = create_gradio_interface()
|
| 187 |
demo.launch(server_name="0.0.0.0")
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
from sentence_transformers import SentenceTransformer
|
| 6 |
import torch
|
| 7 |
+
from torch.amp import autocast
|
| 8 |
from spaces import GPU
|
| 9 |
import json # Import json for direct JSON output in UI
|
| 10 |
|
|
|
|
| 32 |
|
| 33 |
# Model initialization
|
| 34 |
model = None
|
| 35 |
+
model_initialization_error = "" # Global variable to store initialization error
|
| 36 |
|
| 37 |
def initialize_model():
|
| 38 |
+
global model, model_initialization_error
|
| 39 |
try:
|
| 40 |
if model is None:
|
| 41 |
model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=CACHE_DIR, use_auth_token=HF_TOKEN)
|
| 42 |
logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}")
|
| 43 |
+
model_initialization_error = "" # Clear any previous error
|
| 44 |
+
return True, "" # Return success and no error message
|
| 45 |
+
return True, "" # Already initialized, return success and no error
|
| 46 |
except Exception as e:
|
| 47 |
+
error_msg = f"Model initialization failed: {str(e)}"
|
| 48 |
+
logger.error(error_msg)
|
| 49 |
+
model_initialization_error = error_msg # Store error message
|
| 50 |
+
return False, error_msg # Return failure and error message
|
| 51 |
+
|
| 52 |
|
| 53 |
@GPU()
|
| 54 |
def generate_embedding(text, focus):
|
| 55 |
+
global model, model_initialization_error
|
| 56 |
if model is None:
|
| 57 |
+
success, error_message = initialize_model() # Call initialize_model and get status
|
| 58 |
+
if not success:
|
| 59 |
+
return "", error_message # Return initialization error to UI
|
| 60 |
|
| 61 |
try:
|
| 62 |
+
with torch.amp.autocast('cuda'):
|
| 63 |
embedding_vector = model.encode([text])[0].tolist() # Get embedding as list
|
| 64 |
# Convert embedding to JSON string for direct display in UI
|
| 65 |
embedding_json_str = json.dumps(embedding_vector)
|
|
|
|
| 95 |
|
| 96 |
@GPU()
|
| 97 |
def process_files(files, focus):
|
| 98 |
+
global model, model_initialization_error
|
| 99 |
if model is None:
|
| 100 |
+
success, error_message = initialize_model() # Call initialize_model and get status
|
| 101 |
+
if not success:
|
| 102 |
+
return "", error_message # Return initialization error to UI
|
| 103 |
|
| 104 |
try:
|
| 105 |
all_embeddings = []
|
|
|
|
| 108 |
try:
|
| 109 |
with open(file.name, 'r') as f:
|
| 110 |
text = f.read()
|
| 111 |
+
with torch.amp.autocast('cuda'):
|
| 112 |
embedding = model.encode([text])[0].tolist()
|
| 113 |
all_embeddings.append(embedding)
|
| 114 |
file_statuses.append(f"File '{file.name}' processed successfully.")
|
|
|
|
| 133 |
with gr.Blocks() as demo:
|
| 134 |
gr.Markdown("## Text Embedding Generator")
|
| 135 |
|
| 136 |
+
initialization_status_box = gr.Textbox(label="Initialization Status", value=model_initialization_error, visible=False) # Hidden box to hold init error
|
| 137 |
+
|
| 138 |
with gr.Row():
|
| 139 |
text_input = gr.Textbox(label="Enter Text")
|
| 140 |
focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)")
|
|
|
|
| 161 |
process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output
|
| 162 |
process_status = gr.Textbox(label="File Processing Status") # Status for file processing
|
| 163 |
|
| 164 |
+
demo.load( # Call initialize_model on app load
|
| 165 |
+
lambda: ("", model_initialization_error), # Dummy output for other components, error for initialization_status_box
|
| 166 |
+
outputs=[status_box, initialization_status_box] # status_box for general messages, init status for hidden box
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
|
| 170 |
generate_button.click(
|
| 171 |
generate_embedding,
|
|
|
|
| 200 |
return demo
|
| 201 |
|
| 202 |
if __name__ == "__main__":
|
| 203 |
+
# Explicitly initialize the model at app startup and check for errors
|
| 204 |
+
initialization_success, initialization_error_message = initialize_model()
|
| 205 |
+
if not initialization_success:
|
| 206 |
+
print(f"App startup failed due to model initialization error:\n{initialization_error_message}") # Print to console for startup errors
|
| 207 |
+
|
| 208 |
demo = create_gradio_interface()
|
| 209 |
demo.launch(server_name="0.0.0.0")
|