Melissa Roemmele commited on
Commit
a9751eb
·
1 Parent(s): 0336ef0

Updated handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -5
handler.py CHANGED
@@ -7,7 +7,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
  model = AutoModelForCausalLM.from_pretrained(path,
10
- torch_dtype=torch.float16,
11
  trust_remote_code=True)
12
  tokenizer = AutoTokenizer.from_pretrained(path)
13
  #device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -19,10 +19,10 @@ class EndpointHandler():
19
  def __call__(self, data: Dict[str, Any]):
20
  inputs = data.pop("inputs", data)
21
  parameters = data.pop("parameters", {})
22
- # with torch.autocast(self.pipeline.device.type, dtype=torch.float16):
23
- outputs = self.pipeline(inputs,
24
- **parameters)
25
- return outputs
26
 
27
 
28
  # class EndpointHandler:
 
7
  class EndpointHandler():
8
  def __init__(self, path=""):
9
  model = AutoModelForCausalLM.from_pretrained(path,
10
+ torch_dtype=torch.bfloat16,
11
  trust_remote_code=True)
12
  tokenizer = AutoTokenizer.from_pretrained(path)
13
  #device = "cuda:0" if torch.cuda.is_available() else "cpu"
 
19
  def __call__(self, data: Dict[str, Any]):
20
  inputs = data.pop("inputs", data)
21
  parameters = data.pop("parameters", {})
22
+ with torch.autocast(self.pipeline.device.type, dtype=torch.bfloat16):
23
+ outputs = self.pipeline(inputs,
24
+ **parameters)
25
+ return outputs
26
 
27
 
28
  # class EndpointHandler: