embedding_FastAPI / service /prediction_service.py
Chittrarasu's picture
deploy
9288345
raw
history blame
517 Bytes
import pickle
from sentence_transformers import SentenceTransformer
import numpy as np
# Load Model and Transformer
with open('models/logistic_regression_model.pkl', 'rb') as f:
logistic_model = pickle.load(f)
model = SentenceTransformer('models/sentence_transformer')
def predict_label(message: str):
embedding = model.encode([message])
prediction = logistic_model.predict(embedding)[0]
probability = logistic_model.predict_proba(embedding)[0].max()
return prediction, float(probability)