| import os | |
| import json | |
| import fasttext | |
| from typing import Union, List | |
| def model_fn(model_dir): | |
| loaded_model = fasttext.load_model( | |
| os.path.join(model_dir, "fasttext_model_300.bin") | |
| ) | |
| return loaded_model | |
| def input_fn(input_data, content_type): | |
| data = json.loads(input_data) | |
| return data['inputs'] | |
| def predict_fn(data: Union[List[str], str], model): | |
| if isinstance(data, str): | |
| return model.get_sentence_vector(data).tolist() | |
| elif isinstance(data, list): | |
| return [model.get_sentence_vector(sentence).tolist() for sentence in data] | |
| else: | |
| raise ValueError(f"Unsupported data type: {type(data)}") | |
| def output_fn(prediction, accept): | |
| return json.dumps( | |
| obj={ | |
| "outputs": prediction | |
| }, | |
| ensure_ascii=False, | |
| default=str | |
| ) | |