| import os | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| from sagemaker_inference import content_types, decoder, default_inference_handler, encoder | |
| def model_fn(model_dir): | |
| model = SentenceTransformer(model_dir) | |
| return model | |
| def input_fn(request_body, request_content_type): | |
| if request_content_type == content_types.JSON: | |
| input_data = decoder.decode(request_body, content_types.JSON) | |
| return input_data | |
| else: | |
| raise ValueError(f"Requested unsupported ContentType in content_type: {request_content_type}") | |
| def predict_fn(input_data, model): | |
| embeddings = model.encode(input_data) | |
| return embeddings | |
| def output_fn(prediction, accept): | |
| if accept == content_types.JSON: | |
| output = encoder.encode(prediction, content_types.JSON) | |
| return output | |
| else: | |
| raise ValueError(f"Requested unsupported ContentType in Accept: {accept}") | |