import torch from typing import Dict, List, Any from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # multi-model list # multi_model_list = [ # {"model_id": "gemma-2B-2nd_filtered_3_full", "model_path": "omarabb315/gemma-2B-2nd_filtered_3_full", "task": "text-generation"}, # {"model_path": "omarabb315/gemma-2B-2nd_filtered_3_16bit", "task": "text-generation"}, # {"model_path": "omarabb315/Gemma-2-9B-filtered_3_4bits", "task": "text-generation"}, #] class EndpointHandler(): def __init__(self, path=""): # self.multi_model={} # load all the models onto device # for model in multi_model_list: # self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_path"], trust_remote_code=True) model_id = "omarabb315/gemma-2B-2nd_filtered_3_full" task_id = "text-generation" self.pipeline = pipeline(task_id, model=model_id, trust_remote_code=True) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: # deserialize incomin request inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) #model_id = data.pop("model_id", None) # check if model_id is in the list of models # if model_id is None or model_id not in self.multi_model: # raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}") # pass inputs with all kwargs in data if parameters is not None: prediction = self.pipeline(inputs, **parameters) else: prediction = self.pipeline(inputs) return prediction