cephcyn commited on
Commit
1a63a25
·
verified ·
1 Parent(s): 843429b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -35
handler.py CHANGED
@@ -1,35 +1,37 @@
1
- from typing import Dict, List, Any
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
-
4
- # Need to set HF_TOKEN on the endpoint creation process for this to work
5
- model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
6
-
7
- class EndpointHandler:
8
- def __init__(self, path=""):
9
- # load the model
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
12
- # create inference pipeline
13
- self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
14
-
15
- def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, float]]]:
16
- """
17
- input args:
18
- data: a dict with elements...
19
- inputs: List[str] , inputs to batch-process
20
- parameters: Any , parameters to be passed into model
21
- outputs:
22
- list of {'generated_text': str} type outputs
23
- """
24
-
25
- inputs = data.pop("inputs", data)
26
- parameters = data.pop("parameters", None)
27
-
28
- # pass inputs with all kwargs in data
29
- if parameters is not None:
30
- predictions = self.pipeline(inputs, **parameters)
31
- else:
32
- predictions = self.pipeline(inputs)
33
-
34
- # postprocess the prediction
35
- return [{'generated_text': e} for e in predictions]
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
+
4
+ # Need to set HF_TOKEN on the endpoint creation process for this to work
5
+ model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
6
+
7
+ class EndpointHandler:
8
+ def __init__(self, path=""):
9
+ # create inference pipeline
10
+ self.pipeline = pipeline(
11
+ "text-generation",
12
+ model=model_name,
13
+ model_kwargs={"torch_dtype": torch.bfloat16},
14
+ device_map="auto",
15
+ )
16
+
17
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
18
+ """
19
+ input args:
20
+ data: a dict with elements...
21
+ inputs: List[List[Dict[str, str]]] or List[str] , inputs to batch-process in conversational format
22
+ parameters: Any , parameters to be passed into model
23
+ outputs:
24
+ list of {'generated_text': str} type outputs
25
+ """
26
+
27
+ inputs = data.pop("inputs", data)
28
+ parameters = data.pop("parameters", None)
29
+
30
+ # pass inputs with all kwargs in data
31
+ if parameters is not None:
32
+ predictions = self.pipeline(inputs, **parameters)
33
+ else:
34
+ predictions = self.pipeline(inputs)
35
+
36
+ # postprocess the prediction
37
+ return [{'next_chat_turn': e[0]["generated_text"][-1]} for e in predictions]