Mfusenig commited on
Commit
012a858
·
verified ·
1 Parent(s): 520e9a8

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +22 -41
handler.py CHANGED
@@ -1,50 +1,31 @@
1
- import os
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import torch
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- """
8
- Initializes the handler by loading the T5Gemma model and tokenizer.
9
- trust_remote_code=True is essential for new architectures.
10
- """
11
- self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
12
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
13
- path,
14
- torch_dtype=torch.bfloat16,
15
- trust_remote_code=True
16
  )
17
- self.model.eval()
18
-
19
  def __call__(self, data):
20
- """
21
- This method is called for each inference request. It now uses the
22
- tokenizer's chat template, which is the correct and most robust
23
- method for formatting inputs for this model.
24
- """
25
- # Get inputs and generation parameters
26
- inputs_text = data.pop("inputs", [])
27
- parameters = data
28
-
29
- if isinstance(inputs_text, str):
30
- inputs_text = [inputs_text]
31
-
32
- # Create the chat message structure that apply_chat_template expects
33
- messages_list = [[{"role": "user", "content": text}] for text in inputs_text]
34
-
35
- # Apply the model's specific chat template to format the input correctly
36
- # The tokenizer handles padding for batched inputs automatically.
37
- input_ids = [
38
- self.tokenizer.apply_chat_template(
39
- messages, add_generation_prompt=True, return_tensors="pt"
40
- ) for messages in messages_list
41
- ]
42
 
43
- # Batch generation
44
- outputs = []
45
- for ids in input_ids:
46
- output_tokens = self.model.generate(ids, **parameters)
47
- # For T5, the output contains only the generated tokens
48
- outputs.append(self.tokenizer.decode(output_tokens[0], skip_special_tokens=True))
49
 
50
- return outputs
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
2
  import torch
3
 
4
  class EndpointHandler:
5
  def __init__(self, path=""):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
 
 
 
 
7
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
8
+ path,
9
+ torch_dtype=torch.bfloat16
 
10
  )
11
+
 
12
  def __call__(self, data):
13
+ inputs = data.pop("inputs", data)
14
+ messages = [{"role": "user", "content": inputs}]
15
+
16
+ input_ids = self.tokenizer.apply_chat_template(
17
+ messages,
18
+ add_generation_prompt=True,
19
+ return_tensors="pt"
20
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ outputs = self.model.generate(
23
+ input_ids,
24
+ max_new_tokens=1024,
25
+ temperature=0.1,
26
+ do_sample=True
27
+ )
28
 
29
+ return {
30
+ "generated_text": self.tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+ }