qhillerich commited on
Commit
0f9b150
·
verified ·
1 Parent(s): d63b692

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +466 -432
handler.py CHANGED
@@ -1,513 +1,547 @@
1
  """
2
- handler.py — Hugging Face Inference Endpoints (custom handler)
3
-
4
- Goal:
5
- - Fix: `ffmpeg` is not a registered plugin name
6
- - Do NOT use huggingface_inference_toolkit "plugins" for ffmpeg
7
- - Instead: resolve an ffmpeg executable path and call it directly
8
- - Option B: prefer imageio-ffmpeg (repo-only dependency), fallback to system ffmpeg if available
9
-
10
- What you must also do in repo:
11
- - requirements.txt: add `imageio-ffmpeg>=0.4.9`
12
-
13
- Notes:
14
- - This file is intentionally "full-fat": robust input parsing, clear errors, temp-file hygiene,
15
- optional output formats, and a ready-to-wire inference section.
16
- - You can paste your actual model inference in the TODO section, or use the provided
17
- Transformers pipeline example.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  """
19
 
20
  from __future__ import annotations
21
 
22
  import base64
23
  import io
24
- import json
25
  import os
26
- import re
27
- import stat
28
- import subprocess
29
- import tempfile
30
  import time
 
 
31
  from dataclasses import dataclass
32
- from typing import Any, Dict, Optional, Tuple, Union
33
 
34
- # Optional: imageio-ffmpeg provides an ffmpeg executable path
35
- try:
36
- import imageio_ffmpeg # type: ignore
37
- except Exception:
38
- imageio_ffmpeg = None
39
 
40
- # Optional: Transformers ASR pipeline example
41
- # If you don't want this, remove it and wire your own model.
 
 
42
  try:
43
- from transformers import pipeline # type: ignore
 
 
 
44
  except Exception:
45
- pipeline = None
46
-
47
-
48
- # -----------------------------
49
- # Utilities: errors & logging
50
- # -----------------------------
51
-
52
- class HandlerError(RuntimeError):
53
- """Raised for user-facing errors in request handling."""
54
 
55
 
56
  def _now_ms() -> int:
57
  return int(time.time() * 1000)
58
 
59
 
60
- def _truncate(s: str, n: int = 2000) -> str:
61
- if s is None:
62
- return ""
63
- s = str(s)
64
- return s if len(s) <= n else s[:n] + f"...(truncated {len(s)-n} chars)"
65
 
66
 
67
- # -----------------------------
68
- # Utilities: subprocess runner
69
- # -----------------------------
70
-
71
- def _run(cmd: list[str], *, timeout_s: Optional[int] = None) -> subprocess.CompletedProcess:
72
  """
73
- Run a command and raise a readable error with STDOUT/STDERR if it fails.
 
 
 
 
 
74
  """
75
- try:
76
- return subprocess.run(
77
- cmd,
78
- check=True,
79
- capture_output=True,
80
- text=True,
81
- timeout=timeout_s,
82
- )
83
- except subprocess.TimeoutExpired as e:
84
- raise HandlerError(f"Command timed out after {timeout_s}s: {' '.join(cmd)}") from e
85
- except subprocess.CalledProcessError as e:
86
- stdout = _truncate(e.stdout or "", 2000)
87
- stderr = _truncate(e.stderr or "", 2000)
88
- raise HandlerError(
89
- "Command failed.\n"
90
- f"CMD: {' '.join(cmd)}\n"
91
- f"EXIT: {e.returncode}\n"
92
- f"STDOUT: {stdout}\n"
93
- f"STDERR: {stderr}\n"
94
- ) from e
95
- except FileNotFoundError as e:
96
- raise HandlerError(f"Executable not found for command: {' '.join(cmd)}") from e
97
-
98
-
99
- # -----------------------------
100
- # ffmpeg resolution
101
- # -----------------------------
102
-
103
- def _is_executable(path: str) -> bool:
104
- return os.path.isfile(path) and os.access(path, os.X_OK)
105
-
106
-
107
- def _chmod_exec(path: str) -> None:
108
- try:
109
- st = os.stat(path)
110
- os.chmod(path, st.st_mode | stat.S_IEXEC | stat.S_IXGRP | stat.S_IXOTH)
111
- except Exception:
112
- # best-effort
113
- pass
114
 
 
 
115
 
116
- def _get_ffmpeg_path() -> str:
117
- """
118
- Resolve an ffmpeg executable path without any HF plugin mechanism.
119
- Priority:
120
- 1) imageio-ffmpeg managed exe (if installed)
121
- 2) system ffmpeg on PATH
122
- """
123
- # 1) imageio-ffmpeg
124
- if imageio_ffmpeg is not None:
125
- try:
126
- p = imageio_ffmpeg.get_ffmpeg_exe()
127
- if os.path.isfile(p):
128
- _chmod_exec(p)
129
- _run([p, "-version"], timeout_s=10)
130
- return p
131
- except Exception:
132
- pass
133
 
