d-s-b commited on
Commit
c86eed8
·
verified ·
1 Parent(s): fc477be

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +39 -36
handler.py CHANGED
@@ -1,13 +1,18 @@
1
- from typing import Dict, List, Any
2
  import torch
 
 
3
 
4
  class EndpointHandler:
5
- def __init__(self, path=""):
6
- from transformers import AutoModelForCausalLM, AutoTokenizer
7
- self.tokenizer = AutoTokenizer.from_pretrained(path)
8
- self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype="auto", device_map="auto")
9
-
10
- self.inference_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
 
 
 
 
11
  Write a response that appropriately completes the request.
12
  Identify the most suitable meme template based on the provided example situations.
13
 
@@ -20,32 +25,30 @@ Please identify the meme name that best fits the given examples_list.
20
 
21
  ### Response:
22
  """
23
-
24
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
25
- question = data.pop("inputs", data)
26
- parameters = data.pop("parameters", {})
27
-
28
- # Set default parameters
29
- max_new_tokens = parameters.get("max_new_tokens", 12000)
30
-
31
- # Format prompt
32
- prompt = self.inference_prompt_style.format(question) + self.tokenizer.eos_token
33
-
34
- # Tokenize
35
- inputs = self.tokenizer([prompt], return_tensors="pt")
36
-
37
- # Generate
38
- outputs = self.model.generate(
39
- input_ids=inputs.input_ids,
40
- attention_mask=inputs.attention_mask,
41
- max_new_tokens=max_new_tokens,
42
- eos_token_id=self.tokenizer.eos_token_id,
43
- use_cache=True,
44
- **parameters
45
- )
46
-
47
- # Decode and extract response
48
- response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
49
- result = response[0].split("### Response:")[1]
50
-
51
- return [{"generated_text": result}]
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from typing import Dict, List, Any
4
 
5
  class EndpointHandler:
6
+ def __init__(self, path="d-s-b/meme"):
7
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
8
+ self.model = AutoModelForCausalLM.from_pretrained(
9
+ path,
10
+ torch_dtype="auto",
11
+ device_map="auto",
12
+ trust_remote_code=True
13
+ )
14
+
15
+ self.inference_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
16
  Write a response that appropriately completes the request.
17
  Identify the most suitable meme template based on the provided example situations.
18
 
 
25
 
26
  ### Response:
27
  """
28
+
29
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
30
+ question = data.pop("inputs", data)
31
+ parameters = data.pop("parameters", {})
32
+
33
+ max_new_tokens = parameters.get("max_new_tokens", 512)
34
+
35
+ prompt = self.inference_prompt_style.format(question)
36
+
37
+ inputs = self.tokenizer([prompt], return_tensors="pt")
38
+
39
+ outputs = self.model.generate(
40
+ input_ids=inputs.input_ids,
41
+ attention_mask=inputs.attention_mask,
42
+ max_new_tokens=max_new_tokens,
43
+ temperature=0.7,
44
+ do_sample=True,
45
+ eos_token_id=self.tokenizer.eos_token_id,
46
+ pad_token_id=self.tokenizer.eos_token_id,
47
+ use_cache=True,
48
+ **parameters
49
+ )
50
+
51
+ response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
52
+ result = response[0].split("### Response:")[1].strip()
53
+
54
+ return [{"generated_text": result}]