Spaces:
Runtime error
Runtime error
| # stage2_vace.py | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| class VACEWrapper: | |
| def __init__(self, device="cuda"): | |
| from diffusers import WanImageToVideoPipeline | |
| from diffusers.utils import export_to_video | |
| import torch | |
| self.device = device | |
| self.pipe = WanImageToVideoPipeline.from_pretrained( | |
| "Wan-AI/Wan2.1-VACE-1.3B-diffusers", | |
| torch_dtype=torch.bfloat16, | |
| ).to(device) | |
| self.pipe.enable_model_cpu_offload() | |
| def synthesize( | |
| self, | |
| original_frames, | |
| synthesis_masks, | |
| inpaint_masks, | |
| first_frame_ref, | |
| text_prompt="", | |
| ): | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| T, orig_H, orig_W, _ = original_frames.shape | |
| # Round to nearest multiple of 16 (VACE requirement) | |
| H = (orig_H // 16) * 16 | |
| W = (orig_W // 16) * 16 | |
| if H != orig_H or W != orig_W: | |
| original_frames = np.stack([cv2.resize(f, (W, H)) for f in original_frames]) | |
| first_frame_ref = cv2.resize(first_frame_ref, (W, H)) | |
| synthesis_masks = np.stack([ | |
| cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in synthesis_masks | |
| ]) | |
| inpaint_masks = np.stack([ | |
| cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in inpaint_masks | |
| ]) | |
| video_pil = [Image.fromarray(f) for f in original_frames] | |
| combined = np.clip( | |
| synthesis_masks.astype(np.uint16) + inpaint_masks.astype(np.uint16), 0, 255 | |
| ).astype(np.uint8) | |
| mask_pil = [Image.fromarray(m) for m in combined] | |
| ref_pil = Image.fromarray(first_frame_ref) | |
| output = self.pipe( | |
| video=video_pil, | |
| mask=mask_pil, | |
| prompt=text_prompt, | |
| negative_prompt="static, blurry, low quality", | |
| reference_images=[ref_pil], | |
| num_frames=T, | |
| height=H, | |
| width=W, | |
| guidance_scale=5.0, | |
| num_inference_steps=25, | |
| ).frames[0] | |
| result = np.stack([np.array(f) for f in output], axis=0) | |
| # Restore original resolution | |
| if orig_H != H or orig_W != W: | |
| result = np.stack([cv2.resize(f, (orig_W, orig_H)) for f in result]) | |
| return result | |
| class SimpleCompositeStage2: | |
| """ | |
| Fallback Stage 2: simple alpha compositing. | |
| No diffusion model needed. | |
| Works for: clean background, simple objects. | |
| Quality: low but fast for debugging the pipeline. | |
| """ | |
| def synthesize( | |
| self, | |
| original_frames: np.ndarray, # [T, H, W, 3] | |
| synthesis_masks: np.ndarray, # [T, H, W] | |
| inpaint_masks: np.ndarray, # [T, H, W] | |
| object_crop: np.ndarray, # [H_obj, W_obj, 3] | |
| object_mask: np.ndarray, # [H_obj, W_obj] binary | |
| ) -> np.ndarray: | |
| """ | |
| Composite object into new positions using simple alpha blending. | |
| Useful for validating box trajectory before diffusion. | |
| """ | |
| import cv2 | |
| T, H, W, _ = original_frames.shape | |
| result = original_frames.copy() | |
| for t in range(T): | |
| # Find box from synthesis mask | |
| mask_t = synthesis_masks[t] | |
| ys, xs = np.where(mask_t > 0.5) | |
| if len(ys) == 0: | |
| continue | |
| y1, y2 = ys.min(), ys.max() | |
| x1, x2 = xs.min(), xs.max() | |
| bh, bw = y2 - y1, x2 - x1 | |
| if bh <= 0 or bw <= 0: | |
| continue | |
| # Resize object to target box size | |
| obj_resized = cv2.resize( | |
| object_crop, (bw, bh), | |
| interpolation=cv2.INTER_LINEAR | |
| ) | |
| mask_resized = cv2.resize( | |
| object_mask.astype(np.float32), (bw, bh), | |
| interpolation=cv2.INTER_LINEAR | |
| ) | |
| mask_3ch = mask_resized[:, :, None] | |
| # Erase original position (simple fill with nearby bg) | |
| erase_mask = inpaint_masks[t] | |
| if erase_mask.sum() > 0: | |
| result[t] = _inpaint_simple(result[t], erase_mask) | |
| # Composite object at new position | |
| roi = result[t, y1:y2, x1:x2] | |
| result[t, y1:y2, x1:x2] = ( | |
| obj_resized * mask_3ch + roi * (1 - mask_3ch) | |
| ).astype(np.uint8) | |
| return result | |
| def _inpaint_simple(frame: np.ndarray, mask: np.ndarray) -> np.ndarray: | |
| """Simple telea inpainting for object removal""" | |
| import cv2 | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| return cv2.inpaint(frame, mask_uint8, 3, cv2.INPAINT_TELEA) | |