Melissa Roemmele commited on
Commit
2b38b94
·
1 Parent(s): 210820d

Updated handler.py

Browse files
Files changed (1) hide show
  1. handler.py +47 -47
handler.py CHANGED
@@ -4,58 +4,58 @@ from typing import Any, Dict
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
 
7
- class EndpointHandler():
8
- def __init__(self, path=""):
9
- model = AutoModelForCausalLM.from_pretrained(path,
10
- torch_dtype=torch.bfloat16,
11
- trust_remote_code=True,
12
- device_map="auto")
13
- print(model.hf_device_map)
14
- tokenizer = AutoTokenizer.from_pretrained(path)
15
- #device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
- self.pipeline = transformers.pipeline('text-generation',
17
- model=model,
18
- tokenizer=tokenizer)
19
-
20
- def __call__(self, data: Dict[str, Any]):
21
- inputs = data.pop("inputs", data)
22
- parameters = data.pop("parameters", {})
23
- with torch.autocast(self.pipeline.device.type, dtype=torch.bfloat16):
24
- outputs = self.pipeline(inputs,
25
- **parameters)
26
- return outputs
27
-
28
-
29
- # class EndpointHandler:
30
  # def __init__(self, path=""):
31
- # # load model and tokenizer from path
32
- # self.tokenizer = AutoTokenizer.from_pretrained(path)
33
- # self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
- # self.model = AutoModelForCausalLM.from_pretrained(path,
35
- # device_map="auto",
36
- # torch_dtype=torch.float16,
37
- # trust_remote_code=True)
38
-
39
- # def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
40
- # # process input
 
 
41
  # inputs = data.pop("inputs", data)
42
  # parameters = data.pop("parameters", {})
43
- # return_full_text = parameters.pop("return_full_text", True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # # preprocess
46
- # inputs = self.tokenizer(inputs,
47
- # return_tensors="pt",
48
- # return_token_type_ids=False)
49
- # inputs = inputs.to(self.device)
50
- # input_len = len(inputs[0])
51
 
52
- # outputs = self.model.generate(**inputs, **parameters)[0]
53
 
54
- # if not return_full_text:
55
- # outputs = outputs[input_len:]
56
 
57
- # # postprocess the prediction
58
- # prediction = self.tokenizer.decode(outputs,
59
- # skip_special_tokens=True)
60
 
61
- # return [{"generated_text": prediction}]
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
 
7
+ # class EndpointHandler():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # def __init__(self, path=""):
9
+ # model = AutoModelForCausalLM.from_pretrained(path,
10
+ # torch_dtype=torch.bfloat16,
11
+ # trust_remote_code=True,
12
+ # device_map="auto")
13
+ # print(model.hf_device_map)
14
+ # tokenizer = AutoTokenizer.from_pretrained(path)
15
+ # #device = "cuda:0" if torch.cuda.is_available() else "cpu"
16
+ # self.pipeline = transformers.pipeline('text-generation',
17
+ # model=model,
18
+ # tokenizer=tokenizer)
19
+
20
+ # def __call__(self, data: Dict[str, Any]):
21
  # inputs = data.pop("inputs", data)
22
  # parameters = data.pop("parameters", {})
23
+ # with torch.autocast(self.pipeline.device.type, dtype=torch.bfloat16):
24
+ # outputs = self.pipeline(inputs,
25
+ # **parameters)
26
+ # return outputs
27
+
28
+
29
+ class EndpointHandler:
30
+ def __init__(self, path=""):
31
+ # load model and tokenizer from path
32
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
33
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ self.model = AutoModelForCausalLM.from_pretrained(path,
35
+ device_map="auto",
36
+ torch_dtype=torch.float16,
37
+ trust_remote_code=True)
38
+
39
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
40
+ # process input
41
+ inputs = data.pop("inputs", data)
42
+ parameters = data.pop("parameters", {})
43
+ return_full_text = parameters.pop("return_full_text", True)
44
 
45
+ # preprocess
46
+ inputs = self.tokenizer(inputs,
47
+ return_tensors="pt",
48
+ return_token_type_ids=False)
49
+ inputs = inputs.to(self.device)
50
+ input_len = len(inputs[0])
51
 
52
+ outputs = self.model.generate(**inputs, **parameters)[0]
53
 
54
+ if not return_full_text:
55
+ outputs = outputs[input_len:]
56
 
57
+ # postprocess the prediction
58
+ prediction = self.tokenizer.decode(outputs,
59
+ skip_special_tokens=True)
60
 
61
+ return [{"generated_text": prediction}]