Melissa Roemmele commited on
Commit
b827f74
·
1 Parent(s): 5ebba4e

Updated handler.py

Browse files
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -8,7 +8,7 @@ class EndpointHandler:
8
  def __init__(self, path=""):
9
  tokenizer = AutoTokenizer.from_pretrained(path)
10
  model = AutoModelForCausalLM.from_pretrained(path,
11
- torch_dtype=torch.float16,
12
  trust_remote_code=True)
13
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
  self.pipeline = transformers.pipeline('text-generation',
@@ -19,6 +19,6 @@ class EndpointHandler:
19
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
20
  inputs = data.pop("inputs", data)
21
  parameters = data.pop("parameters", {})
22
- with torch.autocast(self.pipeline.device.type, dtype=torch.float16):
23
  outputs = self.pipeline(inputs, **parameters, use_cache=True)
24
  return outputs
 
8
  def __init__(self, path=""):
9
  tokenizer = AutoTokenizer.from_pretrained(path)
10
  model = AutoModelForCausalLM.from_pretrained(path,
11
+ torch_dtype=torch.bfloat16,
12
  trust_remote_code=True)
13
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
  self.pipeline = transformers.pipeline('text-generation',
 
19
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
20
  inputs = data.pop("inputs", data)
21
  parameters = data.pop("parameters", {})
22
+ with torch.autocast(self.pipeline.device.type, dtype=torch.bfloat16):
23
  outputs = self.pipeline(inputs, **parameters, use_cache=True)
24
  return outputs