embedding_FastAPI / service /prediction_service.py
Chittrarasu's picture
deploy
81c6189
raw
history blame
1.81 kB
from sentence_transformers import SentenceTransformer
import os
from huggingface_hub import hf_hub_download
# Get the Hugging Face token from environment variable
hf_token = os.getenv('HF_TOKEN')
# Hugging Face Model ID and local model directory
hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5'
model_dir = '/tmp/sentence_transformer' # Use /tmp for write permissions
# Create model directory if not exists
os.makedirs(model_dir, exist_ok=True)
# Download model if not already downloaded
if not os.path.exists(os.path.join(model_dir, 'config.json')):
print(f"Downloading model '{hf_model_id}' from Hugging Face...")
model = SentenceTransformer(hf_model_id, use_auth_token=hf_token, trust_remote_code=True)
model.save(model_dir)
else:
print(f"Loading model from local directory: {model_dir}")
model = SentenceTransformer(model_dir, trust_remote_code=True) # Added trust_remote_code=True
# Define predict_label function
def predict_label(text):
try:
# Ensure input is a list for the model
if not isinstance(text, list):
text = [text]
# Generate embeddings
embeddings = model.encode(text)
# Ensure embeddings are in the correct shape
if len(embeddings) == 0:
raise ValueError("No embeddings generated.")
# Predict using the logistic regression model
prediction = clf.predict(embeddings)
probability = clf.predict_proba(embeddings).max()
# Convert label to string ("0" or "1")
label = str(prediction[0])
# Return label and probability
return label, float(probability)
except Exception as e:
# Log the exception for debugging
print(f"Error in predict_label: {e}")
return "Error", 0.0