Reward-Forcing / app_wip.py
fffiloni's picture
English version
fa10854 verified
raw
history blame
9.31 kB
import os
import sys
import uuid
import shutil
import gradio as gr
import torch
from omegaconf import OmegaConf
from torchvision.io import write_video
from einops import rearrange
from huggingface_hub import snapshot_download
from pipeline import (
CausalDiffusionInferencePipeline,
CausalInferencePipeline,
)
from utils.dataset import TextDataset
from utils.misc import set_seed
from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller
# -------------------------------------------------------------------
# Download checkpoints once when the Space starts
# -------------------------------------------------------------------
snapshot_download(
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
local_dir="./checkpoints/Wan2.1-T2V-1.3B",
)
snapshot_download(
repo_id="KlingTeam/VideoReward",
local_dir="./checkpoints/Videoreward",
)
snapshot_download(
repo_id="gdhe17/Self-Forcing",
local_dir="./checkpoints/ode_init.pt",
)
snapshot_download(
repo_id="JaydenLu666/Reward-Forcing-T2V-1.3B",
local_dir="./checkpoints/Reward-Forcing-T2V-1.3B",
)
# === Paths ===
CONFIG_PATH = "configs/reward_forcing.yaml"
CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
PROMPT_DIR = "prompts/gradio_inputs"
OUTPUT_ROOT = "videos"
os.makedirs(PROMPT_DIR, exist_ok=True)
os.makedirs(OUTPUT_ROOT, exist_ok=True)
def reward_forcing_inference(
prompt_txt_path: str,
num_output_frames: int,
use_ema: bool,
output_root: str,
progress: gr.Progress,
):
"""
Inline / simplified version of inference.py:
- single GPU
- text-to-video only
- one .txt file = N prompts, but we return only the first generated video
"""
logs = ""
# --------------------- Device & randomness ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(0)
free_vram = get_cuda_free_memory_gb(device)
logs += f"Free VRAM {free_vram} GB\n"
low_memory = free_vram < 40
torch.set_grad_enabled(False)
# --------------------- Stage 1: model & config init ---------------------
progress(0.05, desc="Init: loading config")
logs += "Loading config...\n"
config = OmegaConf.load(CONFIG_PATH)
default_config = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(default_config, config)
progress(0.15, desc="Init: creating pipeline")
logs += "Creating pipeline...\n"
if hasattr(config, "denoising_step_list"):
# few-step sampling pipeline
pipeline = CausalInferencePipeline(config, device=device)
else:
# full diffusion pipeline
pipeline = CausalDiffusionInferencePipeline(config, device=device)
progress(0.35, desc="Init: loading checkpoint")
logs += "Loading checkpoint weights...\n"
state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
pipeline.generator.load_state_dict(state_dict)
checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
checkpoint_step = checkpoint_step.split("_")[-1]
progress(0.55, desc="Init: moving model to device")
logs += "Moving model to device...\n"
pipeline = pipeline.to(dtype=torch.bfloat16)
if low_memory:
DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
else:
pipeline.text_encoder.to(device=device)
pipeline.generator.to(device=device)
pipeline.vae.to(device=device)
# --------------------- Dataset setup ---------------------
progress(0.65, desc="Preparing dataset")
logs += "Preparing dataset (TextDataset)...\n"
dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
num_prompts = len(dataset)
logs += f"Number of prompts: {num_prompts}\n"
from torch.utils.data import DataLoader, SequentialSampler
sampler = SequentialSampler(dataset)
dataloader = DataLoader(
dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False
)
# --------------------- Make a clean output directory ---------------------
progress(0.7, desc="Cleaning output folder")
output_folder = os.path.join(
output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
)
shutil.rmtree(output_folder, ignore_errors=True)
os.makedirs(output_folder, exist_ok=True)
logs += f"Output directory: {output_folder}\n"
# --------------------- Stage 2: inference loop ---------------------
# Gradio can track tqdm progress on iterable loops
for i, batch_data in progress.tqdm(
enumerate(dataloader),
total=num_prompts,
desc="Video generation",
unit="prompt",
):
idx = batch_data["idx"].item()
# Unpack dataset batch
if isinstance(batch_data, dict):
batch = batch_data
elif isinstance(batch_data, list):
batch = batch_data[0]
else:
batch = batch_data
all_video = []
# TEXT-TO-VIDEO only (no I2V here)
prompt = batch["prompts"][0]
extended_prompt = batch.get("extended_prompts", [None])[0]
if extended_prompt is not None:
prompts = [extended_prompt]
else:
prompts = [prompt]
initial_latent = None
# Noise tensor shape matches WAN2 expected latent dims
sampled_noise = torch.randn(
[1, num_output_frames, 16, 60, 104],
device=device,
dtype=torch.bfloat16,
)
logs += f"Generating for prompt: {prompt[:80]}...\n"
# Run WAN inference
video, latents = pipeline.inference(
noise=sampled_noise,
text_prompts=prompts,
return_latents=True,
initial_latent=initial_latent,
low_memory=low_memory,
)
current_video = rearrange(video, "b t c h w -> b t h w c").cpu()
all_video.append(current_video)
# convert to uint8 *after* concatenation
video = 255.0 * torch.cat(all_video, dim=1)
# free VAE cache between clips
pipeline.vae.model.clear_cache()
# Save only the first video
if idx < num_prompts:
model = "regular" if not use_ema else "ema"
safe_name = prompt[:50].replace("/", "_").replace("\\", "_")
output_path = os.path.join(output_folder, f"{safe_name}.mp4")
write_video(output_path, video[0], fps=16)
logs += f"Saved video: {output_path}\n"
progress(1.0, desc="Done ✅")
return output_path, logs
logs += "[WARN] No video generated in loop.\n"
return None, logs
def gradio_generate(
prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True)
):
"""
Triggered by Gradio:
- writes prompt to a temporary .txt file
- runs reward_forcing_inference
- returns video + logs
"""
if not prompt or not prompt.strip():
raise gr.Error("Please type a text prompt 🙂")
# Duration -> number of latent timesteps
if duration == "5s (21 frames)":
num_output_frames = 21
else:
num_output_frames = 120
os.makedirs(PROMPT_DIR, exist_ok=True)
prompt_id = uuid.uuid4().hex[:8]
prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt")
with open(prompt_path, "w", encoding="utf-8") as f:
f.write(prompt.strip() + "\n")
video_path, logs = reward_forcing_inference(
prompt_txt_path=prompt_path,
num_output_frames=num_output_frames,
use_ema=use_ema,
output_root=OUTPUT_ROOT,
progress=progress,
)
if video_path is None or not os.path.exists(video_path):
raise gr.Error(
"No video generated.\n"
"Check the logs below for errors."
)
return video_path, logs
# -------------------------------------------------------------------
# Gradio UI
# -------------------------------------------------------------------
with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
gr.Markdown(
"""
# 🎬 Reward Forcing – Text-to-Video (inline)
This version directly calls the inference logic in Python,
allowing Gradio to track:
- model initialization via `progress(...)`
- video generation progress via `progress.tqdm(...)`
"""
)
with gr.Row():
prompt_in = gr.Textbox(
label="Prompt",
placeholder="Ex: A cinematic shot of a spaceship flying above a neon city at night...",
lines=4,
)
with gr.Row():
duration = gr.Radio(
["5s (21 frames)", "30s (120 frames)"],
value="5s (21 frames)",
label="Duration",
)
use_ema = gr.Checkbox(value=True, label="Use EMA weights (--use_ema)")
generate_btn = gr.Button("🚀 Generate Video", variant="primary")
with gr.Row():
video_out = gr.Video(label="Generated Video")
logs_out = gr.Textbox(
label="Logs",
lines=12,
interactive=False,
)
generate_btn.click(
fn=gradio_generate,
inputs=[prompt_in, duration, use_ema],
outputs=[video_out, logs_out],
)
demo.queue()
if __name__ == "__main__":
demo.launch()