Spaces:
Paused
Paused
back to previous app steady
Browse files- app_wip.py +53 -78
app_wip.py
CHANGED
|
@@ -51,70 +51,6 @@ OUTPUT_ROOT = "videos"
|
|
| 51 |
os.makedirs(PROMPT_DIR, exist_ok=True)
|
| 52 |
os.makedirs(OUTPUT_ROOT, exist_ok=True)
|
| 53 |
|
| 54 |
-
# === Globals pour le cache du modèle ===
|
| 55 |
-
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 56 |
-
PIPELINE = None
|
| 57 |
-
LOW_MEMORY = None
|
| 58 |
-
CHECKPOINT_STEP = None
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def load_pipeline(progress: gr.Progress):
|
| 62 |
-
"""
|
| 63 |
-
Charge la config + pipeline + checkpoint + placement device une seule fois.
|
| 64 |
-
Utilise progress.tqdm pour afficher plusieurs étapes la 1ère fois.
|
| 65 |
-
"""
|
| 66 |
-
global PIPELINE, LOW_MEMORY, CHECKPOINT_STEP
|
| 67 |
-
|
| 68 |
-
logs = ""
|
| 69 |
-
|
| 70 |
-
# Si déjà chargé, on ne refait rien de lourd
|
| 71 |
-
if PIPELINE is not None:
|
| 72 |
-
progress(0.1, desc="Modèle déjà initialisé (cache)")
|
| 73 |
-
logs += "Modèle déjà initialisé, réutilisation du cache.\n"
|
| 74 |
-
return PIPELINE, LOW_MEMORY, CHECKPOINT_STEP, logs
|
| 75 |
-
|
| 76 |
-
# ---- Première initialisation lourde ----
|
| 77 |
-
set_seed(0)
|
| 78 |
-
free_vram = get_cuda_free_memory_gb(DEVICE)
|
| 79 |
-
LOW_MEMORY = free_vram < 40
|
| 80 |
-
logs += f"Free VRAM {free_vram} GB\n"
|
| 81 |
-
|
| 82 |
-
steps = range(4)
|
| 83 |
-
for step in progress.tqdm(steps, desc="Initialisation du modèle", unit="étape"):
|
| 84 |
-
if step == 0:
|
| 85 |
-
logs += "Étape 1/4 : Chargement de la config...\n"
|
| 86 |
-
config = OmegaConf.load(CONFIG_PATH)
|
| 87 |
-
default_config = OmegaConf.load("configs/default_config.yaml")
|
| 88 |
-
config = OmegaConf.merge(default_config, config)
|
| 89 |
-
|
| 90 |
-
elif step == 1:
|
| 91 |
-
logs += "Étape 2/4 : Création de la pipeline...\n"
|
| 92 |
-
if hasattr(config, "denoising_step_list"):
|
| 93 |
-
PIPELINE = CausalInferencePipeline(config, device=DEVICE)
|
| 94 |
-
else:
|
| 95 |
-
PIPELINE = CausalDiffusionInferencePipeline(config, device=DEVICE)
|
| 96 |
-
|
| 97 |
-
elif step == 2:
|
| 98 |
-
logs += "Étape 3/4 : Chargement des poids du checkpoint...\n"
|
| 99 |
-
state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
|
| 100 |
-
PIPELINE.generator.load_state_dict(state_dict)
|
| 101 |
-
ckpt_dir = os.path.dirname(CHECKPOINT_PATH)
|
| 102 |
-
CHECKPOINT_STEP = os.path.basename(ckpt_dir)
|
| 103 |
-
CHECKPOINT_STEP = CHECKPOINT_STEP.split("_")[-1]
|
| 104 |
-
|
| 105 |
-
elif step == 3:
|
| 106 |
-
logs += "Étape 4/4 : Placement du modèle sur le device...\n"
|
| 107 |
-
PIPELINE = PIPELINE.to(dtype=torch.bfloat16)
|
| 108 |
-
if LOW_MEMORY:
|
| 109 |
-
DynamicSwapInstaller.install_model(PIPELINE.text_encoder, device=DEVICE)
|
| 110 |
-
else:
|
| 111 |
-
PIPELINE.text_encoder.to(device=DEVICE)
|
| 112 |
-
PIPELINE.generator.to(device=DEVICE)
|
| 113 |
-
PIPELINE.vae.to(device=DEVICE)
|
| 114 |
-
|
| 115 |
-
logs += "Initialisation du modèle terminée ✅\n"
|
| 116 |
-
return PIPELINE, LOW_MEMORY, CHECKPOINT_STEP, logs
|
| 117 |
-
|
| 118 |
|
| 119 |
def reward_forcing_inference(
|
| 120 |
prompt_txt_path: str,
|
|
@@ -125,17 +61,55 @@ def reward_forcing_inference(
|
|
| 125 |
):
|
| 126 |
"""
|
| 127 |
Version inline / simplifiée de inference.py :
|
|
|
|
| 128 |
- T2V uniquement
|
| 129 |
- 1 fichier .txt = n prompts (mais on retourne la 1ère vidéo)
|
| 130 |
"""
|
| 131 |
logs = ""
|
| 132 |
|
| 133 |
-
# ---------------------
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
# --------------------- Dataset / DataLoader ---------------------
|
| 138 |
-
progress(0.
|
| 139 |
logs += "Préparation du dataset (TextDataset)...\n"
|
| 140 |
dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
|
| 141 |
num_prompts = len(dataset)
|
|
@@ -149,7 +123,7 @@ def reward_forcing_inference(
|
|
| 149 |
)
|
| 150 |
|
| 151 |
# --------------------- Output folder (on le vide) ---------------------
|
| 152 |
-
progress(0.
|
| 153 |
output_folder = os.path.join(
|
| 154 |
output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
|
| 155 |
)
|
|
@@ -157,7 +131,8 @@ def reward_forcing_inference(
|
|
| 157 |
os.makedirs(output_folder, exist_ok=True)
|
| 158 |
logs += f"Dossier de sortie: {output_folder}\n"
|
| 159 |
|
| 160 |
-
# ---------------------
|
|
|
|
| 161 |
for i, batch_data in progress.tqdm(
|
| 162 |
enumerate(dataloader),
|
| 163 |
total=num_prompts,
|
|
@@ -176,7 +151,7 @@ def reward_forcing_inference(
|
|
| 176 |
|
| 177 |
all_video = []
|
| 178 |
|
| 179 |
-
# TEXT-TO-VIDEO uniquement
|
| 180 |
prompt = batch["prompts"][0]
|
| 181 |
extended_prompt = batch.get("extended_prompts", [None])[0]
|
| 182 |
if extended_prompt is not None:
|
|
@@ -188,7 +163,7 @@ def reward_forcing_inference(
|
|
| 188 |
|
| 189 |
sampled_noise = torch.randn(
|
| 190 |
[1, num_output_frames, 16, 60, 104],
|
| 191 |
-
device=
|
| 192 |
dtype=torch.bfloat16,
|
| 193 |
)
|
| 194 |
|
|
@@ -272,15 +247,15 @@ def gradio_generate(
|
|
| 272 |
# UI Gradio
|
| 273 |
# -------------------------------------------------------------------
|
| 274 |
|
| 275 |
-
with gr.Blocks(title="Reward Forcing T2V Demo (inline
|
| 276 |
gr.Markdown(
|
| 277 |
"""
|
| 278 |
-
# 🎬 Reward Forcing – Text-to-Video (inline
|
| 279 |
|
| 280 |
-
Cette version
|
| 281 |
-
|
| 282 |
-
-
|
| 283 |
-
-
|
| 284 |
"""
|
| 285 |
)
|
| 286 |
|
|
|
|
| 51 |
os.makedirs(PROMPT_DIR, exist_ok=True)
|
| 52 |
os.makedirs(OUTPUT_ROOT, exist_ok=True)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def reward_forcing_inference(
|
| 56 |
prompt_txt_path: str,
|
|
|
|
| 61 |
):
|
| 62 |
"""
|
| 63 |
Version inline / simplifiée de inference.py :
|
| 64 |
+
- single GPU
|
| 65 |
- T2V uniquement
|
| 66 |
- 1 fichier .txt = n prompts (mais on retourne la 1ère vidéo)
|
| 67 |
"""
|
| 68 |
logs = ""
|
| 69 |
|
| 70 |
+
# --------------------- Device & seed ---------------------
|
| 71 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 72 |
+
set_seed(0)
|
| 73 |
+
|
| 74 |
+
free_vram = get_cuda_free_memory_gb(device)
|
| 75 |
+
logs += f"Free VRAM {free_vram} GB\n"
|
| 76 |
+
low_memory = free_vram < 40
|
| 77 |
+
|
| 78 |
+
torch.set_grad_enabled(False)
|
| 79 |
+
|
| 80 |
+
# --------------------- Phase 1 : init modèle / config ---------------------
|
| 81 |
+
progress(0.05, desc="Initialisation : chargement de la config")
|
| 82 |
+
logs += "Chargement de la config...\n"
|
| 83 |
+
config = OmegaConf.load(CONFIG_PATH)
|
| 84 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
| 85 |
+
config = OmegaConf.merge(default_config, config)
|
| 86 |
+
|
| 87 |
+
progress(0.15, desc="Initialisation : création de la pipeline")
|
| 88 |
+
logs += "Initialisation de la pipeline...\n"
|
| 89 |
+
if hasattr(config, "denoising_step_list"):
|
| 90 |
+
pipeline = CausalInferencePipeline(config, device=device)
|
| 91 |
+
else:
|
| 92 |
+
pipeline = CausalDiffusionInferencePipeline(config, device=device)
|
| 93 |
+
|
| 94 |
+
progress(0.35, desc="Initialisation : chargement du checkpoint")
|
| 95 |
+
logs += "Chargement des poids du checkpoint...\n"
|
| 96 |
+
state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
|
| 97 |
+
pipeline.generator.load_state_dict(state_dict)
|
| 98 |
+
checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
|
| 99 |
+
checkpoint_step = checkpoint_step.split("_")[-1]
|
| 100 |
+
|
| 101 |
+
progress(0.55, desc="Initialisation : placement sur le device")
|
| 102 |
+
logs += "Placement du modèle sur le device...\n"
|
| 103 |
+
pipeline = pipeline.to(dtype=torch.bfloat16)
|
| 104 |
+
if low_memory:
|
| 105 |
+
DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
|
| 106 |
+
else:
|
| 107 |
+
pipeline.text_encoder.to(device=device)
|
| 108 |
+
pipeline.generator.to(device=device)
|
| 109 |
+
pipeline.vae.to(device=device)
|
| 110 |
|
| 111 |
# --------------------- Dataset / DataLoader ---------------------
|
| 112 |
+
progress(0.65, desc="Préparation du dataset")
|
| 113 |
logs += "Préparation du dataset (TextDataset)...\n"
|
| 114 |
dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
|
| 115 |
num_prompts = len(dataset)
|
|
|
|
| 123 |
)
|
| 124 |
|
| 125 |
# --------------------- Output folder (on le vide) ---------------------
|
| 126 |
+
progress(0.7, desc="Nettoyage du dossier de sortie")
|
| 127 |
output_folder = os.path.join(
|
| 128 |
output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
|
| 129 |
)
|
|
|
|
| 131 |
os.makedirs(output_folder, exist_ok=True)
|
| 132 |
logs += f"Dossier de sortie: {output_folder}\n"
|
| 133 |
|
| 134 |
+
# --------------------- Phase 2 : boucle d'inférence ---------------------
|
| 135 |
+
# Ici on peut utiliser progress.tqdm sur la boucle dataloader
|
| 136 |
for i, batch_data in progress.tqdm(
|
| 137 |
enumerate(dataloader),
|
| 138 |
total=num_prompts,
|
|
|
|
| 151 |
|
| 152 |
all_video = []
|
| 153 |
|
| 154 |
+
# TEXT-TO-VIDEO uniquement (pas d'I2V ici)
|
| 155 |
prompt = batch["prompts"][0]
|
| 156 |
extended_prompt = batch.get("extended_prompts", [None])[0]
|
| 157 |
if extended_prompt is not None:
|
|
|
|
| 163 |
|
| 164 |
sampled_noise = torch.randn(
|
| 165 |
[1, num_output_frames, 16, 60, 104],
|
| 166 |
+
device=device,
|
| 167 |
dtype=torch.bfloat16,
|
| 168 |
)
|
| 169 |
|
|
|
|
| 247 |
# UI Gradio
|
| 248 |
# -------------------------------------------------------------------
|
| 249 |
|
| 250 |
+
with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
|
| 251 |
gr.Markdown(
|
| 252 |
"""
|
| 253 |
+
# 🎬 Reward Forcing – Text-to-Video (inline)
|
| 254 |
|
| 255 |
+
Cette version appelle directement la logique d'inférence en Python,
|
| 256 |
+
ce qui permet à Gradio de suivre :
|
| 257 |
+
- l'initialisation du modèle (via `progress(...)`)
|
| 258 |
+
- la boucle de génération (via `progress.tqdm(...)`)
|
| 259 |
"""
|
| 260 |
)
|
| 261 |
|