MJ92 commited on
Commit
e05cbbb
·
verified ·
1 Parent(s): 84d3d70

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -0
handler.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Thu Nov 14 10:23:53 2024
4
+
5
+ @author: mj118
6
+ """
7
+
8
+ # handler.py
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
11
+
12
+ # check for GPU
13
+ device = 0 if torch.cuda.is_available() else -1
14
+
15
+ # multi-model list
16
+ multi_model_list = [
17
+ {"model_id": "MahmoudIbrahim/Mistral_Nemo_Arabic", "task": "text-generation"},
18
+ {"model_id": "inceptionai/jais-family-2p7b-chat", "task": "text-generation"},
19
+ ]
20
+
21
+ class EndpointHandler():
22
+ def __init__(self, path=""):
23
+ self.multi_model={}
24
+ # load all the models onto device
25
+ for model in multi_model_list:
26
+ self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_id"], device=device)
27
+
28
+ def __call__(self, data):
29
+ # deserialize incomin request
30
+ inputs = data.pop("inputs", data)
31
+ parameters = data.pop("parameters", None)
32
+ model_id = data.pop("model_id", None)
33
+
34
+ # check if model_id is in the list of models
35
+ if model_id is None or model_id not in self.multi_model:
36
+ raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
37
+
38
+ # pass inputs with all kwargs in data
39
+ if parameters is not None:
40
+ prediction = self.multi_model[model_id](inputs, **parameters)
41
+ else:
42
+ prediction = self.multi_model[model_id](inputs)
43
+ # postprocess the prediction
44
+ return prediction