| |
|
| | """
|
| | Created on Thu Nov 14 10:23:53 2024
|
| |
|
| | @author: mj118
|
| | """
|
| |
|
| |
|
| | import torch
|
| | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| |
|
| |
|
| | device = 0 if torch.cuda.is_available() else -1
|
| |
|
| |
|
| | multi_model_list = [
|
| | {"model_id": "MahmoudIbrahim/Mistral_Nemo_Arabic", "task": "text-generation"},
|
| | {"model_id": "Naseej/noon-7b", "task": "text-generation"},
|
| | ]
|
| |
|
| | class EndpointHandler():
|
| | def __init__(self, path=""):
|
| | self.multi_model={}
|
| |
|
| | for model in multi_model_list:
|
| | self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_id"], device=device)
|
| |
|
| | def __call__(self, data):
|
| |
|
| | inputs = data.pop("inputs", data)
|
| | parameters = data.pop("parameters", None)
|
| | model_id = data.pop("model_id", None)
|
| |
|
| |
|
| | if model_id is None or model_id not in self.multi_model:
|
| | raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
|
| |
|
| |
|
| | if parameters is not None:
|
| | prediction = self.multi_model[model_id](inputs, **parameters)
|
| | else:
|
| | prediction = self.multi_model[model_id](inputs)
|
| |
|
| | return prediction |