ObjectInsertion / stage2_vace.py
Leema Krishna Murali
Initial commit
f3d0a26
# stage2_vace.py
import numpy as np
import torch
from PIL import Image
class VACEWrapper:
def __init__(self, device="cuda"):
from diffusers import WanImageToVideoPipeline
from diffusers.utils import export_to_video
import torch
self.device = device
self.pipe = WanImageToVideoPipeline.from_pretrained(
"Wan-AI/Wan2.1-VACE-1.3B-diffusers",
torch_dtype=torch.bfloat16,
).to(device)
self.pipe.enable_model_cpu_offload()
def synthesize(
self,
original_frames,
synthesis_masks,
inpaint_masks,
first_frame_ref,
text_prompt="",
):
import numpy as np
import cv2
import torch
from PIL import Image
T, orig_H, orig_W, _ = original_frames.shape
# Round to nearest multiple of 16 (VACE requirement)
H = (orig_H // 16) * 16
W = (orig_W // 16) * 16
if H != orig_H or W != orig_W:
original_frames = np.stack([cv2.resize(f, (W, H)) for f in original_frames])
first_frame_ref = cv2.resize(first_frame_ref, (W, H))
synthesis_masks = np.stack([
cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in synthesis_masks
])
inpaint_masks = np.stack([
cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in inpaint_masks
])
video_pil = [Image.fromarray(f) for f in original_frames]
combined = np.clip(
synthesis_masks.astype(np.uint16) + inpaint_masks.astype(np.uint16), 0, 255
).astype(np.uint8)
mask_pil = [Image.fromarray(m) for m in combined]
ref_pil = Image.fromarray(first_frame_ref)
output = self.pipe(
video=video_pil,
mask=mask_pil,
prompt=text_prompt,
negative_prompt="static, blurry, low quality",
reference_images=[ref_pil],
num_frames=T,
height=H,
width=W,
guidance_scale=5.0,
num_inference_steps=25,
).frames[0]
result = np.stack([np.array(f) for f in output], axis=0)
# Restore original resolution
if orig_H != H or orig_W != W:
result = np.stack([cv2.resize(f, (orig_W, orig_H)) for f in result])
return result
class SimpleCompositeStage2:
"""
Fallback Stage 2: simple alpha compositing.
No diffusion model needed.
Works for: clean background, simple objects.
Quality: low but fast for debugging the pipeline.
"""
def synthesize(
self,
original_frames: np.ndarray, # [T, H, W, 3]
synthesis_masks: np.ndarray, # [T, H, W]
inpaint_masks: np.ndarray, # [T, H, W]
object_crop: np.ndarray, # [H_obj, W_obj, 3]
object_mask: np.ndarray, # [H_obj, W_obj] binary
) -> np.ndarray:
"""
Composite object into new positions using simple alpha blending.
Useful for validating box trajectory before diffusion.
"""
import cv2
T, H, W, _ = original_frames.shape
result = original_frames.copy()
for t in range(T):
# Find box from synthesis mask
mask_t = synthesis_masks[t]
ys, xs = np.where(mask_t > 0.5)
if len(ys) == 0:
continue
y1, y2 = ys.min(), ys.max()
x1, x2 = xs.min(), xs.max()
bh, bw = y2 - y1, x2 - x1
if bh <= 0 or bw <= 0:
continue
# Resize object to target box size
obj_resized = cv2.resize(
object_crop, (bw, bh),
interpolation=cv2.INTER_LINEAR
)
mask_resized = cv2.resize(
object_mask.astype(np.float32), (bw, bh),
interpolation=cv2.INTER_LINEAR
)
mask_3ch = mask_resized[:, :, None]
# Erase original position (simple fill with nearby bg)
erase_mask = inpaint_masks[t]
if erase_mask.sum() > 0:
result[t] = _inpaint_simple(result[t], erase_mask)
# Composite object at new position
roi = result[t, y1:y2, x1:x2]
result[t, y1:y2, x1:x2] = (
obj_resized * mask_3ch + roi * (1 - mask_3ch)
).astype(np.uint8)
return result
def _inpaint_simple(frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Simple telea inpainting for object removal"""
import cv2
mask_uint8 = (mask * 255).astype(np.uint8)
return cv2.inpaint(frame, mask_uint8, 3, cv2.INPAINT_TELEA)