File size: 8,428 Bytes
f3d0a26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# pipeline.py

import numpy as np
import torch
from utils.video_utils import load_video, save_video
from utils.box_utils import boxes_to_mask_sequence
from stage1_approx import stage1_linear, stage1_cotracker
from stage2_vace import VACEWrapper, SimpleCompositeStage2


class TRACEPrototype:

    def __init__(self, use_vace: bool = False, use_cotracker: bool = False):

        # ── Stage 2: Video Synthesis ──────────────────────────────────
        if use_vace:
            self.stage2 = VACEWrapper()
        else:
            self.stage2 = SimpleCompositeStage2()

        # ── CoTracker for Stage 1 ─────────────────────────────────────
        self.cotracker = None
        if use_cotracker:
            try:
                self.cotracker = torch.hub.load(
                    "facebookresearch/co-tracker",
                    "cotracker3_online"
                ).cuda()
                print("CoTracker loaded.")
            except Exception as e:
                print(f"CoTracker failed to load: {e}")
                print("Falling back to linear interpolation.")

        # ── SAM2 for object segmentation ─────────────────────────────
        self.sam2 = None
        try:
            from sam2.build_sam import build_sam2
            from sam2.sam2_image_predictor import SAM2ImagePredictor
            self.sam2 = SAM2ImagePredictor(
                build_sam2("sam2_hiera_large.pt")
            )
            print("SAM2 loaded.")
        except Exception as e:
            print(f"SAM2 not available: {e}")
            print("Will use box masks directly instead of segmentation.")

                # ── Qwen-Image-Edit for object insertion ──────────────────────
        self.qwen_edit_pipe = None
        try:
            from frame_editor import load_qwen_image_edit
            self.qwen_edit_pipe = load_qwen_image_edit(
                use_lightning=True, device="cuda"
            )
            print("Qwen-Image-Edit loaded.")
        except Exception as e:
            print(f"Qwen-Image-Edit not available: {e}")


    def run_motion_edit(
        self,
        video_path: str,
        keyboxes: dict,       # {frame_idx: [x1, y1, x2, y2]}
        text_prompt: str,
        output_path: str = None,
        frames: np.ndarray = None  # pass directly to avoid reloading
    ) -> np.ndarray:
        """
        Edit the trajectory of an existing object in the video.

        keyboxes must include:
          - frame 0: current object location (start)
          - at least one other frame: target location (end)
        """

        # Load video if frames not passed directly
        if frames is None:
            frames = load_video(video_path)
        T, H, W, _ = frames.shape

        # ── Stage 1: Compute target trajectory ───────────────────────
        if self.cotracker is not None:
            pred_boxes = stage1_cotracker(
                frames, keyboxes, self.cotracker
            )
        else:
            pred_boxes = stage1_linear(keyboxes, T)

        # ── Build masks ───────────────────────────────────────────────
        # Synthesis mask: where to PLACE the object (new trajectory)
        synthesis_masks = boxes_to_mask_sequence(pred_boxes, H, W)

        # Inpainting mask: where to ERASE the object (original position)
        # Use SAM2 for precise mask if available, else use box directly
        orig_box = keyboxes[0]
        if self.sam2 is not None:
            from frame_editor import segment_existing_object
            seg_mask = segment_existing_object(
                frames[0], orig_box, self.sam2
            )
            # Propagate original mask roughly using linear boxes
            orig_keyboxes = {0: orig_box}
            orig_boxes = stage1_linear(orig_keyboxes, T)
            inpaint_masks = boxes_to_mask_sequence(orig_boxes, H, W)
            # Refine frame 0 with SAM2 mask
            inpaint_masks[0] = seg_mask
        else:
            # Fallback: use box directly as inpaint mask
            orig_keyboxes = {0: orig_box}
            orig_boxes    = stage1_linear(orig_keyboxes, T)
            inpaint_masks = boxes_to_mask_sequence(orig_boxes, H, W)

        # ── Stage 2: Synthesize video ─────────────────────────────────
        if isinstance(self.stage2, VACEWrapper):
            result = self.stage2.synthesize(
                original_frames=frames,
                synthesis_masks=synthesis_masks,
                inpaint_masks=inpaint_masks,
                first_frame_ref=frames[0],
                text_prompt=text_prompt
            )
        else:
            # SimpleCompositeStage2: needs object crop
            x1, y1, x2, y2 = [int(v) for v in orig_box]
            obj_crop = frames[0, y1:y2, x1:x2]

            if self.sam2 is not None:
                obj_mask = seg_mask[y1:y2, x1:x2]
            else:
                obj_mask = np.ones(
                    (y2 - y1, x2 - x1), dtype=np.float32
                )

            result = self.stage2.synthesize(
                original_frames=frames,
                synthesis_masks=synthesis_masks,
                inpaint_masks=inpaint_masks,
                object_crop=obj_crop,
                object_mask=obj_mask
            )

        # ── Save if path provided ─────────────────────────────────────
        if output_path is not None:
            save_video(result, output_path)
            print(f"Saved to {output_path}")

        return result

    def run_object_insertion(
        self,
        video_path: str,
        object_description: str,
        keyboxes: dict,       # {frame_idx: [x1, y1, x2, y2]}
        text_prompt: str,
        output_path: str = None,
        frames: np.ndarray = None,
    ) -> np.ndarray:
        """
        Insert a new object into the video and animate it along a trajectory.
        Qwen-Image-Edit paints the object into frame 0 only.
        Stage 2 propagates it through all frames.
        """
        if frames is None:
            frames = load_video(video_path)
        T, H, W, _ = frames.shape

        # Stage 1: trajectory
        pred_boxes = stage1_linear(keyboxes, T)

        # Edit first frame with Qwen-Image-Edit
        if self.qwen_edit_pipe is not None:
            from frame_editor import insert_object_qwen_edit
            edited_first_frame = insert_object_qwen_edit(
                first_frame=frames[0],
                box=pred_boxes[0],
                object_description=object_description,
                pipe=self.qwen_edit_pipe,
            )
        else:
            print("Qwen-Image-Edit not available, using original first frame.")
            edited_first_frame = frames[0]

        # Synthesis masks: where to place object along trajectory
        synthesis_masks = boxes_to_mask_sequence(pred_boxes, H, W)
        # No inpaint masks needed β€” nothing to erase for insertion
        inpaint_masks = np.zeros((T, H, W), dtype=np.uint8)

        # Stage 2
        if isinstance(self.stage2, VACEWrapper):
            result = self.stage2.synthesize(
                original_frames=frames,
                synthesis_masks=synthesis_masks,
                inpaint_masks=inpaint_masks,
                first_frame_ref=edited_first_frame,
                text_prompt=text_prompt,
            )
        else:
            x1, y1, x2, y2 = [int(v) for v in pred_boxes[0]]
            obj_crop = edited_first_frame[y1:y2, x1:x2]
            obj_mask = np.ones((y2 - y1, x2 - x1), dtype=np.float32)

            result = self.stage2.synthesize(
                original_frames=frames,
                synthesis_masks=synthesis_masks,
                inpaint_masks=inpaint_masks,
                object_crop=obj_crop,
                object_mask=obj_mask,
            )

        if output_path is not None:
            save_video(result, output_path)
            print(f"Saved to {output_path}")

        return result