DragStream / app.py
bowmanchow's picture
change flash-attn version
0346ed8
Raw
History Blame Contribute Delete
49 kB
import os
import subprocess
import sys
def install_flash_attn():
"""Auto-detect CUDA, PyTorch, Python versions and install the matching pre-built flash-attn wheel."""
import torch
# Python version (e.g., "cp310")
py_major = sys.version_info.major
py_minor = sys.version_info.minor
cp_tag = f"cp{py_major}{py_minor}"
# PyTorch version (e.g., "2.4" from "2.4.0")
torch_version = torch.__version__.split("+")[0] # strip +cu121 if present
torch_major_minor = ".".join(torch_version.split(".")[:2]) # "2.4"
# CUDA version (e.g., "cu124" from "12.4")
cuda_version = torch.version.cuda
if cuda_version is None:
print("No CUDA detected, skipping flash-attn installation.")
return
cuda_major_minor = cuda_version.replace(".", "") # "124"
# flash-attn wheels use shortened CUDA tags like "cu12" (just major) or "cu121", "cu124"
# Check available tags: most wheels use "cu12" for any 12.x
cuda_tag_short = f"cu{cuda_version.split('.')[0]}" # "cu12"
# CXX11 ABI
cxx11_abi = torch._C._GLIBCXX_USE_CXX11_ABI
abi_tag = "cxx11abiTRUE" if cxx11_abi else "cxx11abiFALSE"
# flash-attn version to install
flash_attn_version = "2.8.3"
# Construct the wheel filename
# Example: flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
wheel_name = (
f"flash_attn-{flash_attn_version}+"
f"{cuda_tag_short}torch{torch_major_minor}{abi_tag}-"
f"{cp_tag}-{cp_tag}-linux_x86_64.whl"
)
base_url = f"https://github.com/Dao-AILab/flash-attention/releases/download/v{flash_attn_version}"
wheel_url = f"{base_url}/{wheel_name}"
print(f"Detected environment:")
print(f" Python: {py_major}.{py_minor} ({cp_tag})")
print(f" PyTorch: {torch_version} (torch{torch_major_minor})")
print(f" CUDA: {cuda_version} ({cuda_tag_short})")
print(f" CXX11 ABI: {cxx11_abi} ({abi_tag})")
print(f" Wheel URL: {wheel_url}")
result = subprocess.run(
[sys.executable, "-m", "pip", "install", wheel_url],
capture_output=True,
text=True,
)
if result.returncode != 0:
print(f"Pre-built wheel failed:\n{result.stderr}")
print("Falling back to building flash-attn from source (this may take a while)...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation"],
check=True,
)
else:
print("flash-attn installed successfully from pre-built wheel.")
print(result.stdout)
install_flash_attn()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import copy
import torch
from torchvision.io import write_video
from torch.utils.data import Dataset
import numpy as np
from pathlib import Path
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
from b_spline import build_clamped_bspline, equidistant_points_on_spline
torch.set_grad_enabled(False)
from palette import _palette
import gradio as gr
import numpy as np
from scipy import ndimage
from PIL import Image
import os
from pathlib import Path
import cv2
# from sam_segment import predict_masks_with_sam
from segment_anything import SamPredictor, sam_model_registry
from tensor_utils import (
image_to_pil,
image_to_np,
bbox_from_mask,
draw_bbox_on_image,
draw_mask_on_image,
draw_points_on_image,
draw_lines_on_image,
trajectory_interpolate,
dilate_mask,
dilate_masks,
)
from optimize_utils import (
MultiTrajectory,
Trajectory,
)
import sys
from utils.misc import set_seed
from stream_inference_wrapper import StreamInferenceWrapper
from stream_drag_inference_wrapper import StreamDragInferenceWrapper
from utils.dataset import TextDataset
from video_operations import generate_video, optimize_video
# from compute_objmc import visualize_ground_truth_from_trajectory_file
def extract_layer_as_mask(image_editor, layer_index=0):
if len(image_editor["layers"]) > layer_index:
layer = image_editor["layers"][layer_index]
return image_to_np(layer.convert("L")) > 0
return None
def apply_mask_to_image(
mask: np.ndarray | None,
image: np.ndarray | Image.Image,
mask_color: list[int],
alpha: float,
) -> None | Image.Image:
if image is None:
return None
if mask is None:
return image_to_pil(image)
mask = np.array(mask)
new_image = draw_mask_on_image(
image,
mask,
mask_color=mask_color,
alpha=alpha,
)
return new_image
def apply_movable_mask_to_image(
mask: np.ndarray | None,
image: np.ndarray | Image.Image,
):
return apply_mask_to_image(
mask=mask,
image=image,
mask_color=(255, 255, 255),
alpha=0.35,
)
def apply_target_mask_to_image(
mask: np.ndarray | None,
image: np.ndarray | Image.Image,
):
return apply_mask_to_image(
mask=mask,
image=image,
mask_color=(255, 64, 64),
alpha=0.5,
)
def get_video_last_frame(
# video: Optional[torch.Tensor], # None or (t, h, w, c)
video_path: str,
):
"""
Loads the last frame from a video.
Returns:
Image: The last frame as a PIL Image.
"""
print(f"Getting last frame from video: {video_path = }")
if video_path is None:
return None
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Failed to open video: {video_path}")
return None
try:
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if frame_count <= 0:
print(f"Video has non-positive frame count: {frame_count}")
cap.release()
return None
# Try direct seek to last frame
target_index = frame_count - 1
cap.set(cv2.CAP_PROP_POS_FRAMES, target_index)
ret, frame = cap.read()
# Fallback: iterate to last frame if random access failed
if (not ret) or frame is None:
print("Direct seek failed, iterating through frames...")
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
last_valid = None
while True:
ret_i, frame_i = cap.read()
if not ret_i:
break
last_valid = frame_i
frame = last_valid
if frame is None:
print("Could not retrieve last frame.")
return None
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
last_frame_image = Image.fromarray(frame)
return last_frame_image
except Exception as e:
print(f"Error extracting last frame: {e}")
return None
finally:
cap.release()
def sam_predict_segmentation(
sam_predictor: SamPredictor,
origin_image: Image.Image | np.ndarray,
restriction_mask: np.ndarray, # (h, w), bool
click_points: list[tuple[int, int]],
previous_sam_logits: np.ndarray | None, # (3, 256, 256)
):
# print(f"{restriction_mask.shape = }")
origin_image_np = image_to_np(origin_image)
# print(f"{origin_image_np.shape = }")
sam_predictor.set_image(origin_image_np)
if previous_sam_logits is not None:
print(f"{previous_sam_logits.shape = }")
else:
print(f"{previous_sam_logits = }")
masks, scores, logits = sam_predictor.predict(
point_coords=np.array(click_points),
point_labels=np.ones((len(click_points),)),
mask_input=(previous_sam_logits[0:1] if previous_sam_logits is not None else None),
multimask_output=True,
)
# mask: np.ndarray
# scores: np.ndarray
# logits: np.ndarray
# print(f"{masks.shape = }") # (3, 480, 832)
# print(f"{logits.shape = }") # (3, 256, 256)
mask = masks[0]
mask *= restriction_mask
logits *= cv2.resize(
restriction_mask.astype(np.uint8),
dsize=(256, 256),
interpolation=cv2.INTER_LINEAR,
)
return mask, logits
def sam_predict_segmentation_wrapper(
sam_predictor: SamPredictor,
original_image: Image.Image | np.ndarray,
restriction_mask: np.ndarray | None,
previous_click_points: list[tuple[int, int]],
previous_sam_logits: np.ndarray | None,
bypass_sam_model: bool,
evt: gr.SelectData,
):
# print(f"{restriction_mask = }")
original_image = image_to_pil(original_image).convert("RGB")
if restriction_mask is None:
labeled_restriction_mask = np.zeros(
(original_image.height, original_image.width), dtype=np.int32
)
else:
labeled_restriction_mask, _ = ndimage.label(restriction_mask, structure=np.ones((3, 3)))
# print(f"{labeled_restriction_mask = }")
current_click_label = labeled_restriction_mask[evt.index[1], evt.index[0]]
# print(f"{current_click_label = }")
if current_click_label == 0:
selected_component_mask = np.zeros_like(labeled_restriction_mask, dtype=bool)
else:
selected_component_mask = labeled_restriction_mask == current_click_label
# print(f"{selected_component_mask = }")
if bypass_sam_model:
click_points = [evt.index]
mask = selected_component_mask
logits = None
else:
click_points = previous_click_points + [evt.index]
mask, logits = sam_predict_segmentation(
sam_predictor=sam_predictor,
origin_image=original_image,
restriction_mask=selected_component_mask,
click_points=click_points,
previous_sam_logits=previous_sam_logits,
)
return mask, click_points, logits
def draw_all_sam_masks(image: Image.Image | None, mask_list: list[np.ndarray]):
if image is None:
return None
if len(mask_list) == 0:
pass
else:
for mask_idx, mask in enumerate(mask_list):
image = draw_mask_on_image(
image,
mask,
mask_color=tuple(_palette[mask_idx + 1]),
alpha=0.65,
)
return image
def draw_sam_mask_wrapper(
original_image,
movable_mask,
current_mask: np.ndarray | None,
previous_masks: list[np.ndarray],
click_points: list[tuple[int, int]],
):
image = apply_movable_mask_to_image(
image=original_image,
mask=movable_mask,
)
if image is None:
return None
image = draw_all_sam_masks(
image,
previous_masks + ([current_mask] if current_mask is not None else []),
)
image = draw_points_on_image(
image,
click_points,
color=[(0, 255, 0, 255) for l in click_points],
radius=5,
)
return image
def save_sam_masks(
current_mask: np.ndarray | None,
previous_masks: list[np.ndarray],
):
new_masks = previous_masks + ([current_mask] if current_mask is not None else [])
return None, new_masks, [], None
def select_target_sam_mask(
masks_list: list[np.ndarray],
evt: gr.SelectData,
):
is_match_mask = False
for mask_index, sam_mask in enumerate(masks_list):
# check if evt point in sam_mask
if sam_mask[evt.index[1], evt.index[0]]:
is_match_mask = True
break
if not is_match_mask:
print(f"Mask not found for {evt.index = }")
mask_index = -1
return mask_index
def draw_rotation_trajectory(
image,
points,
):
image = draw_points_on_image(
image,
[points[0]],
color="green",
radius=15,
)
if len(points) > 1:
image = draw_points_on_image(
image,
points[1:],
color=[
(
255 - int(float(i) / len(points[1:]) * 255.0),
64,
int(float(i) / len(points[1:]) * 255.0),
255,
)
for i in range(len(points[1:]))
],
radius=5,
)
for point in points[1:]:
image = draw_lines_on_image(
image,
[points[0], point],
color="green",
width=3,
)
return image
def draw_translation_trajectory(
image,
points,
control_points: list[tuple[int, int]] = [],
is_draw_control_points: bool = True,
):
if len(points) == 1:
image = draw_points_on_image(
image,
points,
color=[(255, 64, 0, 255)],
radius=6,
)
return image
if is_draw_control_points and (len(control_points) >= 2):
image = draw_points_on_image(
image,
control_points,
color=[(0, 255, 0, 255) for _ in control_points],
radius=3,
)
image = draw_lines_on_image(
image,
control_points,
color=[(0, 255, 0, 255) for _ in control_points],
width=2,
)
image = draw_lines_on_image(
image,
points,
color=[
(
255 - int(float(i) / len(points[1:]) * 255.0),
64,
int(float(i) / len(points[1:]) * 255.0),
255,
)
for i in range(len(points))
],
width=4,
)
image = draw_points_on_image(
image,
points,
color=[
(
255 - int(float(i) / len(points[1:]) * 255.0),
64,
int(float(i) / len(points[1:]) * 255.0),
255,
)
for i in range(len(points))
],
radius=6,
)
return image
def draw_all_trajectories(
image,
trajectory: MultiTrajectory,
is_draw_control_points: bool = True,
):
print(
f"""
draw_all_trajectories:
"""
)
if trajectory.trajectories is None:
return image
for traj in trajectory.trajectories:
if traj.original_trajectory is None:
continue
original_traj = traj.original_trajectory
if original_traj["is_rotation"]:
image = draw_rotation_trajectory(image, original_traj["points"])
else:
image = draw_translation_trajectory(
image,
original_traj["points"],
original_traj.get("control_points", []),
is_draw_control_points=is_draw_control_points,
)
return image
def draw_trajectory_image(
original_image,
movable_mask,
mask_index,
masks_list: list[np.ndarray],
trajectory: MultiTrajectory,
is_draw_bbox: bool = True,
is_draw_control_points: bool = True,
):
print(
f"""
draw_trajectory_image:
{mask_index = }
"""
)
image = apply_movable_mask_to_image(
mask=movable_mask,
image=original_image,
)
image = draw_all_sam_masks(image, masks_list)
if (
(mask_index is not None)
and (mask_index >= 0)
and (mask_index < len(masks_list))
and is_draw_bbox
):
image = draw_bbox_on_image(image, bbox_from_mask(masks_list[mask_index]))
image = draw_all_trajectories(
image,
trajectory,
is_draw_control_points=is_draw_control_points,
)
return image
def update_trajectory(
trajectory: MultiTrajectory,
mask_index: int,
drag_animation_select: str,
translate_rotate_select: str,
evt: gr.SelectData,
):
print(f"update_trajectory")
# Work on a deep copy so Gradio sees a new object
trajectory = copy.deepcopy(trajectory)
if mask_index < 0:
print(f"Invalid mask_index: {mask_index}")
return trajectory
# print(f"{evt.index = }")
x_center, y_center = evt.index # evt.value is (x, y)
clicked_point = (x_center, y_center)
print(f"{clicked_point = }")
# Ensure trajectories list is large enough
while len(trajectory.trajectories) <= mask_index:
trajectory.trajectories.append(Trajectory())
existing_traj_obj = trajectory.trajectories[mask_index]
if existing_traj_obj.original_trajectory is not None:
current_trajectory = dict(existing_traj_obj.original_trajectory)
else:
current_trajectory = {}
if translate_rotate_select == "Translation":
current_trajectory["is_rotation"] = False
# Append clicked control point
control_points = current_trajectory.get("control_points", [])
control_points = control_points + [clicked_point]
# Drag vs Animation behavior
if drag_animation_select == "Drag":
# Restrict to last two control points, sample exactly 2 points
if len(control_points) > 2:
control_points = [clicked_point]
num_traj_points = 2
elif drag_animation_select == "Animation":
# No restriction on control points, sample N = 1 + 3 * block_number
num_traj_points = 1 + 3 * int(trajectory.block_number)
else:
raise ValueError(f"Invalid drag_animation_select: {drag_animation_select}")
current_trajectory["control_points"] = control_points
# Compute trajectory points along BSpline (or pad if not enough controls)
if len(control_points) < 2:
sampled_pts = [control_points[0]] * num_traj_points
else:
spline = build_clamped_bspline(control_points, degree=3)
pts = equidistant_points_on_spline(spline, num_points=num_traj_points, grid=8000)
sampled_pts = [(int(round(px)), int(round(py))) for px, py in pts]
current_trajectory["points"] = sampled_pts
elif translate_rotate_select == "Rotation":
current_trajectory["is_rotation"] = True
# Initialize if missing, else apply 3-point logic
if "points" not in current_trajectory or current_trajectory["points"] is None:
current_trajectory["points"] = [clicked_point]
else:
pts = current_trajectory["points"] + [clicked_point]
# If about to exceed 3, reset to the new point
if len(pts) > 3:
current_trajectory["points"] = [clicked_point]
# If less than 3, just append
elif len(pts) < 3:
current_trajectory["points"] = pts
else:
# len(pts) == 3: pts[0] is rotation center
if drag_animation_select == "Animation":
first = trajectory_interpolate(pts[1:], scale=int(trajectory.block_number))
second = trajectory_interpolate(first, scale=3)
current_trajectory["points"] = pts[0:1] + second
else:
# Drag: do not interpolate
current_trajectory["points"] = pts
else:
raise ValueError("Invalid translation/rotation selection")
# Update the Trajectory object in-place (recomputes block_trajectories)
existing_traj_obj.set_original_trajectory(current_trajectory)
# print(f"{trajectory = }")
return trajectory
def save_trajectory(
save_dir: Path,
saved_trajectory: MultiTrajectory,
original_image: Image.Image,
current_block_index: int,
masks: list[np.ndarray],
):
print(f"save_trajectory")
print(f"{save_dir = }")
print(f"{saved_trajectory = }")
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
drag_animation_select = saved_trajectory.drag_or_animation_select or "Drag"
save_prefix = f"block_{current_block_index}_{drag_animation_select}"
# Use MultiTrajectory's save method
saved_trajectory.save(
save_dir=save_dir,
prefix=save_prefix,
)
# Save the trajectory image
trajectory_image = draw_trajectory_image(
original_image=original_image,
movable_mask=saved_trajectory.movable_mask,
mask_index=None,
masks_list=masks,
trajectory=saved_trajectory,
is_draw_bbox=False,
is_draw_control_points=False,
)
trajectory_image.save(save_dir / f"{save_prefix}_trajectory.png")
def clear_current_trajectory(
idx: int,
trajectory: MultiTrajectory,
):
trajectory = copy.deepcopy(trajectory)
"""Clear the trajectory at the given mask index."""
try:
idx_int = int(idx)
except Exception:
return trajectory
if not trajectory.trajectories:
return trajectory
if idx_int < 0 or idx_int >= len(trajectory.trajectories):
return trajectory
# Reset this trajectory (keep the mask)
mask = trajectory.trajectories[idx_int].mask
trajectory.trajectories[idx_int] = Trajectory(mask=mask)
return trajectory
def clear_all_trajectories(
trajectory: MultiTrajectory,
):
trajectory = copy.deepcopy(trajectory)
"""Clear all trajectories but keep the masks."""
if trajectory.trajectories is not None:
for i in range(len(trajectory.trajectories)):
mask = trajectory.trajectories[i].mask
trajectory.trajectories[i] = Trajectory(mask=mask)
return trajectory
def sync_trajectory_masks(saved_trajectory: MultiTrajectory, dilated_masks: list[np.ndarray]):
"""Resize saved_trajectory.trajectories to match the number of dilated masks,
and update each Trajectory.mask with the corresponding dilated mask."""
saved_trajectory = copy.deepcopy(saved_trajectory)
current_len = len(saved_trajectory.trajectories)
target_len = len(dilated_masks) if dilated_masks else 0
if target_len > current_len:
# Expand: append new empty Trajectory objects
for _ in range(target_len - current_len):
saved_trajectory.trajectories.append(Trajectory())
elif target_len < current_len:
# Shrink: truncate
saved_trajectory.trajectories = saved_trajectory.trajectories[:target_len]
# Update each Trajectory.mask
for i, mask in enumerate(dilated_masks):
saved_trajectory.trajectories[i].mask = mask
return saved_trajectory
def add_listeners_to_trajectory(
saved_trajectory: MultiTrajectory,
prompt_box: gr.Textbox,
trajectory_block_number_slider: gr.Slider,
drag_animation_select: gr.Dropdown,
movable_area_mask: gr.State,
dilated_saved_sam_predicted_masks: gr.State,
):
# Sync prompt into saved_trajectory when prompt_box changes
def sync_trajectory_prompt(saved_trajectory: MultiTrajectory, prompt: str):
saved_trajectory.prompt = prompt
return saved_trajectory
prompt_box.change(
fn=sync_trajectory_prompt,
inputs=[saved_trajectory, prompt_box],
outputs=saved_trajectory,
trigger_mode="always_last",
)
# Sync block_number into saved_trajectory when trajectory_block_number_slider changes
def sync_trajectory_block_number(saved_trajectory: MultiTrajectory, block_number: int):
saved_trajectory.block_number = block_number
return saved_trajectory
trajectory_block_number_slider.change(
fn=sync_trajectory_block_number,
inputs=[saved_trajectory, trajectory_block_number_slider],
outputs=saved_trajectory,
trigger_mode="always_last",
)
# Sync drag_or_animation_select into saved_trajectory when drag_animation_select changes
def sync_trajectory_drag_animation(
saved_trajectory: MultiTrajectory, drag_animation_select: str
):
saved_trajectory.drag_or_animation_select = drag_animation_select
return saved_trajectory
drag_animation_select.change(
fn=sync_trajectory_drag_animation,
inputs=[saved_trajectory, drag_animation_select],
outputs=saved_trajectory,
trigger_mode="always_last",
)
# Sync movable_area_mask into saved_trajectory when it changes
def sync_trajectory_movable_mask(saved_trajectory: MultiTrajectory, movable_mask):
saved_trajectory.movable_mask = movable_mask
return saved_trajectory
movable_area_mask.change(
fn=sync_trajectory_movable_mask,
inputs=[saved_trajectory, movable_area_mask],
outputs=saved_trajectory,
trigger_mode="always_last",
)
# Sync dilated_saved_sam_predicted_masks into saved_trajectory when it changes
dilated_saved_sam_predicted_masks.change(
fn=sync_trajectory_masks,
inputs=[saved_trajectory, dilated_saved_sam_predicted_masks],
outputs=saved_trajectory,
trigger_mode="always_last",
)
def create_generate_video_ui(
label_root: str | Path,
text_dataset: Dataset,
video_path: gr.State,
stream_drag_inference: StreamDragInferenceWrapper,
output_dir: str | Path,
original_image: gr.State,
):
with gr.Row():
prompt_index_number = gr.Number(
label="Step 1: Select Prompt Index Here",
interactive=True,
scale=1,
)
prompt_box = gr.Textbox(
label="Prompt",
interactive=True,
scale=3,
)
save_dir_text_box = gr.Textbox(
label="Save Directory",
interactive=False,
scale=1,
)
prompt_index_number.change(
fn=lambda prompt_index_number: text_dataset[prompt_index_number]["prompts"],
inputs=prompt_index_number,
outputs=[
prompt_box,
],
)
gr.on(
triggers=[
prompt_box.change,
],
fn=lambda prompt_index_number, prompt: str(
label_root / f"{prompt_index_number:04d}-{prompt[:50].replace(' ', '_')}"
),
inputs=[prompt_index_number, prompt_box],
outputs=save_dir_text_box,
trigger_mode="always_last",
)
with gr.Row():
current_block_index_slider = gr.Slider(
label="Current Start Block Index",
minimum=0,
maximum=50,
value=0,
step=1,
)
generate_block_number_slider = gr.Slider(
label="Step 2: Select Number of Blocks to Generate",
minimum=1,
maximum=50,
value=2,
step=1,
)
with gr.Row():
begin_generate_button = gr.Button(
value="Step 3: Click Here to Begin Generation",
)
refresh_video_display_button = gr.Button(value="Refresh Video Display")
with gr.Row():
video_display = gr.Video()
begin_generate_button.click(
fn=lambda pi, p, sbi, bn: generate_video(
stream_inference_model=stream_drag_inference,
prompt_index=pi,
prompt=p,
start_block_index=sbi,
block_number=bn,
output_dir=output_dir,
),
inputs=[
prompt_index_number,
prompt_box,
current_block_index_slider,
generate_block_number_slider,
],
outputs=[video_path, current_block_index_slider],
)
gr.on(
triggers=[
refresh_video_display_button.click,
video_path.change,
],
fn=lambda video_path: video_path,
inputs=video_path,
outputs=video_display,
trigger_mode="always_last",
)
with gr.Row():
get_last_frame_button = gr.Button(
value="Get Last Frame (Normally No Need to Click This, In Case the Last Frame Fails to Update due to Gradio Bug)",
)
gr.on(
triggers=[
video_path.change,
get_last_frame_button.click,
],
fn=get_video_last_frame,
inputs=video_path,
outputs=original_image,
)
return (
prompt_index_number,
save_dir_text_box,
prompt_box,
current_block_index_slider,
generate_block_number_slider,
)
def create_movable_area_ui(
movable_area_mask: gr.State,
original_image: gr.State,
):
with gr.Row():
movable_area_image_editor = gr.ImageEditor(
label="Step 4: This is Last Frame of Video, Draw Editable Area Here. (Normally This Should Be Large and Cover all Possible Area Where the Object You Want to Move/Animate to)",
type="pil",
interactive=True,
brush=gr.Brush(
default_size=100,
colors=[
"rgba(0, 0, 255, 0.5)",
],
default_color="auto",
color_mode="defaults",
),
)
movable_area_image_editor.change(
fn=extract_layer_as_mask,
inputs=movable_area_image_editor,
outputs=movable_area_mask,
trigger_mode="always_last",
)
original_image.change(
fn=lambda image: image,
inputs=original_image,
outputs=movable_area_image_editor,
trigger_mode="always_last",
)
with gr.Row():
refresh_movable_area_button = gr.Button(
value="Refresh Movable Area (Normally No Need to Click This, In Case the Mask Fails to Update due to Gradio Bug)"
)
refresh_movable_area_button.click(
fn=extract_layer_as_mask,
inputs=movable_area_image_editor,
outputs=movable_area_mask,
trigger_mode="always_last",
)
def create_target_area_ui(
target_area_mask: gr.State,
original_image: gr.State,
movable_area_mask: gr.State,
):
with gr.Row():
target_area_image_editor = gr.ImageEditor(
label="Step 5: Draw Target Area on the Object You Want to Move/Animate (Normally This Should Be a Subset of Editable Area) (Normally This Mask should be Bigger than the Desired Object)",
type="pil",
interactive=True,
brush=gr.Brush(
default_size=50,
colors=[
"rgba(255, 0, 0, 0.5)",
],
default_color="auto",
color_mode="defaults",
),
)
target_area_image_editor.change(
fn=extract_layer_as_mask,
inputs=target_area_image_editor,
outputs=target_area_mask,
trigger_mode="always_last",
)
gr.on(
triggers=[
original_image.change,
movable_area_mask.change,
],
fn=apply_movable_mask_to_image,
inputs=[
movable_area_mask,
original_image,
],
outputs=target_area_image_editor,
trigger_mode="always_last",
)
with gr.Row():
refresh_target_area_button = gr.Button(
value="Refresh Target Area (Normally No Need to Click This, In Case the Mask Fails to Update due to Gradio Bug)"
)
refresh_target_area_button.click(
fn=extract_layer_as_mask,
inputs=target_area_image_editor,
outputs=target_area_mask,
trigger_mode="always_last",
)
def create_sam_segmentation_ui(
original_image: gr.State,
movable_area_mask: gr.State,
target_area_mask: gr.State,
sam_predictor: SamPredictor,
sam_click_points: gr.State,
sam_saved_logits: gr.State,
current_sam_predicted_mask: gr.State,
saved_sam_predicted_masks: gr.State,
dilated_current_sam_predicted_mask: gr.State,
dilated_saved_sam_predicted_masks: gr.State,
):
with gr.Row():
refresh_sam_segment_click_image_button = gr.Button(
value="Refresh Target Area Mask Display (Normally No Need to Click This, In Case the Mask Fails to Update due to Gradio Bug)"
)
with gr.Row():
sam_segment_click_image = gr.Image(
label="Step 6: Click to Perform SAM Segment on Target Area, Segment the Object You Want to Move/Animate. The SAM Mask is Restricted within the Target Area Mask",
type="pil",
interactive=True,
)
gr.on(
triggers=[
original_image.change,
movable_area_mask.change,
target_area_mask.change,
refresh_sam_segment_click_image_button.click,
],
fn=lambda movable_mask, target_mask, image: apply_target_mask_to_image(
target_mask,
apply_movable_mask_to_image(
movable_mask,
image,
),
),
inputs=[
movable_area_mask,
target_area_mask,
original_image,
],
outputs=sam_segment_click_image,
trigger_mode="always_last",
)
with gr.Row():
dilate_mask_slider = gr.Slider(
label="Dilate Mask Pixel",
minimum=0,
maximum=50,
value=15,
step=1,
)
bypass_sam_model_check_box = gr.Checkbox(
label="Bypass SAM Model",
value=False,
)
def sam_predict_segmentation_wrapper_wrapper(
oi,
rm,
pcp,
psl,
bs,
evt: gr.SelectData,
):
return sam_predict_segmentation_wrapper(
sam_predictor=sam_predictor,
original_image=oi,
restriction_mask=rm,
previous_click_points=pcp,
previous_sam_logits=psl,
bypass_sam_model=bs,
evt=evt,
)
sam_segment_click_image.select(
fn=sam_predict_segmentation_wrapper_wrapper,
inputs=[
original_image,
target_area_mask,
sam_click_points,
sam_saved_logits,
bypass_sam_model_check_box,
],
outputs=[
current_sam_predicted_mask,
sam_click_points,
sam_saved_logits,
],
trigger_mode="always_last",
)
gr.on(
triggers=[
current_sam_predicted_mask.change,
dilate_mask_slider.change,
],
fn=dilate_mask,
inputs=[
current_sam_predicted_mask,
dilate_mask_slider,
],
outputs=dilated_current_sam_predicted_mask,
trigger_mode="always_last",
)
gr.on(
triggers=[
saved_sam_predicted_masks.change,
dilate_mask_slider.change,
],
fn=dilate_masks,
inputs=[
saved_sam_predicted_masks,
dilate_mask_slider,
],
outputs=dilated_saved_sam_predicted_masks,
trigger_mode="always_last",
)
def create_sam_mask_management_ui(
original_image: gr.State,
movable_area_mask: gr.State,
dilated_current_sam_predicted_mask: gr.State,
dilated_saved_sam_predicted_masks: gr.State,
sam_click_points: gr.State,
current_sam_predicted_mask: gr.State,
saved_sam_predicted_masks: gr.State,
sam_saved_logits: gr.State,
):
with gr.Row():
save_sam_masks_button = gr.Button(
value="Step 7: Save the Current SAM Mask",
)
cancel_sam_mask_button = gr.Button(value="Cancel Current SAM Mask")
delete_sam_mask_button = gr.Button(value="Delete All SAM Masks")
save_sam_masks_button.click(
fn=save_sam_masks,
inputs=[
current_sam_predicted_mask,
saved_sam_predicted_masks,
],
outputs=[
current_sam_predicted_mask,
saved_sam_predicted_masks,
sam_click_points,
sam_saved_logits,
],
trigger_mode="always_last",
)
with gr.Row():
sam_segment_display_image = gr.Image(
label="Step 8: Display the SAM Segmentation, Click to Select Target Object to Create Trajectory",
type="pil",
interactive=True,
)
gr.on(
triggers=[
original_image.change,
movable_area_mask.change,
dilated_current_sam_predicted_mask.change,
dilated_saved_sam_predicted_masks.change,
sam_click_points.change,
],
fn=draw_sam_mask_wrapper,
inputs=[
original_image,
movable_area_mask,
dilated_current_sam_predicted_mask,
dilated_saved_sam_predicted_masks,
sam_click_points,
],
outputs=sam_segment_display_image,
trigger_mode="always_last",
)
cancel_sam_mask_button.click(
fn=lambda: (None, [], None),
outputs=[
current_sam_predicted_mask,
sam_click_points,
sam_saved_logits,
],
trigger_mode="always_last",
)
gr.on(
triggers=[
# target_area_mask.change,
delete_sam_mask_button.click,
],
fn=lambda: (None, [], [], None),
outputs=[
current_sam_predicted_mask,
saved_sam_predicted_masks,
sam_click_points,
sam_saved_logits,
],
trigger_mode="always_last",
)
with gr.Row():
current_selected_mask_index_number = gr.Number(
label="Current Selected Mask Index",
interactive=False,
)
sam_segment_display_image.select(
fn=select_target_sam_mask,
inputs=[
saved_sam_predicted_masks,
],
outputs=[
current_selected_mask_index_number,
],
trigger_mode="always_last",
)
return current_selected_mask_index_number
def create_trajectory_display_ui(
original_image: gr.State,
movable_area_mask: gr.State,
dilated_saved_sam_predicted_masks: gr.State,
saved_trajectory: gr.State,
current_selected_mask_index_number: gr.State,
):
with gr.Row():
trajectory_block_number_slider = gr.Slider(
label="Step 9: Select Number of Trajectory Blocks (For Animation Only, More Blocks Means Longer Animation, For Drag, This Should be 1)",
minimum=1,
maximum=10,
value=1,
step=1,
)
with gr.Row():
drag_animation_select = gr.Dropdown(
choices=["Drag", "Animation"],
label="Step 10: Select Drag or Animation",
)
translate_rotate_select = gr.Dropdown(
choices=["Translation", "Rotation"],
label="Step 11: Select Translation or Rotation",
)
with gr.Row():
trajectory_display_image = gr.Image(
label="Step 12: Click on the Object in the Image to Create Trajectory. The Translation Trajectory is Controlled by Bspline Interpolation. The Rotation Trajectory is Controlled by 3 Points",
type="pil",
interactive=False,
)
gr.on(
triggers=[
original_image.change,
movable_area_mask.change,
current_selected_mask_index_number.change,
dilated_saved_sam_predicted_masks.change,
saved_trajectory.change,
],
fn=draw_trajectory_image,
inputs=[
original_image,
movable_area_mask,
current_selected_mask_index_number,
dilated_saved_sam_predicted_masks,
saved_trajectory,
],
outputs=trajectory_display_image,
trigger_mode="always_last",
)
trajectory_display_image.select(
fn=update_trajectory,
inputs=[
saved_trajectory,
current_selected_mask_index_number,
drag_animation_select,
translate_rotate_select,
],
outputs=saved_trajectory,
)
return drag_animation_select, trajectory_block_number_slider
def create_trajectory_management_ui(
save_dir_text_box: gr.Textbox,
original_image: gr.State,
current_block_index_slider: gr.Slider,
saved_trajectory: gr.State,
dilated_saved_sam_predicted_masks: gr.State,
current_selected_mask_index_number: gr.Number,
):
with gr.Row():
save_trajectory_button = gr.Button(
value="Step 13: Save Trajectory",
)
delete_current_trajectory_button = gr.Button(value="Delete Current Trajectory")
delete_all_trajectory_button = gr.Button(value="Delete All Trajectories")
save_trajectory_button.click(
fn=save_trajectory,
inputs=[
save_dir_text_box,
saved_trajectory,
original_image,
current_block_index_slider,
dilated_saved_sam_predicted_masks,
],
)
delete_current_trajectory_button.click(
fn=clear_current_trajectory,
inputs=[current_selected_mask_index_number, saved_trajectory],
outputs=[saved_trajectory],
)
delete_all_trajectory_button.click(
fn=clear_all_trajectories,
inputs=[saved_trajectory],
outputs=[saved_trajectory],
)
def create_ui(
text_dataset: Dataset,
label_root: str | Path,
output_dir: str | Path,
sam_predictor: SamPredictor,
stream_drag_inference: StreamDragInferenceWrapper,
):
with gr.Blocks() as demo:
video_path = gr.State(value=None)
original_image = gr.State(value=None)
movable_area_mask = gr.State(value=None)
target_area_mask = gr.State(value=None)
sam_click_points = gr.State(value=[])
sam_saved_logits = gr.State(value=None)
saved_sam_predicted_masks = gr.State(value=[])
current_sam_predicted_mask = gr.State(value=None)
dilated_current_sam_predicted_mask = gr.State(value=None)
dilated_saved_sam_predicted_masks = gr.State(value=[])
saved_trajectory = gr.State(value=MultiTrajectory())
(
prompt_index_number,
save_dir_text_box,
prompt_box,
current_block_index_slider,
generate_block_number_slider,
) = create_generate_video_ui(
label_root=label_root,
text_dataset=text_dataset,
video_path=video_path,
stream_drag_inference=stream_drag_inference,
output_dir=output_dir,
original_image=original_image,
)
create_movable_area_ui(movable_area_mask, original_image)
create_target_area_ui(target_area_mask, original_image, movable_area_mask)
create_sam_segmentation_ui(
original_image=original_image,
movable_area_mask=movable_area_mask,
target_area_mask=target_area_mask,
sam_predictor=sam_predictor,
sam_click_points=sam_click_points,
sam_saved_logits=sam_saved_logits,
current_sam_predicted_mask=current_sam_predicted_mask,
saved_sam_predicted_masks=saved_sam_predicted_masks,
dilated_current_sam_predicted_mask=dilated_current_sam_predicted_mask,
dilated_saved_sam_predicted_masks=dilated_saved_sam_predicted_masks,
)
current_selected_mask_index_number = create_sam_mask_management_ui(
original_image=original_image,
movable_area_mask=movable_area_mask,
dilated_current_sam_predicted_mask=dilated_current_sam_predicted_mask,
dilated_saved_sam_predicted_masks=dilated_saved_sam_predicted_masks,
sam_click_points=sam_click_points,
current_sam_predicted_mask=current_sam_predicted_mask,
saved_sam_predicted_masks=saved_sam_predicted_masks,
sam_saved_logits=sam_saved_logits,
)
drag_animation_select, trajectory_block_number_slider = create_trajectory_display_ui(
original_image=original_image,
movable_area_mask=movable_area_mask,
dilated_saved_sam_predicted_masks=dilated_saved_sam_predicted_masks,
saved_trajectory=saved_trajectory,
current_selected_mask_index_number=current_selected_mask_index_number,
)
create_trajectory_management_ui(
save_dir_text_box=save_dir_text_box,
original_image=original_image,
current_block_index_slider=current_block_index_slider,
saved_trajectory=saved_trajectory,
dilated_saved_sam_predicted_masks=dilated_saved_sam_predicted_masks,
current_selected_mask_index_number=current_selected_mask_index_number,
)
add_listeners_to_trajectory(
saved_trajectory=saved_trajectory,
prompt_box=prompt_box,
trajectory_block_number_slider=trajectory_block_number_slider,
drag_animation_select=drag_animation_select,
movable_area_mask=movable_area_mask,
dilated_saved_sam_predicted_masks=dilated_saved_sam_predicted_masks,
)
with gr.Row():
begin_optimize_button = gr.Button(
value="Step 14: Click Here to Begin Optimize, Wait for a Moment and the Dragged/Animated Video will be Displayed Above",
)
begin_optimize_button.click(
fn=lambda pi, sbi, st: optimize_video(
stream_drag_inference_model=stream_drag_inference,
output_dir=output_dir,
prompt_index=pi,
start_block_index=sbi,
multi_trajectory=st,
),
inputs=[
prompt_index_number,
current_block_index_slider,
saved_trajectory,
],
outputs=[
video_path,
current_block_index_slider,
],
)
with gr.Row():
clear_all_button = gr.Button(
value="Step 15: Remember to Click Here to Clear All Before Generation/Editing on Next Video, Otherwise the Previous KV Cache will Affect the Generation/Editing of Next Video",
)
def clear_all():
stream_drag_inference.reset()
return (
0,
None,
None,
None,
None,
[],
None,
[],
None,
MultiTrajectory(),
)
clear_all_button.click(
fn=clear_all,
outputs=[
current_block_index_slider,
video_path,
original_image,
movable_area_mask,
target_area_mask,
sam_click_points,
sam_saved_logits,
saved_sam_predicted_masks,
current_sam_predicted_mask,
saved_trajectory,
],
)
return demo
def download_required_files():
from huggingface_hub import snapshot_download
import urllib.request
# 1. Download "checkpoints" directory from gdhe17/Self-Forcing
if not os.path.exists("checkpoints"):
print("Downloading checkpoints from gdhe17/Self-Forcing...")
snapshot_download(
repo_id="gdhe17/Self-Forcing",
allow_patterns=["checkpoints/*"],
local_dir=".",
)
# 2. Download Wan-AI/Wan2.1-T2V-1.3B and place in wan_models/Wan2.1-T2V-1.3B
wan_model_dir = os.path.join("wan_models", "Wan2.1-T2V-1.3B")
if not os.path.exists(wan_model_dir):
print("Downloading Wan-AI/Wan2.1-T2V-1.3B...")
os.makedirs("wan_models", exist_ok=True)
snapshot_download(
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
local_dir=wan_model_dir,
)
# 3. Download SAM ViT-H checkpoint
sam_checkpoint_path = "sam_vit_h_4b8939.pth"
if not os.path.exists(sam_checkpoint_path):
print("Downloading SAM ViT-H checkpoint...")
urllib.request.urlretrieve(
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
sam_checkpoint_path,
)
def main():
download_required_files()
sam_model = sam_model_registry["vit_h"](checkpoint="./sam_vit_h_4b8939.pth")
sam_model.to(device="cuda")
sam_predictor = SamPredictor(sam_model)
SEED = 42
text_dataset = TextDataset(prompt_path="prompts/MovieGenVideoBench_extended.txt")
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
config_dir = "configs"
stream_config_name = "self_forcing_dmd_vsink_stream_drag"
with initialize(version_base=None, config_path=config_dir):
stream_config = compose(config_name=stream_config_name)
print(f"{stream_config = }")
stream_drag_inference = StreamDragInferenceWrapper(
stream_model_config=stream_config,
checkpoint_path="./checkpoints/self_forcing_dmd.pt",
total_generate_block_number=36,
use_ema=True,
seed=SEED,
)
label_save_dir = Path("./saved_labels")
label_save_dir = label_save_dir / f"{stream_config_name}-seed{SEED}"
label_save_dir.mkdir(parents=True, exist_ok=True)
output_save_dir = Path("outputs-editing")
output_save_dir = output_save_dir / f"{stream_config_name}-seed{SEED}"
output_save_dir.mkdir(parents=True, exist_ok=True)
demo = create_ui(
text_dataset=text_dataset,
label_root=label_save_dir,
output_dir=output_save_dir,
sam_predictor=sam_predictor,
stream_drag_inference=stream_drag_inference,
)
demo.launch(server_name="0.0.0.0")
if __name__ == "__main__":
main()