| import os | |
| import torch | |
| from transformers import AutoTokenizer, pipeline | |
| GPT_WEIGHTS_NAME = "pyg.pt" | |
| def model_fn(model_dir): | |
| model = torch.load(os.path.join(model_dir, GPT_WEIGHTS_NAME)) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| if torch.cuda.is_available(): | |
| device = 0 | |
| else: | |
| device = -1 | |
| generation = pipeline( | |
| "text-generation", model=model, tokenizer=tokenizer, device=device | |
| ) | |
| return generation | |