fffiloni commited on
Commit
c0f475e
·
verified ·
1 Parent(s): 8f3f345

full inference integration

Browse files
Files changed (1) hide show
  1. app_wip.py +192 -122
app_wip.py CHANGED
@@ -1,40 +1,27 @@
1
- import sys
2
  import os
 
3
  import uuid
4
- import subprocess
5
  from datetime import datetime
6
 
7
  import gradio as gr
 
 
 
 
 
8
  from huggingface_hub import snapshot_download
9
 
10
- # -------------------------------------------------------------------
11
- # (Optionnel) flash-attn : comme tu as déjà la bonne wheel dans
12
- # requirements.txt, on laisse commenté pour éviter des builds lents.
13
- # -------------------------------------------------------------------
14
- # def ensure_flash_attn():
15
- # try:
16
- # import flash_attn # noqa: F401
17
- # print("[init] flash-attn déjà installé")
18
- # except Exception as e:
19
- # print("[init] Installation de flash-attn (build from source)...", e, flush=True)
20
- # subprocess.run(
21
- # [
22
- # sys.executable,
23
- # "-m",
24
- # "pip",
25
- # "install",
26
- # "flash-attn==2.7.4.post1",
27
- # "--no-build-isolation",
28
- # ],
29
- # check=True,
30
- # )
31
- # import flash_attn # noqa: F401
32
- # print("[init] flash-attn OK")
33
-
34
- # ensure_flash_attn()
35
 
36
  # -------------------------------------------------------------------
37
- # Téléchargement des checkpoints (fait une fois au démarrage du Space)
38
  # -------------------------------------------------------------------
