qhillerich commited on
Commit
a2cbd86
·
verified ·
1 Parent(s): d37ebe0

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +56 -27
handler.py CHANGED
@@ -2,18 +2,10 @@
2
  handler.py — Hugging Face Inference Endpoint custom handler
3
  Outputs: GIF, WebM, ZIP(frames)
4
 
5
- This version includes:
6
- - Defensive patches to avoid hosted runtime failing with:
7
- "`ffmpeg` is not a registered plugin name."
8
- - Robust frame extraction for shapes like (T,H,W,C), (B,T,H,W,C), (T,C,H,W), (B,T,C,H,W)
9
- - Output encoders:
10
- - GIF: Pillow only (no ffmpeg)
11
- - ZIP: PNG frames zipped (no ffmpeg)
12
- - WebM: imageio + imageio-ffmpeg via IMAGEIO_FFMPEG_EXE env var (NO executable= arg)
13
-
14
- IMPORTANT:
15
- - HF gateway often requires top-level { "inputs": {...} }.
16
- - Send requests wrapped in "inputs".
17
  """
18
 
19
  from __future__ import annotations
@@ -137,6 +129,13 @@ def _b64(data: bytes) -> str:
137
  return base64.b64encode(data).decode("utf-8")
138
 
139
 
 
 
 
 
 
 
 
140
  def _clamp_uint8_frame(frame: np.ndarray) -> np.ndarray:
141
  """
142
  Normalize a frame into uint8 RGB (H,W,3).
@@ -200,10 +199,6 @@ def _encode_gif(frames: List[np.ndarray], fps: int) -> bytes:
200
  def _encode_webm(frames: List[np.ndarray], fps: int, quality: str = "good") -> bytes:
201
  """
202
  Encode WebM (VP9) via imageio.
203
-
204
- IMPORTANT:
205
- - Do NOT pass executable=...; HF's imageio build can reject that parameter.
206
- - We rely on IMAGEIO_FFMPEG_EXE env var set at import time.
207
  """
208
  if not frames:
209
  raise ValueError("No frames to encode WebM.")
@@ -280,6 +275,7 @@ class GenParams:
280
  seed: Optional[int]
281
  num_inference_steps: int
282
  guidance_scale: float
 
283
 
284
 
285
  def _unwrap_inputs(payload: Dict[str, Any]) -> Dict[str, Any]:
@@ -292,8 +288,8 @@ def _parse_request(payload: Dict[str, Any]) -> Tuple[GenParams, List[str], bool,
292
  data = _unwrap_inputs(payload)
293
 
294
  prompt = str(data.get("prompt") or data.get("inputs") or "").strip()
295
- if not prompt:
296
- raise ValueError("Missing `prompt` (or `inputs`).")
297
 
298
  negative_prompt = str(data.get("negative_prompt") or "").strip()
299
 
@@ -304,6 +300,9 @@ def _parse_request(payload: Dict[str, Any]) -> Tuple[GenParams, List[str], bool,
304
  seed = data.get("seed")
305
  seed = int(seed) if seed is not None and str(seed).strip() != "" else None
306
 
 
 
 
307
  num_inference_steps = int(data.get("num_inference_steps") or 30)
308
  guidance_scale = float(data.get("guidance_scale") or 7.5)
309
 
@@ -333,6 +332,7 @@ def _parse_request(payload: Dict[str, Any]) -> Tuple[GenParams, List[str], bool,
333
  seed=seed,
334
  num_inference_steps=max(1, num_inference_steps),
335
  guidance_scale=guidance_scale,
 
336
  )
337
  return params, outputs, return_base64, out_cfg
338
 
@@ -347,13 +347,13 @@ class EndpointHandler:
347
  self.pipe = None
348
  self.init_error: Optional[str] = None
349
 
350
- print("=== CUSTOM handler.py LOADED (webm uses IMAGEIO_FFMPEG_EXE only) ===", flush=True)
351
  print(f"=== HF toolkit patch diag: {HF_TOOLKIT_PATCH_DIAG} ===", flush=True)
352
  print(f"=== imageio-ffmpeg exe: {_FFMPEG_EXE} ===", flush=True)
353
 
354
  try:
355
  import torch # type: ignore
356
- from diffusers import DiffusionPipeline # type: ignore
357
 
358
  device = "cuda" if torch.cuda.is_available() else "cpu"
359
  dtype = torch.float16 if device == "cuda" else torch.float32
@@ -361,7 +361,14 @@ class EndpointHandler:
361
  subdir = os.getenv("HF_MODEL_SUBDIR", "").strip()
362
  model_path = self.repo_path if not subdir else os.path.join(self.repo_path, subdir)
363
 
364
- self.pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype)
 
 
 
 
 
 
 
365
 
366
  try:
367
  self.pipe.to(device)
@@ -373,6 +380,10 @@ class EndpointHandler:
373
  self.pipe.enable_vae_slicing()
374
  except Exception:
375
  pass
 
 
 
 
376
 
377
  except Exception as e:
378
  self.init_error = str(e)
@@ -434,6 +445,8 @@ class EndpointHandler:
434
  }
