Spaces:
Build error
Build error
| import json | |
| import os | |
| import gradio as gr | |
| import logging | |
| import traceback | |
| import spaces | |
| from typing import Optional, List | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from pathlib import Path | |
| import gc | |
| import torch | |
| from torch.amp import autocast | |
| from transformers import AutoModel, AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| import requests | |
| from charset_normalizer import from_bytes | |
| import zipfile | |
| import tempfile | |
| import shutil | |
| # Custom Exception Class (Keep this) | |
| class GPUQuotaExceededError(Exception): | |
| pass | |
| # Constants (Modified Persistent Paths and Cache) | |
| EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
| CHUNK_SIZE = 500 | |
| BATCH_SIZE = 32 | |
| # Set Persistent Storage Path (More Explicit Paths - from Worked Code) | |
| PERSISTENT_PATH = os.getenv("PERSISTENT_PATH", "/data") # Keep this as /data for Spaces persistent storage | |
| os.makedirs(PERSISTENT_PATH, exist_ok=True, mode=0o777) | |
| # Define Subdirectories (More Explicit Paths) | |
| TEMP_DIR = os.path.join(PERSISTENT_PATH, "temp") | |
| os.makedirs(TEMP_DIR, exist_ok=True, mode=0o777) | |
| OUTPUTS_DIR = os.path.join(PERSISTENT_PATH, "outputs") | |
| os.makedirs(OUTPUTS_DIR, exist_ok=True, mode=0o777) | |
| NPY_CACHE = os.path.join(PERSISTENT_PATH, "npy_cache") | |
| os.makedirs(NPY_CACHE, exist_ok=True, mode=0o777) | |
| LOG_DIR = os.getenv("LOG_DIR", os.path.join(PERSISTENT_PATH, "logs")) | |
| os.makedirs(LOG_DIR, exist_ok=True, mode=0o777) | |
| # Set Hugging Face cache directory to persistent storage (From Worked Code - Important!) | |
| os.environ["HF_HOME"] = os.path.join(PERSISTENT_PATH, ".huggingface") | |
| os.makedirs(os.environ["HF_HOME"], exist_ok=True, mode=0o777) | |
| # Set Hugging Face token (Keep this - best to use environment variable) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Logging Setup (Keep this - helpful for debugging) | |
| logging.basicConfig( | |
| filename=os.path.join(LOG_DIR, "app.log"), # Use os.path.join for log file path | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Model initialization | |
| model = None | |
| model_initialization_error = "" # Global variable for initialization error | |
| def initialize_model(): | |
| """ | |
| Initialize the sentence transformer model with explicit cache path and error handling. | |
| Returns: | |
| bool: Whether the model was successfully initialized. | |
| str: Error message if initialization failed, otherwise empty string. | |
| """ | |
| global model, model_initialization_error | |
| try: | |
| if model is None: | |
| model_cache = os.path.join(PERSISTENT_PATH, "models") # Explicit model cache path (from worked code) | |
| os.makedirs(model_cache, exist_ok=True, mode=0o777) # Ensure cache directory exists | |
| # Use the HF_TOKEN to load the model (as in worked code) | |
| model = SentenceTransformer(EMBEDDING_MODEL_NAME, cache_folder=model_cache, use_auth_token=HF_TOKEN) | |
| logger.info(f"Initialized model: {EMBEDDING_MODEL_NAME}") | |
| model_initialization_error = "" # Clear any previous error | |
| return True, "" # Return success and no error message | |
| return True, "" # Already initialized, return success and no error | |
| except requests.exceptions.RequestException as e: # Specific network error handling (from worked code) | |
| error_msg = f"Connection error during model download: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| model_initialization_error = error_msg | |
| return False, error_msg | |
| except Exception as e: # General error handling (from worked code) | |
| error_msg = f"Model initialization failed: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| model_initialization_error = error_msg | |
| return False, error_msg | |
| def generate_embedding(text, focus): | |
| global model, model_initialization_error | |
| if model is None: | |
| success, error_message = initialize_model() # Call initialize_model and get status | |
| if not success: | |
| return "", error_message # Return initialization error to UI | |
| try: | |
| with torch.amp.autocast('cuda'): | |
| embedding_vector = model.encode([text])[0].tolist() # Get embedding as list | |
| # Convert embedding to JSON string for direct display in UI | |
| embedding_json_str = json.dumps(embedding_vector) | |
| return embedding_json_str, "" # Return JSON string to UI | |
| except Exception as e: | |
| error_msg = f"Error generating embedding: {str(e)}" | |
| logger.error(error_msg) | |
| return "", error_msg | |
| def save_embedding(embedding_json, name): # Expect JSON string as input from UI | |
| try: | |
| embedding = json.loads(embedding_json) # Parse JSON string back to list | |
| filepath = os.path.join(PERSISTENT_PATH, f"{name}.npy") # Use os.path.join for filepath | |
| np.save(filepath, np.array(embedding)) | |
| return f"Embedding saved to: {filepath}" # Return filepath in status | |
| except Exception as e: | |
| error_msg = f"Error saving embedding: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| def convert_to_json(embedding_json, name): # Expect JSON string as input | |
| try: | |
| filepath = os.path.join(PERSISTENT_PATH, f"{name}.json") # Use os.path.join for filepath | |
| with open(filepath, "w") as f: | |
| f.write(embedding_json) # Directly write the JSON string | |
| return f"Embedding saved as JSON to: {filepath}" # Return filepath in status | |
| except Exception as e: | |
| error_msg = f"Error converting to JSON: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| def process_files(files, focus): | |
| global model, model_initialization_error | |
| if model is None: | |
| success, error_message = initialize_model() # Call initialize_model and get status | |
| if not success: | |
| return "", error_message # Return initialization error to UI | |
| try: | |
| all_embeddings = [] | |
| file_statuses = [] # To track status for each file | |
| for file in files: | |
| try: | |
| with open(file.name, 'rb') as f: | |
| text = f.read() | |
| with torch.amp.autocast('cuda'): | |
| embedding = model.encode([text])[0].tolist() | |
| all_embeddings.append(embedding) | |
| file_statuses.append(f"File '{file.name}' processed successfully.") | |
| except Exception as file_e: | |
| error_msg = f"Error processing file '{file.name}': {str(file_e)}" | |
| logger.error(error_msg) | |
| file_statuses.append(error_msg) | |
| # Prepare status message for all files | |
| status_message = "\n".join(file_statuses) | |
| # Convert embeddings to JSON string for UI display (for demonstration - might be too long for large files) | |
| all_embeddings_json = json.dumps(all_embeddings) | |
| return all_embeddings_json, status_message # Return JSON string and status message | |
| except Exception as e: | |
| error_msg = f"Error in process_files function: {str(e)}" | |
| logger.error(error_msg) | |
| return "", error_msg | |
| def create_gradio_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Text Embedding Generator") | |
| initialization_status_box = gr.Textbox(label="Initialization Status", value=model_initialization_error, visible=False) # Hidden box to hold init error | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Enter Text") | |
| focus_input = gr.Textbox(label="Main Focus of Embedding (e.g., company structure, staff positions, etc.)") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload Files", file_count="multiple") | |
| generate_button = gr.Button("Generate Embedding") | |
| embedding_output = gr.Textbox(label="Embedding Vector (JSON)", lines=5) # Label changed to JSON | |
| status_box = gr.Textbox(label="Status/Messages") # Renamed error_box to status_box | |
| with gr.Accordion("Save and Download Options", open=False): # Accordion for save/download options | |
| save_name_input = gr.Textbox(label="Save Embedding As (Name without extension)") | |
| with gr.Row(): | |
| save_button = gr.Button("Save as .npy") | |
| convert_button = gr.Button("Save as .json") | |
| with gr.Row(): | |
| save_status = gr.Textbox(label="Save Status") | |
| convert_status = gr.Textbox(label="Convert Status") | |
| download_button = gr.Button("Download JSON") | |
| download_output = gr.File(label="Download JSON File") | |
| process_button = gr.Button("Process Files") | |
| process_output = gr.Textbox(label="Processed Files (Embeddings JSON - limited display)", lines=3) # Limited lines for process output | |
| process_status = gr.Textbox(label="File Processing Status") # Status for file processing | |
| demo.load( # Call initialize_model on app load | |
| lambda: ("", model_initialization_error), # Dummy output for other components, error for initialization_status_box | |
| outputs=[status_box, initialization_status_box] # status_box for general messages, init status for hidden box | |
| ) | |
| generate_button.click( | |
| generate_embedding, | |
| inputs=[text_input, focus_input], | |
| outputs=[embedding_output, status_box] # Renamed error_box to status_box | |
| ) | |
| save_button.click( | |
| save_embedding, | |
| inputs=[embedding_output, save_name_input], # Input is now embedding_output (JSON string) | |
| outputs=[save_status] | |
| ) | |
| convert_button.click( | |
| convert_to_json, | |
| inputs=[embedding_output, save_name_input], # Input is embedding_output (JSON string) | |
| outputs=[convert_status] | |
| ) | |
| download_button.click( | |
| lambda name: os.path.join(PERSISTENT_PATH, f"{name}.json") if name else None, # Handle empty name, use os.path.join | |
| inputs=[save_name_input], | |
| outputs=[download_output] | |
| ) | |
| process_button.click( | |
| process_files, | |
| inputs=[file_input, focus_input], | |
| outputs=[process_output, process_status] # outputs for process_files | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Explicitly initialize the model at app startup and check for errors | |
| initialization_success, initialization_error_message = initialize_model() | |
| if not initialization_success: | |
| print(f"App startup failed due to model initialization error:\n{initialization_error_message}") # Print to console for startup errors | |
| demo = create_gradio_interface() | |
| demo.launch(server_name="0.0.0.0", allowed_paths=["/data"]) |