39
  snapshot_download(
40
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
@@ -61,133 +48,215 @@ CONFIG_PATH = "configs/reward_forcing.yaml"
61
  CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
62
 
63
  PROMPT_DIR = "prompts/gradio_inputs"
64
- # on garde OUTPUT_ROOT mais on va aussi coller au README pour l'output
65
  OUTPUT_ROOT = "videos"
66
 
67
  os.makedirs(PROMPT_DIR, exist_ok=True)
68
  os.makedirs(OUTPUT_ROOT, exist_ok=True)
69
 
70
 
71
- def run_inference(prompt: str, duration: str, use_ema: bool):
 
 
 
 
 
 
72
  """
73
- 1. Écrit le prompt dans un fichier .txt
74
- 2. Lance inference.py avec ce fichier comme --data_path
75
- 3. Retourne le chemin de la vidéo .mp4 générée + les logs
 
76
  """
77
- import glob
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if not prompt or not prompt.strip():
80
  raise gr.Error("Veuillez entrer un prompt texte 🙂")
81
 
82
- # 1) Durée -> num_output_frames + dossier conforme au README
83
  if duration == "5s (21 frames)":
84
  num_output_frames = 21
85
- output_folder = os.path.join(OUTPUT_ROOT, "rewardforcing-5s")
86
- else: # "30s (120 frames)"
87
  num_output_frames = 120
88
- output_folder = os.path.join(OUTPUT_ROOT, "rewardforcing-30s")
89
 
90
- os.makedirs(output_folder, exist_ok=True)
91
 
92
- # 2) Fichier .txt temporaire pour le prompt
93
  prompt_id = uuid.uuid4().hex[:8]
94
  prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt")
95
-
96
  with open(prompt_path, "w", encoding="utf-8") as f:
97
- # TextDataset lit chaque ligne comme un prompt
98
  f.write(prompt.strip() + "\n")
99
 
100
- # 3) On sauve la liste des vidéos AVANT l'inférence
101
- cwd = os.path.dirname(os.path.abspath(__file__))
102
- before_mp4s = set(
103
- os.path.relpath(p, cwd)
104
- for p in glob.glob(os.path.join(cwd, "videos", "**", "*.mp4"), recursive=True)
105
- )
106
 
107
- # 4) Commande inference.py
108
- cmd = [
109
- sys.executable,
110
- "inference.py",
111
- "--num_output_frames",
112
- str(num_output_frames),
113
- "--config_path",
114
- CONFIG_PATH,
115
- "--checkpoint_path",
116
- CHECKPOINT_PATH,
117
- "--output_folder",
118
- output_folder,
119
- "--data_path",
120
- prompt_path,
121
- "--num_samples",
122
- "1",
123
- ]
124
- if use_ema:
125
- cmd.append("--use_ema")
126
-
127
- result = subprocess.run(
128
- cmd,
129
- stdout=subprocess.PIPE,
130
- stderr=subprocess.STDOUT,
131
- text=True,
132
- cwd=cwd, # important sur les Spaces
133
  )
134
 
135
- logs = result.stdout
136
- print(logs)
137
-
138
- # 5) Si inference.py a planté, on remonte l'erreur
139
- if result.returncode != 0:
140
  raise gr.Error(
141
- f"inference.py a retourné un code d'erreur ({result.returncode}).\n\n"
142
- "Regarde les logs ci-dessous pour les détails."
143
- )
144
-
145
- # 6) On regarde les vidéos APRÈS l'inférence
146
- after_mp4s_abs = glob.glob(os.path.join(cwd, "videos", "**", "*.mp4"), recursive=True)
147
- after_mp4s = set(os.path.relpath(p, cwd) for p in after_mp4s_abs)
148
-
149
- new_mp4s = list(after_mp4s - before_mp4s)
150
-
151
- # Debug : log de tout ce qui a été trouvé
152
- logs += "\n\n[DEBUG] Fichiers .mp4 AVANT:\n"
153
- logs += "\n".join(sorted(before_mp4s)) if before_mp4s else "[aucun]\n"
154
- logs += "\n\n[DEBUG] Fichiers .mp4 APRÈS:\n"
155
- logs += "\n".join(sorted(after_mp4s)) if after_mp4s else "[aucun]\n"
156
-
157
- if not new_mp4s:
158
- # Pas de nouvelle vidéo détectée. En dernier recours,
159
- # on prend la plus récente dans tout `videos/` si elle existe.
160
- if after_mp4s_abs:
161
- after_mp4s_abs.sort(key=os.path.getmtime, reverse=True)
162
- fallback_video = after_mp4s_abs[0]
163
- logs += (
164
- "\n\n[WARN] Aucune nouvelle vidéo détectée, "
165
- "on utilise la plus récente trouvée: "
166
- f"{os.path.relpath(fallback_video, cwd)}"
167
- )
168
- return fallback_video, logs
169
-
170
- # Vraiment aucune vidéo
171
- raise gr.Error(
172
- "Aucune vidéo .mp4 trouvée dans le dossier de sortie.\n"
173
  "Regarde les logs ci-dessous pour voir ce qui a coincé."
174
  )
175
 
176
- # On prend la nouvelle vidéo la plus récente
177
- new_mp4s_abs = [os.path.join(cwd, p) for p in new_mp4s]
178
- new_mp4s_abs.sort(key=os.path.getmtime, reverse=True)
179
- video_path = new_mp4s_abs[0]
180
-
181
  return video_path, logs
182
 
183
 
184
- with gr.Blocks(title="Reward Forcing T2V Demo") as demo:
 
 
 
 
185
  gr.Markdown(
186
  """
187
- # 🎬 Reward Forcing – Text-to-Video
188
 
189
- Entrez un prompt texte, on génère un fichier `.txt` en interne
190
- puis on lance `inference.py` avec ce fichier comme `--data_path`.
191
  """
192
  )
193
 
@@ -211,16 +280,17 @@ with gr.Blocks(title="Reward Forcing T2V Demo") as demo:
211
  with gr.Row():
212
  video_out = gr.Video(label="Vidéo générée")
213
  logs_out = gr.Textbox(
214
- label="Logs de inference.py",
215
  lines=12,
216
  interactive=False,
217
  )
218
 
219
  generate_btn.click(
220
- fn=run_inference,
221
  inputs=[prompt_in, duration, use_ema],
222
  outputs=[video_out, logs_out],
223
  )
224
 
 
225
  if __name__ == "__main__":
226
  demo.launch()
 
 
1
  import os
2
+ import sys
3
  import uuid
4
+ import shutil
5
  from datetime import datetime
6
 
7
  import gradio as gr
8
+ import torch
9
+ from omegaconf import OmegaConf
10
+ from tqdm import tqdm
11
+ from torchvision.io import write_video
12
+ from einops import rearrange
13
  from huggingface_hub import snapshot_download
14
 
15
+ from pipeline import (
16
+ CausalDiffusionInferencePipeline,
17
+ CausalInferencePipeline,
18
+ )
19
+ from utils.dataset import TextDataset
20
+ from utils.misc import set_seed
21
+ from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # -------------------------------------------------------------------
24
+ # Téléchargement des checkpoints (comme dans ton app_wip)
25
  # -------------------------------------------------------------------
26
  snapshot_download(
27
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
 
48
  CHECKPOINT_PATH = "checkpoints/Reward-Forcing-T2V-1.3B/rewardforcing.pt"
49
 
50
  PROMPT_DIR = "prompts/gradio_inputs"
 
51
  OUTPUT_ROOT = "videos"
52
 
53
  os.makedirs(PROMPT_DIR, exist_ok=True)
54
  os.makedirs(OUTPUT_ROOT, exist_ok=True)
55
 
56
 
57
+ def reward_forcing_inference(
58
+ prompt_txt_path: str,
59
+ num_output_frames: int,
60
+ use_ema: bool,
61
+ output_root: str,
62
+ progress: gr.Progress,
63
+ ):
64
  """
65
+ Version simplifiée / inline d'inference.py
66
+ - single GPU
67
+ - T2V uniquement
68
+ - 1 prompt par fichier .txt
69
  """
70
+ logs = ""
71
+
72
+ # --------------------- Device & seed ---------------------
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ set_seed(0)
75
+
76
+ free_vram = get_cuda_free_memory_gb(device)
77
+ logs += f"Free VRAM {free_vram} GB\n"
78
+ low_memory = free_vram < 40
79
+
80
+ torch.set_grad_enabled(False)
81
+
82
+ # --------------------- Config & pipeline ---------------------
83
+ logs += "Chargement de la config...\n"
84
+ progress(0.05, desc="Chargement de la config")
85
+ config = OmegaConf.load(CONFIG_PATH)
86
+ default_config = OmegaConf.load("configs/default_config.yaml")
87
+ config = OmegaConf.merge(default_config, config)
88
+
89
+ if hasattr(config, "denoising_step_list"):
90
+ pipeline = CausalInferencePipeline(config, device=device)
91
+ else:
92
+ pipeline = CausalDiffusionInferencePipeline(config, device=device)
93
+
94
+ logs += "Chargement des poids du checkpoint...\n"
95
+ progress(0.1, desc="Chargement du checkpoint")
96
+ state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
97
+ pipeline.generator.load_state_dict(state_dict)
98
+ checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
99
+ checkpoint_step = checkpoint_step.split("_")[-1]
100
+
101
+ pipeline = pipeline.to(dtype=torch.bfloat16)
102
+ if low_memory:
103
+ DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
104
+ else:
105
+ pipeline.text_encoder.to(device=device)
106
+ pipeline.generator.to(device=device)
107
+ pipeline.vae.to(device=device)
108
+
109
+ # --------------------- Dataset / DataLoader ---------------------
110
+ logs += "Préparation du dataset (TextDataset)...\n"
111
+ progress(0.15, desc="Préparation du dataset")
112
+
113
+ dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
114
+ num_prompts = len(dataset)
115
+ logs += f"Number of prompts: {num_prompts}\n"
116
+
117
+ # On ne supporte que batch_size=1 ici
118
+ from torch.utils.data import DataLoader, SequentialSampler
119
+
120
+ sampler = SequentialSampler(dataset)
121
+ dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
122
+
123
+ # --------------------- Output folder (on le vide) ---------------------
124
+ output_folder = os.path.join(output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step)
125
+ shutil.rmtree(output_folder, ignore_errors=True)
126
+ os.makedirs(output_folder, exist_ok=True)
127
+ logs += f"Dossier de sortie: {output_folder}\n"
128
+
129
+ progress(0.2, desc="Démarrage de l'inférence")
130
+
131
+ # --------------------- Boucle d'inférence ---------------------
132
+ # On tracke le tqdm de la boucle avec le Progress Gradio
133
+ for i, batch_data in progress.tqdm(
134
+ enumerate(dataloader),
135
+ total=num_prompts,
136
+ desc="Génération vidéo",
137
+ unit="prompt",
138
+ ):
139
+ idx = batch_data["idx"].item()
140
+
141
+ # batch_size=1 -> on simplifie
142
+ if isinstance(batch_data, dict):
143
+ batch = batch_data
144
+ elif isinstance(batch_data, list):
145
+ batch = batch_data[0]
146
+ else:
147
+ batch = batch_data
148
+
149
+ all_video = []
150
+
151
+ # TEXT-TO-VIDEO uniquement (pas d'I2V)
152
+ prompt = batch["prompts"][0]
153
+ extended_prompt = batch.get("extended_prompts", [None])[0]
154
+ if extended_prompt is not None:
155
+ prompts = [extended_prompt]
156
+ else:
157
+ prompts = [prompt]
158
+
159
+ initial_latent = None
160
+
161
+ sampled_noise = torch.randn(
162
+ [1, num_output_frames, 16, 60, 104],
163
+ device=device,
164
+ dtype=torch.bfloat16,
165
+ )
166
+
167
+ logs += f"Génération pour le prompt: {prompt[:80]}...\n"
168
+ progress(0.4, desc="Sampling latents")
169
+
170
+ # Appel au pipeline
171
+ video, latents = pipeline.inference(
172
+ noise=sampled_noise,
173
+ text_prompts=prompts,
174
+ return_latents=True,
175
+ initial_latent=initial_latent,
176
+ low_memory=low_memory,
177
+ )
178
+
179
+ progress(0.7, desc="Décodage et écriture vidéo")
180
+
181
+ current_video = rearrange(video, "b t c h w -> b t h w c").cpu()
182
+ all_video.append(current_video)
183
+
184
+ video = 255.0 * torch.cat(all_video, dim=1)
185
+
186
+ # Clear VAE cache
187
+ pipeline.vae.model.clear_cache()
188
+
189
+ # Sauvegarde vidéo
190
+ if idx < num_prompts:
191
+ model = "regular" if not use_ema else "ema"
192
+ # pour éviter des noms chelous, on tronque le prompt
193
+ safe_name = prompt[:50].replace("/", "_").replace("\\", "_")
194
+ output_path = os.path.join(output_folder, f"{safe_name}.mp4")
195
+ write_video(output_path, video[0], fps=16)
196
+ logs += f"Vidéo enregistrée: {output_path}\n"
197
+
198
+ # On retourne la première vidéo (une seule dans ton cas)
199
+ return output_path, logs
200
 
201
+ # Si on sort de la boucle sans rien (cas improbable ici)
202
+ logs += "[WARN] Aucune vidéo générée dans la boucle.\n"
203
+ return None, logs
204
+
205
+
206
+ def gradio_generate(prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True)):
207
+ """
208
+ Fonction appelée par Gradio :
209
+ - écrit le prompt dans un .txt
210
+ - appelle reward_forcing_inference
211
+ - retourne (video_path, logs)
212
+ """
213
  if not prompt or not prompt.strip():
214
  raise gr.Error("Veuillez entrer un prompt texte 🙂")
215
 
216
+ # Durée -> frames
217
  if duration == "5s (21 frames)":
218
  num_output_frames = 21
219
+ else:
 
220
  num_output_frames = 120
 
221
 
222
+ os.makedirs(PROMPT_DIR, exist_ok=True)
223
 
 
224
  prompt_id = uuid.uuid4().hex[:8]
225
  prompt_path = os.path.join(PROMPT_DIR, f"prompt_{prompt_id}.txt")
 
226
  with open(prompt_path, "w", encoding="utf-8") as f:
 
227
  f.write(prompt.strip() + "\n")
228
 
229
+ progress(0.01, desc="Préparation de l'inférence")
 
 
 
 
 
230
 
231
+ video_path, logs = reward_forcing_inference(
232
+ prompt_txt_path=prompt_path,
233
+ num_output_frames=num_output_frames,
234
+ use_ema=use_ema,
235
+ output_root=OUTPUT_ROOT,
236
+ progress=progress,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  )
238
 
239
+ if video_path is None or not os.path.exists(video_path):
 
 
 
 
240
  raise gr.Error(
241
+ "Aucune vidéo trouvée après l'inférence.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  "Regarde les logs ci-dessous pour voir ce qui a coincé."
243
  )
244
 
245
+ progress(1.0, desc="Terminé ✅")
 
 
 
 
246
  return video_path, logs
247
 
248
 
249
+ # -------------------------------------------------------------------
250
+ # UI Gradio
251
+ # -------------------------------------------------------------------
252
+
253
+ with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
254
  gr.Markdown(
255
  """
256
+ # 🎬 Reward Forcing – Text-to-Video (inline)
257
 
258
+ Cette version appelle directement la logique d'inférence en Python,
259
+ ce qui permet à Gradio de suivre le `tqdm` et d'afficher une barre de progression.
260
  """
261
  )
262
 
 
280
  with gr.Row():
281
  video_out = gr.Video(label="Vidéo générée")
282
  logs_out = gr.Textbox(
283
+ label="Logs",
284
  lines=12,
285
  interactive=False,
286
  )
287
 
288
  generate_btn.click(
289
+ fn=gradio_generate,
290
  inputs=[prompt_in, duration, use_ema],
291
  outputs=[video_out, logs_out],
292
  )
293
 
294
+ demo.queue()
295
  if __name__ == "__main__":
296
  demo.launch()