Spaces:
Runtime error
Runtime error
| from sentence_transformers import SentenceTransformer | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from joblib import load # <-- Import this to load the model | |
| # Get the Hugging Face token from environment variable | |
| hf_token = os.getenv('HF_TOKEN') | |
| # Hugging Face Model ID and local model directory | |
| hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5' | |
| model_dir = '/tmp/sentence_transformer' # Use /tmp for write permissions | |
| clf_model_path = 'models/logistic_regression_model.pkl' | |
| # Create model directory if not exists | |
| os.makedirs(model_dir, exist_ok=True) | |
| # Download model if not already downloaded | |
| if not os.path.exists(os.path.join(model_dir, 'config.json')): | |
| print(f"Downloading model '{hf_model_id}' from Hugging Face...") | |
| model = SentenceTransformer(hf_model_id, use_auth_token=hf_token, trust_remote_code=True) | |
| model.save(model_dir) | |
| else: | |
| print(f"Loading model from local directory: {model_dir}") | |
| model = SentenceTransformer(model_dir, trust_remote_code=True) # Added trust_remote_code=True | |
| # β Load the logistic regression model and define clf globally | |
| clf = None # Initialize as None | |
| if os.path.exists(clf_model_path): | |
| clf = load(clf_model_path) # Load the logistic regression model | |
| print("Logistic Regression model loaded successfully.") | |
| else: | |
| print("Logistic Regression model not found. Ensure it is saved in /tmp.") | |
| # Define predict_label function | |
| def predict_label(text): | |
| try: | |
| # Check if clf is loaded | |
| if clf is None: | |
| raise ValueError("Logistic Regression model is not loaded.") | |
| # Ensure input is a list for the model | |
| if not isinstance(text, list): | |
| text = [text] | |
| # Generate embeddings | |
| embeddings = model.encode(text) | |
| # Ensure embeddings are in the correct shape | |
| if len(embeddings) == 0: | |
| raise ValueError("No embeddings generated.") | |
| # Predict using the logistic regression model | |
| prediction = clf.predict(embeddings) | |
| probability = clf.predict_proba(embeddings).max() | |
| # Convert label to string ("0" or "1") | |
| label = str(prediction[0]) | |
| # Return label and probability | |
| return label, float(probability) | |
| except Exception as e: | |
| # Log the exception for debugging | |
| print(f"Error in predict_label: {e}") | |
| return "Error", 0.0 | |