Spaces:
Runtime error
Runtime error
File size: 2,411 Bytes
9288345 7098798 3369b7c 9288345 7098798 4e3ad22 5e0bbc5 7098798 4b5415a 35b327b 3369b7c 4b5415a 81c6189 3369b7c 81c6189 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from sentence_transformers import SentenceTransformer
import os
from huggingface_hub import hf_hub_download
from joblib import load # <-- Import this to load the model
# Get the Hugging Face token from environment variable
hf_token = os.getenv('HF_TOKEN')
# Hugging Face Model ID and local model directory
hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5'
model_dir = '/tmp/sentence_transformer' # Use /tmp for write permissions
clf_model_path = 'models/logistic_regression_model.pkl'
# Create model directory if not exists
os.makedirs(model_dir, exist_ok=True)
# Download model if not already downloaded
if not os.path.exists(os.path.join(model_dir, 'config.json')):
print(f"Downloading model '{hf_model_id}' from Hugging Face...")
model = SentenceTransformer(hf_model_id, use_auth_token=hf_token, trust_remote_code=True)
model.save(model_dir)
else:
print(f"Loading model from local directory: {model_dir}")
model = SentenceTransformer(model_dir, trust_remote_code=True) # Added trust_remote_code=True
# ✅ Load the logistic regression model and define clf globally
clf = None # Initialize as None
if os.path.exists(clf_model_path):
clf = load(clf_model_path) # Load the logistic regression model
print("Logistic Regression model loaded successfully.")
else:
print("Logistic Regression model not found. Ensure it is saved in /tmp.")
# Define predict_label function
def predict_label(text):
try:
# Check if clf is loaded
if clf is None:
raise ValueError("Logistic Regression model is not loaded.")
# Ensure input is a list for the model
if not isinstance(text, list):
text = [text]
# Generate embeddings
embeddings = model.encode(text)
# Ensure embeddings are in the correct shape
if len(embeddings) == 0:
raise ValueError("No embeddings generated.")
# Predict using the logistic regression model
prediction = clf.predict(embeddings)
probability = clf.predict_proba(embeddings).max()
# Convert label to string ("0" or "1")
label = str(prediction[0])
# Return label and probability
return label, float(probability)
except Exception as e:
# Log the exception for debugging
print(f"Error in predict_label: {e}")
return "Error", 0.0
|