File size: 524 Bytes
fae8ff7
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import joblib
import mlflow.pyfunc
from sentence_transformers import SentenceTransformer


class MiniLMClassifierWrapper(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        self.encoder = SentenceTransformer(context.artifacts["encoder_path"])
        self.classifier = joblib.load(context.artifacts["classifier_path"])

    def predict(self, context, model_input):
        embeddings = self.encoder.encode(model_input)
        predictions = self.classifier.predict(embeddings)
        return predictions