fffiloni commited on
Commit
e3f74ba
·
verified ·
1 Parent(s): d45d065

Update app_wip.py

Browse files
Files changed (1) hide show
  1. app_wip.py +78 -53
app_wip.py CHANGED
@@ -51,6 +51,70 @@ OUTPUT_ROOT = "videos"
51
  os.makedirs(PROMPT_DIR, exist_ok=True)
52
  os.makedirs(OUTPUT_ROOT, exist_ok=True)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def reward_forcing_inference(
56
  prompt_txt_path: str,
@@ -61,55 +125,17 @@ def reward_forcing_inference(
61
  ):
62
  """
63
  Version inline / simplifiée de inference.py :
64
- - single GPU
65
  - T2V uniquement
66
  - 1 fichier .txt = n prompts (mais on retourne la 1ère vidéo)
67
  """
68
  logs = ""
69
 
70
- # --------------------- Device & seed ---------------------
71
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
- set_seed(0)
73
-
74
- free_vram = get_cuda_free_memory_gb(device)
75
- logs += f"Free VRAM {free_vram} GB\n"
76
- low_memory = free_vram < 40
77
-
78
- torch.set_grad_enabled(False)
79
-
80
- # --------------------- Phase 1 : init modèle / config ---------------------
81
- progress(0.05, desc="Initialisation : chargement de la config")
82
- logs += "Chargement de la config...\n"
83
- config = OmegaConf.load(CONFIG_PATH)
84
- default_config = OmegaConf.load("configs/default_config.yaml")
85
- config = OmegaConf.merge(default_config, config)
86
-
87
- progress(0.15, desc="Initialisation : création de la pipeline")
88
- logs += "Initialisation de la pipeline...\n"
89
- if hasattr(config, "denoising_step_list"):
90
- pipeline = CausalInferencePipeline(config, device=device)
91
- else:
92
- pipeline = CausalDiffusionInferencePipeline(config, device=device)
93
-
94
- progress(0.35, desc="Initialisation : chargement du checkpoint")
95
- logs += "Chargement des poids du checkpoint...\n"
96
- state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
97
- pipeline.generator.load_state_dict(state_dict)
98
- checkpoint_step = os.path.basename(os.path.dirname(CHECKPOINT_PATH))
99
- checkpoint_step = checkpoint_step.split("_")[-1]
100
-
101
- progress(0.55, desc="Initialisation : placement sur le device")
102
- logs += "Placement du modèle sur le device...\n"
103
- pipeline = pipeline.to(dtype=torch.bfloat16)
104
- if low_memory:
105
- DynamicSwapInstaller.install_model(pipeline.text_encoder, device=device)
106
- else:
107
- pipeline.text_encoder.to(device=device)
108
- pipeline.generator.to(device=device)
109
- pipeline.vae.to(device=device)
110
 
111
  # --------------------- Dataset / DataLoader ---------------------
112
- progress(0.65, desc="Préparation du dataset")
113
  logs += "Préparation du dataset (TextDataset)...\n"
114
  dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
115
  num_prompts = len(dataset)
@@ -123,7 +149,7 @@ def reward_forcing_inference(
123
  )
124
 
125
  # --------------------- Output folder (on le vide) ---------------------
126
- progress(0.7, desc="Nettoyage du dossier de sortie")
127
  output_folder = os.path.join(
128
  output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
129
  )
@@ -131,8 +157,7 @@ def reward_forcing_inference(
131
  os.makedirs(output_folder, exist_ok=True)
132
  logs += f"Dossier de sortie: {output_folder}\n"
133
 
134
- # --------------------- Phase 2 : boucle d'inférence ---------------------
135
- # Ici on peut utiliser progress.tqdm sur la boucle dataloader
136
  for i, batch_data in progress.tqdm(
137
  enumerate(dataloader),
138
  total=num_prompts,
@@ -151,7 +176,7 @@ def reward_forcing_inference(
151
 
152
  all_video = []
153
 
154
- # TEXT-TO-VIDEO uniquement (pas d'I2V ici)
155
  prompt = batch["prompts"][0]
156
  extended_prompt = batch.get("extended_prompts", [None])[0]
157
  if extended_prompt is not None:
@@ -163,7 +188,7 @@ def reward_forcing_inference(
163
 
164
  sampled_noise = torch.randn(
165
  [1, num_output_frames, 16, 60, 104],
166
- device=device,
167
  dtype=torch.bfloat16,
168
  )
169
 
@@ -247,15 +272,15 @@ def gradio_generate(
247
  # UI Gradio
248
  # -------------------------------------------------------------------
249
 
250
- with gr.Blocks(title="Reward Forcing T2V Demo (inline inference)") as demo:
251
  gr.Markdown(
252
  """
253
- # 🎬 Reward Forcing – Text-to-Video (inline)
254
 
255
- Cette version appelle directement la logique d'inférence en Python,
256
- ce qui permet à Gradio de suivre :
257
- - l'initialisation du modèle (via `progress(...)`)
258
- - la boucle de génération (via `progress.tqdm(...)`)
259
  """
260
  )
261
 
 
51
  os.makedirs(PROMPT_DIR, exist_ok=True)
52
  os.makedirs(OUTPUT_ROOT, exist_ok=True)
53
 
54
+ # === Globals pour le cache du modèle ===
55
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
+ PIPELINE = None
57
+ LOW_MEMORY = None
58
+ CHECKPOINT_STEP = None
59
+
60
+
61
+ def load_pipeline(progress: gr.Progress):
62
+ """
63
+ Charge la config + pipeline + checkpoint + placement device une seule fois.
64
+ Utilise progress.tqdm pour afficher plusieurs étapes la 1ère fois.
65
+ """
66
+ global PIPELINE, LOW_MEMORY, CHECKPOINT_STEP
67
+
68
+ logs = ""
69
+
70
+ # Si déjà chargé, on ne refait rien de lourd
71
+ if PIPELINE is not None:
72
+ progress(0.1, desc="Modèle déjà initialisé (cache)")
73
+ logs += "Modèle déjà initialisé, réutilisation du cache.\n"
74
+ return PIPELINE, LOW_MEMORY, CHECKPOINT_STEP, logs
75
+
76
+ # ---- Première initialisation lourde ----
77
+ set_seed(0)
78
+ free_vram = get_cuda_free_memory_gb(DEVICE)
79
+ LOW_MEMORY = free_vram < 40
80
+ logs += f"Free VRAM {free_vram} GB\n"
81
+
82
+ steps = range(4)
83
+ for step in progress.tqdm(steps, desc="Initialisation du modèle", unit="étape"):
84
+ if step == 0:
85
+ logs += "Étape 1/4 : Chargement de la config...\n"
86
+ config = OmegaConf.load(CONFIG_PATH)
87
+ default_config = OmegaConf.load("configs/default_config.yaml")
88
+ config = OmegaConf.merge(default_config, config)
89
+
90
+ elif step == 1:
91
+ logs += "Étape 2/4 : Création de la pipeline...\n"
92
+ if hasattr(config, "denoising_step_list"):
93
+ PIPELINE = CausalInferencePipeline(config, device=DEVICE)
94
+ else:
95
+ PIPELINE = CausalDiffusionInferencePipeline(config, device=DEVICE)
96
+
97
+ elif step == 2:
98
+ logs += "Étape 3/4 : Chargement des poids du checkpoint...\n"
99
+ state_dict = torch.load(CHECKPOINT_PATH, map_location="cpu")
100
+ PIPELINE.generator.load_state_dict(state_dict)
101
+ ckpt_dir = os.path.dirname(CHECKPOINT_PATH)
102
+ CHECKPOINT_STEP = os.path.basename(ckpt_dir)
103
+ CHECKPOINT_STEP = CHECKPOINT_STEP.split("_")[-1]
104
+
105
+ elif step == 3:
106
+ logs += "Étape 4/4 : Placement du modèle sur le device...\n"
107
+ PIPELINE = PIPELINE.to(dtype=torch.bfloat16)
108
+ if LOW_MEMORY:
109
+ DynamicSwapInstaller.install_model(PIPELINE.text_encoder, device=DEVICE)
110
+ else:
111
+ PIPELINE.text_encoder.to(device=DEVICE)
112
+ PIPELINE.generator.to(device=DEVICE)
113
+ PIPELINE.vae.to(device=DEVICE)
114
+
115
+ logs += "Initialisation du modèle terminée ✅\n"
116
+ return PIPELINE, LOW_MEMORY, CHECKPOINT_STEP, logs
117
+
118
 
119
  def reward_forcing_inference(
120
  prompt_txt_path: str,
 
125
  ):
126
  """
127
  Version inline / simplifiée de inference.py :
 
128
  - T2V uniquement
129
  - 1 fichier .txt = n prompts (mais on retourne la 1ère vidéo)
130
  """
131
  logs = ""
132
 
133
+ # --------------------- Load / cache pipeline ---------------------
134
+ pipeline, low_memory, checkpoint_step, init_logs = load_pipeline(progress)
135
+ logs += init_logs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  # --------------------- Dataset / DataLoader ---------------------
138
+ progress(0.7, desc="Préparation du dataset")
139
  logs += "Préparation du dataset (TextDataset)...\n"
140
  dataset = TextDataset(prompt_path=prompt_txt_path, extended_prompt_path=None)
141
  num_prompts = len(dataset)
 
149
  )
150
 
151
  # --------------------- Output folder (on le vide) ---------------------
152
+ progress(0.8, desc="Nettoyage du dossier de sortie")
153
  output_folder = os.path.join(
154
  output_root, f"rewardforcing-{num_output_frames}f", checkpoint_step
155
  )
 
157
  os.makedirs(output_folder, exist_ok=True)
158
  logs += f"Dossier de sortie: {output_folder}\n"
159
 
160
+ # --------------------- Boucle d'inférence (tqdm) ---------------------
 
161
  for i, batch_data in progress.tqdm(
162
  enumerate(dataloader),
163
  total=num_prompts,
 
176
 
177
  all_video = []
178
 
179
+ # TEXT-TO-VIDEO uniquement
180
  prompt = batch["prompts"][0]
181
  extended_prompt = batch.get("extended_prompts", [None])[0]
182
  if extended_prompt is not None:
 
188
 
189
  sampled_noise = torch.randn(
190
  [1, num_output_frames, 16, 60, 104],
191
+ device=DEVICE,
192
  dtype=torch.bfloat16,
193
  )
194
 
 
272
  # UI Gradio
273
  # -------------------------------------------------------------------
274
 
275
+ with gr.Blocks(title="Reward Forcing T2V Demo (inline, cached)") as demo:
276
  gr.Markdown(
277
  """
278
+ # 🎬 Reward Forcing – Text-to-Video (inline & cached)
279
 
280
+ Cette version :
281
+ - Charge et initialise le modèle **une seule fois** (cache global)
282
+ - Affiche une barre `tqdm` multi-étapes pour l'initialisation la 1ère fois
283
+ - Affiche une barre `tqdm` pour la génération vidéo (1 step / prompt)
284
  """
285
  )
286