ObjectInsertion / pipeline.py
Leema Krishna Murali
Initial commit
f3d0a26
# pipeline.py
import numpy as np
import torch
from utils.video_utils import load_video, save_video
from utils.box_utils import boxes_to_mask_sequence
from stage1_approx import stage1_linear, stage1_cotracker
from stage2_vace import VACEWrapper, SimpleCompositeStage2
class TRACEPrototype:
def __init__(self, use_vace: bool = False, use_cotracker: bool = False):
# ── Stage 2: Video Synthesis ──────────────────────────────────
if use_vace:
self.stage2 = VACEWrapper()
else:
self.stage2 = SimpleCompositeStage2()
# ── CoTracker for Stage 1 ─────────────────────────────────────
self.cotracker = None
if use_cotracker:
try:
self.cotracker = torch.hub.load(
"facebookresearch/co-tracker",
"cotracker3_online"
).cuda()
print("CoTracker loaded.")
except Exception as e:
print(f"CoTracker failed to load: {e}")
print("Falling back to linear interpolation.")
# ── SAM2 for object segmentation ─────────────────────────────
self.sam2 = None
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
self.sam2 = SAM2ImagePredictor(
build_sam2("sam2_hiera_large.pt")
)
print("SAM2 loaded.")
except Exception as e:
print(f"SAM2 not available: {e}")
print("Will use box masks directly instead of segmentation.")
# ── Qwen-Image-Edit for object insertion ──────────────────────
self.qwen_edit_pipe = None
try:
from frame_editor import load_qwen_image_edit
self.qwen_edit_pipe = load_qwen_image_edit(
use_lightning=True, device="cuda"
)
print("Qwen-Image-Edit loaded.")
except Exception as e:
print(f"Qwen-Image-Edit not available: {e}")
def run_motion_edit(
self,
video_path: str,
keyboxes: dict, # {frame_idx: [x1, y1, x2, y2]}
text_prompt: str,
output_path: str = None,
frames: np.ndarray = None # pass directly to avoid reloading
) -> np.ndarray:
"""
Edit the trajectory of an existing object in the video.
keyboxes must include:
- frame 0: current object location (start)
- at least one other frame: target location (end)
"""
# Load video if frames not passed directly
if frames is None:
frames = load_video(video_path)
T, H, W, _ = frames.shape
# ── Stage 1: Compute target trajectory ───────────────────────
if self.cotracker is not None:
pred_boxes = stage1_cotracker(
frames, keyboxes, self.cotracker
)
else:
pred_boxes = stage1_linear(keyboxes, T)
# ── Build masks ───────────────────────────────────────────────
# Synthesis mask: where to PLACE the object (new trajectory)
synthesis_masks = boxes_to_mask_sequence(pred_boxes, H, W)
# Inpainting mask: where to ERASE the object (original position)
# Use SAM2 for precise mask if available, else use box directly
orig_box = keyboxes[0]
if self.sam2 is not None:
from frame_editor import segment_existing_object
seg_mask = segment_existing_object(
frames[0], orig_box, self.sam2
)
# Propagate original mask roughly using linear boxes
orig_keyboxes = {0: orig_box}
orig_boxes = stage1_linear(orig_keyboxes, T)
inpaint_masks = boxes_to_mask_sequence(orig_boxes, H, W)
# Refine frame 0 with SAM2 mask
inpaint_masks[0] = seg_mask
else:
# Fallback: use box directly as inpaint mask
orig_keyboxes = {0: orig_box}
orig_boxes = stage1_linear(orig_keyboxes, T)
inpaint_masks = boxes_to_mask_sequence(orig_boxes, H, W)
# ── Stage 2: Synthesize video ─────────────────────────────────
if isinstance(self.stage2, VACEWrapper):
result = self.stage2.synthesize(
original_frames=frames,
synthesis_masks=synthesis_masks,
inpaint_masks=inpaint_masks,
first_frame_ref=frames[0],
text_prompt=text_prompt
)
else:
# SimpleCompositeStage2: needs object crop
x1, y1, x2, y2 = [int(v) for v in orig_box]
obj_crop = frames[0, y1:y2, x1:x2]
if self.sam2 is not None:
obj_mask = seg_mask[y1:y2, x1:x2]
else:
obj_mask = np.ones(
(y2 - y1, x2 - x1), dtype=np.float32
)
result = self.stage2.synthesize(
original_frames=frames,
synthesis_masks=synthesis_masks,
inpaint_masks=inpaint_masks,
object_crop=obj_crop,
object_mask=obj_mask
)
# ── Save if path provided ─────────────────────────────────────
if output_path is not None:
save_video(result, output_path)
print(f"Saved to {output_path}")
return result
def run_object_insertion(
self,
video_path: str,
object_description: str,
keyboxes: dict, # {frame_idx: [x1, y1, x2, y2]}
text_prompt: str,
output_path: str = None,
frames: np.ndarray = None,
) -> np.ndarray:
"""
Insert a new object into the video and animate it along a trajectory.
Qwen-Image-Edit paints the object into frame 0 only.
Stage 2 propagates it through all frames.
"""
if frames is None:
frames = load_video(video_path)
T, H, W, _ = frames.shape
# Stage 1: trajectory
pred_boxes = stage1_linear(keyboxes, T)
# Edit first frame with Qwen-Image-Edit
if self.qwen_edit_pipe is not None:
from frame_editor import insert_object_qwen_edit
edited_first_frame = insert_object_qwen_edit(
first_frame=frames[0],
box=pred_boxes[0],
object_description=object_description,
pipe=self.qwen_edit_pipe,
)
else:
print("Qwen-Image-Edit not available, using original first frame.")
edited_first_frame = frames[0]
# Synthesis masks: where to place object along trajectory
synthesis_masks = boxes_to_mask_sequence(pred_boxes, H, W)
# No inpaint masks needed β€” nothing to erase for insertion
inpaint_masks = np.zeros((T, H, W), dtype=np.uint8)
# Stage 2
if isinstance(self.stage2, VACEWrapper):
result = self.stage2.synthesize(
original_frames=frames,
synthesis_masks=synthesis_masks,
inpaint_masks=inpaint_masks,
first_frame_ref=edited_first_frame,
text_prompt=text_prompt,
)
else:
x1, y1, x2, y2 = [int(v) for v in pred_boxes[0]]
obj_crop = edited_first_frame[y1:y2, x1:x2]
obj_mask = np.ones((y2 - y1, x2 - x1), dtype=np.float32)
result = self.stage2.synthesize(
original_frames=frames,
synthesis_masks=synthesis_masks,
inpaint_masks=inpaint_masks,
object_crop=obj_crop,
object_mask=obj_mask,
)
if output_path is not None:
save_video(result, output_path)
print(f"Saved to {output_path}")
return result