134
- # 2) system ffmpeg
135
- try:
136
- _run(["ffmpeg", "-version"], timeout_s=10)
137
- return "ffmpeg"
138
- except Exception as e:
139
- raise HandlerError(
140
- "ffmpeg is not available.\n"
141
- "Fix options:\n"
142
- " - Add `imageio-ffmpeg>=0.4.9` to requirements.txt (recommended repo-only fix)\n"
143
- " - Or ensure `ffmpeg` exists in the runtime image (custom container)\n"
144
- f"Last error: {e}"
145
- )
146
 
147
 
148
- def _get_ffprobe_path(ffmpeg_path: str) -> str:
149
  """
150
- Try to infer ffprobe path if available.
151
- If using imageio-ffmpeg, ffprobe may not be included; we treat it as optional.
152
  """
153
- # If ffmpeg is a full path, try sibling ffprobe
154
- if os.path.sep in ffmpeg_path:
155
- cand = os.path.join(os.path.dirname(ffmpeg_path), "ffprobe")
156
- if _is_executable(cand):
157
- return cand
158
- # fallback to system ffprobe if present
159
- try:
160
- _run(["ffprobe", "-version"], timeout_s=10)
161
- return "ffprobe"
162
- except Exception:
163
- return "" # optional
164
-
165
-
166
- # -----------------------------
167
- # Input parsing helpers
168
- # -----------------------------
169
-
170
- @dataclass
171
- class MediaPayload:
172
- raw_bytes: bytes
173
- filename: str
174
- content_type: str
175
-
176
-
177
- _B64_RE = re.compile(r"^[A-Za-z0-9+/=\s]+$")
178
 
179
 
180
- def _maybe_base64_decode(s: str) -> Optional[bytes]:
181
  """
182
- Attempt base64 decode if the string looks like base64.
183
- Returns bytes if successful, else None.
184
  """
185
- ss = s.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- # handle data URL: data:audio/wav;base64,....
188
- if ss.startswith("data:") and "base64," in ss:
189
  try:
190
- b64 = ss.split("base64,", 1)[1]
191
- return base64.b64decode(b64, validate=False)
192
- except Exception:
193
- return None
194
 
195
- # plain base64 (heuristic)
196
- if len(ss) >= 64 and _B64_RE.match(ss):
 
197
  try:
198
- return base64.b64decode(ss, validate=False)
199
  except Exception:
200
- return None
201
-
202
- return None
203
 
204
 
205
- def _coerce_to_media_payload(data: Union[Dict[str, Any], bytes, str]) -> MediaPayload:
206
  """
207
- Accept common HF endpoint payload patterns and normalize to bytes + filename/content-type.
208
- Supported forms:
209
- - bytes / bytearray
210
- - base64 string (optionally data URL)
211
- - dict {"inputs": <bytes|base64|string|dict>}
212
- - dict with keys: "audio"/"data"/"content" containing bytes or base64
213
- - dict may include "filename", "content_type"
214
  """
215
- filename = "input_media"
216
- content_type = "application/octet-stream"
217
-
218
- if isinstance(data, (bytes, bytearray)):
219
- return MediaPayload(raw_bytes=bytes(data), filename=filename, content_type=content_type)
220
-
221
- if isinstance(data, str):
222
- decoded = _maybe_base64_decode(data)
223
- if decoded is None:
224
- raise HandlerError(
225
- "String input must be base64 (or data:...;base64,...). "
226
- "If you're sending JSON, wrap your bytes in base64."
227
- )
228
- return MediaPayload(raw_bytes=decoded, filename=filename, content_type=content_type)
229
-
230
- if not isinstance(data, dict):
231
- raise HandlerError("Unsupported input type. Send bytes, base64 string, or a JSON object.")
232
-
233
- # Pull metadata if present
234
- filename = str(data.get("filename") or data.get("name") or filename)
235
- content_type = str(data.get("content_type") or data.get("mime_type") or content_type)
236
-
237
- # Common HF: {"inputs": ...}
238
- if "inputs" in data:
239
- inner = data["inputs"]
240
- # If inputs is a dict with richer structure
241
- if isinstance(inner, dict):
242
- # allow nested metadata
243
- filename = str(inner.get("filename") or inner.get("name") or filename)
244
- content_type = str(inner.get("content_type") or inner.get("mime_type") or content_type)
245
- for k in ("data", "audio", "content", "bytes"):
246
- if k in inner:
247
- return _coerce_to_media_payload({**inner, "data": inner[k], "filename": filename, "content_type": content_type})
248
- # If inputs is bytes/base64
249
- return _coerce_to_media_payload(inner)
250
-
251
- # Other common keys
252
- for k in ("audio", "data", "content", "bytes"):
253
- if k in data:
254
- v = data[k]
255
- if isinstance(v, (bytes, bytearray)):
256
- return MediaPayload(raw_bytes=bytes(v), filename=filename, content_type=content_type)
257
- if isinstance(v, str):
258
- decoded = _maybe_base64_decode(v)
259
- if decoded is None:
260
- raise HandlerError(f"Field `{k}` is a string but not base64/data-url.")
261
- return MediaPayload(raw_bytes=decoded, filename=filename, content_type=content_type)
262
- if isinstance(v, dict):
263
- # nested object containing base64
264
- return _coerce_to_media_payload(v)
265
-
266
- raise HandlerError(
267
- "Could not find media bytes in request. "
268
- "Provide bytes directly, a base64 string, or a JSON object with `inputs` or `audio`/`data`."
269
- )
270
 
 
 
 
 
 
 
 
 
 
271
 
