zohaibterminator commited on
Commit
0fbd23f
·
verified ·
1 Parent(s): 4a52a67

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +51 -0
handler.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+ import os
4
+ import torch
5
+ from subprocess import run
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
8
+
9
+ run("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121", shell=True, check=True)
10
+
11
+ class EndpointHandler():
12
+ def __init__(self, path=""):
13
+ # Preload all the elements you are going to need at inference.
14
+ # pseudo
15
+ # self.model = load_model(path)
16
+ self.HF_READ_TOKEN = os.getenv("HF_READ_TOKEN")
17
+
18
+ print("loading model")
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(path, token=self.HF_READ_TOKEN)
21
+
22
+ model= AutoModelForCausalLM.from_pretrained(
23
+ pretrained_model_name_or_path = path,
24
+ token = self.HF_READ_TOKEN,
25
+ torch_dtype=dtype,
26
+ ).to(device)
27
+
28
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
29
+ self.alpaca_prompt = """REDACTED"""
30
+ print("model loaded")
31
+
32
+
33
+
34
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
35
+ """
36
+ data args:
37
+ inputs (:obj: `str` | `PIL.Image` | `np.array`)
38
+ kwargs
39
+ Return:
40
+ A :obj:`list` | `dict`: will be serialized and returned
41
+ """
42
+
43
+ # pseudo
44
+ # self.model(input)
45
+ if data["input"] is not Null:
46
+ request = data.pop("input",data)
47
+ inputs = self.alpaca_prompt.format(request)
48
+ prediction = self.pipeline(inputs)
49
+ return {"prediction": prediction}
50
+ else:
51
+ return [{"Error" : "no input received."}]