File size: 4,606 Bytes
f3d0a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# stage2_vace.py
import numpy as np
import torch
from PIL import Image

class VACEWrapper:
    def __init__(self, device="cuda"):
        from diffusers import WanImageToVideoPipeline
        from diffusers.utils import export_to_video
        import torch

        self.device = device
        self.pipe = WanImageToVideoPipeline.from_pretrained(
            "Wan-AI/Wan2.1-VACE-1.3B-diffusers",
            torch_dtype=torch.bfloat16,
        ).to(device)
        self.pipe.enable_model_cpu_offload()

    
    def synthesize(
    self,
    original_frames,
    synthesis_masks,
    inpaint_masks,
    first_frame_ref,
    text_prompt="",
    ):
      import numpy as np
      import cv2
      import torch
      from PIL import Image

      T, orig_H, orig_W, _ = original_frames.shape

      # Round to nearest multiple of 16 (VACE requirement)
      H = (orig_H // 16) * 16
      W = (orig_W // 16) * 16

      if H != orig_H or W != orig_W:
          original_frames = np.stack([cv2.resize(f, (W, H)) for f in original_frames])
          first_frame_ref = cv2.resize(first_frame_ref, (W, H))
          synthesis_masks = np.stack([
              cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in synthesis_masks
          ])
          inpaint_masks = np.stack([
              cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in inpaint_masks
          ])

      video_pil = [Image.fromarray(f) for f in original_frames]
      combined  = np.clip(
          synthesis_masks.astype(np.uint16) + inpaint_masks.astype(np.uint16), 0, 255
      ).astype(np.uint8)
      mask_pil  = [Image.fromarray(m) for m in combined]
      ref_pil   = Image.fromarray(first_frame_ref)

      output = self.pipe(
          video=video_pil,
          mask=mask_pil,
          prompt=text_prompt,
          negative_prompt="static, blurry, low quality",
          reference_images=[ref_pil],
          num_frames=T,
          height=H,
          width=W,
          guidance_scale=5.0,
          num_inference_steps=25,
      ).frames[0]

      result = np.stack([np.array(f) for f in output], axis=0)

      # Restore original resolution
      if orig_H != H or orig_W != W:
          result = np.stack([cv2.resize(f, (orig_W, orig_H)) for f in result])

      return result




class SimpleCompositeStage2:
    """
    Fallback Stage 2: simple alpha compositing.
    No diffusion model needed.
    Works for: clean background, simple objects.
    Quality: low but fast for debugging the pipeline.
    """

    def synthesize(
        self,
        original_frames: np.ndarray,   # [T, H, W, 3]
        synthesis_masks: np.ndarray,   # [T, H, W]
        inpaint_masks: np.ndarray,     # [T, H, W]
        object_crop: np.ndarray,       # [H_obj, W_obj, 3]
        object_mask: np.ndarray,       # [H_obj, W_obj] binary
    ) -> np.ndarray:
        """
        Composite object into new positions using simple alpha blending.
        Useful for validating box trajectory before diffusion.
        """
        import cv2

        T, H, W, _ = original_frames.shape
        result = original_frames.copy()

        for t in range(T):
            # Find box from synthesis mask
            mask_t = synthesis_masks[t]
            ys, xs = np.where(mask_t > 0.5)
            if len(ys) == 0:
                continue

            y1, y2 = ys.min(), ys.max()
            x1, x2 = xs.min(), xs.max()
            bh, bw = y2 - y1, x2 - x1

            if bh <= 0 or bw <= 0:
                continue

            # Resize object to target box size
            obj_resized = cv2.resize(
                object_crop, (bw, bh),
                interpolation=cv2.INTER_LINEAR
            )
            mask_resized = cv2.resize(
                object_mask.astype(np.float32), (bw, bh),
                interpolation=cv2.INTER_LINEAR
            )
            mask_3ch = mask_resized[:, :, None]

            # Erase original position (simple fill with nearby bg)
            erase_mask = inpaint_masks[t]
            if erase_mask.sum() > 0:
                result[t] = _inpaint_simple(result[t], erase_mask)

            # Composite object at new position
            roi = result[t, y1:y2, x1:x2]
            result[t, y1:y2, x1:x2] = (
                obj_resized * mask_3ch + roi * (1 - mask_3ch)
            ).astype(np.uint8)

        return result


def _inpaint_simple(frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """Simple telea inpainting for object removal"""
    import cv2
    mask_uint8 = (mask * 255).astype(np.uint8)
    return cv2.inpaint(frame, mask_uint8, 3, cv2.INPAINT_TELEA)