mila2030 commited on
Commit
f4997a0
·
verified ·
1 Parent(s): 04a5685

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +159 -0
handler.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ # Hugging Face Inference Toolkit custom handler for chat-style or plain text prompts.
3
+ # Supports two input formats:
4
+ # 1) HF standard: {"inputs": "your prompt", "parameters": {...}}
5
+ # 2) Chat format: {"messages": [{"role":"system"|"user"|"assistant","content":"..."}], "parameters": {...}}
6
+
7
+ import os
8
+ import json
9
+ import torch
10
+
11
+ from typing import Any, Dict, List, Optional, Union
12
+ from transformers import (
13
+ AutoTokenizer,
14
+ AutoModelForCausalLM,
15
+ TextIteratorStreamer,
16
+ )
17
+
18
+ # Optional: respect a few env knobs (set in Endpoint settings)
19
+ DEFAULT_MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
20
+ DEFAULT_TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
21
+ DEFAULT_TOP_P = float(os.getenv("TOP_P", "0.9"))
22
+ DEFAULT_TOP_K = int(os.getenv("TOP_K", "50"))
23
+ DEFAULT_DO_SAMPLE = os.getenv("DO_SAMPLE", "true").lower() in {"1","true","yes"}
24
+ DEFAULT_REPETITION_PEN = float(os.getenv("REPETITION_PENALTY", "1.05"))
25
+
26
+ class EndpointHandler:
27
+ """Hugging Face custom handler contract"""
28
+ def __init__(self, model_dir: str, *args, **kwargs):
29
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
31
+
32
+ # Load model + tokenizer from model_dir (automatically provided by the endpoint)
33
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True)
34
+ self.model = AutoModelForCausalLM.from_pretrained(
35
+ model_dir,
36
+ torch_dtype=dtype,
37
+ device_map="auto" if self.device == "cuda" else None,
38
+ )
39
+ self.model.eval()
40
+
41
+ # Some chat models need this to avoid warnings
42
+ if not self.tokenizer.pad_token_id:
43
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
44
+
45
+ # -------- Utilities --------
46
+ def _to_prompt_from_messages(self, messages: List[Dict[str, str]]) -> str:
47
+ """
48
+ If tokenizer supports chat template, use it. Otherwise, build a simple prompt.
49
+ messages = [{"role":"system"|"user"|"assistant","content":"..."}]
50
+ """
51
+ if hasattr(self.tokenizer, "apply_chat_template"):
52
+ try:
53
+ return self.tokenizer.apply_chat_template(
54
+ messages,
55
+ tokenize=False,
56
+ add_generation_prompt=True
57
+ )
58
+ except Exception:
59
+ pass # fallback below
60
+
61
+ # Minimal fallback prompt
62
+ role_map = {"system": "System", "user": "User", "assistant": "Assistant"}
63
+ lines = []
64
+ for m in messages:
65
+ role = role_map.get(m.get("role","user"), "User")
66
+ content = m.get("content","")
67
+ lines.append(f"{role}: {content}")
68
+ lines.append("Assistant:")
69
+ return "\n".join(lines)
70
+
71
+ def _pack_inputs(
72
+ self,
73
+ payload: Dict[str, Any]
74
+ ) -> Dict[str, Any]:
75
+ """
76
+ Normalize inbound payload to a single string prompt + gen params.
77
+ Accepts:
78
+ {"inputs": "...", "parameters": {...}}
79
+ or
80
+ {"messages":[...], "parameters": {...}}
81
+ """
82
+ parameters = payload.get("parameters", {}) or {}
83
+
84
+ max_new_tokens = int(parameters.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS))
85
+ temperature = float(parameters.get("temperature", DEFAULT_TEMPERATURE))
86
+ top_p = float(parameters.get("top_p", DEFAULT_TOP_P))
87
+ top_k = int(parameters.get("top_k", DEFAULT_TOP_K))
88
+ do_sample = bool(parameters.get("do_sample", DEFAULT_DO_SAMPLE))
89
+ repetition_pen = float(parameters.get("repetition_penalty", DEFAULT_REPETITION_PEN))
90
+
91
+ if "messages" in payload:
92
+ prompt = self._to_prompt_from_messages(payload["messages"])
93
+ else:
94
+ prompt = payload.get("inputs", "")
95
+ if not isinstance(prompt, str):
96
+ # Some clients send list[str]
97
+ if isinstance(prompt, list) and prompt and isinstance(prompt[0], str):
98
+ prompt = prompt[0]
99
+ else:
100
+ prompt = str(prompt)
101
+
102
+ gen_kwargs = {
103
+ "max_new_tokens": max_new_tokens,
104
+ "temperature": temperature,
105
+ "top_p": top_p,
106
+ "top_k": top_k,
107
+ "do_sample": do_sample,
108
+ "repetition_penalty": repetition_pen,
109
+ "eos_token_id": self.tokenizer.eos_token_id,
110
+ "pad_token_id": self.tokenizer.pad_token_id,
111
+ }
112
+ return {"prompt": prompt, "gen_kwargs": gen_kwargs}
113
+
114
+ # -------- Main inference entry --------
115
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
116
+ """
117
+ Return shape (recommended):
118
+ {"text": "..."} # single string
119
+ Optional:
120
+ {"generated_text": "..."} # alt field some tools expect
121
+ {"usage": {"prompt_tokens":..., "completion_tokens":..., "total_tokens":...}}
122
+ """
123
+ try:
124
+ packed = self._pack_inputs(data)
125
+ prompt = packed["prompt"]
126
+ gen_kwargs = packed["gen_kwargs"]
127
+
128
+ if not prompt.strip():
129
+ return {"text": "Empty prompt."}
130
+
131
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
132
+
133
+ with torch.no_grad():
134
+ output_ids = self.model.generate(
135
+ **inputs,
136
+ **gen_kwargs,
137
+ )
138
+
139
+ # Remove the prompt portion to get only the newly generated tokens
140
+ gen_ids = output_ids[0][inputs["input_ids"].shape[-1]:]
141
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
142
+
143
+ # Basic usage metrics (approx)
144
+ prompt_tokens = int(inputs["input_ids"].numel())
145
+ completion_tokens = int(gen_ids.numel())
146
+ total_tokens = prompt_tokens + completion_tokens
147
+
148
+ return {
149
+ "text": text,
150
+ "generated_text": text,
151
+ "usage": {
152
+ "prompt_tokens": prompt_tokens,
153
+ "completion_tokens": completion_tokens,
154
+ "total_tokens": total_tokens,
155
+ },
156
+ }
157
+ except Exception as e:
158
+ # Never crash the container: return a JSON error
159
+ return {"error": str(e)}