File size: 467 Bytes
e290a20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import os
import torch
from transformers import AutoTokenizer, pipeline
GPT_WEIGHTS_NAME = "pyg.pt"
def model_fn(model_dir):
model = torch.load(os.path.join(model_dir, GPT_WEIGHTS_NAME))
tokenizer = AutoTokenizer.from_pretrained(model_dir)
if torch.cuda.is_available():
device = 0
else:
device = -1
generation = pipeline(
"text-generation", model=model, tokenizer=tokenizer, device=device
)
return generation
|