| import torch | |
| from typing import Dict, List, Any | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| # multi-model list | |
| # multi_model_list = [ | |
| # {"model_id": "gemma-2B-2nd_filtered_3_full", "model_path": "omarabb315/gemma-2B-2nd_filtered_3_full", "task": "text-generation"}, | |
| # {"model_path": "omarabb315/gemma-2B-2nd_filtered_3_16bit", "task": "text-generation"}, | |
| # {"model_path": "omarabb315/Gemma-2-9B-filtered_3_4bits", "task": "text-generation"}, | |
| ] | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| # self.multi_model={} | |
| # load all the models onto device | |
| # for model in multi_model_list: | |
| # self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_path"], trust_remote_code=True) | |
| model_id = "omarabb315/gemma-2B-2nd_filtered_3_full" | |
| task_id = "text-generation" | |
| self.pipeline = pipeline(task_id, model=model_id, trust_remote_code=True) | |
| def __call__(self, data: Any) -> List[List[Dict[str, float]]]: | |
| # deserialize incomin request | |
| inputs = data.pop("inputs", data) | |
| parameters = data.pop("parameters", None) | |
| #model_id = data.pop("model_id", None) | |
| # check if model_id is in the list of models | |
| # if model_id is None or model_id not in self.multi_model: | |
| # raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}") | |
| # pass inputs with all kwargs in data | |
| if parameters is not None: | |
| prediction = self.pipeline(inputs, **parameters) | |
| else: | |
| prediction = self.pipeline(inputs) | |
| return prediction |