Spaces:
Paused
Paused
| import sys | |
| import subprocess | |
| def ensure_flash_attn(): | |
| try: | |
| import flash_attn # noqa: F401 | |
| print("[init] flash-attn déjà installé") | |
| except Exception as e: | |
| print("[init] Installation de flash-attn (build from source)...", e, flush=True) | |
| subprocess.run( | |
| [ | |
| sys.executable, | |
| "-m", | |
| "pip", | |
| "install", | |
| "flash-attn==2.7.4.post1", | |
| "--no-build-isolation", | |
| ], | |
| check=True, | |
| ) | |
| import flash_attn # noqa: F401 | |
| print("[init] flash-attn OK") | |
| ensure_flash_attn() | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id='Wan-AI/Wan2.1-T2V-1.3B', | |
| local_dir='./checkpoints/Wan2.1-T2V-1.3B' | |
| ) | |
| snapshot_download( | |
| repo_id='KlingTeam/VideoReward', | |
| local_dir='./checkpoints/Videoreward' | |
| ) | |
| snapshot_download( | |
| repo_id='gdhe17/Self-Forcing', | |
| local_dir='./checkpoints/ode_init.pt' | |
| ) | |
| snapshot_download( | |
| repo_id='JaydenLu666/Reward-Forcing-T2V-1.3B', | |
| local_dir='./checkpoints/Reward-Forcing-T2V-1.3B' | |
| ) | |
| import os | |
| import uuid | |
| import subprocess | |
| from datetime import datetime | |
| import gradio as gr | |
| # === Chemins à adapter si besoin === | |
| CONFIG_PATH = "configs/reward_forcing.yaml" | |
| CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt" | |
| PROMPT_DIR = "prompts/gradio_inputs" | |
| OUTPUT_ROOT = "videos/gradio_outputs" | |
| os.makedirs(PROMPT_DIR, exist_ok=True) | |
| os.makedirs(OUTPUT_ROOT, exist_ok=True) | |
| def run_inference(prompt: str, duration: str, use_ema: bool): | |
| """ | |
| 1. Écrit le prompt dans un fichier .txt | |
| 2. Lance inference.py avec ce fichier comme --data_path | |
| 3. Retourne le chemin de la première vidéo .mp4 générée + les logs | |
| """ | |
| if not prompt or not prompt.strip(): | |
| raise gr.Error("Veuillez entrer un prompt texte 🙂") | |
| # 1) On mappe la durée choisie → num_output_frames | |
| if duration == "5s (21 frames)": | |
| num_output_frames = 21 | |
| else: # "30s (120 frames)" | |
| num_output_frames = 120 | |
| # 2) Fichier .txt temporaire pour le prompt | |
| prompt_id = uuid.uuid4().hex[:8] | |
| prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt") | |
| with open(prompt_path, "w", encoding="utf-8") as f: | |
| # TextDataset lit juste chaque ligne comme un prompt | |
| f.write(prompt.strip() + "\n") | |
| # 3) Dossier de sortie unique pour cette génération | |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| output_folder = os.path.join(OUTPUT_ROOT, f"{ts}_{prompt_id}") | |
| os.makedirs(output_folder, exist_ok=True) | |
| # 4) Commande inference.py | |
| cmd = [ | |
| "python", | |
| "inference.py", | |
| "--num_output_frames", str(num_output_frames), | |
| "--config_path", CONFIG_PATH, | |
| "--checkpoint_path", CHECKPOINT_PATH, | |
| "--output_folder", output_folder, | |
| "--data_path", prompt_path, | |
| "--num_samples", "1", | |
| ] | |
| if use_ema: | |
| cmd.append("--use_ema") | |
| result = subprocess.run( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| ) | |
| logs = result.stdout | |
| print(logs) | |
| # 5) On récupère la première vidéo produite | |
| mp4s = [f for f in os.listdir(output_folder) if f.lower().endswith(".mp4")] | |
| if not mp4s: | |
| raise gr.Error( | |
| "Aucune vidéo trouvée dans le dossier de sortie.\n" | |
| "Regarde les logs ci-dessous pour voir ce qui a coincé." | |
| ) | |
| mp4s.sort() | |
| video_path = os.path.join(output_folder, mp4s[0]) | |
| return video_path, logs | |
| with gr.Blocks(title="Reward Forcing T2V Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎬 Reward Forcing – Text-to-Video | |
| Entrez un prompt texte, on génère un fichier `.txt` en interne | |
| puis on lance `inference.py` avec ce fichier comme `--data_path`. | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt_in = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Ex: A cinematic shot of a spaceship flying above a neon city at night...", | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| duration = gr.Radio( | |
| ["5s (21 frames)", "30s (120 frames)"], | |
| value="5s (21 frames)", | |
| label="Durée", | |
| ) | |
| use_ema = gr.Checkbox(value=True, label="Utiliser les poids EMA (--use_ema)") | |
| generate_btn = gr.Button("🚀 Générer la vidéo", variant="primary") | |
| with gr.Row(): | |
| video_out = gr.Video(label="Vidéo générée") | |
| logs_out = gr.Textbox( | |
| label="Logs de inference.py", | |
| lines=10, | |
| interactive=False, | |
| ) | |
| generate_btn.click( | |
| fn=run_inference, | |
| inputs=[prompt_in, duration, use_ema], | |
| outputs=[video_out, logs_out], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |