Update handler.py
Browse files- handler.py +466 -432
handler.py
CHANGED
|
@@ -1,513 +1,547 @@
|
|
| 1 |
"""
|
| 2 |
-
handler.py — Hugging Face Inference
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
-
|
| 7 |
-
-
|
| 8 |
-
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
from __future__ import annotations
|
| 21 |
|
| 22 |
import base64
|
| 23 |
import io
|
| 24 |
-
import json
|
| 25 |
import os
|
| 26 |
-
import re
|
| 27 |
-
import stat
|
| 28 |
-
import subprocess
|
| 29 |
-
import tempfile
|
| 30 |
import time
|
|
|
|
|
|
|
| 31 |
from dataclasses import dataclass
|
| 32 |
-
from typing import Any, Dict, Optional, Tuple, Union
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
import imageio_ffmpeg # type: ignore
|
| 37 |
-
except Exception:
|
| 38 |
-
imageio_ffmpeg = None
|
| 39 |
|
| 40 |
-
#
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
try:
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
| 44 |
except Exception:
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
# -----------------------------
|
| 49 |
-
# Utilities: errors & logging
|
| 50 |
-
# -----------------------------
|
| 51 |
-
|
| 52 |
-
class HandlerError(RuntimeError):
|
| 53 |
-
"""Raised for user-facing errors in request handling."""
|
| 54 |
|
| 55 |
|
| 56 |
def _now_ms() -> int:
|
| 57 |
return int(time.time() * 1000)
|
| 58 |
|
| 59 |
|
| 60 |
-
def
|
| 61 |
-
|
| 62 |
-
return ""
|
| 63 |
-
s = str(s)
|
| 64 |
-
return s if len(s) <= n else s[:n] + f"...(truncated {len(s)-n} chars)"
|
| 65 |
|
| 66 |
|
| 67 |
-
|
| 68 |
-
# Utilities: subprocess runner
|
| 69 |
-
# -----------------------------
|
| 70 |
-
|
| 71 |
-
def _run(cmd: list[str], *, timeout_s: Optional[int] = None) -> subprocess.CompletedProcess:
|
| 72 |
"""
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
"""
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
# -----------------------------
|
| 100 |
-
# ffmpeg resolution
|
| 101 |
-
# -----------------------------
|
| 102 |
-
|
| 103 |
-
def _is_executable(path: str) -> bool:
|
| 104 |
-
return os.path.isfile(path) and os.access(path, os.X_OK)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def _chmod_exec(path: str) -> None:
|
| 108 |
-
try:
|
| 109 |
-
st = os.stat(path)
|
| 110 |
-
os.chmod(path, st.st_mode | stat.S_IEXEC | stat.S_IXGRP | stat.S_IXOTH)
|
| 111 |
-
except Exception:
|
| 112 |
-
# best-effort
|
| 113 |
-
pass
|
| 114 |
|
|
|
|
|
|
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
Resolve an ffmpeg executable path without any HF plugin mechanism.
|
| 119 |
-
Priority:
|
| 120 |
-
1) imageio-ffmpeg managed exe (if installed)
|
| 121 |
-
2) system ffmpeg on PATH
|
| 122 |
-
"""
|
| 123 |
-
# 1) imageio-ffmpeg
|
| 124 |
-
if imageio_ffmpeg is not None:
|
| 125 |
-
try:
|
| 126 |
-
p = imageio_ffmpeg.get_ffmpeg_exe()
|
| 127 |
-
if os.path.isfile(p):
|
| 128 |
-
_chmod_exec(p)
|
| 129 |
-
_run([p, "-version"], timeout_s=10)
|
| 130 |
-
return p
|
| 131 |
-
except Exception:
|
| 132 |
-
pass
|
| 133 |
|
| 134 |
-
#
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
"ffmpeg is not available.\n"
|
| 141 |
-
"Fix options:\n"
|
| 142 |
-
" - Add `imageio-ffmpeg>=0.4.9` to requirements.txt (recommended repo-only fix)\n"
|
| 143 |
-
" - Or ensure `ffmpeg` exists in the runtime image (custom container)\n"
|
| 144 |
-
f"Last error: {e}"
|
| 145 |
-
)
|
| 146 |
|
| 147 |
|
| 148 |
-
def
|
| 149 |
"""
|
| 150 |
-
|
| 151 |
-
If using imageio-ffmpeg, ffprobe may not be included; we treat it as optional.
|
| 152 |
"""
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
@dataclass
|
| 171 |
-
class MediaPayload:
|
| 172 |
-
raw_bytes: bytes
|
| 173 |
-
filename: str
|
| 174 |
-
content_type: str
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
_B64_RE = re.compile(r"^[A-Za-z0-9+/=\s]+$")
|
| 178 |
|
| 179 |
|
| 180 |
-
def
|
| 181 |
"""
|
| 182 |
-
|
| 183 |
-
|
| 184 |
"""
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
# handle data URL: data:audio/wav;base64,....
|
| 188 |
-
if ss.startswith("data:") and "base64," in ss:
|
| 189 |
try:
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
|
|
|
| 197 |
try:
|
| 198 |
-
|
| 199 |
except Exception:
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
return None
|
| 203 |
|
| 204 |
|
| 205 |
-
def
|
| 206 |
"""
|
| 207 |
-
|
| 208 |
-
Supported forms:
|
| 209 |
-
- bytes / bytearray
|
| 210 |
-
- base64 string (optionally data URL)
|
| 211 |
-
- dict {"inputs": <bytes|base64|string|dict>}
|
| 212 |
-
- dict with keys: "audio"/"data"/"content" containing bytes or base64
|
| 213 |
-
- dict may include "filename", "content_type"
|
| 214 |
"""
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
if isinstance(data, (bytes, bytearray)):
|
| 219 |
-
return MediaPayload(raw_bytes=bytes(data), filename=filename, content_type=content_type)
|
| 220 |
-
|
| 221 |
-
if isinstance(data, str):
|
| 222 |
-
decoded = _maybe_base64_decode(data)
|
| 223 |
-
if decoded is None:
|
| 224 |
-
raise HandlerError(
|
| 225 |
-
"String input must be base64 (or data:...;base64,...). "
|
| 226 |
-
"If you're sending JSON, wrap your bytes in base64."
|
| 227 |
-
)
|
| 228 |
-
return MediaPayload(raw_bytes=decoded, filename=filename, content_type=content_type)
|
| 229 |
-
|
| 230 |
-
if not isinstance(data, dict):
|
| 231 |
-
raise HandlerError("Unsupported input type. Send bytes, base64 string, or a JSON object.")
|
| 232 |
-
|
| 233 |
-
# Pull metadata if present
|
| 234 |
-
filename = str(data.get("filename") or data.get("name") or filename)
|
| 235 |
-
content_type = str(data.get("content_type") or data.get("mime_type") or content_type)
|
| 236 |
-
|
| 237 |
-
# Common HF: {"inputs": ...}
|
| 238 |
-
if "inputs" in data:
|
| 239 |
-
inner = data["inputs"]
|
| 240 |
-
# If inputs is a dict with richer structure
|
| 241 |
-
if isinstance(inner, dict):
|
| 242 |
-
# allow nested metadata
|
| 243 |
-
filename = str(inner.get("filename") or inner.get("name") or filename)
|
| 244 |
-
content_type = str(inner.get("content_type") or inner.get("mime_type") or content_type)
|
| 245 |
-
for k in ("data", "audio", "content", "bytes"):
|
| 246 |
-
if k in inner:
|
| 247 |
-
return _coerce_to_media_payload({**inner, "data": inner[k], "filename": filename, "content_type": content_type})
|
| 248 |
-
# If inputs is bytes/base64
|
| 249 |
-
return _coerce_to_media_payload(inner)
|
| 250 |
-
|
| 251 |
-
# Other common keys
|
| 252 |
-
for k in ("audio", "data", "content", "bytes"):
|
| 253 |
-
if k in data:
|
| 254 |
-
v = data[k]
|
| 255 |
-
if isinstance(v, (bytes, bytearray)):
|
| 256 |
-
return MediaPayload(raw_bytes=bytes(v), filename=filename, content_type=content_type)
|
| 257 |
-
if isinstance(v, str):
|
| 258 |
-
decoded = _maybe_base64_decode(v)
|
| 259 |
-
if decoded is None:
|
| 260 |
-
raise HandlerError(f"Field `{k}` is a string but not base64/data-url.")
|
| 261 |
-
return MediaPayload(raw_bytes=decoded, filename=filename, content_type=content_type)
|
| 262 |
-
if isinstance(v, dict):
|
| 263 |
-
# nested object containing base64
|
| 264 |
-
return _coerce_to_media_payload(v)
|
| 265 |
-
|
| 266 |
-
raise HandlerError(
|
| 267 |
-
"Could not find media bytes in request. "
|
| 268 |
-
"Provide bytes directly, a base64 string, or a JSON object with `inputs` or `audio`/`data`."
|
| 269 |
-
)
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
-
# -----------------------------
|
| 273 |
-
# Media conversion: any -> wav
|
| 274 |
-
# -----------------------------
|
| 275 |
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
|
| 283 |
-
|
| 284 |
-
media_bytes: bytes,
|
| 285 |
-
*,
|
| 286 |
-
ffmpeg_path: str,
|
| 287 |
-
target_sr: int = 16000,
|
| 288 |
-
target_channels: int = 1,
|
| 289 |
-
output_pcm: str = "s16le",
|
| 290 |
-
) -> str:
|
| 291 |
"""
|
| 292 |
-
|
| 293 |
-
Returns a temp WAV file path that the caller should delete.
|
| 294 |
"""
|
| 295 |
-
# Work inside a temp dir, then copy output to a NamedTemporaryFile outside the context
|
| 296 |
-
with tempfile.TemporaryDirectory() as d:
|
| 297 |
-
in_path = _write_temp_file(d, "input_media.bin", media_bytes)
|
| 298 |
-
out_path = os.path.join(d, "output.wav")
|
| 299 |
-
|
| 300 |
-
cmd = [
|
| 301 |
-
ffmpeg_path,
|
| 302 |
-
"-y",
|
| 303 |
-
"-hide_banner",
|
| 304 |
-
"-loglevel",
|
| 305 |
-
"error",
|
| 306 |
-
"-i",
|
| 307 |
-
in_path,
|
| 308 |
-
"-vn",
|
| 309 |
-
"-ac",
|
| 310 |
-
str(target_channels),
|
| 311 |
-
"-ar",
|
| 312 |
-
str(target_sr),
|
| 313 |
-
"-acodec",
|
| 314 |
-
f"pcm_{output_pcm}",
|
| 315 |
-
out_path,
|
| 316 |
-
]
|
| 317 |
-
_run(cmd, timeout_s=120)
|
| 318 |
-
|
| 319 |
-
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
| 320 |
-
tmp_path = tmp.name
|
| 321 |
-
tmp.close()
|
| 322 |
-
|
| 323 |
-
with open(out_path, "rb") as src, open(tmp_path, "wb") as dst:
|
| 324 |
-
dst.write(src.read())
|
| 325 |
-
|
| 326 |
-
return tmp_path
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
def _read_file_bytes(path: str) -> bytes:
|
| 330 |
-
with open(path, "rb") as f:
|
| 331 |
-
return f.read()
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
# -----------------------------
|
| 335 |
-
# Output helpers
|
| 336 |
-
# -----------------------------
|
| 337 |
-
|
| 338 |
-
def _as_base64(b: bytes) -> str:
|
| 339 |
-
return base64.b64encode(b).decode("utf-8")
|
| 340 |
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
extra: Optional[Dict[str, Any]] = None,
|
| 347 |
-
diagnostics: Optional[Dict[str, Any]] = None,
|
| 348 |
-
) -> Dict[str, Any]:
|
| 349 |
-
out: Dict[str, Any] = {"ok": ok, "text": text}
|
| 350 |
-
if extra:
|
| 351 |
-
out.update(extra)
|
| 352 |
-
if diagnostics:
|
| 353 |
-
out["diagnostics"] = diagnostics
|
| 354 |
-
return out
|
| 355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
-
#
|
| 358 |
-
|
| 359 |
-
|
|
|
|
|
|
|
| 360 |
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
- Runs ASR (example: Transformers pipeline)
|
| 369 |
-
- Returns text and optional timing/diagnostics
|
| 370 |
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
- Or set `self.asr` to your pipeline/model in __init__
|
| 374 |
-
"""
|
| 375 |
|
| 376 |
-
def
|
| 377 |
-
self.model_path = path or ""
|
| 378 |
-
self.ffmpeg_path = _get_ffmpeg_path()
|
| 379 |
-
self.ffprobe_path = _get_ffprobe_path(self.ffmpeg_path)
|
| 380 |
-
|
| 381 |
-
# Settings (can be overridden per-request)
|
| 382 |
-
self.default_sr = int(os.getenv("TARGET_SAMPLE_RATE", "16000"))
|
| 383 |
-
self.default_channels = int(os.getenv("TARGET_CHANNELS", "1"))
|
| 384 |
-
|
| 385 |
-
# Optional: initialize an ASR pipeline if transformers is available.
|
| 386 |
-
# If you're not doing ASR, delete this and implement your task.
|
| 387 |
-
self.asr = None
|
| 388 |
-
if pipeline is not None:
|
| 389 |
-
# If your repo contains a model, `path` is typically the model directory.
|
| 390 |
-
# If not, you can hardcode a model id here.
|
| 391 |
-
model_id_or_path = self.model_path if self.model_path else os.getenv("ASR_MODEL_ID", "").strip()
|
| 392 |
-
if model_id_or_path:
|
| 393 |
-
# You can choose task="automatic-speech-recognition"
|
| 394 |
-
# and pass device_map / torch_dtype in advanced setups.
|
| 395 |
-
self.asr = pipeline("automatic-speech-recognition", model=model_id_or_path)
|
| 396 |
-
|
| 397 |
-
# Startup self-test (fast fail)
|
| 398 |
-
_run([self.ffmpeg_path, "-version"], timeout_s=10)
|
| 399 |
-
|
| 400 |
-
def __call__(self, data: Union[Dict[str, Any], bytes, str]) -> Dict[str, Any]:
|
| 401 |
t0 = _now_ms()
|
| 402 |
|
| 403 |
-
# Parse request
|
| 404 |
try:
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
return _response(ok=False, text=str(e), diagnostics={"stage": "parse"})
|
| 408 |
-
|
| 409 |
-
# Allow per-request overrides
|
| 410 |
-
req: Dict[str, Any] = data if isinstance(data, dict) else {}
|
| 411 |
-
target_sr = int(req.get("target_sr") or req.get("sample_rate") or self.default_sr)
|
| 412 |
-
target_channels = int(req.get("target_channels") or req.get("channels") or self.default_channels)
|
| 413 |
-
|
| 414 |
-
# Convert to wav
|
| 415 |
-
wav_path = ""
|
| 416 |
-
try:
|
| 417 |
-
wav_path = _convert_to_wav(
|
| 418 |
-
payload.raw_bytes,
|
| 419 |
-
ffmpeg_path=self.ffmpeg_path,
|
| 420 |
-
target_sr=target_sr,
|
| 421 |
-
target_channels=target_channels,
|
| 422 |
-
)
|
| 423 |
t1 = _now_ms()
|
| 424 |
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
"timing_ms": {
|
| 446 |
-
"total":
|
| 447 |
-
"
|
| 448 |
-
"
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
"target_sr": target_sr,
|
| 452 |
-
"target_channels": target_channels,
|
| 453 |
},
|
|
|
|
|
|
|
|
|
|
| 454 |
},
|
| 455 |
-
|
| 456 |
|
| 457 |
except Exception as e:
|
| 458 |
-
return
|
| 459 |
-
ok
|
| 460 |
-
|
| 461 |
-
diagnostics
|
| 462 |
-
"
|
| 463 |
-
"
|
| 464 |
},
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
-
# ----------------------------
|
| 474 |
-
#
|
| 475 |
-
# ----------------------------
|
| 476 |
|
| 477 |
-
def
|
| 478 |
"""
|
| 479 |
-
|
| 480 |
-
|
| 481 |
|
| 482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
"""
|
| 484 |
-
if self.
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
f"
|
| 489 |
-
"ASR pipeline not configured. Set ASR_MODEL_ID env var or pass a model path.",
|
| 490 |
-
{"type": "none", "note": "no ASR pipeline"},
|
| 491 |
)
|
| 492 |
|
| 493 |
-
#
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
if "chunk_length_s" in req:
|
| 497 |
-
asr_kwargs["chunk_length_s"] = float(req["chunk_length_s"])
|
| 498 |
-
if "stride_length_s" in req:
|
| 499 |
-
asr_kwargs["stride_length_s"] = float(req["stride_length_s"])
|
| 500 |
|
| 501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
|
| 503 |
-
#
|
| 504 |
-
#
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
else:
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
}
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
handler.py — Hugging Face Inference Endpoint custom handler
|
| 3 |
+
Outputs: GIF, WebM, ZIP(frames)
|
| 4 |
+
|
| 5 |
+
Key points:
|
| 6 |
+
- No "huggingface_inference_toolkit plugin" usage at all.
|
| 7 |
+
- WebM encoding uses imageio + imageio-ffmpeg (ffmpeg binary resolved internally).
|
| 8 |
+
- GIF encoding uses Pillow (no ffmpeg needed).
|
| 9 |
+
- ZIP output is a zip of PNG frames.
|
| 10 |
+
|
| 11 |
+
Request JSON (examples):
|
| 12 |
+
{
|
| 13 |
+
"prompt": "a cinematic shot of a hawk flying over snowy mountains",
|
| 14 |
+
"negative_prompt": "low quality, blurry",
|
| 15 |
+
"num_frames": 48,
|
| 16 |
+
"fps": 16,
|
| 17 |
+
"height": 512,
|
| 18 |
+
"width": 512,
|
| 19 |
+
"seed": 123,
|
| 20 |
+
"outputs": ["gif", "webm", "zip"], // any subset
|
| 21 |
+
"return_base64": true, // default true
|
| 22 |
+
"gif": {"fps": 12}, // optional overrides
|
| 23 |
+
"webm": {"fps": 24, "quality": "good"}, // quality: "fast"|"good"|"best"
|
| 24 |
+
"zip": {"format": "png"} // currently png only
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
Response JSON:
|
| 28 |
+
{
|
| 29 |
+
"ok": true,
|
| 30 |
+
"diagnostics": {...},
|
| 31 |
+
"outputs": {
|
| 32 |
+
"gif_base64": "...",
|
| 33 |
+
"webm_base64": "...",
|
| 34 |
+
"zip_base64": "..."
|
| 35 |
+
}
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
Notes on payload sizes:
|
| 39 |
+
- base64 video payloads can be large. For production, consider uploading to R2/S3
|
| 40 |
+
and returning a URL instead. This handler keeps it in-response for simplicity.
|
| 41 |
"""
|
| 42 |
|
| 43 |
from __future__ import annotations
|
| 44 |
|
| 45 |
import base64
|
| 46 |
import io
|
|
|
|
| 47 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
import time
|
| 49 |
+
import tempfile
|
| 50 |
+
import zipfile
|
| 51 |
from dataclasses import dataclass
|
| 52 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 53 |
|
| 54 |
+
import numpy as np
|
| 55 |
+
from PIL import Image
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
# WebM encoding (uses ffmpeg resolved by imageio-ffmpeg)
|
| 58 |
+
import imageio
|
| 59 |
+
|
| 60 |
+
# Ensure imageio uses the packaged ffmpeg binary (not HF toolkit plugins)
|
| 61 |
try:
|
| 62 |
+
import imageio_ffmpeg # type: ignore
|
| 63 |
+
_FFMPEG_EXE = imageio_ffmpeg.get_ffmpeg_exe()
|
| 64 |
+
# imageio reads this env var to locate ffmpeg
|
| 65 |
+
os.environ["IMAGEIO_FFMPEG_EXE"] = _FFMPEG_EXE
|
| 66 |
except Exception:
|
| 67 |
+
_FFMPEG_EXE = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
def _now_ms() -> int:
|
| 71 |
return int(time.time() * 1000)
|
| 72 |
|
| 73 |
|
| 74 |
+
def _b64(data: bytes) -> str:
|
| 75 |
+
return base64.b64encode(data).decode("utf-8")
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
+
def _clamp_uint8_frame(frame: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
"""
|
| 80 |
+
Ensure frame is uint8 HxWx3 (RGB).
|
| 81 |
+
Accepts:
|
| 82 |
+
- float in [0,1] or [-1,1]
|
| 83 |
+
- uint8 already
|
| 84 |
+
- grayscale (HxW) -> RGB
|
| 85 |
+
- RGBA -> RGB
|
| 86 |
"""
|
| 87 |
+
if not isinstance(frame, np.ndarray):
|
| 88 |
+
frame = np.array(frame)
|
| 89 |
+
|
| 90 |
+
# squeeze batch-like dims if present (best-effort)
|
| 91 |
+
if frame.ndim == 4 and frame.shape[0] == 1:
|
| 92 |
+
frame = frame[0]
|
| 93 |
+
|
| 94 |
+
if frame.ndim == 2:
|
| 95 |
+
frame = np.stack([frame, frame, frame], axis=-1)
|
| 96 |
+
|
| 97 |
+
if frame.ndim != 3:
|
| 98 |
+
raise ValueError(f"Frame must be HxW, HxWxC, or 1xHxWxC; got shape {frame.shape}")
|
| 99 |
+
|
| 100 |
+
# Channels fixups
|
| 101 |
+
if frame.shape[-1] == 4:
|
| 102 |
+
frame = frame[..., :3]
|
| 103 |
+
elif frame.shape[-1] == 1:
|
| 104 |
+
frame = np.repeat(frame, 3, axis=-1)
|
| 105 |
+
elif frame.shape[-1] != 3:
|
| 106 |
+
# sometimes CxHxW
|
| 107 |
+
if frame.shape[0] == 3 and frame.ndim == 3:
|
| 108 |
+
frame = np.transpose(frame, (1, 2, 0))
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unsupported channel dimension: {frame.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
if frame.dtype == np.uint8:
|
| 113 |
+
return frame
|
| 114 |
|
| 115 |
+
# Convert float -> uint8
|
| 116 |
+
f = frame.astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
# If looks like [-1,1], map to [0,1]
|
| 119 |
+
if f.min() < 0.0:
|
| 120 |
+
f = (f + 1.0) / 2.0
|
| 121 |
+
|
| 122 |
+
f = np.clip(f, 0.0, 1.0)
|
| 123 |
+
return (f * 255.0).round().astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
+
def _encode_gif(frames: List[np.ndarray], fps: int) -> bytes:
|
| 127 |
"""
|
| 128 |
+
Encode GIF using Pillow (no ffmpeg dependency).
|
|
|
|
| 129 |
"""
|
| 130 |
+
if not frames:
|
| 131 |
+
raise ValueError("No frames to encode.")
|
| 132 |
+
pil_frames = [Image.fromarray(_clamp_uint8_frame(f)) for f in frames]
|
| 133 |
+
duration_ms = int(1000 / max(1, fps))
|
| 134 |
+
|
| 135 |
+
buf = io.BytesIO()
|
| 136 |
+
pil_frames[0].save(
|
| 137 |
+
buf,
|
| 138 |
+
format="GIF",
|
| 139 |
+
save_all=True,
|
| 140 |
+
append_images=pil_frames[1:],
|
| 141 |
+
duration=duration_ms,
|
| 142 |
+
loop=0,
|
| 143 |
+
optimize=False,
|
| 144 |
+
disposal=2,
|
| 145 |
+
)
|
| 146 |
+
return buf.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
|
| 149 |
+
def _encode_webm(frames: List[np.ndarray], fps: int, quality: str = "good") -> bytes:
|
| 150 |
"""
|
| 151 |
+
Encode WebM using imageio (ffmpeg under the hood via imageio-ffmpeg).
|
| 152 |
+
quality: "fast" | "good" | "best"
|
| 153 |
"""
|
| 154 |
+
if not frames:
|
| 155 |
+
raise ValueError("No frames to encode.")
|
| 156 |
+
|
| 157 |
+
# Choose VP9 settings. These are pragmatic defaults.
|
| 158 |
+
# For smaller file sizes: lower bitrate or higher crf.
|
| 159 |
+
# For quality: lower crf (but larger files).
|
| 160 |
+
quality = (quality or "good").lower()
|
| 161 |
+
if quality == "fast":
|
| 162 |
+
crf = 42
|
| 163 |
+
preset = "veryfast"
|
| 164 |
+
elif quality == "best":
|
| 165 |
+
crf = 28
|
| 166 |
+
preset = "slow"
|
| 167 |
+
else:
|
| 168 |
+
crf = 34
|
| 169 |
+
preset = "medium"
|
| 170 |
+
|
| 171 |
+
# Write to a temp file then return bytes
|
| 172 |
+
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp:
|
| 173 |
+
out_path = tmp.name
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
writer = imageio.get_writer(
|
| 177 |
+
out_path,
|
| 178 |
+
fps=max(1, fps),
|
| 179 |
+
format="FFMPEG",
|
| 180 |
+
codec="libvpx-vp9",
|
| 181 |
+
# ffmpeg_params are passed through to ffmpeg invocation
|
| 182 |
+
ffmpeg_params=[
|
| 183 |
+
"-pix_fmt", "yuv420p",
|
| 184 |
+
"-crf", str(crf),
|
| 185 |
+
"-b:v", "0",
|
| 186 |
+
"-preset", preset,
|
| 187 |
+
],
|
| 188 |
+
)
|
| 189 |
|
|
|
|
|
|
|
| 190 |
try:
|
| 191 |
+
for f in frames:
|
| 192 |
+
writer.append_data(_clamp_uint8_frame(f))
|
| 193 |
+
finally:
|
| 194 |
+
writer.close()
|
| 195 |
|
| 196 |
+
with open(out_path, "rb") as f:
|
| 197 |
+
return f.read()
|
| 198 |
+
finally:
|
| 199 |
try:
|
| 200 |
+
os.remove(out_path)
|
| 201 |
except Exception:
|
| 202 |
+
pass
|
|
|
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
+
def _encode_zip_frames(frames: List[np.ndarray]) -> bytes:
|
| 206 |
"""
|
| 207 |
+
Zip frames as PNG images: frame_000000.png, ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
"""
|
| 209 |
+
if not frames:
|
| 210 |
+
raise ValueError("No frames to zip.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
buf = io.BytesIO()
|
| 213 |
+
with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=6) as zf:
|
| 214 |
+
for i, f in enumerate(frames):
|
| 215 |
+
arr = _clamp_uint8_frame(f)
|
| 216 |
+
im = Image.fromarray(arr)
|
| 217 |
+
frame_buf = io.BytesIO()
|
| 218 |
+
im.save(frame_buf, format="PNG", optimize=True)
|
| 219 |
+
zf.writestr(f"frame_{i:06d}.png", frame_buf.getvalue())
|
| 220 |
+
return buf.getvalue()
|
| 221 |
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
+
@dataclass
|
| 224 |
+
class GenParams:
|
| 225 |
+
prompt: str
|
| 226 |
+
negative_prompt: str
|
| 227 |
+
num_frames: int
|
| 228 |
+
fps: int
|
| 229 |
+
height: int
|
| 230 |
+
width: int
|
| 231 |
+
seed: Optional[int]
|
| 232 |
|
| 233 |
|
| 234 |
+
class EndpointHandler:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
"""
|
| 236 |
+
Custom handler entrypoint for Hugging Face Inference Endpoints.
|
|
|
|
| 237 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
+
def __init__(self, path: str = "") -> None:
|
| 240 |
+
self.repo_path = path or ""
|
| 241 |
|
| 242 |
+
# Attempt to initialize a diffusers pipeline if available.
|
| 243 |
+
# If your repo uses a different entrypoint, edit `_generate_frames()`.
|
| 244 |
+
self.pipe = None
|
| 245 |
+
self._init_error = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
+
try:
|
| 248 |
+
import torch # type: ignore
|
| 249 |
+
from diffusers import DiffusionPipeline # type: ignore
|
| 250 |
+
|
| 251 |
+
# Prefer fp16 if CUDA is available; otherwise float32.
|
| 252 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 253 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 254 |
+
|
| 255 |
+
# Load from repository path (the model code/checkpoints are in the repo)
|
| 256 |
+
# This is the most generic path for "diffusers-like" repos.
|
| 257 |
+
self.pipe = DiffusionPipeline.from_pretrained(
|
| 258 |
+
self.repo_path if self.repo_path else None,
|
| 259 |
+
torch_dtype=dtype,
|
| 260 |
+
)
|
| 261 |
|
| 262 |
+
# Move to device if possible
|
| 263 |
+
try:
|
| 264 |
+
self.pipe.to(device)
|
| 265 |
+
except Exception:
|
| 266 |
+
pass
|
| 267 |
|
| 268 |
+
# Some pipelines benefit from enabling memory optimizations
|
| 269 |
+
try:
|
| 270 |
+
if hasattr(self.pipe, "enable_vae_slicing"):
|
| 271 |
+
self.pipe.enable_vae_slicing()
|
| 272 |
+
except Exception:
|
| 273 |
+
pass
|
| 274 |
|
| 275 |
+
except Exception as e:
|
| 276 |
+
self._init_error = str(e)
|
| 277 |
+
self.pipe = None
|
|
|
|
|
|
|
| 278 |
|
| 279 |
+
# Quick diagnostic: ensure imageio-ffmpeg resolved (for WebM)
|
| 280 |
+
self.ffmpeg_exe = _FFMPEG_EXE
|
|
|
|
|
|
|
| 281 |
|
| 282 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
t0 = _now_ms()
|
| 284 |
|
|
|
|
| 285 |
try:
|
| 286 |
+
params, outputs, return_b64, out_cfg = self._parse_request(data)
|
| 287 |
+
frames, gen_diag = self._generate_frames(params, out_cfg=out_cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
t1 = _now_ms()
|
| 289 |
|
| 290 |
+
result_outputs: Dict[str, Any] = {}
|
| 291 |
+
|
| 292 |
+
# GIF
|
| 293 |
+
if "gif" in outputs:
|
| 294 |
+
gif_fps = int((out_cfg.get("gif") or {}).get("fps") or params.fps)
|
| 295 |
+
gif_bytes = _encode_gif(frames, fps=gif_fps)
|
| 296 |
+
result_outputs["gif_base64" if return_b64 else "gif_bytes"] = _b64(gif_bytes) if return_b64 else gif_bytes
|
| 297 |
+
t_gif = _now_ms()
|
| 298 |
+
else:
|
| 299 |
+
t_gif = t1
|
| 300 |
+
|
| 301 |
+
# WebM
|
| 302 |
+
if "webm" in outputs:
|
| 303 |
+
webm_cfg = out_cfg.get("webm") or {}
|
| 304 |
+
webm_fps = int(webm_cfg.get("fps") or params.fps)
|
| 305 |
+
webm_quality = str(webm_cfg.get("quality") or "good")
|
| 306 |
+
webm_bytes = _encode_webm(frames, fps=webm_fps, quality=webm_quality)
|
| 307 |
+
result_outputs["webm_base64" if return_b64 else "webm_bytes"] = _b64(webm_bytes) if return_b64 else webm_bytes
|
| 308 |
+
t_webm = _now_ms()
|
| 309 |
+
else:
|
| 310 |
+
t_webm = t_gif
|
| 311 |
+
|
| 312 |
+
# ZIP frames
|
| 313 |
+
if "zip" in outputs:
|
| 314 |
+
zip_bytes = _encode_zip_frames(frames)
|
| 315 |
+
result_outputs["zip_base64" if return_b64 else "zip_bytes"] = _b64(zip_bytes) if return_b64 else zip_bytes
|
| 316 |
+
t_zip = _now_ms()
|
| 317 |
+
else:
|
| 318 |
+
t_zip = t_webm
|
| 319 |
+
|
| 320 |
+
return {
|
| 321 |
+
"ok": True,
|
| 322 |
+
"outputs": result_outputs,
|
| 323 |
+
"diagnostics": {
|
| 324 |
"timing_ms": {
|
| 325 |
+
"total": t_zip - t0,
|
| 326 |
+
"generate": t1 - t0,
|
| 327 |
+
"gif": (t_gif - t1) if "gif" in outputs else 0,
|
| 328 |
+
"webm": (t_webm - t_gif) if "webm" in outputs else 0,
|
| 329 |
+
"zip": (t_zip - t_webm) if "zip" in outputs else 0,
|
|
|
|
|
|
|
| 330 |
},
|
| 331 |
+
"generator": gen_diag,
|
| 332 |
+
"ffmpeg_exe": self.ffmpeg_exe,
|
| 333 |
+
"init_error": self._init_error,
|
| 334 |
},
|
| 335 |
+
}
|
| 336 |
|
| 337 |
except Exception as e:
|
| 338 |
+
return {
|
| 339 |
+
"ok": False,
|
| 340 |
+
"error": str(e),
|
| 341 |
+
"diagnostics": {
|
| 342 |
+
"ffmpeg_exe": self.ffmpeg_exe,
|
| 343 |
+
"init_error": self._init_error,
|
| 344 |
},
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
# ----------------------------
|
| 348 |
+
# Request parsing
|
| 349 |
+
# ----------------------------
|
| 350 |
+
|
| 351 |
+
def _parse_request(self, data: Dict[str, Any]) -> Tuple[GenParams, List[str], bool, Dict[str, Any]]:
|
| 352 |
+
if not isinstance(data, dict):
|
| 353 |
+
raise ValueError("Request must be a JSON object.")
|
| 354 |
+
|
| 355 |
+
prompt = str(data.get("prompt") or data.get("inputs") or "").strip()
|
| 356 |
+
if not prompt:
|
| 357 |
+
raise ValueError("Missing `prompt` (or `inputs`).")
|
| 358 |
+
|
| 359 |
+
negative_prompt = str(data.get("negative_prompt") or "").strip()
|
| 360 |
+
|
| 361 |
+
num_frames = int(data.get("num_frames") or data.get("frames") or 32)
|
| 362 |
+
fps = int(data.get("fps") or 12)
|
| 363 |
+
height = int(data.get("height") or 512)
|
| 364 |
+
width = int(data.get("width") or 512)
|
| 365 |
+
seed = data.get("seed")
|
| 366 |
+
seed = int(seed) if seed is not None and str(seed).strip() != "" else None
|
| 367 |
+
|
| 368 |
+
outputs = data.get("outputs") or ["gif"]
|
| 369 |
+
if isinstance(outputs, str):
|
| 370 |
+
outputs = [outputs]
|
| 371 |
+
outputs = [str(x).lower() for x in outputs]
|
| 372 |
+
|
| 373 |
+
allowed = {"gif", "webm", "zip"}
|
| 374 |
+
outputs = [o for o in outputs if o in allowed]
|
| 375 |
+
if not outputs:
|
| 376 |
+
outputs = ["gif"]
|
| 377 |
+
|
| 378 |
+
return_b64 = bool(data.get("return_base64", True))
|
| 379 |
+
out_cfg = data.get("output_config") or {}
|
| 380 |
+
# also allow top-level gif/webm/zip config objects
|
| 381 |
+
for k in ("gif", "webm", "zip"):
|
| 382 |
+
if k in data and isinstance(data[k], dict):
|
| 383 |
+
out_cfg[k] = data[k]
|
| 384 |
+
|
| 385 |
+
params = GenParams(
|
| 386 |
+
prompt=prompt,
|
| 387 |
+
negative_prompt=negative_prompt,
|
| 388 |
+
num_frames=max(1, num_frames),
|
| 389 |
+
fps=max(1, fps),
|
| 390 |
+
height=max(64, height),
|
| 391 |
+
width=max(64, width),
|
| 392 |
+
seed=seed,
|
| 393 |
+
)
|
| 394 |
+
return params, outputs, return_b64, out_cfg
|
| 395 |
|
| 396 |
+
# ----------------------------
|
| 397 |
+
# Frame generation
|
| 398 |
+
# ----------------------------
|
| 399 |
|
| 400 |
+
def _generate_frames(self, params: GenParams, out_cfg: Dict[str, Any]) -> Tuple[List[np.ndarray], Dict[str, Any]]:
|
| 401 |
"""
|
| 402 |
+
Generates frames as a list of uint8 RGB numpy arrays.
|
| 403 |
+
This is the only place you should need to customize for your specific repo/model.
|
| 404 |
|
| 405 |
+
Current implementation:
|
| 406 |
+
- If a diffusers pipeline is available:
|
| 407 |
+
pipe(prompt=..., negative_prompt=..., height=..., width=..., num_frames=...)
|
| 408 |
+
Then tries common output fields: frames / videos / images.
|
| 409 |
+
- Otherwise: raises with init details.
|
| 410 |
+
|
| 411 |
+
If your model call is different (e.g., special args like guidance_scale, num_inference_steps),
|
| 412 |
+
add them here.
|
| 413 |
"""
|
| 414 |
+
if self.pipe is None:
|
| 415 |
+
raise RuntimeError(
|
| 416 |
+
"Model pipeline is not initialized. "
|
| 417 |
+
"If your repo doesn't use diffusers DiffusionPipeline, edit _generate_frames(). "
|
| 418 |
+
f"Init error: {self._init_error}"
|
|
|
|
|
|
|
| 419 |
)
|
| 420 |
|
| 421 |
+
# Optional knobs (safe defaults)
|
| 422 |
+
num_inference_steps = int(out_cfg.get("num_inference_steps") or 30)
|
| 423 |
+
guidance_scale = float(out_cfg.get("guidance_scale") or 7.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
+
# Seed (best effort)
|
| 426 |
+
generator = None
|
| 427 |
+
try:
|
| 428 |
+
import torch # type: ignore
|
| 429 |
+
if params.seed is not None:
|
| 430 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 431 |
+
generator = torch.Generator(device=device).manual_seed(params.seed)
|
| 432 |
+
except Exception:
|
| 433 |
+
generator = None
|
| 434 |
+
|
| 435 |
+
# Call the pipeline (generic diffusers-style)
|
| 436 |
+
kwargs: Dict[str, Any] = {
|
| 437 |
+
"prompt": params.prompt,
|
| 438 |
+
"negative_prompt": params.negative_prompt if params.negative_prompt else None,
|
| 439 |
+
"height": params.height,
|
| 440 |
+
"width": params.width,
|
| 441 |
+
"num_inference_steps": num_inference_steps,
|
| 442 |
+
"guidance_scale": guidance_scale,
|
| 443 |
+
}
|
| 444 |
|
| 445 |
+
# Common video arg names across repos
|
| 446 |
+
# Some pipelines use num_frames, some use video_length, some use num_frames.
|
| 447 |
+
# We'll try a few.
|
| 448 |
+
# (If your repo is strict, adjust this section.)
|
| 449 |
+
called = False
|
| 450 |
+
last_err: Optional[Exception] = None
|
| 451 |
+
output = None
|
| 452 |
+
for frame_arg in ("num_frames", "video_length", "num_video_frames"):
|
| 453 |
+
try:
|
| 454 |
+
call_kwargs = dict(kwargs)
|
| 455 |
+
call_kwargs[frame_arg] = params.num_frames
|
| 456 |
+
if generator is not None:
|
| 457 |
+
call_kwargs["generator"] = generator
|
| 458 |
+
output = self.pipe(**{k: v for k, v in call_kwargs.items() if v is not None})
|
| 459 |
+
called = True
|
| 460 |
+
break
|
| 461 |
+
except Exception as e:
|
| 462 |
+
last_err = e
|
| 463 |
+
continue
|
| 464 |
+
|
| 465 |
+
if not called:
|
| 466 |
+
raise RuntimeError(f"Pipeline call failed for frame args. Last error: {last_err}")
|
| 467 |
+
|
| 468 |
+
# Extract frames from common output structures
|
| 469 |
+
frames: List[np.ndarray] = []
|
| 470 |
+
|
| 471 |
+
# diffusers outputs vary:
|
| 472 |
+
# - output.frames (list of PIL/np)
|
| 473 |
+
# - output.videos (tensor/np)
|
| 474 |
+
# - output.images (list of PIL for single-frame)
|
| 475 |
+
if hasattr(output, "frames") and output.frames is not None:
|
| 476 |
+
frames_raw = output.frames
|
| 477 |
+
frames = [np.array(f) for f in frames_raw]
|
| 478 |
+
elif hasattr(output, "videos") and output.videos is not None:
|
| 479 |
+
vids = output.videos
|
| 480 |
+
arr = None
|
| 481 |
+
|
| 482 |
+
# torch tensor or numpy
|
| 483 |
+
try:
|
| 484 |
+
import torch # type: ignore
|
| 485 |
+
if isinstance(vids, torch.Tensor):
|
| 486 |
+
arr = vids.detach().cpu().numpy()
|
| 487 |
+
else:
|
| 488 |
+
arr = np.array(vids)
|
| 489 |
+
except Exception:
|
| 490 |
+
arr = np.array(vids)
|
| 491 |
+
|
| 492 |
+
# Common shapes:
|
| 493 |
+
# (B, T, C, H, W) or (B, T, H, W, C) or (T, H, W, C)
|
| 494 |
+
if arr.ndim == 5:
|
| 495 |
+
# pick first batch
|
| 496 |
+
arr = arr[0]
|
| 497 |
+
if arr.ndim == 4 and arr.shape[1] in (1, 3, 4):
|
| 498 |
+
# likely (T, C, H, W) -> (T, H, W, C)
|
| 499 |
+
arr = np.transpose(arr, (0, 2, 3, 1))
|
| 500 |
+
if arr.ndim != 4:
|
| 501 |
+
raise ValueError(f"Unexpected video tensor shape: {arr.shape}")
|
| 502 |
+
|
| 503 |
+
frames = [arr[t] for t in range(arr.shape[0])]
|
| 504 |
+
elif hasattr(output, "images") and output.images is not None:
|
| 505 |
+
imgs = output.images
|
| 506 |
+
# if it's just one image, treat as 1-frame "video"
|
| 507 |
+
if isinstance(imgs, list):
|
| 508 |
+
frames = [np.array(im) for im in imgs]
|
| 509 |
+
else:
|
| 510 |
+
frames = [np.array(imgs)]
|
| 511 |
else:
|
| 512 |
+
# final fallback: try dict-like
|
| 513 |
+
if isinstance(output, dict):
|
| 514 |
+
for key in ("frames", "videos", "images"):
|
| 515 |
+
if key in output and output[key] is not None:
|
| 516 |
+
v = output[key]
|
| 517 |
+
if key == "videos":
|
| 518 |
+
arr = np.array(v)
|
| 519 |
+
if arr.ndim == 5:
|
| 520 |
+
arr = arr[0]
|
| 521 |
+
if arr.ndim == 4 and arr.shape[1] in (1, 3, 4):
|
| 522 |
+
arr = np.transpose(arr, (0, 2, 3, 1))
|
| 523 |
+
frames = [arr[t] for t in range(arr.shape[0])]
|
| 524 |
+
else:
|
| 525 |
+
if isinstance(v, list):
|
| 526 |
+
frames = [np.array(x) for x in v]
|
| 527 |
+
else:
|
| 528 |
+
frames = [np.array(v)]
|
| 529 |
+
break
|
| 530 |
+
|
| 531 |
+
if not frames:
|
| 532 |
+
raise RuntimeError("Could not extract frames from pipeline output (no frames/videos/images found).")
|
| 533 |
+
|
| 534 |
+
# Normalize to uint8 RGB
|
| 535 |
+
frames_u8 = [_clamp_uint8_frame(f) for f in frames]
|
| 536 |
+
|
| 537 |
+
diag = {
|
| 538 |
+
"prompt_len": len(params.prompt),
|
| 539 |
+
"negative_prompt_len": len(params.negative_prompt),
|
| 540 |
+
"num_frames": len(frames_u8),
|
| 541 |
+
"height": int(frames_u8[0].shape[0]),
|
| 542 |
+
"width": int(frames_u8[0].shape[1]),
|
| 543 |
+
"num_inference_steps": num_inference_steps,
|
| 544 |
+
"guidance_scale": guidance_scale,
|
| 545 |
+
"seed": params.seed,
|
| 546 |
}
|
| 547 |
+
return frames_u8, diag
|