272
- # -----------------------------
273
- # Media conversion: any -> wav
274
- # -----------------------------
275
 
276
- def _write_temp_file(directory: str, name: str, data: bytes) -> str:
277
- path = os.path.join(directory, name)
278
- with open(path, "wb") as f:
279
- f.write(data)
280
- return path
 
 
 
 
281
 
282
 
283
- def _convert_to_wav(
284
- media_bytes: bytes,
285
- *,
286
- ffmpeg_path: str,
287
- target_sr: int = 16000,
288
- target_channels: int = 1,
289
- output_pcm: str = "s16le",
290
- ) -> str:
291
  """
292
- Convert arbitrary audio/video bytes to a WAV file (PCM).
293
- Returns a temp WAV file path that the caller should delete.
294
  """
295
- # Work inside a temp dir, then copy output to a NamedTemporaryFile outside the context
296
- with tempfile.TemporaryDirectory() as d:
297
- in_path = _write_temp_file(d, "input_media.bin", media_bytes)
298
- out_path = os.path.join(d, "output.wav")
299
-
300
- cmd = [
301
- ffmpeg_path,
302
- "-y",
303
- "-hide_banner",
304
- "-loglevel",
305
- "error",
306
- "-i",
307
- in_path,
308
- "-vn",
309
- "-ac",
310
- str(target_channels),
311
- "-ar",
312
- str(target_sr),
313
- "-acodec",
314
- f"pcm_{output_pcm}",
315
- out_path,
316
- ]
317
- _run(cmd, timeout_s=120)
318
-
319
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
320
- tmp_path = tmp.name
321
- tmp.close()
322
-
323
- with open(out_path, "rb") as src, open(tmp_path, "wb") as dst:
324
- dst.write(src.read())
325
-
326
- return tmp_path
327
-
328
-
329
- def _read_file_bytes(path: str) -> bytes:
330
- with open(path, "rb") as f:
331
- return f.read()
332
-
333
-
334
- # -----------------------------
335
- # Output helpers
336
- # -----------------------------
337
-
338
- def _as_base64(b: bytes) -> str:
339
- return base64.b64encode(b).decode("utf-8")
340
 
 
 
341
 
342
- def _response(
343
- *,
344
- ok: bool,
345
- text: str = "",
346
- extra: Optional[Dict[str, Any]] = None,
347
- diagnostics: Optional[Dict[str, Any]] = None,
348
- ) -> Dict[str, Any]:
349
- out: Dict[str, Any] = {"ok": ok, "text": text}
350
- if extra:
351
- out.update(extra)
352
- if diagnostics:
353
- out["diagnostics"] = diagnostics
354
- return out
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
- # -----------------------------
358
- # EndpointHandler
359
- # -----------------------------
 
 
360
 
361
- class EndpointHandler:
362
- """
363
- Hugging Face Inference Endpoint handler.
 
 
 
364
 
365
- This handler:
366
- - Accepts bytes/base64/json payloads
367
- - Converts input media to WAV using ffmpeg (no plugin system)
368
- - Runs ASR (example: Transformers pipeline)
369
- - Returns text and optional timing/diagnostics
370
 
371
- To use your own model:
372
- - Replace `_infer_transcription()` with your model logic
373
- - Or set `self.asr` to your pipeline/model in __init__
374
- """
375
 
376
- def __init__(self, path: str = "") -> None:
377
- self.model_path = path or ""
378
- self.ffmpeg_path = _get_ffmpeg_path()
379
- self.ffprobe_path = _get_ffprobe_path(self.ffmpeg_path)
380
-
381
- # Settings (can be overridden per-request)
382
- self.default_sr = int(os.getenv("TARGET_SAMPLE_RATE", "16000"))
383
- self.default_channels = int(os.getenv("TARGET_CHANNELS", "1"))
384
-
385
- # Optional: initialize an ASR pipeline if transformers is available.
386
- # If you're not doing ASR, delete this and implement your task.
387
- self.asr = None
388
- if pipeline is not None:
389
- # If your repo contains a model, `path` is typically the model directory.
390
- # If not, you can hardcode a model id here.
391
- model_id_or_path = self.model_path if self.model_path else os.getenv("ASR_MODEL_ID", "").strip()
392
- if model_id_or_path:
393
- # You can choose task="automatic-speech-recognition"
394
- # and pass device_map / torch_dtype in advanced setups.
395
- self.asr = pipeline("automatic-speech-recognition", model=model_id_or_path)
396
-
397
- # Startup self-test (fast fail)
398
- _run([self.ffmpeg_path, "-version"], timeout_s=10)
399
-
400
- def __call__(self, data: Union[Dict[str, Any], bytes, str]) -> Dict[str, Any]:
401
  t0 = _now_ms()
402
 
403
- # Parse request
404
  try:
