fffiloni commited on
Commit
0d629ec
·
verified ·
1 Parent(s): c0f475e

Update app_wip.py

Browse files
Files changed (1) hide show
  1. app_wip.py +43 -50
app_wip.py CHANGED
@@ -21,7 +21,7 @@ from utils.misc import set_seed
21
  from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller
22
 
23
  # -------------------------------------------------------------------
24
- # Téléchargement des checkpoints (comme dans ton app_wip)
25
  # -------------------------------------------------------------------
26
  snapshot_download(
27
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
@@ -62,10 +62,10 @@ def reward_forcing_inference(
62
  progress: gr.Progress,
63
  ):
64
  """
65
- Version simplifiée / inline d'inference.py
66
  - single GPU
67
  - T2V uniquement
68
- - 1 prompt par fichier .txt
69
  """
70
  logs = ""
71
 
@@ -79,42 +79,45 @@ def reward_forcing_inference(
79
 
80
  torch.set_grad_enabled(False)
81
 
82
- # --------------------- Config & pipeline ---------------------
83
- logs += "Chargement de la config...\n"
84
- progress(0.05, desc="Chargement de la config")
85
- config = OmegaConf.load(CONFIG_PATH)
86
- default_config = OmegaConf.load("configs/default_config.yaml")
87
- config = OmegaConf.merge(default_config, config)
88
-
89
- if hasattr(config, "denoising_step_list"):
90
- pipeline = CausalInferencePipeline(config, device=device)
91
- else:
92
- pipeline = CausalDiffusionInferencePipeline(config, device=device)
93
-
94
- logs += "Chargement des poids du checkpoint...\n"
95
- progress(0.1, desc="Chargement du checkpoint")
96
- state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
97
- pipeline.generator.load_state_dict(state_dict)
98
- checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
99
- checkpoint_step = checkpoint_step.split("_")[-1]
100
-
101
- pipeline = pipeline.to(dtype=torch.bfloat16)
102
- if low_memory:
103
- DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
104
- else:
105
- pipeline.text_encoder.to(device=device)
106
- pipeline.generator.to(device=device)
107
- pipeline.vae.to(device=device)
 
 
 
 
 
 
108
 
109
  # --------------------- Dataset / DataLoader ---------------------
110
  logs += "Préparation du dataset (TextDataset)...\n"
111
- progress(0.15, desc="Préparation du dataset")
112
-
113
  dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
114
  num_prompts = len(dataset)
115
  logs += f"Number of prompts: {num_prompts}\n"
116
 
117
- # On ne supporte que batch_size=1 ici
118
  from torch.utils.data import DataLoader, SequentialSampler
119
 
120
  sampler = SequentialSampler(dataset)
@@ -126,10 +129,7 @@ def reward_forcing_inference(
126
  os.makedirs(output_folder, exist_ok=True)
127
  logs += f"Dossier de sortie: {output_folder}\n"
128
 
129
- progress(0.2, desc="Démarrage de l'inférence")
130
-
131
- # --------------------- Boucle d'inférence ---------------------
132
- # On tracke le tqdm de la boucle avec le Progress Gradio
133
  for i, batch_data in progress.tqdm(
134
  enumerate(dataloader),
135
  total=num_prompts,
@@ -138,7 +138,7 @@ def reward_forcing_inference(
138
  ):
139
  idx = batch_data["idx"].item()
140
 
141
- # batch_size=1 -> on simplifie
142
  if isinstance(batch_data, dict):
143
  batch = batch_data
144
  elif isinstance(batch_data, list):
@@ -148,7 +148,7 @@ def reward_forcing_inference(
148
 
149
  all_video = []
150
 
151
- # TEXT-TO-VIDEO uniquement (pas d'I2V)
152
  prompt = batch["prompts"][0]
153
  extended_prompt = batch.get("extended_prompts", [None])[0]
154
  if extended_prompt is not None:
@@ -165,7 +165,6 @@ def reward_forcing_inference(
165
  )
166
 
167
  logs += f"Génération pour le prompt: {prompt[:80]}...\n"
168
- progress(0.4, desc="Sampling latents")
169
 
170
  # Appel au pipeline
171
  video, latents = pipeline.inference(
@@ -176,8 +175,6 @@ def reward_forcing_inference(
176
  low_memory=low_memory,
177
  )
178
 
179
- progress(0.7, desc="Décodage et écriture vidéo")
180
-
181
  current_video = rearrange(video, "b t c h w -> b t h w c").cpu()
182
  all_video.append(current_video)
183
 
@@ -186,19 +183,15 @@ def reward_forcing_inference(
186
  # Clear VAE cache
187
  pipeline.vae.model.clear_cache()
188
 
189
- # Sauvegarde vidéo
190
  if idx < num_prompts:
191
  model = "regular" if not use_ema else "ema"
192
- # pour éviter des noms chelous, on tronque le prompt
193
  safe_name = prompt[:50].replace("/", "_").replace("\\", "_")
194
  output_path = os.path.join(output_folder, f"{safe_name}.mp4")
195
  write_video(output_path, video[0], fps=16)
196
  logs += f"Vidéo enregistrée: {output_path}\n"
197
-
198
- # On retourne la première vidéo (une seule dans ton cas)
199
  return output_path, logs
200
 
201
- # Si on sort de la boucle sans rien (cas improbable ici)
202
  logs += "[WARN] Aucune vidéo générée dans la boucle.\n"
203
  return None, logs
204
 
@@ -226,8 +219,7 @@ def gradio_generate(prompt: str, duration: str, use_ema: bool, progress=gr.Progr
226
  with open(prompt_path, "w", encoding="utf-8") as f:
227
  f.write(prompt.strip() + "\n")
228
 
229
- progress(0.01, desc="Préparation de l'inférence")
230
-
231
  video_path, logs = reward_forcing_inference(
232
  prompt_txt_path=prompt_path,
233
  num_output_frames=num_output_frames,
@@ -242,7 +234,6 @@ def gradio_generate(prompt: str, duration: str, use_ema: bool, progress=gr.Progr
242
  "Regarde les logs ci-dessous pour voir ce qui a coincé."
243
  )
244
 
245
- progress(1.0, desc="Terminé ✅")
246
  return video_path, logs
247
 
248
 
@@ -256,7 +247,9 @@ with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
256
  # 🎬 Reward Forcing – Text-to-Video (inline)
257
 
258
  Cette version appelle directement la logique d'inférence en Python,
259
- ce qui permet à Gradio de suivre le `tqdm` et d'afficher une barre de progression.
 
 
260
  """
261
  )
262
 
 
21
  from demo_utils.memory import get_cuda_free_memory_gb, DynamicSwapInstaller
22
 
23
  # -------------------------------------------------------------------
24
+ # Téléchargement des checkpoints (une fois au démarrage du Space)
25
  # -------------------------------------------------------------------
26
  snapshot_download(
27
  repo_id="Wan-AI/Wan2.1-T2V-1.3B",
 
62
  progress: gr.Progress,
63
  ):
64
  """
65
+ Version inline / simplifiée de inference.py :
66
  - single GPU
67
  - T2V uniquement
68
+ - 1 fichier .txt = n prompts (mais on retourne la 1ère vidéo)
69
  """
70
  logs = ""
71
 
 
79
 
80
  torch.set_grad_enabled(False)
81
 
82
+ # --------------------- BARRE 1 : init modèle / config ---------------------
83
+ # 4 étapes : config, pipeline, checkpoint, move to device
84
+ with progress.tqdm(total=4, desc="Initialisation du modèle", unit="step") as pbar:
85
+ logs += "Chargement de la config...\n"
86
+ config = OmegaConf.load(CONFIG_PATH)
87
+ default_config = OmegaConf.load("configs/default_config.yaml")
88
+ config = OmegaConf.merge(default_config, config)
89
+ pbar.update(1)
90
+
91
+ logs += "Initialisation de la pipeline...\n"
92
+ if hasattr(config, "denoising_step_list"):
93
+ pipeline = CausalInferencePipeline(config, device=device)
94
+ else:
95
+ pipeline = CausalDiffusionInferencePipeline(config, device=device)
96
+ pbar.update(1)
97
+
98
+ logs += "Chargement des poids du checkpoint...\n"
99
+ state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
100
+ pipeline.generator.load_state_dict(state_dict)
101
+ checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
102
+ checkpoint_step = checkpoint_step.split("_")[-1]
103
+ pbar.update(1)
104
+
105
+ logs += "Placement du modèle sur le device...\n"
106
+ pipeline = pipeline.to(dtype=torch.bfloat16)
107
+ if low_memory:
108
+ DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
109
+ else:
110
+ pipeline.text_encoder.to(device=device)
111
+ pipeline.generator.to(device=device)
112
+ pipeline.vae.to(device=device)
113
+ pbar.update(1)
114
 
115
  # --------------------- Dataset / DataLoader ---------------------
116
  logs += "Préparation du dataset (TextDataset)...\n"
 
 
117
  dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
118
  num_prompts = len(dataset)
119
  logs += f"Number of prompts: {num_prompts}\n"
120
 
 
121
  from torch.utils.data import DataLoader, SequentialSampler
122
 
123
  sampler = SequentialSampler(dataset)
 
129
  os.makedirs(output_folder, exist_ok=True)
130
  logs += f"Dossier de sortie: {output_folder}\n"
131
 
132
+ # --------------------- BARRE 2 : boucle d'inférence ---------------------
 
 
 
133
  for i, batch_data in progress.tqdm(
134
  enumerate(dataloader),
135
  total=num_prompts,
 
138
  ):
139
  idx = batch_data["idx"].item()
140
 
141
+ # Unpack batch
142
  if isinstance(batch_data, dict):
143
  batch = batch_data
144
  elif isinstance(batch_data, list):
 
148
 
149
  all_video = []
150
 
151
+ # TEXT-TO-VIDEO uniquement (pas d'I2V ici)
152
  prompt = batch["prompts"][0]
153
  extended_prompt = batch.get("extended_prompts", [None])[0]
154
  if extended_prompt is not None:
 
165
  )
166
 
167
  logs += f"Génération pour le prompt: {prompt[:80]}...\n"
 
168
 
169
  # Appel au pipeline
170
  video, latents = pipeline.inference(
 
175
  low_memory=low_memory,
176
  )
177
 
 
 
178
  current_video = rearrange(video, "b t c h w -> b t h w c").cpu()
179
  all_video.append(current_video)
180
 
 
183
  # Clear VAE cache
184
  pipeline.vae.model.clear_cache()
185
 
186
+ # Sauvegarde vidéo (on retourne la 1ère vidéo)
187
  if idx < num_prompts:
188
  model = "regular" if not use_ema else "ema"
 
189
  safe_name = prompt[:50].replace("/", "_").replace("\\", "_")
190
  output_path = os.path.join(output_folder, f"{safe_name}.mp4")
191
  write_video(output_path, video[0], fps=16)
192
  logs += f"Vidéo enregistrée: {output_path}\n"
 
 
193
  return output_path, logs
194
 
 
195
  logs += "[WARN] Aucune vidéo générée dans la boucle.\n"
196
  return None, logs
197
 
 
219
  with open(prompt_path, "w", encoding="utf-8") as f:
220
  f.write(prompt.strip() + "\n")
221
 
222
+ # Appel de la fonction d'inférence inline
 
223
  video_path, logs = reward_forcing_inference(
224
  prompt_txt_path=prompt_path,
225
  num_output_frames=num_output_frames,
 
234
  "Regarde les logs ci-dessous pour voir ce qui a coincé."
235
  )
236
 
 
237
  return video_path, logs
238
 
239
 
 
247
  # 🎬 Reward Forcing – Text-to-Video (inline)
248
 
249
  Cette version appelle directement la logique d'inférence en Python,
250
+ ce qui permet à Gradio de suivre les `tqdm` :
251
+ - Initialisation du modèle
252
+ - Génération vidéo
253
  """
254
  )
255