abashar commited on
Commit
7a0f5b2
·
verified ·
1 Parent(s): de6e79a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -4
handler.py CHANGED
@@ -7,9 +7,9 @@ device = 0 if torch.cuda.is_available() else -1
7
 
8
  # multi-model list
9
  multi_model_list = [
10
- {"model_id": "omarabb315/gemma-2B-2nd_filtered_3_full", "task": "text-generation"},
11
- {"model_id": "omarabb315/gemma-2B-2nd_filtered_3_16bit", "task": "text-generation"},
12
- {"model_id": "omarabb315/Gemma-2-9B-filtered_3_4bits", "task": "text-generation"},
13
  ]
14
 
15
  class EndpointHandler():
@@ -17,7 +17,7 @@ class EndpointHandler():
17
  self.multi_model={}
18
  # load all the models onto device
19
  for model in multi_model_list:
20
- self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_id"], device=device)
21
 
22
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
23
  # deserialize incomin request
 
7
 
8
  # multi-model list
9
  multi_model_list = [
10
+ {"model_id": "gemma-2B-2nd_filtered_3_full", "model_path": "omarabb315/gemma-2B-2nd_filtered_3_full", "task": "text-generation"},
11
+ # {"model_path": "omarabb315/gemma-2B-2nd_filtered_3_16bit", "task": "text-generation"},
12
+ # {"model_path": "omarabb315/Gemma-2-9B-filtered_3_4bits", "task": "text-generation"},
13
  ]
14
 
15
  class EndpointHandler():
 
17
  self.multi_model={}
18
  # load all the models onto device
19
  for model in multi_model_list:
20
+ self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_path"], device=device)
21
 
22
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
23
  # deserialize incomin request