435
 
436
  except Exception as e:
 
 
437
  return {
438
  "ok": False,
439
  "error": str(e),
@@ -469,21 +482,37 @@ class EndpointHandler:
469
  "width": params.width,
470
  "num_inference_steps": params.num_inference_steps,
471
  "guidance_scale": params.guidance_scale,
 
472
  }
 
 
 
 
 
 
 
 
473
 
474
  # Try common frame arg names across video pipelines
475
  output = None
476
  last_err: Optional[Exception] = None
 
 
477
  for frame_arg in ("num_frames", "video_length", "num_video_frames"):
478
  try:
479
  call_kwargs = dict(kwargs)
480
  call_kwargs[frame_arg] = params.num_frames
481
  if generator is not None:
482
  call_kwargs["generator"] = generator
483
- output = self.pipe(**{k: v for k, v in call_kwargs.items() if v is not None})
 
 
 
 
484
  break
485
  except Exception as e:
486
  last_err = e
 
487
  continue
488
 
489
  if output is None:
@@ -491,6 +520,8 @@ class EndpointHandler:
491
 
492
  frames: List[np.ndarray] = []
493
 
 
 
494
  # 1) output.frames — may be list OR ndarray/tensor-like
495
  if hasattr(output, "frames") and getattr(output, "frames") is not None:
496
  frames_raw = getattr(output, "frames")
@@ -537,7 +568,7 @@ class EndpointHandler:
537
 
538
  # 3) output.images (single frame or list)
539
  elif hasattr(output, "images") and getattr(output, "images") is not None:
540
- imgs = getattr(output, "images")
541
  if isinstance(imgs, list):
542
  frames = [np.array(im) for im in imgs]
543
  else:
@@ -584,8 +615,6 @@ class EndpointHandler:
584
  "num_frames": len(frames_u8),
585
  "height": int(frames_u8[0].shape[0]),
586
  "width": int(frames_u8[0].shape[1]),
587
- "num_inference_steps": params.num_inference_steps,
588
- "guidance_scale": params.guidance_scale,
589
- "seed": params.seed,
590
  }
591
- return frames_u8, diag
 
2
  handler.py — Hugging Face Inference Endpoint custom handler
3
  Outputs: GIF, WebM, ZIP(frames)
4
 
5
+ This version maintains UNIVERSAL compatibility:
6
+ - Defensive argument guessing (num_frames vs video_length)
7
+ - Robust output shape parsing (TBL, BCTHW, etc.)
8
+ - Adds Support for Image-to-Video via `image` input (base64)
 
 
 
 
 
 
 
 
9
  """
10
 
11
  from __future__ import annotations
 
129
  return base64.b64encode(data).decode("utf-8")
130
 
131
 
132
+ def _b64_to_pil(b64_str: str) -> Image.Image:
133
+ if "," in b64_str:
134
+ b64_str = b64_str.split(",")[1]
135
+ data = base64.b64decode(b64_str)
136
+ return Image.open(io.BytesIO(data)).convert("RGB")
137
+
138
+
139
  def _clamp_uint8_frame(frame: np.ndarray) -> np.ndarray:
140
  """
141
  Normalize a frame into uint8 RGB (H,W,3).
 
199
  def _encode_webm(frames: List[np.ndarray], fps: int, quality: str = "good") -> bytes:
200
  """
201
  Encode WebM (VP9) via imageio.
 
 
 
 
202
  """
203
  if not frames:
204
  raise ValueError("No frames to encode WebM.")
 
275
  seed: Optional[int]
276
  num_inference_steps: int
277
  guidance_scale: float
278
+ image_b64: Optional[str] = None
279
 
280
 
281
  def _unwrap_inputs(payload: Dict[str, Any]) -> Dict[str, Any]:
 
288
  data = _unwrap_inputs(payload)
289
 
290
  prompt = str(data.get("prompt") or data.get("inputs") or "").strip()
291
+ if not prompt and "image" not in data:
292
+ pass
293
 
294
  negative_prompt = str(data.get("negative_prompt") or "").strip()
295
 
 
300
  seed = data.get("seed")
301
  seed = int(seed) if seed is not None and str(seed).strip() != "" else None
302
 
303
+ # Image input for I2V
304
+ image_b64 = data.get("image") or data.get("image_base64")
305
+
306
  num_inference_steps = int(data.get("num_inference_steps") or 30)
307
  guidance_scale = float(data.get("guidance_scale") or 7.5)
