Reward-Forcing / app_wip.py
fffiloni's picture
Update app_wip.py
fffb44f verified
raw
history blame
4.84 kB
import sys
import subprocess
def ensure_flash_attn():
try:
import flash_attn # noqa: F401
print("[init] flash-attn déjà installé")
except Exception as e:
print("[init] Installation de flash-attn (build from source)...", e, flush=True)
subprocess.run(
[
sys.executable,
"-m",
"pip",
"install",
"flash-attn==2.7.4.post1",
"--no-build-isolation",
],
check=True,
)
import flash_attn # noqa: F401
print("[init] flash-attn OK")
ensure_flash_attn()
from huggingface_hub import snapshot_download
snapshot_download(
repo_id='Wan-AI/Wan2.1-T2V-1.3B',
local_dir='./checkpoints/Wan2.1-T2V-1.3B'
)
snapshot_download(
repo_id='KlingTeam/VideoReward',
local_dir='./checkpoints/Videoreward'
)
snapshot_download(
repo_id='gdhe17/Self-Forcing',
local_dir='./checkpoints/ode_init.pt'
)
snapshot_download(
repo_id='JaydenLu666/Reward-Forcing-T2V-1.3B',
local_dir='./checkpoints/Reward-Forcing-T2V-1.3B'
)
import os
import uuid
import subprocess
from datetime import datetime
import gradio as gr
# === Chemins à adapter si besoin ===
CONFIG_PATH = "configs/reward_forcing.yaml"
CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
PROMPT_DIR = "prompts/gradio_inputs"
OUTPUT_ROOT = "videos/gradio_outputs"
os.makedirs(PROMPT_DIR, exist_ok=True)
os.makedirs(OUTPUT_ROOT, exist_ok=True)
def run_inference(prompt: str, duration: str, use_ema: bool):
"""
1. Écrit le prompt dans un fichier .txt
2. Lance inference.py avec ce fichier comme --data_path
3. Retourne le chemin de la première vidéo .mp4 générée + les logs
"""
if not prompt or not prompt.strip():
raise gr.Error("Veuillez entrer un prompt texte 🙂")
# 1) On mappe la durée choisie → num_output_frames
if duration == "5s (21 frames)":
num_output_frames = 21
else: # "30s (120 frames)"
num_output_frames = 120
# 2) Fichier .txt temporaire pour le prompt
prompt_id = uuid.uuid4().hex[:8]
prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt")
with open(prompt_path, "w", encoding="utf-8") as f:
# TextDataset lit juste chaque ligne comme un prompt
f.write(prompt.strip() + "\n")
# 3) Dossier de sortie unique pour cette génération
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
output_folder = os.path.join(OUTPUT_ROOT, f"{ts}_{prompt_id}")
os.makedirs(output_folder, exist_ok=True)
# 4) Commande inference.py
cmd = [
"python",
"inference.py",
"--num_output_frames", str(num_output_frames),
"--config_path", CONFIG_PATH,
"--checkpoint_path", CHECKPOINT_PATH,
"--output_folder", output_folder,
"--data_path", prompt_path,
"--num_samples", "1",
]
if use_ema:
cmd.append("--use_ema")
result = subprocess.run(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
logs = result.stdout
print(logs)
# 5) On récupère la première vidéo produite
mp4s = [f for f in os.listdir(output_folder) if f.lower().endswith(".mp4")]
if not mp4s:
raise gr.Error(
"Aucune vidéo trouvée dans le dossier de sortie.\n"
"Regarde les logs ci-dessous pour voir ce qui a coincé."
)
mp4s.sort()
video_path = os.path.join(output_folder, mp4s[0])
return video_path, logs
with gr.Blocks(title="Reward Forcing T2V Demo") as demo:
gr.Markdown(
"""
# 🎬 Reward Forcing – Text-to-Video
Entrez un prompt texte, on génère un fichier `.txt` en interne
puis on lance `inference.py` avec ce fichier comme `--data_path`.
"""
)
with gr.Row():
prompt_in = gr.Textbox(
label="Prompt",
placeholder="Ex: A cinematic shot of a spaceship flying above a neon city at night...",
lines=4,
)
with gr.Row():
duration = gr.Radio(
["5s (21 frames)", "30s (120 frames)"],
value="5s (21 frames)",
label="Durée",
)
use_ema = gr.Checkbox(value=True, label="Utiliser les poids EMA (--use_ema)")
generate_btn = gr.Button("🚀 Générer la vidéo", variant="primary")
with gr.Row():
video_out = gr.Video(label="Vidéo générée")
logs_out = gr.Textbox(
label="Logs de inference.py",
lines=10,
interactive=False,
)
generate_btn.click(
fn=run_inference,
inputs=[prompt_in, duration, use_ema],
outputs=[video_out, logs_out],
)
if __name__ == "__main__":
demo.launch()