d-s-b commited on
Commit
5de9685
·
verified ·
1 Parent(s): f92601e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -42
handler.py CHANGED
@@ -1,53 +1,28 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from typing import Dict, List, Any
 
 
4
 
5
  class EndpointHandler:
6
- def __init__(self, path="d-s-b/meme"):
7
- self.tokenizer = AutoTokenizer.from_pretrained("d-s-b/meme")
 
8
  self.model = AutoModelForCausalLM.from_pretrained(
9
- "d-s-b/meme",
10
- torch_dtype="auto",
11
  device_map="auto"
12
  )
13
-
14
- self.inference_prompt_style = """Below is an instruction that describes a task, paired with an input that provides further context.
15
- Write a response that appropriately completes the request.
16
- Identify the most suitable meme template based on the provided example situations.
17
-
18
- ### Instruction:
19
- You are a meme expert who knows how to map real-life situations to the correct meme name.
20
- Please identify the meme name that best fits the given examples_list.
21
-
22
- ### Input (examples_list):
23
- {}
24
-
25
- ### Response:
26
- """
27
-
28
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
29
- question = data.pop("inputs", data)
30
- parameters = data.pop("parameters", {})
31
-
32
- max_new_tokens = parameters.get("max_new_tokens", 512)
33
-
34
- prompt = self.inference_prompt_style.format(question)
35
-
36
- inputs = self.tokenizer([prompt], return_tensors="pt")
37
 
 
38
  outputs = self.model.generate(
39
- input_ids=inputs.input_ids,
40
- attention_mask=inputs.attention_mask,
41
- max_new_tokens=max_new_tokens,
42
- temperature=0.7,
43
- do_sample=True,
44
- eos_token_id=self.tokenizer.eos_token_id,
45
- pad_token_id=self.tokenizer.eos_token_id,
46
- use_cache=True,
47
- **parameters
48
  )
49
 
50
- response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
51
- result = response[0].split("### Response:")[1].strip()
52
-
53
- return [{"generated_text": result}]
 
 
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
  class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ # Load from local path, not from hub
8
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
9
  self.model = AutoModelForCausalLM.from_pretrained(
10
+ path,
11
+ torch_dtype=torch.float16,
12
  device_map="auto"
13
  )
14
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ inputs = data.get("inputs", "")
17
+ parameters = data.get("parameters", {})
 
 
 
 
 
 
18
 
19
+ encoded = self.tokenizer(inputs, return_tensors="pt")
20
  outputs = self.model.generate(
21
+ **encoded,
22
+ max_length=parameters.get("max_length", 100),
23
+ temperature=parameters.get("temperature", 0.7),
24
+ do_sample=True
 
 
 
 
 
25
  )
26
 
27
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
28
+ return [{"generated_text": response}]