| from transformers import pipeline | |
| import torch | |
| def model_fn(model_dir): | |
| """ | |
| Overrides the default model load function in the HuggingFace Deep Learning Container | |
| """ | |
| instruct_pipeline = pipeline(model="fuwangwang/mpt-7b", torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto") | |
| return instruct_pipeline |