Mir-2002 commited on
Commit
a81cb37
·
1 Parent(s): 992672e

added handler for inference endpoint

Browse files
Files changed (1) hide show
  1. handler.py +59 -0
handler.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+
5
+ MAX_INPUT_LENGTH = 256
6
+ MAX_OUTPUT_LENGTH = 128
7
+
8
+ class EndpointHandler:
9
+ def __init__(self, model_dir: str = "", **kwargs: Any) -> None:
10
+ """
11
+ Initializes the model and tokenizer when the endpoint starts.
12
+ """
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
14
+ # Assuming you fine-tuned CodeT5+ for a sequence-to-sequence task
15
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
16
+ self.model.eval() # Set model to evaluation mode
17
+ # You might want to move the model to GPU if available
18
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ self.model.to(self.device)
20
+
21
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
22
+ """
23
+ Handles incoming inference requests.
24
+ """
25
+ inputs = data.get("inputs")
26
+ if not inputs:
27
+ raise ValueError("No 'inputs' found in the request data.")
28
+
29
+ # Ensure inputs are in a list for batch processing, even if single input
30
+ if isinstance(inputs, str):
31
+ inputs = [inputs]
32
+
33
+ # Pre-processing
34
+ # Adjust max_length and padding based on your model's training and task
35
+ tokenized_inputs = self.tokenizer(
36
+ inputs,
37
+ max_length=MAX_INPUT_LENGTH,
38
+ padding=True,
39
+ truncation=True,
40
+ return_tensors="pt"
41
+ ).to(self.device)
42
+
43
+ # Inference
44
+ with torch.no_grad():
45
+ outputs = self.model.generate(
46
+ tokenized_inputs["input_ids"],
47
+ attention_mask=tokenized_inputs["attention_mask"],
48
+ # Add generation arguments relevant to your task (e.g., max_length, num_beams)
49
+ max_length=MAX_OUTPUT_LENGTH, # Example, adjust as needed
50
+ num_beams=8, # Example, adjust as needed
51
+ no_repeat_ngram_size=3,
52
+ pad_token_id=self.tokenizer.pad_token_id) # Fixed: Added self. before tokenizer
53
+
54
+ # Post-processing
55
+ decoded_outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
56
+
57
+ # Format the output as a list of dictionaries
58
+ results = [{"generated_text": text} for text in decoded_outputs]
59
+ return results