abashar commited on
Commit
3315acd
·
verified ·
1 Parent(s): 7a0f5b2

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -14
handler.py CHANGED
@@ -2,37 +2,39 @@ import torch
2
  from typing import Dict, List, Any
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
 
5
- # check for GPU
6
- device = 0 if torch.cuda.is_available() else -1
7
-
8
  # multi-model list
9
- multi_model_list = [
10
- {"model_id": "gemma-2B-2nd_filtered_3_full", "model_path": "omarabb315/gemma-2B-2nd_filtered_3_full", "task": "text-generation"},
11
  # {"model_path": "omarabb315/gemma-2B-2nd_filtered_3_16bit", "task": "text-generation"},
12
  # {"model_path": "omarabb315/Gemma-2-9B-filtered_3_4bits", "task": "text-generation"},
13
  ]
14
 
15
  class EndpointHandler():
16
  def __init__(self, path=""):
17
- self.multi_model={}
18
  # load all the models onto device
19
- for model in multi_model_list:
20
- self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_path"], device=device)
 
 
 
 
21
 
22
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
23
  # deserialize incomin request
24
  inputs = data.pop("inputs", data)
25
  parameters = data.pop("parameters", None)
26
- model_id = data.pop("model_id", None)
 
27
 
28
  # check if model_id is in the list of models
29
- if model_id is None or model_id not in self.multi_model:
30
- raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
31
 
32
  # pass inputs with all kwargs in data
33
  if parameters is not None:
34
- prediction = self.multi_model[model_id](inputs, **parameters)
35
  else:
36
- prediction = self.multi_model[model_id](inputs)
37
- # postprocess the prediction
38
  return prediction
 
2
  from typing import Dict, List, Any
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
 
 
 
 
5
  # multi-model list
6
+ # multi_model_list = [
7
+ # {"model_id": "gemma-2B-2nd_filtered_3_full", "model_path": "omarabb315/gemma-2B-2nd_filtered_3_full", "task": "text-generation"},
8
  # {"model_path": "omarabb315/gemma-2B-2nd_filtered_3_16bit", "task": "text-generation"},
9
  # {"model_path": "omarabb315/Gemma-2-9B-filtered_3_4bits", "task": "text-generation"},
10
  ]
11
 
12
  class EndpointHandler():
13
  def __init__(self, path=""):
14
+ # self.multi_model={}
15
  # load all the models onto device
16
+ # for model in multi_model_list:
17
+ # self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_path"], trust_remote_code=True)
18
+
19
+ model_id = "omarabb315/gemma-2B-2nd_filtered_3_full"
20
+ task_id = "text-generation"
21
+ self.pipeline = pipeline(task_id, model=model_id, trust_remote_code=True)
22
 
23
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
24
  # deserialize incomin request
25
  inputs = data.pop("inputs", data)
26
  parameters = data.pop("parameters", None)
27
+
28
+ #model_id = data.pop("model_id", None)
29
 
30
  # check if model_id is in the list of models
31
+ # if model_id is None or model_id not in self.multi_model:
32
+ # raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
33
 
34
  # pass inputs with all kwargs in data
35
  if parameters is not None:
36
+ prediction = self.pipeline(inputs, **parameters)
37
  else:
38
+ prediction = self.pipeline(inputs)
39
+
40
  return prediction