Melissa Roemmele commited on
Commit
758ff51
·
1 Parent(s): fbc6b9c

Updated handler.py

Browse files
Files changed (1) hide show
  1. handler.py +24 -12
handler.py CHANGED
@@ -4,22 +4,34 @@ from typing import Any, Dict
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class EndpointHandler:
8
  def __init__(self, path=""):
9
- tokenizer = AutoTokenizer.from_pretrained(path)
10
- model = AutoModelForCausalLM.from_pretrained(path,
11
- torch_dtype=torch.bfloat16,
12
- trust_remote_code=True)
13
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
- self.pipeline = transformers.pipeline('text-generation',
15
- model=model,
16
- tokenizer=tokenizer,
17
- device=device)
18
 
19
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
20
- torch.cuda.empty_cache()
21
  inputs = data.pop("inputs", data)
22
  parameters = data.pop("parameters", {})
23
  with torch.autocast(self.pipeline.device.type, dtype=torch.bfloat16):
24
- outputs = self.pipeline(inputs, **parameters, use_cache=True)
25
- return outputs
 
 
 
 
 
 
 
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
 
7
+ def load_pipeline(path):
8
+ tokenizer = AutoTokenizer.from_pretrained(path)
9
+ model = AutoModelForCausalLM.from_pretrained(path,
10
+ torch_dtype=torch.bfloat16,
11
+ trust_remote_code=True)
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+ pipeline = transformers.pipeline('text-generation',
14
+ model=model,
15
+ tokenizer=tokenizer,
16
+ device=device)
17
+ return pipeline
18
+
19
+
20
  class EndpointHandler:
21
  def __init__(self, path=""):
22
+ self.path = path
23
+ self.pipeline = load_pipeline(self.path)
 
 
 
 
 
 
 
24
 
25
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
 
26
  inputs = data.pop("inputs", data)
27
  parameters = data.pop("parameters", {})
28
  with torch.autocast(self.pipeline.device.type, dtype=torch.bfloat16):
29
+ try:
30
+ outputs = self.pipeline(inputs, **parameters, use_cache=True)
31
+ return outputs
32
+ except Exception as e:
33
+ print("Exception encounted. Reloading pipeline")
34
+ # Reload pipeline
35
+ self.pipeline = load_pipeline(self.path)
36
+ torch.cuda.empty_cache()
37
+ raise e