fffiloni commited on
Commit
d45d065
·
verified ·
1 Parent(s): 0d629ec

Update app_wip.py

Browse files
Files changed (1) hide show
  1. app_wip.py +48 -42
app_wip.py CHANGED
@@ -2,12 +2,10 @@ import os
2
  import sys
3
  import uuid
4
  import shutil
5
- from datetime import datetime
6
 
7
  import gradio as gr
8
  import torch
9
  from omegaconf import OmegaConf
10
- from tqdm import tqdm
11
  from torchvision.io import write_video
12
  from einops import rearrange
13
  from huggingface_hub import snapshot_download
@@ -79,40 +77,39 @@ def reward_forcing_inference(
79
 
80
  torch.set_grad_enabled(False)
81
 
82
- # --------------------- BARRE 1 : init modèle / config ---------------------
83
- # 4 étapes : config, pipeline, checkpoint, move to device
84
- with progress.tqdm(total=4, desc="Initialisation du modèle", unit="step") as pbar:
85
- logs += "Chargement de la config...\n"
86
- config = OmegaConf.load(CONFIG_PATH)
87
- default_config = OmegaConf.load("configs/default_config.yaml")
88
- config = OmegaConf.merge(default_config, config)
89
- pbar.update(1)
90
-
91
- logs += "Initialisation de la pipeline...\n"
92
- if hasattr(config, "denoising_step_list"):
93
- pipeline = CausalInferencePipeline(config, device=device)
94
- else:
95
- pipeline = CausalDiffusionInferencePipeline(config, device=device)
96
- pbar.update(1)
97
-
98
- logs += "Chargement des poids du checkpoint...\n"
99
- state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
100
- pipeline.generator.load_state_dict(state_dict)
101
- checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
102
- checkpoint_step = checkpoint_step.split("_")[-1]
103
- pbar.update(1)
104
-
105
- logs += "Placement du modèle sur le device...\n"
106
- pipeline = pipeline.to(dtype=torch.bfloat16)
107
- if low_memory:
108
- DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
109
- else:
110
- pipeline.text_encoder.to(device=device)
111
- pipeline.generator.to(device=device)
112
- pipeline.vae.to(device=device)
113
- pbar.update(1)
114
 
115
  # --------------------- Dataset / DataLoader ---------------------
 
116
  logs += "Préparation du dataset (TextDataset)...\n"
117
  dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
118
  num_prompts = len(dataset)
@@ -121,15 +118,21 @@ def reward_forcing_inference(
121
  from torch.utils.data import DataLoader, SequentialSampler
122
 
123
  sampler = SequentialSampler(dataset)
124
- dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
 
 
125
 
126
  # --------------------- Output folder (on le vide) ---------------------
127
- output_folder = os.path.join(output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step)
 
 
 
128
  shutil.rmtree(output_folder, ignore_errors=True)
129
  os.makedirs(output_folder, exist_ok=True)
130
  logs += f"Dossier de sortie: {output_folder}\n"
131
 
132
- # --------------------- BARRE 2 : boucle d'inférence ---------------------
 
133
  for i, batch_data in progress.tqdm(
134
  enumerate(dataloader),
135
  total=num_prompts,
@@ -190,13 +193,17 @@ def reward_forcing_inference(
190
  output_path = os.path.join(output_folder, f"{safe_name}.mp4")
191
  write_video(output_path, video[0], fps=16)
192
  logs += f"Vidéo enregistrée: {output_path}\n"
 
 
193
  return output_path, logs
194
 
195
  logs += "[WARN] Aucune vidéo générée dans la boucle.\n"
196
  return None, logs
197
 
198
 
199
- def gradio_generate(prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True)):
 
 
200
  """
201
  Fonction appelée par Gradio :
202
  - écrit le prompt dans un .txt
@@ -219,7 +226,6 @@ def gradio_generate(prompt: str, duration: str, use_ema: bool, progress=gr.Progr
219
  with open(prompt_path, "w", encoding="utf-8") as f:
220
  f.write(prompt.strip() + "\n")
221
 
222
- # Appel de la fonction d'inférence inline
223
  video_path, logs = reward_forcing_inference(
224
  prompt_txt_path=prompt_path,
225
  num_output_frames=num_output_frames,
@@ -247,9 +253,9 @@ with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
247
  # 🎬 Reward Forcing – Text-to-Video (inline)
248
 
249
  Cette version appelle directement la logique d'inférence en Python,
250
- ce qui permet à Gradio de suivre les `tqdm` :
251
- - Initialisation du modèle
252
- - Génération vidéo
253
  """
254
  )
255
 
 
2
  import sys
3
  import uuid
4
  import shutil
 
5
 
6
  import gradio as gr
7
  import torch
8
  from omegaconf import OmegaConf
 
9
  from torchvision.io import write_video
10
  from einops import rearrange
11
  from huggingface_hub import snapshot_download
 
77
 
78
  torch.set_grad_enabled(False)
79
 
80
+ # --------------------- Phase 1 : init modèle / config ---------------------
81
+ progress(0.05, desc="Initialisation : chargement de la config")
82
+ logs += "Chargement de la config...\n"
83
+ config = OmegaConf.load(CONFIG_PATH)
84
+ default_config = OmegaConf.load("configs/default_config.yaml")
85
+ config = OmegaConf.merge(default_config, config)
86
+
87
+ progress(0.15, desc="Initialisation : création de la pipeline")
88
+ logs += "Initialisation de la pipeline...\n"
89
+ if hasattr(config, "denoising_step_list"):
90
+ pipeline = CausalInferencePipeline(config, device=device)
91
+ else:
92
+ pipeline = CausalDiffusionInferencePipeline(config, device=device)
93
+
94
+ progress(0.35, desc="Initialisation : chargement du checkpoint")
95
+ logs += "Chargement des poids du checkpoint...\n"
96
+ state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
97
+ pipeline.generator.load_state_dict(state_dict)
98
+ checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
99
+ checkpoint_step = checkpoint_step.split("_")[-1]
100
+
101
+ progress(0.55, desc="Initialisation : placement sur le device")
102
+ logs += "Placement du modèle sur le device...\n"
103
+ pipeline = pipeline.to(dtype=torch.bfloat16)
104
+ if low_memory:
105
+ DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
106
+ else:
107
+ pipeline.text_encoder.to(device=device)
108
+ pipeline.generator.to(device=device)
109
+ pipeline.vae.to(device=device)
 
 
110
 
111
  # --------------------- Dataset / DataLoader ---------------------
112
+ progress(0.65, desc="Préparation du dataset")
113
  logs += "Préparation du dataset (TextDataset)...\n"
114
  dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
115
  num_prompts = len(dataset)
 
118
  from torch.utils.data import DataLoader, SequentialSampler
119
 
120
  sampler = SequentialSampler(dataset)
121
+ dataloader = DataLoader(
122
+ dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False
123
+ )
124
 
125
  # --------------------- Output folder (on le vide) ---------------------
126
+ progress(0.7, desc="Nettoyage du dossier de sortie")
127
+ output_folder = os.path.join(
128
+ output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
129
+ )
130
  shutil.rmtree(output_folder, ignore_errors=True)
131
  os.makedirs(output_folder, exist_ok=True)
132
  logs += f"Dossier de sortie: {output_folder}\n"
133
 
134
+ # --------------------- Phase 2 : boucle d'inférence ---------------------
135
+ # Ici on peut utiliser progress.tqdm sur la boucle dataloader
136
  for i, batch_data in progress.tqdm(
137
  enumerate(dataloader),
138
  total=num_prompts,
 
193
  output_path = os.path.join(output_folder, f"{safe_name}.mp4")
194
  write_video(output_path, video[0], fps=16)
195
  logs += f"Vidéo enregistrée: {output_path}\n"
196
+
197
+ progress(1.0, desc="Terminé ✅")
198
  return output_path, logs
199
 
200
  logs += "[WARN] Aucune vidéo générée dans la boucle.\n"
201
  return None, logs
202
 
203
 
204
+ def gradio_generate(
205
+ prompt: str, duration: str, use_ema: bool, progress=gr.Progress(track_tqdm=True)
206
+ ):
207
  """
208
  Fonction appelée par Gradio :
209
  - écrit le prompt dans un .txt
 
226
  with open(prompt_path, "w", encoding="utf-8") as f:
227
  f.write(prompt.strip() + "\n")
228
 
 
229
  video_path, logs = reward_forcing_inference(
230
  prompt_txt_path=prompt_path,
231
  num_output_frames=num_output_frames,
 
253
  # 🎬 Reward Forcing – Text-to-Video (inline)
254
 
255
  Cette version appelle directement la logique d'inférence en Python,
256
+ ce qui permet à Gradio de suivre :
257
+ - l'initialisation du modèle (via `progress(...)`)
258
+ - la boucle de génération (via `progress.tqdm(...)`)
259
  """
260
  )
261