import os import spaces import shutil import subprocess import sys import copy import random import tempfile import warnings import time import gc import json import uuid from tqdm import tqdm import cv2 import numpy as np import torch import torch._dynamo from huggingface_hub import list_models, HfApi from torch.nn import functional as F from PIL import Image, ImageDraw import gradio as gr from diffusers import ( FlowMatchEulerDiscreteScheduler, SASolverScheduler, DEISMultistepScheduler, DPMSolverMultistepInverseScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, ) from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline from diffusers.utils.export_utils import export_to_video from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig, Int8WeightOnlyConfig import aoti os.environ["TOKENIZERS_PARALLELISM"] = "true" warnings.filterwarnings("ignore") IS_ZERO_GPU = bool(os.getenv("SPACES_ZERO_GPU")) if IS_ZERO_GPU: print("Loading...") subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) # --- FRAME EXTRACTION JS & LOGIC --- # JS to grab timestamp from the output video get_timestamp_js = """ function() { // Select the video element specifically inside the component with id 'generated-video' const video = document.querySelector('#generated-video video'); if (video) { console.log("Video found! Time: " + video.currentTime); return video.currentTime; } else { console.log("No video element found."); return 0; } } """ def extract_frame(video_path, timestamp): # Safety check: if no video is present if not video_path: return None print(f"Extracting frame at timestamp: {timestamp}") cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None # Calculate frame number fps = cap.get(cv2.CAP_PROP_FPS) target_frame_num = int(float(timestamp) * fps) # Cap total frames to prevent errors at the very end of video total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if target_frame_num >= total_frames: target_frame_num = total_frames - 1 # Set position cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame_num) ret, frame = cap.read() cap.release() if ret: # Convert from BGR (OpenCV) to RGB (Gradio) # Gradio Image component handles Numpy array -> PIL conversion automatically return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return None # --- END FRAME EXTRACTION LOGIC --- # --- EXAMPLES: PLACEHOLDER IMAGE GENERATION --- EXAMPLES_DIR = "examples" EXAMPLES_JSON_PATH = "examples_config.json" os.makedirs(EXAMPLES_DIR, exist_ok=True) def _make_placeholder(label: str, filename: str, color): path = os.path.join(EXAMPLES_DIR, filename) if os.path.exists(path): return path img = Image.new("RGB", (512, 512), tuple(color)) draw = ImageDraw.Draw(img) # Simple centered label (default PIL font, no extra deps) try: bbox = draw.textbbox((0, 0), label) w = bbox[2] - bbox[0] h = bbox[3] - bbox[1] except Exception: w, h = (len(label) * 6, 11) draw.text(((512 - w) / 2, (512 - h) / 2), label, fill=(255, 255, 255)) img.save(path) return path # Default examples — used when examples_config.json does not exist. DEFAULT_EXAMPLES = [ {"filename": "portrait_a.png", "color": [62, 92, 138], "label": "Portrait A", "prompt": "subtle cinematic motion, gentle head turn", "style": "Static-Cam"}, {"filename": "portrait_b.png", "color": [138, 75, 62], "label": "Portrait B", "prompt": "slow dolly in toward subject, smooth animation", "style": "Dolly-In"}, {"filename": "landscape_a.png", "color": [60, 110, 80], "label": "Landscape A", "prompt": "camera pans right across the scene, smooth motion", "style": "Pan-Right"}, {"filename": "landscape_b.png", "color": [110, 90, 60], "label": "Landscape B", "prompt": "slow zoom out reveals wider environment", "style": "Zoom-Out"}, {"filename": "hair_blow.png", "color": [95, 70, 130], "label": "Hair Blow", "prompt": "hair flows in the wind, soft natural motion", "style": "Wind-Blow"}, {"filename": "water.png", "color": [50, 100, 130], "label": "Water", "prompt": "water ripples gently, reflections shimmer", "style": "Water-Flow"}, {"filename": "cinematic.png", "color": [40, 40, 60], "label": "Cinematic", "prompt": "cinematic motion, soft camera drift, film grain", "style": "Cinematic"}, {"filename": "shake.png", "color": [130, 60, 60], "label": "Action", "prompt": "handheld camera shake, dynamic motion", "style": "Camera-Shake"}, ] def load_examples_config(): """Load examples from JSON if present, else fall back to DEFAULT_EXAMPLES.""" if os.path.exists(EXAMPLES_JSON_PATH): try: with open(EXAMPLES_JSON_PATH, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list) and all(isinstance(r, dict) for r in data): return data except Exception as e: print(f"[examples] Failed to load {EXAMPLES_JSON_PATH}: {e}") return DEFAULT_EXAMPLES def write_examples_config(data): with open(EXAMPLES_JSON_PATH, "w", encoding="utf-8") as f: json.dump(data, f, indent=2, ensure_ascii=False) def examples_to_samples(rows): """Convert config rows to gr.Examples sample triplets.""" return [ [os.path.join(EXAMPLES_DIR, row["filename"]), row["prompt"], row["style"]] for row in rows ] def _format_row_label(row, idx): label = row.get("label") or row.get("filename") or "(unnamed)" style = row.get("style") or "-" return f"#{idx+1} {label} · {style}" def _row_choices(state): """Build (display_label, index) tuples for the row selector.""" return [(_format_row_label(r, i), i) for i, r in enumerate(state or [])] def _slot_to_inputs(state, idx): """Return (image_path, label, prompt, style) for the row at idx.""" if not state or idx is None or idx < 0 or idx >= len(state): return None, "", "", "" row = state[idx] img_path = None if row.get("filename"): candidate = os.path.join(EXAMPLES_DIR, row["filename"]) if os.path.exists(candidate): img_path = candidate return img_path, row.get("label", ""), row.get("prompt", ""), row.get("style", "") # Loaded at startup. Mutated by the in-app editor. _examples_config = load_examples_config() # Generate placeholder images for any filenames that don't yet exist. for _row in _examples_config: _make_placeholder(_row["label"], _row["filename"], _row["color"]) MOTION_STYLE_CHOICES = sorted({row["style"] for row in _examples_config}) DEFAULT_MOTION_STYLE = "Cinematic" if "Cinematic" in MOTION_STYLE_CHOICES else (MOTION_STYLE_CHOICES[0] if MOTION_STYLE_CHOICES else "Cinematic") # --- END EXAMPLES SETUP --- def clear_vram(): gc.collect() torch.cuda.empty_cache() # RIFE if not os.path.exists("RIFEv4.26_0921.zip"): print("Downloading RIFE Model...") subprocess.run([ "wget", "-q", "https://huggingface.co/r3gm/RIFE/resolve/main/RIFEv4.26_0921.zip", "-O", "RIFEv4.26_0921.zip" ], check=True) subprocess.run(["unzip", "-o", "RIFEv4.26_0921.zip"], check=True) # sys.path.append(os.getcwd()) from train_log.RIFE_HDv3 import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") rife_model = Model() rife_model.load_model("train_log", -1) rife_model.eval() @torch.no_grad() def interpolate_bits(frames_np, multiplier=2, scale=1.0): """ Interpolation maintaining Numpy Float 0-1 format. Args: frames_np: Numpy Array (Time, Height, Width, Channels) - Float32 [0.0, 1.0] multiplier: int (2, 4, 8) Returns: List of Numpy Arrays (Height, Width, Channels) - Float32 [0.0, 1.0] """ # Handle input shape if isinstance(frames_np, list): # Convert list of arrays to one big array for easier shape handling if needed, # but here we just grab dims from first frame T = len(frames_np) H, W, C = frames_np[0].shape else: T, H, W, C = frames_np.shape # 1. No Interpolation Case if multiplier < 2: # Just convert 4D array to list of 3D arrays if isinstance(frames_np, np.ndarray): return list(frames_np) return frames_np n_interp = multiplier - 1 # Pre-calc padding for RIFE (requires dimensions divisible by 32/scale) tmp = max(128, int(128 / scale)) ph = ((H - 1) // tmp + 1) * tmp pw = ((W - 1) // tmp + 1) * tmp padding = (0, pw - W, 0, ph - H) # Helper: Numpy (H, W, C) Float -> Tensor (1, C, H, W) Half def to_tensor(frame_np): # frame_np is float32 0-1 t = torch.from_numpy(frame_np).to(device) # HWC -> CHW t = t.permute(2, 0, 1).unsqueeze(0) return F.pad(t, padding).half() # Helper: Tensor (1, C, H, W) Half -> Numpy (H, W, C) Float def from_tensor(tensor): # Crop padding t = tensor[0, :, :H, :W] # CHW -> HWC t = t.permute(1, 2, 0) # Keep as float32, range 0-1 return t.float().cpu().numpy() def make_inference(I0, I1, n): if rife_model.version >= 3.9: res = [] for i in range(n): res.append(rife_model.inference(I0, I1, (i+1) * 1. / (n+1), scale)) return res else: middle = rife_model.inference(I0, I1, scale) if n == 1: return [middle] first_half = make_inference(I0, middle, n=n//2) second_half = make_inference(middle, I1, n=n//2) if n % 2: return [*first_half, middle, *second_half] else: return [*first_half, *second_half] output_frames = [] # Process Frames # Load first frame into GPU I1 = to_tensor(frames_np[0]) total_steps = T - 1 with tqdm(total=total_steps, desc="Interpolating", unit="frame") as pbar: for i in range(total_steps): I0 = I1 # Add original frame to output output_frames.append(from_tensor(I0)) # Load next frame I1 = to_tensor(frames_np[i+1]) # Generate intermediate frames mid_tensors = make_inference(I0, I1, n_interp) # Append intermediate frames for mid in mid_tensors: output_frames.append(from_tensor(mid)) if (i + 1) % 50 == 0: pbar.update(50) pbar.update(total_steps % 50) # Add the very last frame output_frames.append(from_tensor(I1)) # Cleanup del I0, I1, mid_tensors torch.cuda.empty_cache() return output_frames # WAN # ORG_NAME = "TestOrganizationPleaseIgnore" MODEL_ID = "TestOrganizationPleaseIgnore/WAMU-Merge-VisualEffects_WAN2.2_I2V_LIGHTNING" #"Wan-AI/Wan2.2-I2V-A14B-Diffusers" # MODEL_ID = os.getenv("REPO_ID") or random.choice( # list(list_models(author=ORG_NAME, filter='diffusers:WanImageToVideoPipeline')) # ).modelId # CACHE_DIR = os.path.expanduser("~/.cache/huggingface/") LORA_MODELS = [ # { # "repo_id": "exampleuser/example_lora_1", # "high_tr": "example_lora_1_high.safetensors", # "low_tr": "example_lora_1_low.safetensors", # "high_scale": 0.5, # "low_scale": 0.5 # }, # { # "repo_id": "exampleuser/example_lora_2", # "high_tr": "subfolder/example_lora_2_high.safetensors", # "low_tr": "subfolder/example_lora_2_low.safetensors", # "high_scale": 0.4, # "low_scale": 0.4 # }, ] MAX_DIM = 832 MIN_DIM = 480 SQUARE_DIM = 640 MULTIPLE_OF = 16 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS = 16 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 160 MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1) MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1) SCHEDULER_MAP = { "FlowMatchEulerDiscrete": FlowMatchEulerDiscreteScheduler, "SASolver": SASolverScheduler, "DEISMultistep": DEISMultistepScheduler, "DPMSolverMultistepInverse": DPMSolverMultistepInverseScheduler, "UniPCMultistep": UniPCMultistepScheduler, "DPMSolverMultistep": DPMSolverMultistepScheduler, "DPMSolverSinglestep": DPMSolverSinglestepScheduler, } pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, ).to('cuda') original_scheduler = copy.deepcopy(pipe.scheduler) for i, lora in enumerate(LORA_MODELS): name_high_tr = lora["high_tr"].split(".")[0].split("/")[-1] + "Hh" name_low_tr = lora["low_tr"].split(".")[0].split("/")[-1] + "Ll" try: pipe.load_lora_weights( lora["repo_id"], weight_name=lora["high_tr"], adapter_name=name_high_tr ) kwargs_lora = {"load_into_transformer_2": True} pipe.load_lora_weights( lora["repo_id"], weight_name=lora["low_tr"], adapter_name=name_low_tr, **kwargs_lora ) pipe.set_adapters([name_high_tr, name_low_tr], adapter_weights=[1.0, 1.0]) pipe.fuse_lora(adapter_names=[name_high_tr], lora_scale=lora["high_scale"], components=["transformer"]) pipe.fuse_lora(adapter_names=[name_low_tr], lora_scale=lora["low_scale"], components=["transformer_2"]) pipe.unload_lora_weights() print(f"Applied: {lora['high_tr']}, hs={lora['high_scale']}/ls={lora['low_scale']}, {i+1}/{len(LORA_MODELS)}") except Exception as e: print("Error:", str(e)) print("Failed LoRA:", name_high_tr) pipe.unload_lora_weights() # if os.path.exists(CACHE_DIR): # shutil.rmtree(CACHE_DIR) # print("Deleted Hugging Face cache.") # else: # print("No hub cache found.") quantize_(pipe.text_encoder, Int8WeightOnlyConfig()) torch._dynamo.reset() quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig()) torch._dynamo.reset() quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig()) torch._dynamo.reset() aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da') aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da') # pipe.vae.enable_slicing() # pipe.vae.enable_tiling() default_prompt_i2v = "make , cinematic motion, smooth animation" default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走" # Motion-style → prompt suffix mapping. These hints are appended to the user prompt # (after a space) to nudge the model toward the chosen motion. MOTION_STYLE_HINTS = { "Static-Cam": "static camera, locked-off shot", "Dolly-In": "slow dolly in, smooth forward motion", "Pan-Right": "camera pans right, steady horizontal motion", "Zoom-Out": "slow zoom out, revealing wider scene", "Cinematic": "cinematic motion, soft camera drift, film grain", "Wind-Blow": "wind blowing, natural soft motion", "Water-Flow": "gentle water flow, shimmering reflections", "Camera-Shake": "handheld camera shake, dynamic motion", } def model_title(): repo_name = MODEL_ID.split('/')[-1].replace("_", " ") url = f"https://huggingface.co/{MODEL_ID}" return f"## This space is currently running [{repo_name}]({url}) 🐢" def resize_image(image: Image.Image) -> Image.Image: width, height = image.size if width == height: return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS) aspect_ratio = width / height MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM image_to_resize = image if aspect_ratio > MAX_ASPECT_RATIO: target_w, target_h = MAX_DIM, MIN_DIM crop_width = int(round(height * MAX_ASPECT_RATIO)) left = (width - crop_width) // 2 image_to_resize = image.crop((left, 0, left + crop_width, height)) elif aspect_ratio < MIN_ASPECT_RATIO: target_w, target_h = MIN_DIM, MAX_DIM crop_height = int(round(width / MIN_ASPECT_RATIO)) top = (height - crop_height) // 2 image_to_resize = image.crop((0, top, width, top + crop_height)) else: if width > height: target_w = MAX_DIM target_h = int(round(target_w / aspect_ratio)) else: target_h = MAX_DIM target_w = int(round(target_h * aspect_ratio)) final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF final_w = max(MIN_DIM, min(MAX_DIM, final_w)) final_h = max(MIN_DIM, min(MAX_DIM, final_h)) return image_to_resize.resize((final_w, final_h), Image.LANCZOS) def resize_and_crop_to_match(target_image, reference_image): ref_width, ref_height = reference_image.size target_width, target_height = target_image.size scale = max(ref_width / target_width, ref_height / target_height) new_width, new_height = int(target_width * scale), int(target_height * scale) resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2 return resized.crop((left, top, left + ref_width, top + ref_height)) def get_num_frames(duration_seconds: float): return 1 + int(np.clip( int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL, )) def get_inference_duration( resized_image, processed_last_image, prompt, steps, negative_prompt, num_frames, guidance_scale, guidance_scale_2, current_seed, scheduler_name, flow_shift, frame_multiplier, quality, duration_seconds, progress ): BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624 BASE_STEP_DURATION = 15 width, height = resized_image.size factor = num_frames * width * height / BASE_FRAMES_HEIGHT_WIDTH step_duration = BASE_STEP_DURATION * factor ** 1.5 gen_time = int(steps) * step_duration if guidance_scale > 1: gen_time = gen_time * 1.8 frame_factor = frame_multiplier // FIXED_FPS if frame_factor > 1: total_out_frames = (num_frames * frame_factor) - num_frames inter_time = (total_out_frames * 0.02) gen_time += inter_time return 10 + gen_time @spaces.GPU(duration=60) def run_inference( resized_image, processed_last_image, prompt, steps, negative_prompt, num_frames, guidance_scale, guidance_scale_2, current_seed, scheduler_name, flow_shift, frame_multiplier, quality, duration_seconds, progress=gr.Progress(track_tqdm=True), ): scheduler_class = SCHEDULER_MAP.get(scheduler_name) if scheduler_class.__name__ != pipe.scheduler.config._class_name or flow_shift != pipe.scheduler.config.get("flow_shift", "shift"): config = copy.deepcopy(original_scheduler.config) if scheduler_class == FlowMatchEulerDiscreteScheduler: config['shift'] = flow_shift else: config['flow_shift'] = flow_shift pipe.scheduler = scheduler_class.from_config(config) clear_vram() task_name = str(uuid.uuid4())[:8] print(f"Generating {num_frames} frames, task: {task_name}, {duration_seconds}, {resized_image.size}") start = time.time() result = pipe( image=resized_image, last_image=processed_last_image, prompt=prompt, negative_prompt=negative_prompt, height=resized_image.height, width=resized_image.width, num_frames=num_frames, guidance_scale=float(guidance_scale), guidance_scale_2=float(guidance_scale_2), num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed), output_type="np" ) print("gen time passed:", time.time() - start) raw_frames_np = result.frames[0] # Returns (T, H, W, C) float32 pipe.scheduler = original_scheduler frame_factor = frame_multiplier // FIXED_FPS if frame_factor > 1: start = time.time() print(f"Processing frames (RIFE Multiplier: {frame_factor}x)...") rife_model.device() rife_model.flownet = rife_model.flownet.half() final_frames = interpolate_bits(raw_frames_np, multiplier=int(frame_factor)) print("Interpolation time passed:", time.time() - start) else: final_frames = list(raw_frames_np) final_fps = FIXED_FPS * int(frame_factor) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name start = time.time() with tqdm(total=3, desc="Rendering Media", unit="clip") as pbar: pbar.update(2) export_to_video(final_frames, video_path, fps=final_fps, quality=quality) pbar.update(1) print(f"Export time passed, {final_fps} FPS:", time.time() - start) return video_path, task_name def generate_video( input_image, last_image, prompt, motion_style=DEFAULT_MOTION_STYLE, steps=4, negative_prompt=default_negative_prompt, duration_seconds=MAX_DURATION, guidance_scale=1, guidance_scale_2=1, seed=42, randomize_seed=False, quality=5, scheduler="UniPCMultistep", flow_shift=6.0, frame_multiplier=16, video_component=True, progress=gr.Progress(track_tqdm=True), ): """ Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA. """ if input_image is None: raise gr.Error("Please upload an input image.") # Append motion-style hint to the prompt if a known style was chosen style_hint = MOTION_STYLE_HINTS.get(motion_style) if style_hint and style_hint.lower() not in (prompt or "").lower(): prompt = f"{prompt.strip()} {style_hint}".strip() num_frames = get_num_frames(duration_seconds) current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) resized_image = resize_image(input_image) processed_last_image = None if last_image: processed_last_image = resize_and_crop_to_match(last_image, resized_image) video_path, task_n = run_inference( resized_image, processed_last_image, prompt, steps, negative_prompt, num_frames, guidance_scale, guidance_scale_2, current_seed, scheduler, flow_shift, frame_multiplier, quality, duration_seconds, progress, ) print(f"GPU complete: {task_n}") return (video_path if video_component else None), video_path, current_seed CSS = """ #hidden-timestamp { opacity: 0; height: 0px; width: 0px; margin: 0px; padding: 0px; overflow: hidden; position: absolute; pointer-events: none; } """ with gr.Blocks(delete_cache=(3600, 10800)) as demo: # gr.Markdown(model_title()) gr.Markdown("Runs in ~50s on MIG'd H200 for a 4s output") with gr.Row(): with gr.Column(): input_image_component = gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"]) last_image_component = gr.Image(type="pil", label="Last Image (Optional)", sources=["upload", "clipboard"]) prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v) motion_style_input = gr.Dropdown( label="Motion Style", choices=MOTION_STYLE_CHOICES, value=DEFAULT_MOTION_STYLE, allow_custom_value=True, info="Adds a motion hint to the prompt. Used by the Examples below.", ) duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=4, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.") with gr.Accordion("Advanced Settings", open=False): frame_multi = gr.Dropdown( choices=[FIXED_FPS, FIXED_FPS*2, FIXED_FPS*4, FIXED_FPS*8], value=FIXED_FPS, label="Video Fluidity" ) negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, info="Used if any Guidance Scale > 1.", lines=3) quality_slider = gr.Slider(minimum=1, maximum=10, step=1, value=6, label="Video Quality", info="If set to 10, the generated video may be too large and won't play in the Gradio preview.") seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True) randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True) steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Inference Steps") guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage", info="Values above 1 increase GPU usage and may take longer to process.") guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage") scheduler_dropdown = gr.Dropdown( label="Scheduler", choices=list(SCHEDULER_MAP.keys()), value="UniPCMultistep", info="Select a custom scheduler." ) flow_shift_slider = gr.Slider(minimum=0.5, maximum=15.0, step=0.1, value=3.0, label="Flow Shift") play_result_video = gr.Checkbox(label="Display result", value=True, interactive=True) generate_button = gr.Button("Generate Video", variant="primary") with gr.Column(): # ASSIGNED elem_id="generated-video" so JS can find it video_output = gr.Video(label="Generated Video", autoplay=True, sources=["upload"], buttons=["download", "share"], interactive=True, elem_id="generated-video") # --- Frame Grabbing UI --- with gr.Row(): grab_frame_btn = gr.Button("📸 Use Current Frame as Input", variant="secondary") timestamp_box = gr.Number(value=0, label="Timestamp", visible=True, elem_id="hidden-timestamp") # ------------------------- file_output = gr.File(label="Download Video") # --- Examples table (hidden until password unlock — same gate as the admin panel) examples_group = gr.Group(visible=False) with examples_group: examples_widget = gr.Examples( examples=examples_to_samples(_examples_config), inputs=[input_image_component, prompt_input, motion_style_input], label="Examples", examples_per_page=8, ) # --------------------------------------------------------------------- # --- Admin: Edit Examples panel (hidden until password unlock) --- EDIT_PASSWORD = os.getenv("EDIT_PASSWORD") HF_TOKEN_ENV = os.getenv("HF_TOKEN") SPACE_ID_ENV = os.getenv("SPACE_ID") if EDIT_PASSWORD: # Discreet password input — the only visible hint of an admin feature. # Press Enter to unlock; once unlocked, this field disappears and the # admin panel reveals itself for the rest of the session. with gr.Row(): admin_pwd_entry = gr.Textbox( placeholder="🔑 admin password", type="password", show_label=False, max_lines=1, container=False, scale=1, ) admin_pwd_msg = gr.Markdown(visible=False) admin_panel = gr.Group(visible=False) with admin_panel: gr.Markdown("### 🔒 Edit Examples (admin)") gr.Markdown( "Edit each row directly in the fields below — upload an image, type " "label / prompt / motion style — then Apply to that row. Click " "**Save All Changes** to commit everything back to the repo." ) # State holding the current (possibly unsaved) examples list. examples_state = gr.State(value=[dict(r) for r in _examples_config]) _initial_idx = 0 if _examples_config else None _init_img, _init_label, _init_prompt, _init_style = _slot_to_inputs( _examples_config, _initial_idx ) row_selector = gr.Radio( choices=_row_choices(_examples_config), value=_initial_idx, label="Select a row to edit", info="Pick a row, edit the fields below, then click Apply.", ) with gr.Row(): add_row_btn = gr.Button("➕ Add new row", variant="secondary") remove_row_btn = gr.Button("🗑️ Remove selected row", variant="stop") gr.Markdown("---") gr.Markdown("**Edit the selected row:**") with gr.Row(): with gr.Column(scale=1, min_width=220): img_field = gr.Image( value=_init_img, type="filepath", label="Image (upload or paste)", sources=["upload", "clipboard"], height=220, ) with gr.Column(scale=2): label_field = gr.Textbox( value=_init_label, label="Label", info="Short title shown in the row selector.", max_lines=1, ) prompt_field = gr.Textbox( value=_init_prompt, label="Prompt", info="Text inserted into the Prompt box when the user clicks this example.", lines=3, ) style_field = gr.Textbox( value=_init_style, label="Motion Style", info="e.g. Static-Cam, Dolly-In, Pan-Right, Cinematic, Wind-Blow", max_lines=1, ) apply_btn = gr.Button("✅ Apply to selected row", variant="primary") apply_status = gr.Markdown() gr.Markdown("---") save_all_btn = gr.Button( "💾 Save All Changes & Commit to Repo", variant="primary" ) save_all_status = gr.Markdown() # === Unlock handler (single-shot, hides the password field) === def _try_unlock(pwd): if pwd and pwd == EDIT_PASSWORD: # Correct: reveal Examples + admin panel, hide the password field. return ( gr.update(visible=True), # examples_group gr.update(visible=True), # admin_panel gr.update(value="", visible=False), # admin_pwd_entry (hide) gr.update(value="", visible=False), # admin_pwd_msg ) elif pwd: # Wrong: everything stays hidden, clear value, show brief error. return ( gr.update(visible=False), gr.update(visible=False), gr.update(value=""), gr.update(value="❌ Wrong password.", visible=True), ) # Empty submit: do nothing return ( gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(visible=False), ) admin_pwd_entry.submit( fn=_try_unlock, inputs=[admin_pwd_entry], outputs=[examples_group, admin_panel, admin_pwd_entry, admin_pwd_msg], ) # === Editor handlers — no password re-check needed # (panel is only reachable after a correct password unlock) def _on_select_row(idx, state): img, label, prompt, style = _slot_to_inputs(state, idx) return ( gr.update(value=img), gr.update(value=label), gr.update(value=prompt), gr.update(value=style), ) row_selector.change( fn=_on_select_row, inputs=[row_selector, examples_state], outputs=[img_field, label_field, prompt_field, style_field], ) def _on_apply(img_path, label, prompt, style, idx, state): if idx is None: return state, gr.update(), "❌ No row selected. Click 'Add new row' first." if idx < 0 or idx >= len(state): return state, gr.update(), "❌ Invalid row index." new_state = [dict(r) for r in state] row = dict(new_state[idx]) # Image upload: if path is outside EXAMPLES_DIR, copy it in with a unique name. if img_path: abs_examples = os.path.abspath(EXAMPLES_DIR) abs_img = os.path.abspath(img_path) if not abs_img.startswith(abs_examples + os.sep): ext = os.path.splitext(img_path)[1].lower() or ".png" if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"): ext = ".png" new_fname = f"example_{uuid.uuid4().hex[:8]}{ext}" dest = os.path.join(EXAMPLES_DIR, new_fname) try: shutil.copy(img_path, dest) row["filename"] = new_fname except Exception as e: return state, gr.update(), f"❌ Image copy failed: {e}" row["label"] = (label or "").strip() or row.get("label", "") row["prompt"] = (prompt or "").strip() or row.get("prompt", "") row["style"] = (style or "").strip() or row.get("style", "") row.setdefault("color", [80, 90, 120]) new_state[idx] = row return ( new_state, gr.update(choices=_row_choices(new_state), value=idx), f"✅ Applied to row #{idx+1}. Click 'Save All Changes' to commit.", ) apply_btn.click( fn=_on_apply, inputs=[ img_field, label_field, prompt_field, style_field, row_selector, examples_state, ], outputs=[examples_state, row_selector, apply_status], ) def _on_add(state): new_row = { "filename": f"new_{uuid.uuid4().hex[:6]}.png", "color": [80, 90, 120], "label": "New Example", "prompt": "describe the motion you want", "style": "Cinematic", } _make_placeholder(new_row["label"], new_row["filename"], new_row["color"]) new_state = [dict(r) for r in (state or [])] + [new_row] new_idx = len(new_state) - 1 img, label, prompt, style = _slot_to_inputs(new_state, new_idx) return ( new_state, gr.update(choices=_row_choices(new_state), value=new_idx), gr.update(value=img), gr.update(value=label), gr.update(value=prompt), gr.update(value=style), f"➕ Added row #{new_idx+1}. Edit the fields and click Apply.", ) add_row_btn.click( fn=_on_add, inputs=[examples_state], outputs=[ examples_state, row_selector, img_field, label_field, prompt_field, style_field, apply_status, ], ) def _on_remove(idx, state): if not state: return state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "❌ Nothing to remove." if idx is None or idx < 0 or idx >= len(state): return state, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), "❌ Select a row first." new_state = [dict(r) for r in state] removed = new_state.pop(idx) new_idx = min(idx, len(new_state) - 1) if new_state else None img, label, prompt, style = _slot_to_inputs(new_state, new_idx) return ( new_state, gr.update(choices=_row_choices(new_state), value=new_idx), gr.update(value=img), gr.update(value=label), gr.update(value=prompt), gr.update(value=style), f"🗑️ Removed '{removed.get('filename', '?')}'.", ) remove_row_btn.click( fn=_on_remove, inputs=[row_selector, examples_state], outputs=[ examples_state, row_selector, img_field, label_field, prompt_field, style_field, apply_status, ], ) def _on_save_all(state): if not state: return "❌ Examples list is empty.", gr.update() # Final validation + default fixups. clean = [] for i, row in enumerate(state, 1): if not row.get("filename"): return f"❌ Row #{i}: missing filename (upload an image or Apply).", gr.update() if not row.get("prompt"): return f"❌ Row #{i}: prompt is empty.", gr.update() if not row.get("style"): return f"❌ Row #{i}: style is empty.", gr.update() fixed = dict(row) fixed.setdefault("label", fixed["filename"]) fixed.setdefault("color", [80, 90, 120]) clean.append(fixed) try: write_examples_config(clean) except Exception as e: return f"❌ Failed to write JSON: {e}", gr.update() # Ensure every referenced image exists (create placeholders for missing ones) for row in clean: _make_placeholder(row["label"], row["filename"], row["color"]) new_samples = examples_to_samples(clean) msgs = ["✅ Saved locally"] if HF_TOKEN_ENV and SPACE_ID_ENV: try: from huggingface_hub import CommitOperationAdd operations = [ CommitOperationAdd( path_in_repo="examples_config.json", path_or_fileobj=EXAMPLES_JSON_PATH, ) ] for row in clean: local_img = os.path.join(EXAMPLES_DIR, row["filename"]) if os.path.exists(local_img): operations.append(CommitOperationAdd( path_in_repo=f"{EXAMPLES_DIR}/{row['filename']}", path_or_fileobj=local_img, )) api = HfApi(token=HF_TOKEN_ENV) api.create_commit( repo_id=SPACE_ID_ENV, repo_type="space", operations=operations, commit_message="Update examples via in-app editor", ) msgs.append("✅ Committed to repo (Space will rebuild)") except Exception as e: msgs.append(f"⚠️ Repo commit failed: {e}") else: msgs.append("⚠️ HF_TOKEN/SPACE_ID not set — changes will reset on rebuild") return " · ".join(msgs), gr.update(samples=new_samples) save_all_btn.click( fn=_on_save_all, inputs=[examples_state], outputs=[save_all_status, examples_widget.dataset], ) # --------------------------------- ui_inputs = [ input_image_component, last_image_component, prompt_input, motion_style_input, steps_slider, negative_prompt_input, duration_seconds_input, guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox, quality_slider, scheduler_dropdown, flow_shift_slider, frame_multi, play_result_video ] generate_button.click( fn=generate_video, inputs=ui_inputs, outputs=[video_output, file_output, seed_input] ) # --- Frame Grabbing Events --- # 1. Click button -> JS runs -> puts time in hidden number box grab_frame_btn.click( fn=None, inputs=None, outputs=[timestamp_box], js=get_timestamp_js ) # 2. Hidden number box changes -> Python runs -> puts frame in Input Image timestamp_box.change( fn=extract_frame, inputs=[video_output, timestamp_box], outputs=[input_image_component] ) if __name__ == "__main__": demo.queue().launch( mcp_server=True, css=CSS, show_error=True, )