sam-motamed commited on
Commit
59a493b
·
verified ·
1 Parent(s): b51f7da

Move pipeline_void.py to diffusers/

Browse files
Files changed (1) hide show
  1. pipeline_void.py +0 -559
pipeline_void.py DELETED
@@ -1,559 +0,0 @@
1
- """
2
- VOID (Video Object and Interaction Deletion) Pipeline.
3
-
4
- Simple usage:
5
-
6
- from pipeline_void import VOIDPipeline
7
-
8
- pipe = VOIDPipeline.from_pretrained("netflix/void-model")
9
- result = pipe.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.")
10
- result.save("output.mp4")
11
-
12
- Pass 2 refinement:
13
-
14
- pipe2 = VOIDPipeline.from_pretrained("netflix/void-model", void_pass=2)
15
- result2 = pipe2.inpaint("input.mp4", "quadmask.mp4", "A lime falls on the table.",
16
- pass1_video="output.mp4")
17
- result2.save("output_refined.mp4")
18
- """
19
-
20
- import os
21
- import json
22
- import subprocess
23
- import sys
24
- import tempfile
25
- from dataclasses import dataclass
26
- from typing import List, Optional, Tuple, Union
27
-
28
- import cv2
29
- import numpy as np
30
- import torch
31
- import torch.nn.functional as F
32
- from huggingface_hub import hf_hub_download, snapshot_download
33
- from safetensors.torch import load_file
34
- from diffusers import CogVideoXDDIMScheduler
35
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
36
-
37
- from cogvideox_transformer3d import CogVideoXTransformer3DModel
38
- from cogvideox_vae import AutoencoderKLCogVideoX
39
- from pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline
40
-
41
- # The base model that VOID is fine-tuned from
42
- BASE_MODEL_REPO = "alibaba-pai/CogVideoX-Fun-V1.5-5b-InP"
43
-
44
- # Checkpoint filenames in the VOID repo
45
- PASS_CHECKPOINTS = {
46
- 1: "void_pass1.safetensors",
47
- 2: "void_pass2.safetensors",
48
- }
49
-
50
- # Default negative prompt (from config/quadmask_cogvideox.py)
51
- DEFAULT_NEGATIVE_PROMPT = (
52
- "The video is not of a high quality, it has a low resolution. "
53
- "Watermark present in each frame. The background is solid. "
54
- "Strange body and strange trajectory. Distortion. "
55
- )
56
-
57
-
58
- @dataclass
59
- class VOIDOutput:
60
- """Output from VOID pipeline."""
61
- video: torch.Tensor # (T, H, W, 3) uint8
62
- video_float: torch.Tensor # (1, C, T, H, W) float [0, 1]
63
-
64
- def save(self, path: str, fps: int = 12):
65
- """Save output video to file."""
66
- import imageio
67
- frames = [f for f in self.video.cpu().numpy()]
68
- imageio.mimwrite(path, frames, fps=fps)
69
- print(f"Saved {len(frames)} frames to {path}")
70
-
71
-
72
- def _merge_void_weights(transformer, checkpoint_path):
73
- """Merge VOID checkpoint into base transformer, handling channel mismatch."""
74
- state_dict = load_file(checkpoint_path)
75
- param_name = "patch_embed.proj.weight"
76
-
77
- if state_dict[param_name].size(1) != transformer.state_dict()[param_name].size(1):
78
- latent_ch = 16
79
- feat_scale = 8
80
- feat_dim = int(latent_ch * feat_scale)
81
-
82
- new_weight = transformer.state_dict()[param_name].clone()
83
- new_weight[:, :feat_dim] = state_dict[param_name][:, :feat_dim]
84
- new_weight[:, -feat_dim:] = state_dict[param_name][:, -feat_dim:]
85
- state_dict[param_name] = new_weight
86
-
87
- m, u = transformer.load_state_dict(state_dict, strict=False)
88
- if m:
89
- print(f"[VOID] Missing keys: {len(m)}")
90
- if u:
91
- print(f"[VOID] Unexpected keys: {len(u)}")
92
-
93
- return transformer
94
-
95
-
96
- def _load_video(path: str, max_frames: int) -> np.ndarray:
97
- """Load video as numpy array (T, H, W, 3) uint8."""
98
- import imageio
99
- frames = list(imageio.imiter(path))
100
- frames = frames[:max_frames]
101
- return np.array(frames)
102
-
103
-
104
- def _prep_video_tensor(
105
- video_np: np.ndarray,
106
- sample_size: Tuple[int, int],
107
- ) -> torch.Tensor:
108
- """Convert video numpy array to pipeline input tensor.
109
-
110
- Returns: (1, C, T, H, W) float32 in [0, 1]
111
- """
112
- video = torch.from_numpy(video_np).float()
113
- video = video.permute(3, 0, 1, 2) / 255.0 # (C, T, H, W)
114
- video = F.interpolate(video, sample_size, mode="area")
115
- return video.unsqueeze(0) # (1, C, T, H, W)
116
-
117
-
118
- def _prep_mask_tensor(
119
- mask_np: np.ndarray,
120
- sample_size: Tuple[int, int],
121
- use_quadmask: bool = True,
122
- ) -> torch.Tensor:
123
- """Convert mask numpy array to pipeline input tensor.
124
-
125
- Quantizes to quadmask values [0, 63, 127, 255], inverts,
126
- and normalizes to [0, 1].
127
-
128
- Returns: (1, 1, T, H, W) float32 in [0, 1]
129
- """
130
- mask = torch.from_numpy(mask_np).float()
131
- if mask.ndim == 4:
132
- mask = mask[..., 0] # drop channel dim -> (T, H, W)
133
- mask = F.interpolate(mask.unsqueeze(0), sample_size, mode="area")
134
- mask = mask.unsqueeze(0) # (1, 1, T, H, W)
135
-
136
- if use_quadmask:
137
- # Quantize to 4 values
138
- mask = torch.where(mask <= 31, 0., mask)
139
- mask = torch.where((mask > 31) * (mask <= 95), 63., mask)
140
- mask = torch.where((mask > 95) * (mask <= 191), 127., mask)
141
- mask = torch.where(mask > 191, 255., mask)
142
- else:
143
- # Trimask: 3 values
144
- mask = torch.where(mask > 192, 255., mask)
145
- mask = torch.where((mask <= 192) * (mask >= 64), 128., mask)
146
- mask = torch.where(mask < 64, 0., mask)
147
-
148
- # Invert and normalize to [0, 1]
149
- mask = (255. - mask) / 255.
150
-
151
- return mask
152
-
153
-
154
- def _temporal_padding(
155
- tensor: torch.Tensor,
156
- min_length: int = 85,
157
- max_length: int = 197,
158
- dim: int = 2,
159
- ) -> torch.Tensor:
160
- """Pad video temporally by mirroring, matching CogVideoX requirements."""
161
- length = tensor.size(dim)
162
-
163
- min_len = (length // 4) * 4 + 1
164
- if min_len < length:
165
- min_len += 4
166
- if (min_len / 4) % 2 == 0:
167
- min_len += 4
168
- target_length = min(min_len, max_length)
169
- target_length = max(min_length, target_length)
170
-
171
- # Truncate if needed
172
- if dim == 2:
173
- tensor = tensor[:, :, :target_length]
174
- else:
175
- raise NotImplementedError(f"dim={dim} not supported")
176
-
177
- # Pad by mirroring
178
- while tensor.size(dim) < target_length:
179
- flipped = torch.flip(tensor, [dim])
180
- tensor = torch.cat([tensor, flipped], dim=dim)
181
-
182
- if dim == 2:
183
- tensor = tensor[:, :, :target_length]
184
-
185
- return tensor
186
-
187
-
188
- def _generate_warped_noise(
189
- pass1_video_path: str,
190
- target_shape: Tuple[int, int, int, int],
191
- device: torch.device,
192
- dtype: torch.dtype,
193
- ) -> torch.Tensor:
194
- """Generate warped noise from Pass 1 output video.
195
-
196
- Args:
197
- pass1_video_path: Path to Pass 1 output video.
198
- target_shape: (latent_T, latent_H, latent_W, latent_C)
199
- device: Target device.
200
- dtype: Target dtype.
201
-
202
- Returns: (1, T, C, H, W) warped noise tensor.
203
- """
204
- # Try to import rp and nw for direct warped noise generation
205
- try:
206
- # Fix for SLURM: rp crashes parsing GPU UUIDs like "GPU-9fca2b4f-..."
207
- # Set CUDA_VISIBLE_DEVICES to numeric index if it contains UUIDs
208
- cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", "")
209
- if cuda_env and not cuda_env.replace(",", "").isdigit():
210
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
211
-
212
- import rp
213
- rp.r._pip_import_autoyes = True
214
- rp.git_import('CommonSource')
215
- import rp.git.CommonSource.noise_warp as nw
216
- return _generate_warped_noise_direct(pass1_video_path, target_shape, device, dtype)
217
- except ImportError as e:
218
- print(f"[VOID] rp/noise_warp not available: {e}")
219
- except Exception as e:
220
- print(f"[VOID] Warped noise generation via rp failed: {e}")
221
- import traceback
222
- traceback.print_exc()
223
-
224
- # Fallback: try to find and run make_warped_noise.py as subprocess
225
- script_candidates = [
226
- os.path.join(os.path.dirname(__file__), "make_warped_noise.py"),
227
- os.path.join(os.path.dirname(__file__), "..", "inference", "cogvideox_fun", "make_warped_noise.py"),
228
- ]
229
- gwf_script = None
230
- for candidate in script_candidates:
231
- if os.path.exists(candidate):
232
- gwf_script = candidate
233
- break
234
-
235
- if gwf_script is None:
236
- raise RuntimeError(
237
- "Cannot generate warped noise: 'rp' package not installed and "
238
- "make_warped_noise.py not found. Install 'rp' package or provide "
239
- "pre-computed warped noise via warped_noise_path parameter."
240
- )
241
-
242
- with tempfile.TemporaryDirectory() as tmpdir:
243
- cmd = [sys.executable, gwf_script, os.path.abspath(pass1_video_path), tmpdir]
244
- print(f"[VOID] Generating warped noise (this may take a few minutes)...")
245
- result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
246
- if result.returncode != 0:
247
- raise RuntimeError(f"Warped noise generation failed:\n{result.stderr}")
248
-
249
- # Find the output noises.npy
250
- video_stem = os.path.splitext(os.path.basename(pass1_video_path))[0]
251
- noise_path = os.path.join(tmpdir, video_stem, "noises.npy")
252
- if not os.path.exists(noise_path):
253
- # Try flat path
254
- noise_path = os.path.join(tmpdir, "noises.npy")
255
- if not os.path.exists(noise_path):
256
- raise RuntimeError(f"Warped noise file not found after generation")
257
-
258
- return _load_warped_noise(noise_path, target_shape, device, dtype)
259
-
260
-
261
- def _generate_warped_noise_direct(
262
- video_path: str,
263
- target_shape: Tuple[int, int, int, int],
264
- device: torch.device,
265
- dtype: torch.dtype,
266
- ) -> torch.Tensor:
267
- """Generate warped noise directly using rp package."""
268
- import rp
269
- import rp.git.CommonSource.noise_warp as nw
270
-
271
- video = rp.load_video(video_path)
272
- video = rp.resize_list(video, length=72)
273
- video = rp.resize_images_to_hold(video, height=480, width=720)
274
- video = rp.crop_images(video, height=480, width=720, origin='center')
275
- video = rp.as_numpy_array(video)
276
-
277
- FRAME = 2**-1
278
- FLOW = 2**3
279
- LATENT = 8
280
-
281
- output = nw.get_noise_from_video(
282
- video,
283
- remove_background=False,
284
- visualize=False,
285
- save_files=False,
286
- noise_channels=16,
287
- resize_frames=FRAME,
288
- resize_flow=FLOW,
289
- downscale_factor=round(FRAME * FLOW) * LATENT,
290
- )
291
-
292
- noises = output.numpy_noises # (T, H, W, C)
293
- return _numpy_noise_to_tensor(noises, target_shape, device, dtype)
294
-
295
-
296
- def _load_warped_noise(
297
- noise_path: str,
298
- target_shape: Tuple[int, int, int, int],
299
- device: torch.device,
300
- dtype: torch.dtype,
301
- ) -> torch.Tensor:
302
- """Load and resize pre-computed warped noise."""
303
- noises = np.load(noise_path)
304
- if noises.dtype == np.float16:
305
- noises = noises.astype(np.float32)
306
- # Ensure THWC format
307
- if noises.shape[1] == 16: # TCHW -> THWC
308
- noises = np.transpose(noises, (0, 2, 3, 1))
309
- return _numpy_noise_to_tensor(noises, target_shape, device, dtype)
310
-
311
-
312
- def _numpy_noise_to_tensor(
313
- noises: np.ndarray,
314
- target_shape: Tuple[int, int, int, int],
315
- device: torch.device,
316
- dtype: torch.dtype,
317
- ) -> torch.Tensor:
318
- """Convert numpy noise (T, H, W, C) to pipeline tensor (1, T, C, H, W)."""
319
- latent_T, latent_H, latent_W, latent_C = target_shape
320
-
321
- # Temporal resize if needed
322
- if noises.shape[0] != latent_T:
323
- indices = np.linspace(0, noises.shape[0] - 1, latent_T)
324
- lower = np.floor(indices).astype(int)
325
- upper = np.ceil(indices).astype(int)
326
- frac = indices - lower
327
- noises = noises[lower] * (1 - frac[:, None, None, None]) + noises[upper] * frac[:, None, None, None]
328
-
329
- # Spatial resize if needed
330
- if noises.shape[1] != latent_H or noises.shape[2] != latent_W:
331
- resized = np.zeros((latent_T, latent_H, latent_W, latent_C), dtype=noises.dtype)
332
- for t in range(latent_T):
333
- for c in range(latent_C):
334
- resized[t, :, :, c] = cv2.resize(
335
- noises[t, :, :, c], (latent_W, latent_H),
336
- interpolation=cv2.INTER_LINEAR,
337
- )
338
- noises = resized
339
-
340
- # Convert to tensor: (T, H, W, C) -> (1, T, C, H, W)
341
- tensor = torch.from_numpy(noises).permute(0, 3, 1, 2).unsqueeze(0)
342
- return tensor.to(device=device, dtype=dtype)
343
-
344
-
345
- class VOIDPipeline(CogVideoXFunInpaintPipeline):
346
- """
347
- VOID: Video Object and Interaction Deletion.
348
-
349
- Removes objects and their physical interactions from videos using
350
- quadmask-conditioned video inpainting.
351
- """
352
-
353
- @classmethod
354
- def from_pretrained(
355
- cls,
356
- pretrained_model_name_or_path: str,
357
- void_pass: int = 1,
358
- base_model: str = BASE_MODEL_REPO,
359
- torch_dtype: torch.dtype = torch.bfloat16,
360
- **kwargs,
361
- ):
362
- """
363
- Load the VOID pipeline.
364
-
365
- Args:
366
- pretrained_model_name_or_path: HF repo ID or local path containing
367
- VOID checkpoint files (void_pass1.safetensors, etc.)
368
- void_pass: Which pass checkpoint to load (1 or 2). Default: 1.
369
- base_model: HF repo ID for the base CogVideoX-Fun model.
370
- torch_dtype: Weight dtype. Default: torch.bfloat16.
371
- """
372
- if void_pass not in PASS_CHECKPOINTS:
373
- raise ValueError(f"void_pass must be 1 or 2, got {void_pass}")
374
-
375
- # --- Download VOID checkpoint ---
376
- checkpoint_name = PASS_CHECKPOINTS[void_pass]
377
- print(f"[VOID] Loading Pass {void_pass} checkpoint...")
378
-
379
- if os.path.isdir(pretrained_model_name_or_path):
380
- checkpoint_path = os.path.join(pretrained_model_name_or_path, checkpoint_name)
381
- else:
382
- checkpoint_path = hf_hub_download(
383
- repo_id=pretrained_model_name_or_path,
384
- filename=checkpoint_name,
385
- )
386
-
387
- # --- Download and load base model ---
388
- print(f"[VOID] Loading base model: {base_model}")
389
- base_model_path = snapshot_download(repo_id=base_model)
390
-
391
- # Transformer (with VAE mask channels)
392
- print("[VOID] Loading transformer...")
393
- transformer = CogVideoXTransformer3DModel.from_pretrained(
394
- base_model_path,
395
- subfolder="transformer",
396
- low_cpu_mem_usage=True,
397
- torch_dtype=torch_dtype,
398
- use_vae_mask=True,
399
- )
400
-
401
- # Merge VOID weights
402
- print(f"[VOID] Merging Pass {void_pass} weights...")
403
- transformer = _merge_void_weights(transformer, checkpoint_path)
404
- transformer = transformer.to(torch_dtype)
405
-
406
- # VAE
407
- print("[VOID] Loading VAE...")
408
- vae = AutoencoderKLCogVideoX.from_pretrained(
409
- base_model_path, subfolder="vae"
410
- ).to(torch_dtype)
411
-
412
- # Tokenizer + Text encoder
413
- print("[VOID] Loading tokenizer and text encoder...")
414
- from transformers import T5Tokenizer, T5EncoderModel
415
- tokenizer = T5Tokenizer.from_pretrained(base_model_path, subfolder="tokenizer")
416
- text_encoder = T5EncoderModel.from_pretrained(
417
- base_model_path, subfolder="text_encoder", torch_dtype=torch_dtype,
418
- )
419
-
420
- # Scheduler
421
- scheduler = CogVideoXDDIMScheduler.from_pretrained(
422
- base_model_path, subfolder="scheduler"
423
- )
424
-
425
- # Build pipeline
426
- pipe = cls(
427
- tokenizer=tokenizer,
428
- text_encoder=text_encoder,
429
- vae=vae,
430
- transformer=transformer,
431
- scheduler=scheduler,
432
- )
433
- pipe._void_pass = void_pass
434
-
435
- print("[VOID] Pipeline ready!")
436
- return pipe
437
-
438
- def inpaint(
439
- self,
440
- video_path: str,
441
- mask_path: str,
442
- prompt: str,
443
- negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
444
- height: int = 384,
445
- width: int = 672,
446
- num_inference_steps: int = 30,
447
- guidance_scale: float = 1.0,
448
- strength: float = 1.0,
449
- temporal_window_size: int = 85,
450
- max_video_length: int = 197,
451
- fps: int = 12,
452
- seed: int = 42,
453
- pass1_video: Optional[str] = None,
454
- warped_noise_path: Optional[str] = None,
455
- use_quadmask: bool = True,
456
- ) -> VOIDOutput:
457
- """
458
- Run VOID inpainting on a video.
459
-
460
- Args:
461
- video_path: Path to input video (mp4).
462
- mask_path: Path to quadmask video (mp4). Grayscale with values:
463
- 0=object to remove, 63=overlap, 127=affected region, 255=background.
464
- prompt: Text description of the desired result after removal.
465
- E.g., "A lime falls on the table."
466
- negative_prompt: Negative prompt for generation quality.
467
- height: Output height (default 384).
468
- width: Output width (default 672).
469
- num_inference_steps: Denoising steps (default 30).
470
- guidance_scale: CFG scale (default 1.0 = no CFG).
471
- strength: Denoising strength (default 1.0).
472
- temporal_window_size: Frames per inference window (default 85).
473
- max_video_length: Max frames to process (default 197).
474
- fps: Output FPS (default 12).
475
- seed: Random seed (default 42).
476
- pass1_video: Path to Pass 1 output video, for Pass 2 warped noise init.
477
- warped_noise_path: Path to pre-computed warped noise (.npy).
478
- use_quadmask: Use 4-value quadmask (default True). Set False for trimask.
479
-
480
- Returns:
481
- VOIDOutput with .video (uint8) and .save() method.
482
- """
483
- sample_size = (height, width)
484
-
485
- # Align video length to VAE temporal compression ratio
486
- vae_temporal_ratio = self.vae.config.temporal_compression_ratio
487
- video_length = int((max_video_length - 1) // vae_temporal_ratio * vae_temporal_ratio) + 1
488
-
489
- # --- Load and prep video ---
490
- print("[VOID] Loading video and mask...")
491
- vid_np = _load_video(video_path, video_length)
492
- mask_np = _load_video(mask_path, video_length)
493
-
494
- video = _prep_video_tensor(vid_np, sample_size)
495
- mask = _prep_mask_tensor(mask_np, sample_size, use_quadmask=use_quadmask)
496
-
497
- # Temporal padding
498
- video = _temporal_padding(video, min_length=temporal_window_size, max_length=max_video_length)
499
- mask = _temporal_padding(mask, min_length=temporal_window_size, max_length=max_video_length)
500
-
501
- num_frames = min(video.shape[2], temporal_window_size)
502
-
503
- print(f"[VOID] Video: {video.shape}, Mask: {mask.shape}, Frames: {num_frames}")
504
-
505
- # --- Handle warped noise for Pass 2 ---
506
- latents = None
507
- if warped_noise_path is not None or pass1_video is not None:
508
- latent_T = (num_frames - 1) // 4 + 1
509
- latent_H = height // 8
510
- latent_W = width // 8
511
- latent_C = 16
512
- target_shape = (latent_T, latent_H, latent_W, latent_C)
513
-
514
- if warped_noise_path is not None:
515
- print(f"[VOID] Loading pre-computed warped noise from {warped_noise_path}")
516
- latents = _load_warped_noise(
517
- warped_noise_path, target_shape,
518
- device=torch.device("cpu"), dtype=torch.bfloat16,
519
- )
520
- else:
521
- print(f"[VOID] Generating warped noise from Pass 1 output...")
522
- latents = _generate_warped_noise(
523
- pass1_video, target_shape,
524
- device=torch.device("cpu"), dtype=torch.bfloat16,
525
- )
526
- print(f"[VOID] Warped noise: {latents.shape}, mean={latents.mean():.4f}, std={latents.std():.4f}")
527
-
528
- # --- Run inference ---
529
- generator = torch.Generator(device="cpu").manual_seed(seed)
530
-
531
- print(f"[VOID] Running inference ({num_frames} frames, {num_inference_steps} steps)...")
532
- with torch.no_grad():
533
- output = self(
534
- prompt=prompt,
535
- negative_prompt=negative_prompt,
536
- num_frames=num_frames,
537
- height=height,
538
- width=width,
539
- guidance_scale=guidance_scale,
540
- num_inference_steps=num_inference_steps,
541
- generator=generator,
542
- video=video,
543
- mask_video=mask,
544
- strength=strength,
545
- use_trimask=True,
546
- use_vae_mask=True,
547
- latents=latents,
548
- ).videos
549
-
550
- # --- Process output ---
551
- if isinstance(output, np.ndarray):
552
- output = torch.from_numpy(output)
553
-
554
- # output is (B, C, T, H, W) in [0, 1]
555
- video_float = output
556
- video_uint8 = (output[0].permute(1, 2, 3, 0).clamp(0, 1) * 255).to(torch.uint8)
557
-
558
- print(f"[VOID] Done! Output: {video_uint8.shape}")
559
- return VOIDOutput(video=video_uint8, video_float=video_float)