turing-space / turing /modeling /models /MiniLMClassifierWrapper.py
github-actions[bot]
Sync turing folder from GitHub
fae8ff7
raw
history blame contribute delete
524 Bytes
import joblib
import mlflow.pyfunc
from sentence_transformers import SentenceTransformer
class MiniLMClassifierWrapper(mlflow.pyfunc.PythonModel):
def load_context(self, context):
self.encoder = SentenceTransformer(context.artifacts["encoder_path"])
self.classifier = joblib.load(context.artifacts["classifier_path"])
def predict(self, context, model_input):
embeddings = self.encoder.encode(model_input)
predictions = self.classifier.predict(embeddings)
return predictions