File size: 467 Bytes
e290a20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import os
import torch
from transformers import AutoTokenizer, pipeline

GPT_WEIGHTS_NAME = "pyg.pt"


def model_fn(model_dir):
    model = torch.load(os.path.join(model_dir, GPT_WEIGHTS_NAME))
    tokenizer = AutoTokenizer.from_pretrained(model_dir)

    if torch.cuda.is_available():
        device = 0
    else:
        device = -1

    generation = pipeline(
        "text-generation", model=model, tokenizer=tokenizer, device=device
    )

    return generation