405
- payload = _coerce_to_media_payload(data)
406
- except Exception as e:
407
- return _response(ok=False, text=str(e), diagnostics={"stage": "parse"})
408
-
409
- # Allow per-request overrides
410
- req: Dict[str, Any] = data if isinstance(data, dict) else {}
411
- target_sr = int(req.get("target_sr") or req.get("sample_rate") or self.default_sr)
412
- target_channels = int(req.get("target_channels") or req.get("channels") or self.default_channels)
413
-
414
- # Convert to wav
415
- wav_path = ""
416
- try:
417
- wav_path = _convert_to_wav(
418
- payload.raw_bytes,
419
- ffmpeg_path=self.ffmpeg_path,
420
- target_sr=target_sr,
421
- target_channels=target_channels,
422
- )
423
  t1 = _now_ms()
424
 
425
- # Inference
426
- text, model_meta = self._infer_transcription(wav_path, req=req)
427
- t2 = _now_ms()
428
-
429
- # Optional: include wav bytes or base64 (off by default)
430
- include_wav_b64 = bool(req.get("include_wav_base64", False))
431
- extra: Dict[str, Any] = {
432
- "filename": payload.filename,
433
- "content_type": payload.content_type,
434
- "model": model_meta,
435
- }
436
- if include_wav_b64:
437
- extra["wav_base64"] = _as_base64(_read_file_bytes(wav_path))
438
-
439
- return _response(
440
- ok=True,
441
- text=text,
442
- extra=extra,
443
- diagnostics={
444
- "ffmpeg": self.ffmpeg_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  "timing_ms": {
446
- "total": t2 - t0,
447
- "convert": t1 - t0,
448
- "inference": t2 - t1,
449
- },
450
- "audio": {
451
- "target_sr": target_sr,
452
- "target_channels": target_channels,
453
  },
 
 
 
454
  },
455
- )
456
 
457
  except Exception as e:
458
- return _response(
459
- ok=False,
460
- text=str(e),
461
- diagnostics={
462
- "stage": "convert_or_infer",
463
- "ffmpeg": self.ffmpeg_path,
464
  },
465
- )
466
- finally:
467
- if wav_path:
468
- try:
469
- os.remove(wav_path)
470
- except Exception:
471
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
- # -----------------------------
474
- # Inference (replace this)
475
- # -----------------------------
476
 
