crodri commited on
Commit
e48c5e9
·
verified ·
1 Parent(s): d78fef8

Rename handler.py.bak to handler.py

Browse files
Files changed (2) hide show
  1. handler.py +237 -0
  2. handler.py.bak +0 -296
handler.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ # Hugging Face Inference Endpoints - Custom Handler
3
+ #
4
+ # This handler starts an internal SGLang server (OpenAI-compatible) and proxies
5
+ # requests to it. It supports both:
6
+ # - HF "inputs": str (single prompt)
7
+ # - HF "inputs": list[{"role": "...", "content": "..."}] (chat style)
8
+ #
9
+ # Expected request body patterns (common in HF endpoints):
10
+ # - {"inputs": "Hello", "parameters": {"max_new_tokens": 256, "temperature": 0.7}}
11
+ # - {"inputs": [{"role":"user","content":"Hello"}], "parameters": {...}}
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import os
17
+ import socket
18
+ import subprocess
19
+ import time
20
+ from typing import Any, Dict, List, Optional, Union
21
+
22
+ import requests
23
+
24
+
25
+ def _is_port_open(host: str, port: int, timeout_s: float = 0.5) -> bool:
26
+ try:
27
+ with socket.create_connection((host, port), timeout=timeout_s):
28
+ return True
29
+ except OSError:
30
+ return False
31
+
32
+
33
+ def _wait_for_server(host: str, port: int, health_url: str, timeout_s: int = 300) -> None:
34
+ start = time.time()
35
+ # 1) Wait for TCP port
36
+ while time.time() - start < timeout_s:
37
+ if _is_port_open(host, port):
38
+ break
39
+ time.sleep(0.5)
40
+
41
+ # 2) Wait for /health (preferred)
42
+ while time.time() - start < timeout_s:
43
+ try:
44
+ r = requests.get(health_url, timeout=2)
45
+ if r.status_code == 200:
46
+ return
47
+ except requests.RequestException:
48
+ pass
49
+ time.sleep(0.5)
50
+
51
+ raise RuntimeError(
52
+ f"SGLang server did not become ready within {timeout_s}s "
53
+ f"(host={host}, port={port}, health={health_url})."
54
+ )
55
+
56
+
57
+ def _coerce_messages(inputs: Any) -> List[Dict[str, str]]:
58
+ """
59
+ Convert HF inputs into OpenAI chat messages.
60
+ """
61
+ if isinstance(inputs, str):
62
+ return [{"role": "user", "content": inputs}]
63
+
64
+ if isinstance(inputs, list):
65
+ # Already messages?
66
+ # We accept list of dicts with role/content, or list of strings.
67
+ if all(isinstance(x, dict) for x in inputs):
68
+ msgs: List[Dict[str, str]] = []
69
+ for m in inputs:
70
+ role = str(m.get("role", "user"))
71
+ content = m.get("content", "")
72
+ if content is None:
73
+ content = ""
74
+ msgs.append({"role": role, "content": str(content)})
75
+ return msgs
76
+ if all(isinstance(x, str) for x in inputs):
77
+ # Treat as a multi-line user prompt
78
+ return [{"role": "user", "content": "\n".join(inputs)}]
79
+
80
+ # Fallback: stringify
81
+ return [{"role": "user", "content": json.dumps(inputs, ensure_ascii=False)}]
82
+
83
+
84
+ def _map_generation_params(hf_params: Dict[str, Any]) -> Dict[str, Any]:
85
+ """
86
+ Map typical HF params to OpenAI-compatible chat completion params.
87
+ Keep pass-through for unknown keys where it is safe.
88
+ """
89
+ if hf_params is None:
90
+ hf_params = {}
91
+
92
+ # Common HF keys: max_new_tokens, temperature, top_p, repetition_penalty, stop, seed
93
+ out: Dict[str, Any] = {}
94
+
95
+ max_new_tokens = hf_params.get("max_new_tokens", hf_params.get("max_tokens"))
96
+ if max_new_tokens is not None:
97
+ out["max_tokens"] = int(max_new_tokens)
98
+
99
+ for k in ("temperature", "top_p", "seed"):
100
+ if k in hf_params and hf_params[k] is not None:
101
+ out[k] = hf_params[k]
102
+
103
+ # HF sometimes uses "stop" (str or list[str])
104
+ if "stop" in hf_params and hf_params["stop"] is not None:
105
+ out["stop"] = hf_params["stop"]
106
+
107
+ # OpenAI-compatible streaming flag; HF toolkit generally expects non-streaming response
108
+ if "stream" in hf_params:
109
+ out["stream"] = bool(hf_params["stream"])
110
+ else:
111
+ out["stream"] = False
112
+
113
+ # Best-effort pass-through for presence/frequency penalty if provided
114
+ for k in ("presence_penalty", "frequency_penalty"):
115
+ if k in hf_params and hf_params[k] is not None:
116
+ out[k] = hf_params[k]
117
+
118
+ return out
119
+
120
+
121
+ class EndpointHandler:
122
+ """
123
+ Hugging Face Inference Endpoints custom handler:
124
+ - __init__(model_dir): load/init anything
125
+ - __call__(data): run inference
126
+ """
127
+
128
+ def __init__(self, model_dir: str, **_: Any) -> None:
129
+ # HF mounts the repo under model_dir (typically /repository)
130
+ self.model_dir = model_dir
131
+
132
+ # Where SGLang will listen
133
+ self.host = os.getenv("SGLANG_HOST", "127.0.0.1")
134
+ self.port = int(os.getenv("SGLANG_PORT", "30000"))
135
+
136
+ # Model identifier/path
137
+ # For Inference Endpoints, weights/artifacts are available under model_dir.
138
+ # Using local path avoids extra hub downloads.
139
+ self.model_path = os.getenv("SGLANG_MODEL_PATH", model_dir)
140
+
141
+ # Optional: tokenizer path (defaults to model path)
142
+ self.tokenizer_path = os.getenv("SGLANG_TOKENIZER_PATH", self.model_path)
143
+
144
+ # Optional: tensor parallel size, chunked prefill, etc. (SGLang server args)
145
+ self.tp_size = int(os.getenv("SGLANG_TP_SIZE", "1"))
146
+ self.chunked_prefill_size = os.getenv("SGLANG_CHUNKED_PREFILL_SIZE", "") # e.g. "4096"
147
+ self.max_running_requests = os.getenv("SGLANG_MAX_RUNNING_REQUESTS", "") # e.g. "64"
148
+
149
+ # If you already have a command you want to run, you can override entirely:
150
+ # SGLANG_LAUNCH_CMD='python -m sglang.launch_server --model-path ...'
151
+ launch_cmd = os.getenv("SGLANG_LAUNCH_CMD", "").strip()
152
+ if launch_cmd:
153
+ cmd = launch_cmd.split()
154
+ else:
155
+ # Default launch command (SGLang OpenAI-compatible server)
156
+ cmd = [
157
+ "python",
158
+ "-m",
159
+ "sglang.launch_server",
160
+ "--model-path",
161
+ self.model_path,
162
+ "--tokenizer-path",
163
+ self.tokenizer_path,
164
+ "--host",
165
+ "0.0.0.0",
166
+ "--port",
167
+ str(self.port),
168
+ "--tp-size",
169
+ str(self.tp_size),
170
+ ]
171
+
172
+ if self.chunked_prefill_size:
173
+ cmd += ["--chunked-prefill-size", str(self.chunked_prefill_size)]
174
+ if self.max_running_requests:
175
+ cmd += ["--max-running-requests", str(self.max_running_requests)]
176
+
177
+ self.health_url = f"http://{self.host}:{self.port}/health"
178
+ self.chat_url = f"http://{self.host}:{self.port}/v1/chat/completions"
179
+
180
+ # Start SGLang server if not already up
181
+ if not _is_port_open(self.host, self.port):
182
+ # Important: do NOT use stdout=PIPE in production unless you drain it (deadlocks).
183
+ self.proc = subprocess.Popen(
184
+ cmd,
185
+ env=os.environ.copy(),
186
+ )
187
+ else:
188
+ self.proc = None
189
+
190
+ _wait_for_server(self.host, self.port, self.health_url, timeout_s=int(os.getenv("SGLANG_STARTUP_TIMEOUT", "600")))
191
+
192
+ # Model name presented to OpenAI-compatible API (some servers accept "model" as optional)
193
+ self.served_model_name = os.getenv("SGLANG_SERVED_MODEL_NAME", "ALIA-40b-instruct-nvfp4")
194
+
195
+ def __call__(self, data: Dict[str, Any]) -> Union[str, Dict[str, Any]]:
196
+ inputs = data.get("inputs", data) # sometimes HF passes the full payload as inputs
197
+ params = data.get("parameters", {}) or {}
198
+
199
+ messages = _coerce_messages(inputs)
200
+ gen = _map_generation_params(params)
201
+
202
+ payload: Dict[str, Any] = {
203
+ "model": self.served_model_name,
204
+ "messages": messages,
205
+ **gen,
206
+ }
207
+
208
+ # Optional: allow user to set response_format / tools, etc. via "parameters"
209
+ # We pass through a small allowlist safely.
210
+ for k in ("response_format", "tools", "tool_choice"):
211
+ if k in params and params[k] is not None:
212
+ payload[k] = params[k]
213
+
214
+ try:
215
+ r = requests.post(self.chat_url, json=payload, timeout=float(os.getenv("SGLANG_REQUEST_TIMEOUT", "300")))
216
+ r.raise_for_status()
217
+ out = r.json()
218
+ except requests.RequestException as e:
219
+ raise RuntimeError(f"SGLang request failed: {e}") from e
220
+
221
+ # Normalize return to what HF widgets commonly expect:
222
+ # either a raw string or a dict with generated_text
223
+ try:
224
+ text = out["choices"][0]["message"]["content"]
225
+ except Exception:
226
+ # Fallback: return the full response
227
+ return out
228
+
229
+ # If caller asked for "details", return full payload
230
+ if bool(params.get("return_full_text")) or bool(params.get("details")):
231
+ return {
232
+ "generated_text": text,
233
+ "raw": out,
234
+ }
235
+
236
+ return text
237
+
handler.py.bak DELETED
@@ -1,296 +0,0 @@
1
- # handler.py
2
- # handler.py
3
- # Hugging Face Inference Endpoints "custom handler" for TensorRT-LLM (trtllm-serve),
4
- # including NVFP4-quantized engines.
5
- #
6
- # Expected by HF Inference Toolkit:
7
- # - file name: handler.py (repo root)
8
- # - class: EndpointHandler with __init__(path) and __call__(data)
9
- #
10
- # This handler:
11
- # 1) starts `trtllm-serve <model_dir>` once (lazy init)
12
- # 2) forwards requests to the local OpenAI-compatible HTTP API
13
- #
14
- # Environment variables (optional):
15
- # TRTLLM_HOST default: 127.0.0.1
16
- # TRTLLM_PORT default: 8000
17
- # TRTLLM_START_CMD default: "trtllm-serve"
18
- # TRTLLM_START_ARGS default: "" (extra args appended verbatim)
19
- # TRTLLM_HEALTH_PATH default: "/health"
20
- # TRTLLM_READY_TIMEOUT default: 180 (seconds)
21
- # TRTLLM_VERBOSE default: "0"
22
- #
23
- # Notes:
24
- # - If your container uses a different binary or endpoints, set TRTLLM_START_CMD
25
- # and/or adjust _chat/_completion URLs below.
26
- # - HF will call __call__ with a dict similar to:
27
- # {"inputs": "...", "parameters": {...}}
28
- # or for chat:
29
- # {"messages": [...], "parameters": {...}}
30
-
31
- from __future__ import annotations
32
-
33
- import json
34
- import os
35
- import subprocess
36
- import time
37
- import threading
38
- from typing import Any, Dict, Optional
39
-
40
- try:
41
- import requests
42
- except Exception: # pragma: no cover
43
- requests = None # type: ignore
44
-
45
-
46
- class EndpointHandler:
47
- _lock = threading.Lock()
48
- _server_proc: Optional[subprocess.Popen] = None
49
- _server_started: bool = False
50
-
51
- def __init__(self, path: str):
52
- # HF passes the model directory path (repo checkout) here.
53
- self.model_dir = path
54
-
55
- self.host = os.getenv("TRTLLM_HOST", "127.0.0.1")
56
- self.port = int(os.getenv("TRTLLM_PORT", "8000"))
57
- self.base_url = f"http://{self.host}:{self.port}"
58
-
59
- self.health_path = os.getenv("TRTLLM_HEALTH_PATH", "/health")
60
- self.ready_timeout = int(os.getenv("TRTLLM_READY_TIMEOUT", "180"))
61
-
62
- self.start_cmd = os.getenv("TRTLLM_START_CMD", "trtllm-serve")
63
- self.start_args = os.getenv("TRTLLM_START_ARGS", "").strip()
64
-
65
- self.verbose = os.getenv("TRTLLM_VERBOSE", "0").strip() in ("1", "true", "TRUE", "yes", "YES")
66
-
67
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
68
- self._ensure_server()
69
-
70
- # HF commonly uses:
71
- # - data["inputs"] + data["parameters"]
72
- # For chat-like:
73
- # - data["messages"] + data["parameters"]
74
- parameters = data.get("parameters") or {}
75
- if not isinstance(parameters, dict):
76
- parameters = {}
77
-
78
- # If the caller provides "messages", treat it as chat.
79
- if "messages" in data and isinstance(data["messages"], list):
80
- return self._handle_chat(data["messages"], parameters)
81
-
82
- # Otherwise treat as completion.
83
- inputs = data.get("inputs")
84
- if inputs is None:
85
- # Some clients use "prompt"
86
- inputs = data.get("prompt")
87
-
88
- if isinstance(inputs, list):
89
- # Batch prompts: run sequentially (simple + robust).
90
- outputs = [self._handle_completion(prompt, parameters) for prompt in inputs]
91
- return {"results": outputs}
92
-
93
- if not isinstance(inputs, str):
94
- raise ValueError("Expected 'inputs' (or 'prompt') to be a string or list of strings.")
95
-
96
- return self._handle_completion(inputs, parameters)
97
-
98
- # -------------------------
99
- # TensorRT-LLM server start
100
- # -------------------------
101
- def _ensure_server(self) -> None:
102
- with self._lock:
103
- if self._server_started:
104
- return
105
-
106
- # If server already reachable (e.g., started by container entrypoint), skip spawning.
107
- if self._is_healthy():
108
- self._server_started = True
109
- return
110
-
111
- cmd = [self.start_cmd, self.model_dir]
112
-
113
- if self.start_args:
114
- # Append extra args verbatim, allowing the user to pass things like:
115
- # "--backend pytorch --max_batch_size 4 --port 8000"
116
- cmd.extend(self.start_args.split())
117
-
118
- # Ensure server binds to desired port if user didn't specify it.
119
- # If you already pass "--port" in TRTLLM_START_ARGS, this is redundant but harmless.
120
- if "--port" not in cmd:
121
- cmd.extend(["--port", str(self.port)])
122
-
123
- if self.verbose:
124
- print(f"[handler] Starting TensorRT-LLM server: {' '.join(cmd)}", flush=True)
125
-
126
- # Start server process.
127
- # Important: do not use shell=True.
128
- self._server_proc = subprocess.Popen(
129
- cmd,
130
- stdout=subprocess.PIPE,
131
- stderr=subprocess.STDOUT,
132
- env=os.environ.copy(),
133
- text=True,
134
- bufsize=1,
135
- )
136
-
137
- # Wait until healthy
138
- self._wait_until_ready()
139
-
140
- self._server_started = True
141
-
142
- def _wait_until_ready(self) -> None:
143
- deadline = time.time() + self.ready_timeout
144
- last_line = None
145
-
146
- while time.time() < deadline:
147
- if self._server_proc is not None:
148
- # If process exited early, surface logs.
149
- code = self._server_proc.poll()
150
- if code is not None:
151
- logs = self._drain_logs(max_lines=2000)
152
- raise RuntimeError(
153
- f"TensorRT-LLM server exited with code {code} before becoming ready.\n"
154
- f"Last logs:\n{logs}"
155
- )
156
-
157
- if self._is_healthy():
158
- if self.verbose:
159
- print("[handler] TensorRT-LLM server is healthy.", flush=True)
160
- return
161
-
162
- # Optionally peek at logs to help debugging (non-blocking-ish).
163
- if self.verbose:
164
- line = self._read_one_log_line()
165
- if line:
166
- last_line = line.strip()
167
- print(f"[trtllm] {last_line}", flush=True)
168
-
169
- time.sleep(0.5)
170
-
171
- logs = self._drain_logs(max_lines=500)
172
- raise TimeoutError(
173
- f"TensorRT-LLM server not ready after {self.ready_timeout}s. "
174
- f"Health endpoint: {self.base_url}{self.health_path}\n"
175
- f"Recent logs:\n{logs}"
176
- )
177
-
178
- def _is_healthy(self) -> bool:
179
- try:
180
- if requests is None:
181
- return False
182
- r = requests.get(f"{self.base_url}{self.health_path}", timeout=1.5)
183
- return 200 <= r.status_code < 300
184
- except Exception:
185
- return False
186
-
187
- def _read_one_log_line(self) -> Optional[str]:
188
- try:
189
- if self._server_proc and self._server_proc.stdout:
190
- return self._server_proc.stdout.readline()
191
- except Exception:
192
- return None
193
- return None
194
-
195
- def _drain_logs(self, max_lines: int = 500) -> str:
196
- if not self._server_proc or not self._server_proc.stdout:
197
- return ""
198
- lines = []
199
- try:
200
- for _ in range(max_lines):
201
- line = self._server_proc.stdout.readline()
202
- if not line:
203
- break
204
- lines.append(line.rstrip("\n"))
205
- except Exception:
206
- pass
207
- return "\n".join(lines)
208
-
209
- # -------------------------
210
- # Request forwarding
211
- # -------------------------
212
- def _handle_chat(self, messages: list, parameters: Dict[str, Any]) -> Dict[str, Any]:
213
- payload = {
214
- "model": parameters.pop("model", "trtllm"),
215
- "messages": messages,
216
- }
217
- payload.update(self._map_parameters(parameters))
218
-
219
- # TensorRT-LLM OpenAI-compatible chat endpoint
220
- url = f"{self.base_url}/v1/chat/completions"
221
- resp = self._post_json(url, payload)
222
-
223
- # Normalize output for HF consumers
224
- # Prefer returning OpenAI-like response, but also provide HF-style "generated_text".
225
- generated_text = None
226
- try:
227
- generated_text = resp["choices"][0]["message"]["content"]
228
- except Exception:
229
- pass
230
-
231
- return {
232
- "generated_text": generated_text,
233
- "raw": resp,
234
- }
235
-
236
- def _handle_completion(self, prompt: str, parameters: Dict[str, Any]) -> Dict[str, Any]:
237
- payload = {
238
- "model": parameters.pop("model", "trtllm"),
239
- "prompt": prompt,
240
- }
241
- payload.update(self._map_parameters(parameters))
242
-
243
- # TensorRT-LLM OpenAI-compatible completions endpoint
244
- url = f"{self.base_url}/v1/completions"
245
- resp = self._post_json(url, payload)
246
-
247
- generated_text = None
248
- try:
249
- generated_text = resp["choices"][0]["text"]
250
- except Exception:
251
- pass
252
-
253
- return {
254
- "generated_text": generated_text,
255
- "raw": resp,
256
- }
257
-
258
- def _post_json(self, url: str, payload: Dict[str, Any]) -> Dict[str, Any]:
259
- if requests is None:
260
- raise RuntimeError(
261
- "The 'requests' package is not available in the container. "
262
- "Install it or replace _post_json with urllib."
263
- )
264
- headers = {"Content-Type": "application/json"}
265
- r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=600)
266
- if r.status_code >= 400:
267
- raise RuntimeError(f"Upstream TRTLLM error {r.status_code}: {r.text}")
268
- return r.json()
269
-
270
- def _map_parameters(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
271
- """
272
- Map common HF generation parameters to OpenAI-compatible fields.
273
- TensorRT-LLM may ignore unsupported fields; this mapping is conservative.
274
- """
275
- out: Dict[str, Any] = {}
276
-
277
- # Common parameters
278
- if "max_new_tokens" in parameters and "max_tokens" not in parameters:
279
- out["max_tokens"] = parameters["max_new_tokens"]
280
- if "max_tokens" in parameters:
281
- out["max_tokens"] = parameters["max_tokens"]
282
-
283
- for k in ("temperature", "top_p", "seed", "stop"):
284
- if k in parameters:
285
- out[k] = parameters[k]
286
-
287
- # HF sometimes uses repetition_penalty; OpenAI doesn't have it.
288
- # TensorRT-LLM may support presence/frequency penalties; pass through if provided.
289
- for k in ("presence_penalty", "frequency_penalty"):
290
- if k in parameters:
291
- out[k] = parameters[k]
292
-
293
- # Streaming is not supported by this handler (HF expects a single response).
294
- # Ignore "stream" if present.
295
- return out
296
-