File size: 830 Bytes
50b1ecc 9f03147 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import os
import json
import fasttext
from typing import Union, List
def model_fn(model_dir):
loaded_model = fasttext.load_model(
os.path.join(model_dir, "fasttext_model_300.bin")
)
return loaded_model
def input_fn(input_data, content_type):
data = json.loads(input_data)
return data['inputs']
def predict_fn(data: Union[List[str], str], model):
if isinstance(data, str):
return model.get_sentence_vector(data).tolist()
elif isinstance(data, list):
return [model.get_sentence_vector(sentence).tolist() for sentence in data]
else:
raise ValueError(f"Unsupported data type: {type(data)}")
def output_fn(prediction, accept):
return json.dumps(
obj={
"outputs": prediction
},
ensure_ascii=False,
default=str
)
|