shamaayan's picture
copy reqs
6c25421
raw
history blame contribute delete
830 Bytes
import os
import json
import fasttext
from typing import Union, List
def model_fn(model_dir):
loaded_model = fasttext.load_model(
os.path.join(model_dir, "fasttext_model_300.bin")
)
return loaded_model
def input_fn(input_data, content_type):
data = json.loads(input_data)
return data['inputs']
def predict_fn(data: Union[List[str], str], model):
if isinstance(data, str):
return model.get_sentence_vector(data).tolist()
elif isinstance(data, list):
return [model.get_sentence_vector(sentence).tolist() for sentence in data]
else:
raise ValueError(f"Unsupported data type: {type(data)}")
def output_fn(prediction, accept):
return json.dumps(
obj={
"outputs": prediction
},
ensure_ascii=False,
default=str
)