Mfusenig commited on
Commit
520e9a8
·
verified ·
1 Parent(s): 5c66470

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +50 -0
handler.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import torch
4
+
5
+ class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ """
8
+ Initializes the handler by loading the T5Gemma model and tokenizer.
9
+ trust_remote_code=True is essential for new architectures.
10
+ """
11
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
12
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
13
+ path,
14
+ torch_dtype=torch.bfloat16,
15
+ trust_remote_code=True
16
+ )
17
+ self.model.eval()
18
+
19
+ def __call__(self, data):
20
+ """
21
+ This method is called for each inference request. It now uses the
22
+ tokenizer's chat template, which is the correct and most robust
23
+ method for formatting inputs for this model.
24
+ """
25
+ # Get inputs and generation parameters
26
+ inputs_text = data.pop("inputs", [])
27
+ parameters = data
28
+
29
+ if isinstance(inputs_text, str):
30
+ inputs_text = [inputs_text]
31
+
32
+ # Create the chat message structure that apply_chat_template expects
33
+ messages_list = [[{"role": "user", "content": text}] for text in inputs_text]
34
+
35
+ # Apply the model's specific chat template to format the input correctly
36
+ # The tokenizer handles padding for batched inputs automatically.
37
+ input_ids = [
38
+ self.tokenizer.apply_chat_template(
39
+ messages, add_generation_prompt=True, return_tensors="pt"
40
+ ) for messages in messages_list
41
+ ]
42
+
43
+ # Batch generation
44
+ outputs = []
45
+ for ids in input_ids:
46
+ output_tokens = self.model.generate(ids, **parameters)
47
+ # For T5, the output contains only the generated tokens
48
+ outputs.append(self.tokenizer.decode(output_tokens[0], skip_special_tokens=True))
49
+
50
+ return outputs