KiwiEdit / app.py
linyq's picture
Update app.py
a7fb4a8 verified
import spaces
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
import gradio as gr
import numpy as np
import torch
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video
from torchvision.io import read_video
os.environ["TOKENIZERS_PARALLELISM"] = "false"
DEFAULT_MODEL = "linyq/kiwi-edit-5b-instruct-reference-diffusers"
MODEL_CHOICES = [
"linyq/kiwi-edit-5b-instruct-only-diffusers",
"linyq/kiwi-edit-5b-reference-only-diffusers",
"linyq/kiwi-edit-5b-instruct-reference-diffusers",
]
_PIPELINE_CACHE = {}
APP_ROOT = Path(__file__).resolve().parent
print(f"App root: {APP_ROOT}")
def _device_and_dtype() -> Tuple[str, torch.dtype]:
if torch.cuda.is_available():
# float16 is broadly supported on consumer GPUs in Spaces.
return "cuda:0", torch.bfloat16
return "cpu", torch.float32
def _safe_int(v: Union[int, float]) -> int:
return int(v)
def load_video_frames(video_path: str, max_frames: int, max_pixels: int) -> List[Image.Image]:
vframes, _, _ = read_video(video_path, pts_unit="sec")
frames: List[Image.Image] = []
for i in range(min(len(vframes), max_frames)):
img = Image.fromarray(vframes[i].numpy())
w, h = img.size
scale = min(1.0, (max_pixels / (w * h)) ** 0.5)
new_w = max((int(w * scale) // 32) * 32, 32)
new_h = max((int(h * scale) // 32) * 32, 32)
img = img.resize((new_w, new_h), Image.LANCZOS)
frames.append(img)
return frames
def load_pipeline(model_id: str):
device, dtype = _device_and_dtype()
cache_key = (model_id, device, str(dtype))
if cache_key in _PIPELINE_CACHE:
return _PIPELINE_CACHE[cache_key], device
pipe = DiffusionPipeline.from_pretrained(model_id, trust_remote_code=True)
pipe.to(device, dtype=dtype)
_PIPELINE_CACHE[cache_key] = pipe
return pipe, device
def make_side_by_side(video1: Sequence[Image.Image], video2: Sequence, bg=(0, 0, 0)) -> List[np.ndarray]:
cat_video = []
for img1, img2 in zip(video1, video2):
w1, h1 = img1.size
w2, h2 = img2.size
H = max(h1, h2)
W = w1 + w2
canvas = Image.new("RGB", (W, H), bg)
canvas.paste(img1, (0, 0))
canvas.paste(img2, (w1, 0))
cat_video.append(canvas)
return cat_video
def _asset_path(*parts: str) -> Optional[str]:
asset = APP_ROOT.joinpath(*parts)
return str(asset) if asset.exists() else None
TASK_ENHANCED_PROMPT = {
"global_style": [
"Ensure seamless temporal consistency across all frames of the video.",
"Retain the original motion, character actions, and camera movements throughout the sequence.",
"Preserve the video's narrative flow and structural coherence to keep the original intent intact.",
"Maintain strict frame-by-frame consistency to ensure visual harmony.",
"Eliminate flickering or abrupt style changes between consecutive frames.",
"The model must preserve the dynamic interplay of light and shadow from the source footage.",
],
"local_change": [
"Ensure the object maintains the exact same position and pose within the video scene.",
"The modified element must stay aligned with the subject's original physical orientation.",
"Keep the same pose and position for the subject throughout the entire video.",
"The new attire or object must fit the subject's pose and spatial coordinates perfectly.",
"Maintain the original object's dimensions and perspective during the replacement process.",
],
"background_change": [
"The subject in the foreground must remain perfectly still throughout the video.",
"The person and any foreground objects should stay static and unchanged.",
"Ensure the foreground subject remains perfectly still while the background transforms.",
"Include subtle movements of environmental elements, such as shifting sunlight and shadows.",
"Transform the background into a dynamic scene without altering the narrative flow of the foreground.",
],
"local_remove": [
"The background must be reconstructed with temporal consistency to match the original context.",
"All other video content must remain entirely unchanged after the object is removed.",
"Perform the removal using temporally consistent background inpainting techniques.",
"Ensure the background is inpainted smoothly across all frames to avoid visual artifacts.",
"The removal of the subject must leave the surrounding environment structurally intact.",
],
"local_add": [
"The added object must be perfectly tracked to the specified surface as the camera moves.",
"Maintain consistent shadows and lighting for the added object across all frames.",
"All other parts of the video must remain unchanged after the new object is overlaid.",
"Reflections and shadows must dynamically adapt to the changing light in the environment.",
"The added subject should exhibit subtle natural movements to enhance realism.",
"Ensure the object remains fixed relative to its anchor point as the camera pans or zooms.",
],
}
TASK_LABELS = {
"global_style": "Global Style",
"background_change": "Background Change",
"local_remove": "Local Remove",
"local_add": "Local Add",
"local_change": "Local Change / Replace",
}
PRESET_EXAMPLES = {
"Style": {
"preview_gif": _asset_path("examples", "0007_global_style_Apply_the_dynamic_ae_concat.gif"),
"source_video": _asset_path("examples", "0007_global_style_Apply_the_dynamic_ae_raw.mp4"),
"prompt": "Apply the dynamic aesthetic of abstract art to this video, maintaining strict temporal consistency across all frames. The result should exude the spontaneity and emotional depth of abstract expressionism, with synchronized brushstroke patterns and color transitions that align with the original narrative flow. Preserve all original motion, character movements, and camera dynamics, ensuring no frame deviates from the video’s intended rhythm.",
"ref_image": None,
"task_type": "global_style",
},
"Replace": {
"preview_gif": _asset_path("examples", "0083_local_change_Replace_the_sofa_wit_70_concat.gif"),
"source_video": _asset_path("examples", "0083_local_change_Replace_the_sofa_wit2_raw.mp4"),
"prompt": "Replace the sofa with a classic brown leather sofa with visible stitching, ensuring it maintains the sofa's exact position and pose within the scene.",
"ref_image": None,
"task_type": "local_change",
},
"Add": {
"preview_gif": _asset_path("examples", "0095_local_change_Add_a_classic_brown_concat.gif"),
"source_video": _asset_path("examples", "0095_local_change_Add_a_classic_brown_raw.mp4"),
"prompt": "Add a classic brown fedora hat to the boy's head, maintaining the same position and pose within the video scene.",
"ref_image": None,
"task_type": "local_add",
},
"Remove": {
"preview_gif": _asset_path("examples", "0191_local_remove_Remove_the_person_we_concat.gif"),
"source_video": _asset_path("examples", "0191_local_remove_Remove_the_person_we_raw.mp4"),
"prompt": "Remove the person wearing a light blue shirt and dark pants from the entire video sequence. The background must be reconstructed with temporal consistency, and all other video content must remain unchanged.",
"ref_image": None,
"task_type": "local_remove",
},
"Background Replace": {
"preview_gif": _asset_path("examples", "0145_background_change_Replace_the_backgrou_concat.gif"),
"source_video": _asset_path("examples", "0145_background_change_Replace_the_backgrou_raw.mp4"),
"prompt": "Replace the background with a lively urban rooftop garden scene during winter. Include subtle movement of leaves in a gentle breeze, distant city traffic sounds implied by soft light flickers, and shifting sunlight casting dynamic shadows. The deer remains perfectly still.",
"ref_image": None,
"task_type": "background_change",
},
"Subject Reference": {
"preview_gif": _asset_path("examples", "0125_background_change_Replace_the_backgrou_concat.gif"),
"source_video": _asset_path("examples", "0125_background_change_Replace_the_backgrou_raw.mp4"),
"prompt": "Add a pair of iconic red heart-shaped sunglasses to the girl's face. It must be tracked and integrated realistically and consistently across all frames, without altering any other video content.",
"ref_image": _asset_path("examples", "41_shape_heart_sunglasses_1328_1328_1.png"),
"task_type": "local_add",
},
"Background Reference": {
"preview_gif": _asset_path("examples", "1_Replace_th_gym-ball_concat.gif"),
"source_video": _asset_path("examples", "1_Replace_th_gym-ball_raw.mp4"),
"prompt": "Replace the background with a Chinese ink painting, featuring a large golden mountain peak rising above swirling clouds, ensuring it appears in the same position and pose within the video scene.",
"ref_image": _asset_path("examples", "0_mountain_ink_1664_928_0.png"),
"task_type": "background_change",
},
}
def _format_task_tips(task_type: str) -> str:
tips = TASK_ENHANCED_PROMPT.get(task_type, [])
task_label = TASK_LABELS.get(task_type, task_type)
if not tips:
return f"**Prompt Tips · {task_label}**\n- No tips available."
return f"**Prompt Tips · {task_label}**\n" + "\n".join(f"- {tip}" for tip in tips)
def _preset_gallery_items():
items = []
for name, cfg in PRESET_EXAMPLES.items():
preview = cfg.get("preview_gif")
if preview:
items.append((preview, name))
return items
def _empty_tips() -> str:
return (
"**Prompt Tips**\n"
"- Select a example thumbnail to auto-fill all inputs.\n"
"- Or write your own prompt and enhance the prompt by below category."
)
def load_preset_by_index(evt: gr.SelectData):
if evt is None or evt.index is None:
return None, "", None, "global_style", _empty_tips(), "No preset selected."
idx = evt.index[0] if isinstance(evt.index, tuple) else evt.index
preset_names = [name for name, cfg in PRESET_EXAMPLES.items() if cfg.get("preview_gif")]
if not isinstance(idx, int) or idx < 0 or idx >= len(preset_names):
return None, "", None, "global_style", _empty_tips(), "No preset selected."
return load_preset_example(preset_names[idx])
def load_preset_example(example_name: str):
cfg = PRESET_EXAMPLES.get(example_name)
if not cfg:
return None, "", None, "global_style", _format_task_tips("global_style"), "Preset not found."
task_type = cfg["task_type"]
status_text = f"Loaded preset: {example_name}"
return (
cfg["source_video"],
cfg["prompt"],
cfg["ref_image"],
task_type,
_format_task_tips(task_type),
status_text,
)
@spaces.GPU(duration=150)
def run_edit(
source_video,
prompt: str,
ref_image: Optional[str],
model_id: str,
max_frames: int,
steps: int,
seed: int,
):
if isinstance(source_video, dict):
source_video = source_video.get("path")
if not source_video:
raise gr.Error("Please upload a source video.")
if not prompt or not prompt.strip():
raise gr.Error("Please enter an edit prompt.")
max_frames = _safe_int(max_frames)
steps = _safe_int(steps)
seed = _safe_int(seed)
source_frames = load_video_frames(source_video, max_frames=max_frames, max_pixels=921600)
if not source_frames:
raise gr.Error("Could not read frames from this video.")
pipe, device = load_pipeline(model_id)
ref = None
if ref_image and os.path.exists(ref_image):
ref = [Image.open(ref_image)]
height, width = source_frames[0].size[1], source_frames[0].size[0]
edited_frames = pipe(
prompt=prompt.strip(),
source_video=source_frames,
ref_image=ref,
height=height,
width=width,
num_frames=min(len(source_frames), max_frames),
num_inference_steps=steps,
# guidance_scale=guidance_scale,
seed=seed,
tiled=True,
)
out_dir = tempfile.mkdtemp(prefix="kiwi_edit_")
os.makedirs(out_dir, exist_ok=True)
edited_path = os.path.join(out_dir, "edited.mp4")
compare_path = os.path.join(out_dir, "source_vs_edited.mp4")
export_to_video(edited_frames, edited_path, fps=15)
compare_frames = make_side_by_side(source_frames, edited_frames)
if compare_frames:
export_to_video(compare_frames, compare_path, fps=15)
else:
compare_path = None
status = f"Done on {device} | {len(source_frames)} frames | {width}x{height}"
if not device.startswith("cuda"):
status += " (CPU mode can be very slow)"
return edited_path, compare_path, status
CUSTOM_CSS = """
#kiwi-app {
max-width: 1200px;
margin: 0 auto;
}
.example-gallery-pane {
height: 224px;
max-height: 224px;
overflow-y: auto;
border: 1px solid #e6e9ef;
border-radius: 10px;
padding: 8px;
}
#preset-gallery img {
max-height: 224px !important;
object-fit: contain !important;
}
#preset-gallery .grid-wrap,
#preset-gallery .grid-container {
gap: 4px !important;
}
#preset-gallery .thumbnail-item,
#preset-gallery [data-testid="gallery-item"] {
margin-top: 0 !important;
margin-bottom: 0 !important;
padding-top: 2px !important;
padding-bottom: 2px !important;
}
#preset-gallery figcaption {
margin-top: 2px !important;
margin-bottom: 0 !important;
padding-top: 0 !important;
padding-bottom: 0 !important;
}
.kiwi-caption p {
color: #5f6368;
margin-top: 4px;
font-size: 0.94rem;
}
.my-video { max-height: 480px;}
"""
with gr.Blocks(
title="Kiwi-Edit Diffusers Demo",
theme=gr.themes.Soft(primary_hue="emerald"),
css=CUSTOM_CSS,
) as demo:
with gr.Column(elem_id="kiwi-app"):
gr.Markdown(
"## Kiwi-Edit Video Editor\n"
"Upload a video, write a prompt, and run video editing with optional reference image guidance.",
elem_classes=["kiwi-caption"],
)
with gr.Row():
with gr.Column(scale=4):
gr.Markdown("### Example Library")
with gr.Group(elem_classes=["example-gallery-pane"]):
preset_gallery = gr.Gallery(
value=_preset_gallery_items(),
elem_id="preset-gallery",
columns=1,
rows=7,
object_fit="contain",
allow_preview=False,
height=320,
label="Examples (Click a thumbnail to auto-fill)",
)
task_type = gr.Radio(
choices=[(label, key) for key, label in TASK_LABELS.items()],
value="global_style",
label="Prompt Tip Category",
)
prompt_tips = gr.Markdown(value=_empty_tips())
with gr.Column(scale=4):
gr.Markdown("### Inputs")
source_video = gr.Video(label="Source Video", container=True)
prompt = gr.Textbox(
label="Edit Prompt",
lines=4,
placeholder="Describe what should change and what should stay unchanged (Enhanced Prompts are listed in the 'Prompt Tips' panel).",
)
ref_image = gr.Image(
type="filepath",
label="Reference Image (Optional)",
)
with gr.Accordion("Advanced Settings", open=True):
model_id = gr.Dropdown(MODEL_CHOICES, value=DEFAULT_MODEL, label="Model")
max_frames = gr.Slider(8, 81, value=81, step=1, label="Max Frames")
steps = gr.Slider(10, 80, value=50, step=1, label="Inference Steps")
seed = gr.Number(value=0, precision=0, label="Seed")
run_btn = gr.Button("Run Edit", variant="primary")
with gr.Column(scale=4):
gr.Markdown("### Outputs")
edited_video = gr.Video(label="Edited Video", container=True)
compare_video = gr.Video(label="Side-by-Side (Source | Edited)", container=True)
status = gr.Textbox(label="Status", interactive=False)
preset_gallery.select(
fn=load_preset_by_index,
outputs=[source_video, prompt, ref_image, task_type, prompt_tips, status],
)
task_type.change(fn=_format_task_tips, inputs=[task_type], outputs=[prompt_tips])
run_btn.click(
fn=run_edit,
inputs=[source_video, prompt, ref_image, model_id, max_frames, steps, seed],
outputs=[edited_video, compare_video, status],
)
if __name__ == "__main__":
demo.queue().launch()