| import torch | |
| from transformers import pipeline | |
| def model_fn(model_dir): | |
| instruct_pipeline = pipeline( | |
| model=model_dir, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| model_kwargs={"load_in_8bit": True}, | |
| ) | |
| return instruct_pipeline |