Spaces:
Paused
Paused
Update app_wip.py
Browse files- app_wip.py +48 -42
app_wip.py
CHANGED
|
@@ -2,12 +2,10 @@ import os
|
|
| 2 |
import sys
|
| 3 |
import uuid
|
| 4 |
import shutil
|
| 5 |
-
from datetime import datetime
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
| 9 |
from omegaconf import OmegaConf
|
| 10 |
-
from tqdm import tqdm
|
| 11 |
from torchvision.io import write_video
|
| 12 |
from einops import rearrange
|
| 13 |
from huggingface_hub import snapshot_download
|
|
@@ -79,40 +77,39 @@ def reward_forcing_inference(
|
|
| 79 |
|
| 80 |
torch.set_grad_enabled(False)
|
| 81 |
|
| 82 |
-
# ---------------------
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
pipeline.vae.to(device=device)
|
| 113 |
-
pbar.update(1)
|
| 114 |
|
| 115 |
# --------------------- Dataset / DataLoader ---------------------
|
|
|
|
| 116 |
logs += "Préparation du dataset (TextDataset)...\n"
|
| 117 |
dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
|
| 118 |
num_prompts = len(dataset)
|
|
@@ -121,15 +118,21 @@ def reward_forcing_inference(
|
|
| 121 |
from torch.utils.data import DataLoader, SequentialSampler
|
| 122 |
|
| 123 |
sampler = SequentialSampler(dataset)
|
| 124 |
-
dataloader = DataLoader(
|
|
|
|
|
|
|
| 125 |
|
| 126 |
# --------------------- Output folder (on le vide) ---------------------
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
| 128 |
shutil.rmtree(output_folder, ignore_errors=True)
|
| 129 |
os.makedirs(output_folder, exist_ok=True)
|
| 130 |
logs += f"Dossier de sortie: {output_folder}\n"
|
| 131 |
|
| 132 |
-
# ---------------------
|
|
|
|
| 133 |
for i, batch_data in progress.tqdm(
|
| 134 |
enumerate(dataloader),
|
| 135 |
total=num_prompts,
|
|
@@ -190,13 +193,17 @@ def reward_forcing_inference(
|
|
| 190 |
output_path = os.path.join(output_folder, f"{safe_name}.mp4")
|
| 191 |
write_video(output_path, video[0], fps=16)
|
| 192 |
logs += f"Vidéo enregistrée: {output_path}\n"
|
|
|
|
|
|
|
| 193 |
return output_path, logs
|
| 194 |
|
| 195 |
logs += "[WARN] Aucune vidéo générée dans la boucle.\n"
|
| 196 |
return None, logs
|
| 197 |
|
| 198 |
|
| 199 |
-
def gradio_generate(
|
|
|
|
|
|
|
| 200 |
"""
|
| 201 |
Fonction appelée par Gradio :
|
| 202 |
- écrit le prompt dans un .txt
|
|
@@ -219,7 +226,6 @@ def gradio_generate(prompt: str, duration: str, use_ema: bool, progress=gr.Progr
|
|
| 219 |
with open(prompt_path, "w", encoding="utf-8") as f:
|
| 220 |
f.write(prompt.strip() + "\n")
|
| 221 |
|
| 222 |
-
# Appel de la fonction d'inférence inline
|
| 223 |
video_path, logs = reward_forcing_inference(
|
| 224 |
prompt_txt_path=prompt_path,
|
| 225 |
num_output_frames=num_output_frames,
|
|
@@ -247,9 +253,9 @@ with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
|
|
| 247 |
# 🎬 Reward Forcing – Text-to-Video (inline)
|
| 248 |
|
| 249 |
Cette version appelle directement la logique d'inférence en Python,
|
| 250 |
-
ce qui permet à Gradio de suivre
|
| 251 |
-
-
|
| 252 |
-
-
|
| 253 |
"""
|
| 254 |
)
|
| 255 |
|
|
|
|
| 2 |
import sys
|
| 3 |
import uuid
|
| 4 |
import shutil
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import torch
|
| 8 |
from omegaconf import OmegaConf
|
|
|
|
| 9 |
from torchvision.io import write_video
|
| 10 |
from einops import rearrange
|
| 11 |
from huggingface_hub import snapshot_download
|
|
|
|
| 77 |
|
| 78 |
torch.set_grad_enabled(False)
|
| 79 |
|
| 80 |
+
# --------------------- Phase 1 : init modèle / config ---------------------
|
| 81 |
+
progress(0.05, desc="Initialisation : chargement de la config")
|
| 82 |
+
logs += "Chargement de la config...\n"
|
| 83 |
+
config = OmegaConf.load(CONFIG_PATH)
|
| 84 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
| 85 |
+
config = OmegaConf.merge(default_config, config)
|
| 86 |
+
|
| 87 |
+
progress(0.15, desc="Initialisation : création de la pipeline")
|
| 88 |
+
logs += "Initialisation de la pipeline...\n"
|
| 89 |
+
if hasattr(config, "denoising_step_list"):
|
| 90 |
+
pipeline = CausalInferencePipeline(config, device=device)
|
| 91 |
+
else:
|
| 92 |
+
pipeline = CausalDiffusionInferencePipeline(config, device=device)
|
| 93 |
+
|
| 94 |
+
progress(0.35, desc="Initialisation : chargement du checkpoint")
|
| 95 |
+
logs += "Chargement des poids du checkpoint...\n"
|
| 96 |
+
state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
|
| 97 |
+
pipeline.generator.load_state_dict(state_dict)
|
| 98 |
+
checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
|
| 99 |
+
checkpoint_step = checkpoint_step.split("_")[-1]
|
| 100 |
+
|
| 101 |
+
progress(0.55, desc="Initialisation : placement sur le device")
|
| 102 |
+
logs += "Placement du modèle sur le device...\n"
|
| 103 |
+
pipeline = pipeline.to(dtype=torch.bfloat16)
|
| 104 |
+
if low_memory:
|
| 105 |
+
DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
|
| 106 |
+
else:
|
| 107 |
+
pipeline.text_encoder.to(device=device)
|
| 108 |
+
pipeline.generator.to(device=device)
|
| 109 |
+
pipeline.vae.to(device=device)
|
|
|
|
|
|
|
| 110 |
|
| 111 |
# --------------------- Dataset / DataLoader ---------------------
|
| 112 |
+
progress(0.65, desc="Préparation du dataset")
|
| 113 |
logs += "Préparation du dataset (TextDataset)...\n"
|
| 114 |
dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
|
| 115 |
num_prompts = len(dataset)
|
|
|
|
| 118 |
from torch.utils.data import DataLoader, SequentialSampler
|
| 119 |
|
| 120 |
sampler = SequentialSampler(dataset)
|
| 121 |
+
dataloader = DataLoader(
|
| 122 |
+
dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False
|
| 123 |
+
)
|
| 124 |
|
| 125 |
# --------------------- Output folder (on le vide) ---------------------
|
| 126 |
+
progress(0.7, desc="Nettoyage du dossier de sortie")
|
| 127 |
+
output_folder = os.path.join(
|
| 128 |
+
output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
|
| 129 |
+
)
|
| 130 |
shutil.rmtree(output_folder, ignore_errors=True)
|
| 131 |
os.makedirs(output_folder, exist_ok=True)
|
| 132 |
logs += f"Dossier de sortie: {output_folder}\n"
|
| 133 |
|
| 134 |
+
# --------------------- Phase 2 : boucle d'inférence ---------------------
|
| 135 |
+
# Ici on peut utiliser progress.tqdm sur la boucle dataloader
|
| 136 |
for i, batch_data in progress.tqdm(
|
| 137 |
enumerate(dataloader),
|
| 138 |
total=num_prompts,
|
|
|
|
| 193 |
output_path = os.path.join(output_folder, f"{safe_name}.mp4")
|
| 194 |
write_video(output_path, video[0], fps=16)
|
| 195 |
logs += f"Vidéo enregistrée: {output_path}\n"
|
| 196 |
+
|
| 197 |
+
progress(1.0, desc="Terminé ✅")
|
| 198 |
return output_path, logs
|
| 199 |
|
| 200 |
logs += "[WARN] Aucune vidéo générée dans la boucle.\n"
|
| 201 |
return None, logs
|
| 202 |
|
| 203 |
|
| 204 |
+
def gradio_generate(
|
| 205 |
+
prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True)
|
| 206 |
+
):
|
| 207 |
"""
|
| 208 |
Fonction appelée par Gradio :
|
| 209 |
- écrit le prompt dans un .txt
|
|
|
|
| 226 |
with open(prompt_path, "w", encoding="utf-8") as f:
|
| 227 |
f.write(prompt.strip() + "\n")
|
| 228 |
|
|
|
|
| 229 |
video_path, logs = reward_forcing_inference(
|
| 230 |
prompt_txt_path=prompt_path,
|
| 231 |
num_output_frames=num_output_frames,
|
|
|
|
| 253 |
# 🎬 Reward Forcing – Text-to-Video (inline)
|
| 254 |
|
| 255 |
Cette version appelle directement la logique d'inférence en Python,
|
| 256 |
+
ce qui permet à Gradio de suivre :
|
| 257 |
+
- l'initialisation du modèle (via `progress(...)`)
|
| 258 |
+
- la boucle de génération (via `progress.tqdm(...)`)
|
| 259 |
"""
|
| 260 |
)
|
| 261 |
|