Spaces:
Runtime error
Runtime error
File size: 4,606 Bytes
f3d0a26 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | # stage2_vace.py
import numpy as np
import torch
from PIL import Image
class VACEWrapper:
def __init__(self, device="cuda"):
from diffusers import WanImageToVideoPipeline
from diffusers.utils import export_to_video
import torch
self.device = device
self.pipe = WanImageToVideoPipeline.from_pretrained(
"Wan-AI/Wan2.1-VACE-1.3B-diffusers",
torch_dtype=torch.bfloat16,
).to(device)
self.pipe.enable_model_cpu_offload()
def synthesize(
self,
original_frames,
synthesis_masks,
inpaint_masks,
first_frame_ref,
text_prompt="",
):
import numpy as np
import cv2
import torch
from PIL import Image
T, orig_H, orig_W, _ = original_frames.shape
# Round to nearest multiple of 16 (VACE requirement)
H = (orig_H // 16) * 16
W = (orig_W // 16) * 16
if H != orig_H or W != orig_W:
original_frames = np.stack([cv2.resize(f, (W, H)) for f in original_frames])
first_frame_ref = cv2.resize(first_frame_ref, (W, H))
synthesis_masks = np.stack([
cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in synthesis_masks
])
inpaint_masks = np.stack([
cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in inpaint_masks
])
video_pil = [Image.fromarray(f) for f in original_frames]
combined = np.clip(
synthesis_masks.astype(np.uint16) + inpaint_masks.astype(np.uint16), 0, 255
).astype(np.uint8)
mask_pil = [Image.fromarray(m) for m in combined]
ref_pil = Image.fromarray(first_frame_ref)
output = self.pipe(
video=video_pil,
mask=mask_pil,
prompt=text_prompt,
negative_prompt="static, blurry, low quality",
reference_images=[ref_pil],
num_frames=T,
height=H,
width=W,
guidance_scale=5.0,
num_inference_steps=25,
).frames[0]
result = np.stack([np.array(f) for f in output], axis=0)
# Restore original resolution
if orig_H != H or orig_W != W:
result = np.stack([cv2.resize(f, (orig_W, orig_H)) for f in result])
return result
class SimpleCompositeStage2:
"""
Fallback Stage 2: simple alpha compositing.
No diffusion model needed.
Works for: clean background, simple objects.
Quality: low but fast for debugging the pipeline.
"""
def synthesize(
self,
original_frames: np.ndarray, # [T, H, W, 3]
synthesis_masks: np.ndarray, # [T, H, W]
inpaint_masks: np.ndarray, # [T, H, W]
object_crop: np.ndarray, # [H_obj, W_obj, 3]
object_mask: np.ndarray, # [H_obj, W_obj] binary
) -> np.ndarray:
"""
Composite object into new positions using simple alpha blending.
Useful for validating box trajectory before diffusion.
"""
import cv2
T, H, W, _ = original_frames.shape
result = original_frames.copy()
for t in range(T):
# Find box from synthesis mask
mask_t = synthesis_masks[t]
ys, xs = np.where(mask_t > 0.5)
if len(ys) == 0:
continue
y1, y2 = ys.min(), ys.max()
x1, x2 = xs.min(), xs.max()
bh, bw = y2 - y1, x2 - x1
if bh <= 0 or bw <= 0:
continue
# Resize object to target box size
obj_resized = cv2.resize(
object_crop, (bw, bh),
interpolation=cv2.INTER_LINEAR
)
mask_resized = cv2.resize(
object_mask.astype(np.float32), (bw, bh),
interpolation=cv2.INTER_LINEAR
)
mask_3ch = mask_resized[:, :, None]
# Erase original position (simple fill with nearby bg)
erase_mask = inpaint_masks[t]
if erase_mask.sum() > 0:
result[t] = _inpaint_simple(result[t], erase_mask)
# Composite object at new position
roi = result[t, y1:y2, x1:x2]
result[t, y1:y2, x1:x2] = (
obj_resized * mask_3ch + roi * (1 - mask_3ch)
).astype(np.uint8)
return result
def _inpaint_simple(frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Simple telea inpainting for object removal"""
import cv2
mask_uint8 = (mask * 255).astype(np.uint8)
return cv2.inpaint(frame, mask_uint8, 3, cv2.INPAINT_TELEA)
|