# stage2_vace.py import numpy as np import torch from PIL import Image class VACEWrapper: def __init__(self, device="cuda"): from diffusers import WanImageToVideoPipeline from diffusers.utils import export_to_video import torch self.device = device self.pipe = WanImageToVideoPipeline.from_pretrained( "Wan-AI/Wan2.1-VACE-1.3B-diffusers", torch_dtype=torch.bfloat16, ).to(device) self.pipe.enable_model_cpu_offload() def synthesize( self, original_frames, synthesis_masks, inpaint_masks, first_frame_ref, text_prompt="", ): import numpy as np import cv2 import torch from PIL import Image T, orig_H, orig_W, _ = original_frames.shape # Round to nearest multiple of 16 (VACE requirement) H = (orig_H // 16) * 16 W = (orig_W // 16) * 16 if H != orig_H or W != orig_W: original_frames = np.stack([cv2.resize(f, (W, H)) for f in original_frames]) first_frame_ref = cv2.resize(first_frame_ref, (W, H)) synthesis_masks = np.stack([ cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in synthesis_masks ]) inpaint_masks = np.stack([ cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in inpaint_masks ]) video_pil = [Image.fromarray(f) for f in original_frames] combined = np.clip( synthesis_masks.astype(np.uint16) + inpaint_masks.astype(np.uint16), 0, 255 ).astype(np.uint8) mask_pil = [Image.fromarray(m) for m in combined] ref_pil = Image.fromarray(first_frame_ref) output = self.pipe( video=video_pil, mask=mask_pil, prompt=text_prompt, negative_prompt="static, blurry, low quality", reference_images=[ref_pil], num_frames=T, height=H, width=W, guidance_scale=5.0, num_inference_steps=25, ).frames[0] result = np.stack([np.array(f) for f in output], axis=0) # Restore original resolution if orig_H != H or orig_W != W: result = np.stack([cv2.resize(f, (orig_W, orig_H)) for f in result]) return result class SimpleCompositeStage2: """ Fallback Stage 2: simple alpha compositing. No diffusion model needed. Works for: clean background, simple objects. Quality: low but fast for debugging the pipeline. """ def synthesize( self, original_frames: np.ndarray, # [T, H, W, 3] synthesis_masks: np.ndarray, # [T, H, W] inpaint_masks: np.ndarray, # [T, H, W] object_crop: np.ndarray, # [H_obj, W_obj, 3] object_mask: np.ndarray, # [H_obj, W_obj] binary ) -> np.ndarray: """ Composite object into new positions using simple alpha blending. Useful for validating box trajectory before diffusion. """ import cv2 T, H, W, _ = original_frames.shape result = original_frames.copy() for t in range(T): # Find box from synthesis mask mask_t = synthesis_masks[t] ys, xs = np.where(mask_t > 0.5) if len(ys) == 0: continue y1, y2 = ys.min(), ys.max() x1, x2 = xs.min(), xs.max() bh, bw = y2 - y1, x2 - x1 if bh <= 0 or bw <= 0: continue # Resize object to target box size obj_resized = cv2.resize( object_crop, (bw, bh), interpolation=cv2.INTER_LINEAR ) mask_resized = cv2.resize( object_mask.astype(np.float32), (bw, bh), interpolation=cv2.INTER_LINEAR ) mask_3ch = mask_resized[:, :, None] # Erase original position (simple fill with nearby bg) erase_mask = inpaint_masks[t] if erase_mask.sum() > 0: result[t] = _inpaint_simple(result[t], erase_mask) # Composite object at new position roi = result[t, y1:y2, x1:x2] result[t, y1:y2, x1:x2] = ( obj_resized * mask_3ch + roi * (1 - mask_3ch) ).astype(np.uint8) return result def _inpaint_simple(frame: np.ndarray, mask: np.ndarray) -> np.ndarray: """Simple telea inpainting for object removal""" import cv2 mask_uint8 = (mask * 255).astype(np.uint8) return cv2.inpaint(frame, mask_uint8, 3, cv2.INPAINT_TELEA)