Melissa Roemmele commited on
Commit
ac83b16
·
1 Parent(s): 758ff51

Updated handler.py

Browse files
Files changed (1) hide show
  1. handler.py +26 -27
handler.py CHANGED
@@ -1,37 +1,36 @@
1
  import torch
2
- import transformers
3
  from typing import Any, Dict
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
 
7
- def load_pipeline(path):
8
- tokenizer = AutoTokenizer.from_pretrained(path)
9
- model = AutoModelForCausalLM.from_pretrained(path,
10
- torch_dtype=torch.bfloat16,
11
- trust_remote_code=True)
12
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
- pipeline = transformers.pipeline('text-generation',
14
- model=model,
15
- tokenizer=tokenizer,
16
- device=device)
17
- return pipeline
18
-
19
-
20
  class EndpointHandler:
21
  def __init__(self, path=""):
22
- self.path = path
23
- self.pipeline = load_pipeline(self.path)
 
 
 
 
24
 
25
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
 
26
  inputs = data.pop("inputs", data)
27
- parameters = data.pop("parameters", {})
28
- with torch.autocast(self.pipeline.device.type, dtype=torch.bfloat16):
29
- try:
30
- outputs = self.pipeline(inputs, **parameters, use_cache=True)
31
- return outputs
32
- except Exception as e:
33
- print("Exception encounted. Reloading pipeline")
34
- # Reload pipeline
35
- self.pipeline = load_pipeline(self.path)
36
- torch.cuda.empty_cache()
37
- raise e
 
 
 
 
 
 
 
 
 
1
  import torch
 
2
  from typing import Any, Dict
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  class EndpointHandler:
7
  def __init__(self, path=""):
8
+ # load model and tokenizer from path
9
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.model = AutoModelForCausalLM.from_pretrained(
11
+ path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
12
+ )
13
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
16
+ # process input
17
  inputs = data.pop("inputs", data)
18
+ parameters = data.pop("parameters", None)
19
+
20
+ # preprocess
21
+ inputs = self.tokenizer(inputs, return_tensors="pt")
22
+ inputs.pop("token_type_ids")
23
+ inputs.pop("attention_mask")
24
+ inputs = inputs.to(self.device)
25
+
26
+ # pass inputs with all kwargs in data
27
+ if parameters is not None:
28
+ outputs = self.model.generate(**inputs, **parameters)
29
+ else:
30
+ outputs = self.model.generate(**inputs)
31
+
32
+ # postprocess the prediction
33
+ prediction = self.tokenizer.decode(
34
+ outputs[0], skip_special_tokens=True)
35
+
36
+ return [{"generated_text": prediction}]