shankz7 commited on
Commit
f00f443
·
verified ·
1 Parent(s): 1378471

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +61 -0
handler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from peft import PeftModel
3
+ from unsloth import FastLanguageModel
4
+
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ llm_model,tokenizer=initialize_model_and_tokenizer("mistralai/Mistral-7B-Instruct-v0.2",2048)
9
+ llm_model = PeftModel.from_pretrained(llm_model, "./")
10
+ llm_model.eval()
11
+ self.llm_model = llm_model
12
+ self.tokenizer = tokenizer
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
+ """
16
+ data args:
17
+ inputs (:obj: `str`)
18
+ date (:obj: `str`)
19
+ Return:
20
+ A :obj:`list` | `dict`: will be serialized and returned
21
+ """
22
+ # get inputs
23
+ prompt = data.pop("prompt", "")
24
+ model_input = self.tokenizer(prompt, return_tensors="pt").to(device_map)
25
+ output = self.llm_model.generate(input_ids=model_input["input_ids"].to(device_map),
26
+ use_cache=False,
27
+ temperature=0.1, top_k=1, top_p=1.0, repetition_penalty=1.4,
28
+ max_new_tokens=256,
29
+ do_sample=True,
30
+ pad_token_id=tokenizer.pad_token_id,
31
+ eos_token_id=tokenizer.eos_token_id,
32
+ num_beams=1,
33
+ num_return_sequences=1)
34
+ output = self.tokenizer.decode(output[0])
35
+ result = (output
36
+ .split(tokenizer.eos_token)[0]
37
+ .split("Response:")[1]
38
+ .strip()
39
+ .split("###")[0]
40
+ .replace("```json", "")
41
+ .replace("```", ""))
42
+
43
+ return {"response":result}
44
+
45
+
46
+ def initialize_model_and_tokenizer(model_id: str, max_seq_length: int):
47
+ # Initialize model and tokenizer
48
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
49
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
50
+
51
+ model, tokenizer = FastLanguageModel.from_pretrained(
52
+ model_name=model_id,
53
+ max_seq_length=max_seq_length,
54
+ dtype=dtype,
55
+ load_in_4bit=load_in_4bit,
56
+ )
57
+ model.config.use_cache = False
58
+ model.config.pad_token_id = model.config.eos_token_id
59
+ tokenizer.pad_token = tokenizer.eos_token
60
+ tokenizer.padding_side = "right"
61
+ return model, tokenizer