308
 
 
332
  seed=seed,
333
  num_inference_steps=max(1, num_inference_steps),
334
  guidance_scale=guidance_scale,
335
+ image_b64=image_b64
336
  )
337
  return params, outputs, return_base64, out_cfg
338
 
 
347
  self.pipe = None
348
  self.init_error: Optional[str] = None
349
 
350
+ print("=== CUSTOM handler.py LOADED (Universal Mode) ===", flush=True)
351
  print(f"=== HF toolkit patch diag: {HF_TOOLKIT_PATCH_DIAG} ===", flush=True)
352
  print(f"=== imageio-ffmpeg exe: {_FFMPEG_EXE} ===", flush=True)
353
 
354
  try:
355
  import torch # type: ignore
356
+ from diffusers import DiffusionPipeline, LTXConditionPipeline
357
 
358
  device = "cuda" if torch.cuda.is_available() else "cpu"
359
  dtype = torch.float16 if device == "cuda" else torch.float32
 
361
  subdir = os.getenv("HF_MODEL_SUBDIR", "").strip()
362
  model_path = self.repo_path if not subdir else os.path.join(self.repo_path, subdir)
363
 
364
+ # --- Attempt to load LTXConditionPipeline first (for I2V Support) ---
365
+ # If that fails (e.g. model isn't LTX or diffusers version old), fallback to generic.
366
+ try:
367
+ print("Attempting to load LTXConditionPipeline...", flush=True)
368
+ self.pipe = LTXConditionPipeline.from_pretrained(model_path, torch_dtype=dtype)
369
+ except Exception as e:
370
+ print(f"LTXConditionPipeline load failed ({e}), falling back to generic DiffusionPipeline...", flush=True)
371
+ self.pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype)
372
 
373
  try:
374
  self.pipe.to(device)
 
380
  self.pipe.enable_vae_slicing()
381
  except Exception:
382
  pass
383
+
384
+ # Optimization for LTX / newer diffusers
385
+ if hasattr(self.pipe, "vae") and hasattr(self.pipe.vae, "enable_tiling"):
386
+ self.pipe.vae.enable_tiling()
387
 
388
  except Exception as e:
389
  self.init_error = str(e)
 
445
  }
446
 
447
  except Exception as e:
448
+ import traceback
449
+ traceback.print_exc()
450
  return {
451
  "ok": False,
452
  "error": str(e),
 
482
  "width": params.width,
483
  "num_inference_steps": params.num_inference_steps,
484
  "guidance_scale": params.guidance_scale,
485
+ # "num_frames" is intentionally OMITTED here to be handled by the loop below
486
  }
487
+
488
+ # Handle Image-to-Video
489
+ # Use simple argument passing if pipeline supports it (LTXConditionPipeline does)
490
+ # If image is present, we pass it.
491
+ if params.image_b64:
492
+ print("Received image input, performing Image-to-Video.", flush=True)
493
+ pil_image = _b64_to_pil(params.image_b64)
494
+ kwargs["image"] = pil_image
495
 
496
  # Try common frame arg names across video pipelines
497
  output = None
498
  last_err: Optional[Exception] = None
499
+
500
+ # UNIVERSAL LOOP: Try all known frame arguments
501
  for frame_arg in ("num_frames", "video_length", "num_video_frames"):
502
  try:
503
  call_kwargs = dict(kwargs)
504
  call_kwargs[frame_arg] = params.num_frames
505
  if generator is not None:
506
  call_kwargs["generator"] = generator
507
+
508
+ # Filter out None values just in case
509
+ clean_kwargs = {k: v for k, v in call_kwargs.items() if v is not None}
510
+
511
+ output = self.pipe(**clean_kwargs)
512
  break
513
  except Exception as e:
514
  last_err = e
515
+ # Don't print spam, just try next arg
516
  continue
517
 
518
  if output is None:
 
520
 
521
  frames: List[np.ndarray] = []
522
 
523
+ # UNIVERSAL OUTPUT PARSING: Handle all known shapes
524
+
525
  # 1) output.frames — may be list OR ndarray/tensor-like
526
  if hasattr(output, "frames") and getattr(output, "frames") is not None:
527
  frames_raw = getattr(output, "frames")
 
568
 
569
  # 3) output.images (single frame or list)
570
  elif hasattr(output, "images") and getattr(output, "images") is not None:
571
+ imgs = getattr(output, "images\")
572
  if isinstance(imgs, list):
573
  frames = [np.array(im) for im in imgs]
574
  else:
 
615
  "num_frames": len(frames_u8),
616
  "height": int(frames_u8[0].shape[0]),
617
  "width": int(frames_u8[0].shape[1]),
618
+ "mode": "i2v" if params.image_b64 else "t2v"
 
 
619
  }
620
+ return frames_u8, diag