477
- def _infer_transcription(self, wav_path: str, *, req: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
478
  """
479
- Default behavior: if transformers pipeline is configured, run ASR.
480
- Otherwise, return a diagnostic placeholder.
481
 
482
- Replace this method with your real inference logic if needed.
 
 
 
 
 
 
 
483
  """
484
- if self.asr is None:
485
- # Placeholder: you should replace with your actual model logic.
486
- # This is still useful to verify that ffmpeg conversion works and the endpoint runs end-to-end.
487
- return (
488
- f"OK: converted to wav successfully (path={os.path.basename(wav_path)}). "
489
- "ASR pipeline not configured. Set ASR_MODEL_ID env var or pass a model path.",
490
- {"type": "none", "note": "no ASR pipeline"},
491
  )
492
 
493
- # Transformers pipeline accepts a file path for ASR
494
- # You can pass options like chunk_length_s, stride_length_s for long audio
495
- asr_kwargs: Dict[str, Any] = {}
496
- if "chunk_length_s" in req:
497
- asr_kwargs["chunk_length_s"] = float(req["chunk_length_s"])
498
- if "stride_length_s" in req:
499
- asr_kwargs["stride_length_s"] = float(req["stride_length_s"])
500
 
501
- result = self.asr(wav_path, **asr_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
- # Normalize output
504
- # Typical result: {"text": "...", ...}
505
- if isinstance(result, dict) and "text" in result:
506
- text = str(result["text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  else:
508
- text = json.dumps(result)
509
-
510
- return text, {
511
- "type": "transformers-pipeline",
512
- "task": "automatic-speech-recognition",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  }
 
 
1
  """
2
+ handler.py — Hugging Face Inference Endpoint custom handler
3
+ Outputs: GIF, WebM, ZIP(frames)
4
+
5
+ Key points:
6
+ - No "huggingface_inference_toolkit plugin" usage at all.
7
+ - WebM encoding uses imageio + imageio-ffmpeg (ffmpeg binary resolved internally).
8
+ - GIF encoding uses Pillow (no ffmpeg needed).
9
+ - ZIP output is a zip of PNG frames.
10
+
11
+ Request JSON (examples):
12
+ {
13
+ "prompt": "a cinematic shot of a hawk flying over snowy mountains",
14
+ "negative_prompt": "low quality, blurry",
15
+ "num_frames": 48,
16
+ "fps": 16,
17
+ "height": 512,
18
+ "width": 512,
19
+ "seed": 123,
20
+ "outputs": ["gif", "webm", "zip"], // any subset
21
+ "return_base64": true, // default true
22
+ "gif": {"fps": 12}, // optional overrides
23
+ "webm": {"fps": 24, "quality": "good"}, // quality: "fast"|"good"|"best"
24
+ "zip": {"format": "png"} // currently png only
25
+ }
26
+
27
+ Response JSON:
28
+ {
29
+ "ok": true,
30
+ "diagnostics": {...},
31
+ "outputs": {
32
+ "gif_base64": "...",
33
+ "webm_base64": "...",
34
+ "zip_base64": "..."
35
+ }
36
+ }
37
+
38
+ Notes on payload sizes:
39
+ - base64 video payloads can be large. For production, consider uploading to R2/S3
40
+ and returning a URL instead. This handler keeps it in-response for simplicity.
41
  """
42
 
43
  from __future__ import annotations
44
 
45
  import base64
46
  import io
 
47
  import os
 
 
 
 
48
  import time
49
+ import tempfile
50
+ import zipfile
51
  from dataclasses import dataclass
52
+ from typing import Any, Dict, List, Optional, Tuple, Union
53
 
54
+ import numpy as np
55
+ from PIL import Image
 
 
 
56
 
57
+ # WebM encoding (uses ffmpeg resolved by imageio-ffmpeg)
58
+ import imageio
59
+
60
+ # Ensure imageio uses the packaged ffmpeg binary (not HF toolkit plugins)
61
  try:
62
+ import imageio_ffmpeg # type: ignore
63
+ _FFMPEG_EXE = imageio_ffmpeg.get_ffmpeg_exe()
64
+ # imageio reads this env var to locate ffmpeg
65
+ os.environ["IMAGEIO_FFMPEG_EXE"] = _FFMPEG_EXE
66
  except Exception:
67
+ _FFMPEG_EXE = ""
 
 
 
 
 
 
 
 
68
 
69
 
70
  def _now_ms() -> int:
71
  return int(time.time() * 1000)
72
 
73
 
74
+ def _b64(data: bytes) -> str:
75
+ return base64.b64encode(data).decode("utf-8")
 
 
 
76
 
77
 
78
+ def _clamp_uint8_frame(frame: np.ndarray) -> np.ndarray:
 
 
 
 
79
  """
80
+ Ensure frame is uint8 HxWx3 (RGB).
81
+ Accepts:
82
+ - float in [0,1] or [-1,1]
83
+ - uint8 already
84
+ - grayscale (HxW) -> RGB
85
+ - RGBA -> RGB
86
  """
87
+ if not isinstance(frame, np.ndarray):
88
+ frame = np.array(frame)
89
+
90
+ # squeeze batch-like dims if present (best-effort)
91
+ if frame.ndim == 4 and frame.shape[0] == 1:
92
+ frame = frame[0]
93
+
94
+ if frame.ndim == 2:
95
+ frame = np.stack([frame, frame, frame], axis=-1)
96
+
97
+ if frame.ndim != 3:
98
+ raise ValueError(f"Frame must be HxW, HxWxC, or 1xHxWxC; got shape {frame.shape}")
99
+
100
+ # Channels fixups
101
+ if frame.shape[-1] == 4:
102
+ frame = frame[..., :3]
103
+ elif frame.shape[-1] == 1:
104
+ frame = np.repeat(frame, 3, axis=-1)
105
+ elif frame.shape[-1] != 3:
106
+ # sometimes CxHxW
107
+ if frame.shape[0] == 3 and frame.ndim == 3:
108
+ frame = np.transpose(frame, (1, 2, 0))
109
+ else:
110
+ raise ValueError(f"Unsupported channel dimension: {frame.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ if frame.dtype == np.uint8:
113
+ return frame
114
 
115
+ # Convert float -> uint8
116
+ f = frame.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # If looks like [-1,1], map to [0,1]
119
+ if f.min() < 0.0:
120
+ f = (f + 1.0) / 2.0
121
+
122
+ f = np.clip(f, 0.0, 1.0)
123
+ return (f * 255.0).round().astype(np.uint8)
 
 
 
 
 
 
124
 
125
 
126
+ def _encode_gif(frames: List[np.ndarray], fps: int) -> bytes:
127
  """
128
+ Encode GIF using Pillow (no ffmpeg dependency).
 
129
  """
130
+ if not frames:
131
+ raise ValueError("No frames to encode.")
132
+ pil_frames = [Image.fromarray(_clamp_uint8_frame(f)) for f in frames]
133
+ duration_ms = int(1000 / max(1, fps))
134
+
135
+ buf = io.BytesIO()
136
+ pil_frames[0].save(
137
+ buf,
138
+ format="GIF",
139
+ save_all=True,
140
+ append_images=pil_frames[1:],
141
+ duration=duration_ms,
142
+ loop=0,
143
+ optimize=False,
144
+ disposal=2,
145
+ )
146
+ return buf.getvalue()
 
 
 
 
 
 
 
 
147
 
148
 
149
+ def _encode_webm(frames: List[np.ndarray], fps: int, quality: str = "good") -> bytes:
150
  """
151
+ Encode WebM using imageio (ffmpeg under the hood via imageio-ffmpeg).
152
+ quality: "fast" | "good" | "best"
153
  """
154
+ if not frames:
155
+ raise ValueError("No frames to encode.")
156
+
157
+ # Choose VP9 settings. These are pragmatic defaults.
158
+ # For smaller file sizes: lower bitrate or higher crf.
159
+ # For quality: lower crf (but larger files).
160
+ quality = (quality or "good").lower()
161
+ if quality == "fast":
162
+ crf = 42
163
+ preset = "veryfast"
164
+ elif quality == "best":
165
+ crf = 28
166
+ preset = "slow"
167
+ else:
168
+ crf = 34
169
+ preset = "medium"
170
+
171
+ # Write to a temp file then return bytes
172
+ with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp:
173
+ out_path = tmp.name
174
+
175
+ try:
176
+ writer = imageio.get_writer(
177
+ out_path,
178
+ fps=max(1, fps),
179
+ format="FFMPEG",
180
+ codec="libvpx-vp9",
181
+ # ffmpeg_params are passed through to ffmpeg invocation
182
+ ffmpeg_params=[
183
+ "-pix_fmt", "yuv420p",
184
+ "-crf", str(crf),
185
+ "-b:v", "0",
186
+ "-preset", preset,
187
+ ],
188
+ )
189
 
 
 
190
  try:
191
+ for f in frames:
192
+ writer.append_data(_clamp_uint8_frame(f))
193
+ finally:
194
+ writer.close()
195
 
196
+ with open(out_path, "rb") as f:
197
+ return f.read()
198
+ finally:
199
  try:
200
+ os.remove(out_path)
201
  except Exception:
202
+ pass
 
 
203
 
204
 
205
+ def _encode_zip_frames(frames: List[np.ndarray]) -> bytes:
206
  """
207
+ Zip frames as PNG images: frame_000000.png, ...
 
 
 
 
 
 
208
  """
209
+ if not frames:
210
+ raise ValueError("No frames to zip.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ buf = io.BytesIO()
213
+ with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=6) as zf:
214
+ for i, f in enumerate(frames):
215
+ arr = _clamp_uint8_frame(f)
216
+ im = Image.fromarray(arr)
217
+ frame_buf = io.BytesIO()
218
+ im.save(frame_buf, format="PNG", optimize=True)
219
+ zf.writestr(f"frame_{i:06d}.png", frame_buf.getvalue())
220
+ return buf.getvalue()
221
 
 
 
 
222
 
223
+ @dataclass
224
+ class GenParams:
225
+ prompt: str
226
+ negative_prompt: str
227
+ num_frames: int
228
+ fps: int
229
+ height: int
230
+ width: int
231
+ seed: Optional[int]
232
 
233
 
234
+ class EndpointHandler:
 
 
 
 
 
 
 
235
  """
236
+ Custom handler entrypoint for Hugging Face Inference Endpoints.
 
237
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
+ def __init__(self, path: str = "") -> None:
240
+ self.repo_path = path or ""
241
 
242
+ # Attempt to initialize a diffusers pipeline if available.
243
+ # If your repo uses a different entrypoint, edit `_generate_frames()`.
244
+ self.pipe = None
245
+ self._init_error = None
 
 
 
 
 
 
 
 
 
246
 
247
+ try:
248
+ import torch # type: ignore
249
+ from diffusers import DiffusionPipeline # type: ignore
250
+
251
+ # Prefer fp16 if CUDA is available; otherwise float32.
252
+ device = "cuda" if torch.cuda.is_available() else "cpu"
253
+ dtype = torch.float16 if device == "cuda" else torch.float32
254
+
255
+ # Load from repository path (the model code/checkpoints are in the repo)
256
+ # This is the most generic path for "diffusers-like" repos.
257
+ self.pipe = DiffusionPipeline.from_pretrained(
258
+ self.repo_path if self.repo_path else None,
259
+ torch_dtype=dtype,
260
+ )
261
 
262
+ # Move to device if possible
263
+ try:
264
+ self.pipe.to(device)
265
+ except Exception:
266
+ pass
267
 
268
+ # Some pipelines benefit from enabling memory optimizations
269
+ try:
270
+ if hasattr(self.pipe, "enable_vae_slicing"):
271
+ self.pipe.enable_vae_slicing()
272
+ except Exception:
273
+ pass
274
 
275
+ except Exception as e:
276
+ self._init_error = str(e)
277
+ self.pipe = None
 
 
278
 
279
+ # Quick diagnostic: ensure imageio-ffmpeg resolved (for WebM)
280
+ self.ffmpeg_exe = _FFMPEG_EXE
 
 
281
 
282
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  t0 = _now_ms()
284
 
 
285
  try:
286
+ params, outputs, return_b64, out_cfg = self._parse_request(data)
287
+ frames, gen_diag = self._generate_frames(params, out_cfg=out_cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  t1 = _now_ms()
289
 
290
+ result_outputs: Dict[str, Any] = {}
291
+
292
+ # GIF
293
+ if "gif" in outputs:
294
+ gif_fps = int((out_cfg.get("gif") or {}).get("fps") or params.fps)
295
+ gif_bytes = _encode_gif(frames, fps=gif_fps)
296
+ result_outputs["gif_base64" if return_b64 else "gif_bytes"] = _b64(gif_bytes) if return_b64 else gif_bytes
297
+ t_gif = _now_ms()
298
+ else:
299
+ t_gif = t1
300
+
301
+ # WebM
302
+ if "webm" in outputs:
303
+ webm_cfg = out_cfg.get("webm") or {}
304
+ webm_fps = int(webm_cfg.get("fps") or params.fps)
305
+ webm_quality = str(webm_cfg.get("quality") or "good")
306
+ webm_bytes = _encode_webm(frames, fps=webm_fps, quality=webm_quality)
307
+ result_outputs["webm_base64" if return_b64 else "webm_bytes"] = _b64(webm_bytes) if return_b64 else webm_bytes
308
+ t_webm = _now_ms()
309
+ else:
310
+ t_webm = t_gif
311
+
312
+ # ZIP frames
313
+ if "zip" in outputs:
314
+ zip_bytes = _encode_zip_frames(frames)
315
+ result_outputs["zip_base64" if return_b64 else "zip_bytes"] = _b64(zip_bytes) if return_b64 else zip_bytes
316
+ t_zip = _now_ms()
317
+ else:
318
+ t_zip = t_webm
319
+
320
+ return {
321
+ "ok": True,
322
+ "outputs": result_outputs,
323
+ "diagnostics": {
324
  "timing_ms": {
325
+ "total": t_zip - t0,
326
+ "generate": t1 - t0,
327
+ "gif": (t_gif - t1) if "gif" in outputs else 0,
328
+ "webm": (t_webm - t_gif) if "webm" in outputs else 0,
329
+ "zip": (t_zip - t_webm) if "zip" in outputs else 0,
 
 
330
  },
331
+ "generator": gen_diag,
332
+ "ffmpeg_exe": self.ffmpeg_exe,
333
+ "init_error": self._init_error,
334
  },
335
+ }
336
 
337
  except Exception as e:
338
+ return {
339
+ "ok": False,
340
+ "error": str(e),
341
+ "diagnostics": {
342
+ "ffmpeg_exe": self.ffmpeg_exe,
343
+ "init_error": self._init_error,
344
  },
345
+ }
346
+
347
+ # ----------------------------
348
+ # Request parsing
349
+ # ----------------------------
350
+
351
+ def _parse_request(self, data: Dict[str, Any]) -> Tuple[GenParams, List[str], bool, Dict[str, Any]]:
352
+ if not isinstance(data, dict):
353
+ raise ValueError("Request must be a JSON object.")
354
+
355
+ prompt = str(data.get("prompt") or data.get("inputs") or "").strip()
356
+ if not prompt:
357
+ raise ValueError("Missing `prompt` (or `inputs`).")
358
+
359
+ negative_prompt = str(data.get("negative_prompt") or "").strip()
360
+
361
+ num_frames = int(data.get("num_frames") or data.get("frames") or 32)
362
+ fps = int(data.get("fps") or 12)
363
+ height = int(data.get("height") or 512)
364
+ width = int(data.get("width") or 512)
365
+ seed = data.get("seed")
366
+ seed = int(seed) if seed is not None and str(seed).strip() != "" else None
367
+
368
+ outputs = data.get("outputs") or ["gif"]
369
+ if isinstance(outputs, str):
370
+ outputs = [outputs]
371
+ outputs = [str(x).lower() for x in outputs]
372
+
373
+ allowed = {"gif", "webm", "zip"}
374
+ outputs = [o for o in outputs if o in allowed]
375
+ if not outputs:
376
+ outputs = ["gif"]
377
+
378
+ return_b64 = bool(data.get("return_base64", True))
379
+ out_cfg = data.get("output_config") or {}
380
+ # also allow top-level gif/webm/zip config objects
381
+ for k in ("gif", "webm", "zip"):
382
+ if k in data and isinstance(data[k], dict):
383
+ out_cfg[k] = data[k]
384
+
385
+ params = GenParams(
386
+ prompt=prompt,
387
+ negative_prompt=negative_prompt,
388
+ num_frames=max(1, num_frames),
389
+ fps=max(1, fps),
390
+ height=max(64, height),
391
+ width=max(64, width),
392
+ seed=seed,
393
+ )
394
+ return params, outputs, return_b64, out_cfg
395
 
396
+ # ----------------------------
397
+ # Frame generation
398
+ # ----------------------------
399
 
400
+ def _generate_frames(self, params: GenParams, out_cfg: Dict[str, Any]) -> Tuple[List[np.ndarray], Dict[str, Any]]:
401
  """
402
+ Generates frames as a list of uint8 RGB numpy arrays.
403
+ This is the only place you should need to customize for your specific repo/model.
404
 
405
+ Current implementation:
406
+ - If a diffusers pipeline is available:
407
+ pipe(prompt=..., negative_prompt=..., height=..., width=..., num_frames=...)
408
+ Then tries common output fields: frames / videos / images.
409
+ - Otherwise: raises with init details.
410
+
411
+ If your model call is different (e.g., special args like guidance_scale, num_inference_steps),
412
+ add them here.
413
  """
414
+ if self.pipe is None:
415
+ raise RuntimeError(
416
+ "Model pipeline is not initialized. "
417
+ "If your repo doesn't use diffusers DiffusionPipeline, edit _generate_frames(). "
418
+ f"Init error: {self._init_error}"
 
 
419
  )
420
 
421
+ # Optional knobs (safe defaults)
422
+ num_inference_steps = int(out_cfg.get("num_inference_steps") or 30)
423
+ guidance_scale = float(out_cfg.get("guidance_scale") or 7.5)
 
 
 
 
424
 
425
+ # Seed (best effort)
426
+ generator = None
427
+ try:
428
+ import torch # type: ignore
429
+ if params.seed is not None:
430
+ device = "cuda" if torch.cuda.is_available() else "cpu"
431
+ generator = torch.Generator(device=device).manual_seed(params.seed)
432
+ except Exception:
433
+ generator = None
434
+
435
+ # Call the pipeline (generic diffusers-style)
436
+ kwargs: Dict[str, Any] = {
437
+ "prompt": params.prompt,
438
+ "negative_prompt": params.negative_prompt if params.negative_prompt else None,
439
+ "height": params.height,
440
+ "width": params.width,
441
+ "num_inference_steps": num_inference_steps,
442
+ "guidance_scale": guidance_scale,
443
+ }
444
 
445
+ # Common video arg names across repos
446
+ # Some pipelines use num_frames, some use video_length, some use num_frames.
447
+ # We'll try a few.
448
+ # (If your repo is strict, adjust this section.)
449
+ called = False
450
+ last_err: Optional[Exception] = None
451
+ output = None
452
+ for frame_arg in ("num_frames", "video_length", "num_video_frames"):
453
+ try:
454
+ call_kwargs = dict(kwargs)
455
+ call_kwargs[frame_arg] = params.num_frames
456
+ if generator is not None:
457
+ call_kwargs["generator"] = generator
458
+ output = self.pipe(**{k: v for k, v in call_kwargs.items() if v is not None})
459
+ called = True
460
+ break
461
+ except Exception as e:
462
+ last_err = e
463
+ continue
464
+
465
+ if not called:
466
+ raise RuntimeError(f"Pipeline call failed for frame args. Last error: {last_err}")
467
+
468
+ # Extract frames from common output structures
469
+ frames: List[np.ndarray] = []
470
+
471
+ # diffusers outputs vary:
472
+ # - output.frames (list of PIL/np)
473
+ # - output.videos (tensor/np)
474
+ # - output.images (list of PIL for single-frame)
475
+ if hasattr(output, "frames") and output.frames is not None:
476
+ frames_raw = output.frames
477
+ frames = [np.array(f) for f in frames_raw]
478
+ elif hasattr(output, "videos") and output.videos is not None:
479
+ vids = output.videos
480
+ arr = None
481
+
482
+ # torch tensor or numpy
483
+ try:
484
+ import torch # type: ignore
485
+ if isinstance(vids, torch.Tensor):
486
+ arr = vids.detach().cpu().numpy()
487
+ else:
488
+ arr = np.array(vids)
489
+ except Exception:
490
+ arr = np.array(vids)
491
+
492
+ # Common shapes:
493
+ # (B, T, C, H, W) or (B, T, H, W, C) or (T, H, W, C)
494
+ if arr.ndim == 5:
495
+ # pick first batch
496
+ arr = arr[0]
497
+ if arr.ndim == 4 and arr.shape[1] in (1, 3, 4):
498
+ # likely (T, C, H, W) -> (T, H, W, C)
499
+ arr = np.transpose(arr, (0, 2, 3, 1))
500
+ if arr.ndim != 4:
501
+ raise ValueError(f"Unexpected video tensor shape: {arr.shape}")
502
+
503
+ frames = [arr[t] for t in range(arr.shape[0])]
504
+ elif hasattr(output, "images") and output.images is not None:
505
+ imgs = output.images
506
+ # if it's just one image, treat as 1-frame "video"
507
+ if isinstance(imgs, list):
508
+ frames = [np.array(im) for im in imgs]
509
+ else:
510
+ frames = [np.array(imgs)]
511
  else:
512
+ # final fallback: try dict-like
513
+ if isinstance(output, dict):
514
+ for key in ("frames", "videos", "images"):
515
+ if key in output and output[key] is not None:
516
+ v = output[key]
517
+ if key == "videos":
518
+ arr = np.array(v)
519
+ if arr.ndim == 5:
520
+ arr = arr[0]
521
+ if arr.ndim == 4 and arr.shape[1] in (1, 3, 4):
522
+ arr = np.transpose(arr, (0, 2, 3, 1))
523
+ frames = [arr[t] for t in range(arr.shape[0])]
524
+ else:
525
+ if isinstance(v, list):
526
+ frames = [np.array(x) for x in v]
527
+ else:
528
+ frames = [np.array(v)]
529
+ break
530
+
531
+ if not frames:
532
+ raise RuntimeError("Could not extract frames from pipeline output (no frames/videos/images found).")
533
+
534
+ # Normalize to uint8 RGB
535
+ frames_u8 = [_clamp_uint8_frame(f) for f in frames]
536
+
537
+ diag = {
538
+ "prompt_len": len(params.prompt),
539
+ "negative_prompt_len": len(params.negative_prompt),
540
+ "num_frames": len(frames_u8),
541
+ "height": int(frames_u8[0].shape[0]),
542
+ "width": int(frames_u8[0].shape[1]),
543
+ "num_inference_steps": num_inference_steps,
544
+ "guidance_scale": guidance_scale,
545
+ "seed": params.seed,
546
  }
547
+ return frames_u8, diag