busybisi commited on
Commit
b219423
·
verified ·
1 Parent(s): 9538bc8

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +89 -0
handler.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+
6
+ class EndpointHandler:
7
+ """
8
+ Custom handler for DoloresAI model on HuggingFace Inference Endpoints.
9
+ """
10
+
11
+ def __init__(self, path=""):
12
+ """
13
+ Initialize the handler with the model and tokenizer.
14
+
15
+ Args:
16
+ path (str): Path to the model directory
17
+ """
18
+ # Load tokenizer and model
19
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
20
+ self.model = AutoModelForCausalLM.from_pretrained(
21
+ path,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto",
24
+ low_cpu_mem_usage=True
25
+ )
26
+
27
+ # Verify vocab sizes match
28
+ assert self.model.config.vocab_size == len(self.tokenizer), \
29
+ f"Vocab size mismatch: model={self.model.config.vocab_size}, tokenizer={len(self.tokenizer)}"
30
+
31
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
32
+ """
33
+ Process inference requests.
34
+
35
+ Args:
36
+ data (Dict): Input data with format:
37
+ {
38
+ "inputs": str, # The prompt text
39
+ "parameters": { # Optional generation parameters
40
+ "max_new_tokens": int,
41
+ "temperature": float,
42
+ "top_p": float,
43
+ "do_sample": bool,
44
+ "repetition_penalty": float
45
+ }
46
+ }
47
+
48
+ Returns:
49
+ List[Dict]: Generated text response
50
+ """
51
+ # Extract inputs
52
+ inputs = data.pop("inputs", data)
53
+ parameters = data.pop("parameters", {})
54
+
55
+ # Default generation parameters
56
+ max_new_tokens = parameters.get("max_new_tokens", 512)
57
+ temperature = parameters.get("temperature", 0.7)
58
+ top_p = parameters.get("top_p", 0.9)
59
+ do_sample = parameters.get("do_sample", True)
60
+ repetition_penalty = parameters.get("repetition_penalty", 1.1)
61
+
62
+ # Tokenize input
63
+ input_ids = self.tokenizer(
64
+ inputs,
65
+ return_tensors="pt",
66
+ truncation=True,
67
+ max_length=self.model.config.max_position_embeddings - max_new_tokens
68
+ ).input_ids.to(self.model.device)
69
+
70
+ # Generate response
71
+ with torch.no_grad():
72
+ outputs = self.model.generate(
73
+ input_ids,
74
+ max_new_tokens=max_new_tokens,
75
+ temperature=temperature,
76
+ top_p=top_p,
77
+ do_sample=do_sample,
78
+ repetition_penalty=repetition_penalty,
79
+ pad_token_id=self.tokenizer.eos_token_id,
80
+ eos_token_id=self.tokenizer.eos_token_id,
81
+ )
82
+
83
+ # Decode output
84
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+
86
+ # Remove the input prompt from the response
87
+ response_text = generated_text[len(inputs):].strip()
88
+
89
+ return [{"generated_text": response_text}]