Melissa Roemmele commited on
Commit
210820d
·
1 Parent(s): a9751eb

Updated handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -3
handler.py CHANGED
@@ -8,13 +8,14 @@ class EndpointHandler():
8
  def __init__(self, path=""):
9
  model = AutoModelForCausalLM.from_pretrained(path,
10
  torch_dtype=torch.bfloat16,
11
- trust_remote_code=True)
 
 
12
  tokenizer = AutoTokenizer.from_pretrained(path)
13
  #device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
  self.pipeline = transformers.pipeline('text-generation',
15
  model=model,
16
- tokenizer=tokenizer,
17
- device_map="auto")
18
 
19
  def __call__(self, data: Dict[str, Any]):
20
  inputs = data.pop("inputs", data)
 
8
  def __init__(self, path=""):
9
  model = AutoModelForCausalLM.from_pretrained(path,
10
  torch_dtype=torch.bfloat16,
11
+ trust_remote_code=True,
12
+ device_map="auto")
13
+ print(model.hf_device_map)
14
  tokenizer = AutoTokenizer.from_pretrained(path)
15
  #device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
  self.pipeline = transformers.pipeline('text-generation',
17
  model=model,
18
+ tokenizer=tokenizer)
 
19
 
20
  def __call__(self, data: Dict[str, Any]):
21
  inputs = data.pop("inputs", data)