embedding_FastAPI / service /prediction_service.py
Chittrarasu's picture
deploy
4e3ad22
raw
history blame
852 Bytes
from sentence_transformers import SentenceTransformer
import os
from huggingface_hub import hf_hub_download
# Get the Hugging Face token from environment variable
hf_token = os.getenv('HF_TOKEN')
# Hugging Face Model ID and local model directory
hf_model_id = 'Alibaba-NLP/gte-base-en-v1.5'
model_dir = '/tmp/sentence_transformer' # Use /tmp for write permissions
# Create model directory if not exists
os.makedirs(model_dir, exist_ok=True)
# Download model if not already downloaded
if not os.path.exists(os.path.join(model_dir, 'config.json')):
print(f"Downloading model '{hf_model_id}' from Hugging Face...")
model = SentenceTransformer(hf_model_id, use_auth_token=hf_token, trust_remote_code=True)
model.save(model_dir)
else:
print(f"Loading model from local directory: {model_dir}")
model = SentenceTransformer(model_dir)