kangqiao-ctrl commited on
Commit
08c0847
·
1 Parent(s): 5be07b8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +37 -60
handler.py CHANGED
@@ -1,67 +1,44 @@
 
 
1
 
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
- import peft
5
- from peft import PeftModel
6
 
7
- # def run():
8
- # base_model_id = "mistralai/Mistral-7B-v0.1"
9
- # bnb_config = BitsAndBytesConfig(
10
- # load_in_4bit=True,
11
- # bnb_4bit_use_double_quant=True,
12
- # bnb_4bit_quant_type="nf4",
13
- # bnb_4bit_compute_dtype=torch.bfloat16
14
- # )
15
-
16
- # base_model = AutoModelForCausalLM.from_pretrained(
17
- # base_model_id, # Mistral, same as before
18
- # quantization_config=bnb_config, # Same quantization config as before
19
- # device_map="auto",
20
- # trust_remote_code=True,
21
- # # use_auth_token=True
22
- # )
23
-
24
- # tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True)
25
-
26
- # ft_model = PeftModel.from_pretrained(base_model, "./checkpoint-100")
27
-
28
- # return ft_model
29
 
 
 
 
30
 
31
 
32
  class EndpointHandler():
33
  def __init__(self, path=""):
34
- base_model_id = "mistralai/Mistral-7B-v0.1"
35
- bnb_config = BitsAndBytesConfig(
36
- load_in_4bit=True,
37
- bnb_4bit_use_double_quant=True,
38
- bnb_4bit_quant_type="nf4",
39
- bnb_4bit_compute_dtype=torch.bfloat16
40
- )
41
-
42
- base_model = AutoModelForCausalLM.from_pretrained(
43
- base_model_id, # Mistral, same as before
44
- quantization_config=bnb_config, # Same quantization config as before
45
- device_map="auto",
46
- trust_remote_code=True,
47
- # use_auth_token=True
48
- )
49
-
50
- tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True)
51
-
52
- self.ft_model = PeftModel.from_pretrained(base_model, "./checkpoint-100")
53
-
54
-
55
-
56
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
57
-
58
- eval_prompt = data.get("inputs", data)
59
-
60
- model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
61
-
62
- ft_model.eval()
63
- with torch.no_grad():
64
- prediction = (tokenizer.decode(self.ft_model.generate(**model_input, max_new_tokens=100, repetition_penalty=1.15)[0], skip_special_tokens=True))
65
-
66
-
67
- return prediction
 
1
+ from typing import Dict, Any
2
+ import logging
3
 
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from peft import PeftConfig, PeftModel
6
+ import torch.cuda
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ LOGGER = logging.getLogger(__name__)
10
+ logging.basicConfig(level=logging.INFO)
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
 
14
  class EndpointHandler():
15
  def __init__(self, path=""):
16
+ config = PeftConfig.from_pretrained(path)
17
+ model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map='auto')
18
+ self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
19
+ # Load the Lora model
20
+ self.model = PeftModel.from_pretrained(model, path)
21
+
22
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
23
+ """
24
+ Args:
25
+ data (Dict): The payload with the text prompt and generation parameters.
26
+ """
27
+ LOGGER.info(f"Received data: {data}")
28
+ # Get inputs
29
+ prompt = data.pop("inputs", None)
30
+ parameters = data.pop("parameters", None)
31
+ if prompt is None:
32
+ raise ValueError("Missing prompt.")
33
+ # Preprocess
34
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device)
35
+ # Forward
36
+ LOGGER.info(f"Start generation.")
37
+ if parameters is not None:
38
+ output = self.model.generate(input_ids=input_ids, **parameters)
39
+ else:
40
+ output = self.model.generate(input_ids=input_ids)
41
+ # Postprocess
42
+ prediction = self.tokenizer.decode(output[0])
43
+ LOGGER.info(f"Generated text: {prediction}")
44
+ return {"generated_text": prediction}