shankz7 commited on
Commit
ef719e5
·
verified ·
1 Parent(s): 0c202de

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -21
handler.py CHANGED
@@ -2,31 +2,20 @@ from typing import Dict, List, Any
2
  from peft import PeftModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
 
5
 
6
- class EndpointHandler:
7
- def __init__(self, path=""):
8
-
 
9
  self.device_map = "cuda" # the device to load the model onto
 
10
 
11
- llm_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
12
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
13
 
14
- # llm_model = PeftModel.from_pretrained(llm_model, ".")
15
- llm_model.eval()
16
- self.llm_model = llm_model
17
- self.tokenizer = tokenizer
18
-
19
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
20
- """
21
- data args:
22
- inputs (:obj: `str`)
23
- date (:obj: `str`)
24
- Return:
25
- A :obj:`list` | `dict`: will be serialized and returned
26
- """
27
- # get inputs
28
- prompt = data.pop("prompt", "")
29
- model_input = self.tokenizer(prompt, return_tensors="pt").to(self.device_map)
30
  output = self.llm_model.generate(input_ids=model_input["input_ids"].to(self.device_map),
31
  use_cache=False,
32
  temperature=0.1, top_k=1, top_p=1.0, repetition_penalty=1.4,
 
2
  from peft import PeftModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
 
7
+ class EndpointHandler():
8
+ def __init__(self, path="."):
9
+ self.model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
10
+ self.model.eval()
11
  self.device_map = "cuda" # the device to load the model onto
12
+ self.tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
13
 
 
 
14
 
15
+ def __call__(self, inputs: str):
16
+ if len(inputs) == 0:
17
+ raise ValueError("prompt cannot be empty")
18
+ model_input = self.tokenizer(inputs, return_tensors="pt").to(self.device_map)
 
 
 
 
 
 
 
 
 
 
 
 
19
  output = self.llm_model.generate(input_ids=model_input["input_ids"].to(self.device_map),
20
  use_cache=False,
21
  temperature=0.1, top_k=1, top_p=1.0, repetition_penalty=1.4,