Schmadge commited on
Commit
225c8e1
·
1 Parent(s): f54fe31

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +53 -21
handler.py CHANGED
@@ -1,17 +1,51 @@
 
1
  import torch
2
-
3
- from typing import Any, Dict
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
 
6
-
7
- class EndpointHandler:
8
- def __init__(self, path=''):
9
- # load model and tokenizer from path
10
- self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
 
 
11
  self.model = AutoModelForCausalLM.from_pretrained(
12
- path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True
 
 
 
 
 
 
13
  )
14
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
17
  # process input
@@ -19,15 +53,13 @@ class EndpointHandler:
19
  parameters = data.pop("parameters", None)
20
 
21
  # preprocess
22
- inputs = self.tokenizer(inputs, return_tensors="pt").to(self.device)
23
-
24
- # pass inputs with all kwargs in data
25
- if parameters is not None:
26
- outputs = self.model.generate(**inputs, **parameters)
27
- else:
28
- outputs = self.model.generate(**inputs)
29
-
30
- # postprocess the prediction
31
- prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
32
-
33
- return [{"generated_text": prediction}]
 
1
+ import warnings
2
  import torch
 
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from typing import Any, Dict
5
 
6
+ class InstructionTextGenerationPipeline:
7
+ def __init__(
8
+ self,
9
+ path,
10
+ torch_dtype=torch.bfloat16,
11
+ trust_remote_code=True,
12
+ ) -> None:
13
  self.model = AutoModelForCausalLM.from_pretrained(
14
+ path,
15
+ torch_dtype=torch_dtype,
16
+ trust_remote_code=trust_remote_code
17
+ )
18
+ tokenizer = AutoTokenizer.from_pretrained(
19
+ "mosaicml/mpt-7b-instruct",
20
+ trust_remote_code=trust_remote_code
21
  )
22
+ if tokenizer.pad_token_id is None:
23
+ warnings.warn(
24
+ "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
25
+ )
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+
28
+ tokenizer.padding_side = "right" # "left"
29
+ self.tokenizer = tokenizer
30
+
31
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ self.model.eval()
33
+ self.model.to(device=self.device, dtype=torch_dtype)
34
+
35
+ self.generate_kwargs = {
36
+ "temperature": 0.01,
37
+ "top_p": 0.92,
38
+ "top_k": 0,
39
+ "max_new_tokens": 512,
40
+ "use_cache": True,
41
+ "do_sample": True,
42
+ "eos_token_id": self.tokenizer.eos_token_id,
43
+ "pad_token_id": self.tokenizer.pad_token_id,
44
+ "repetition_penalty": 1.0
45
+ }
46
+
47
+ def format_instruction(self, instruction):
48
+ return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
49
 
50
  def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
51
  # process input
 
53
  parameters = data.pop("parameters", None)
54
 
55
  # preprocess
56
+ s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=inputs)
57
+ input_ids = self.tokenizer(s, return_tensors="pt").input_ids.to(self.device)
58
+ gkw = {**self.generate_kwargs, **parameters}
59
+ # pass inputs with all kwargs in data
60
+ with torch.no_grad():
61
+ output_ids = self.model.generate(input_ids, **gkw)
62
+ # Slice the output_ids tensor to get only new tokens
63
+ new_tokens = output_ids[0, len(input_ids[0]) :]
64
+ output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
65
+ return [{"generated_text": output_text}]