klyfff commited on
Commit
22da5c3
·
verified ·
1 Parent(s): d38665e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -84,13 +84,9 @@ class WaveCollapseTracker:
84
  self.snapshot = None
85
 
86
  def callback(self, pipe, step_index, timestep, callback_kwargs):
87
- # 1. Trigger the step timer
88
  self.timer.step(step_index + 1)
89
-
90
- # 2. Extract current latents
91
  latents = callback_kwargs["latents"]
92
 
93
- # 3. Wave Collapse Math
94
  if self.prev_latents is not None:
95
  delta = (latents - self.prev_latents).abs().mean(dim=1, keepdim=True)
96
  new_settled = delta < self.epsilon
@@ -159,6 +155,7 @@ def infer(
159
  try:
160
  torch.randn_like = _locked_randn_like
161
 
 
162
  if auto_anti_prompt and prompt:
163
  pos_hidden, pos_pooled = encode_prompt_sdxl(pipe, prompt, device)
164
  neg_hidden = -pos_hidden
@@ -166,7 +163,7 @@ def infer(
166
 
167
  wave_tracker.timer.start()
168
  t_gen_start = time.time()
169
- image = pipe(
170
  prompt_embeds=pos_hidden,
171
  negative_prompt_embeds=neg_hidden,
172
  pooled_prompt_embeds=pos_pooled,
@@ -176,12 +173,13 @@ def infer(
176
  width=width,
177
  height=height,
178
  generator=generator,
 
179
  callback_on_step_end=wave_tracker.callback,
180
- ).images[0]
181
  else:
182
  wave_tracker.timer.start()
183
  t_gen_start = time.time()
184
- image = pipe(
185
  prompt=prompt,
186
  negative_prompt=negative_prompt if negative_prompt else None,
187
  guidance_scale=guidance_scale,
@@ -189,29 +187,32 @@ def infer(
189
  width=width,
190
  height=height,
191
  generator=generator,
 
192
  callback_on_step_end=wave_tracker.callback,
193
- ).images[0]
194
 
195
  t_gen_end = time.time()
 
196
 
197
- # --- Decode the Wave Collapse Map ---
198
- collapse_image = None
199
  if wave_tracker.snapshot is not None and wave_tracker.cumulative_mask is not None:
200
- unsettled = (~wave_tracker.cumulative_mask).expand_as(wave_tracker.prev_latents).to(wave_tracker.prev_latents.dtype)
201
- final_snapshot = (wave_tracker.snapshot * (1.0 - unsettled)) + (wave_tracker.prev_latents * unsettled)
 
 
202
 
203
- with torch.no_grad():
204
- # Upcast to prevent VAE black-screen bug
205
- pipe.vae.to(dtype=torch.float32)
206
- final_snapshot = (final_snapshot / pipe.vae.config.scaling_factor).to(torch.float32)
207
- collapse_tensor = pipe.vae.decode(final_snapshot, return_dict=False)[0]
208
- collapse_image = pipe.image_processor.postprocess(collapse_tensor, output_type="pil")[0]
209
 
210
  total_time = t_gen_end - t_start
211
  step_summary = timer.summary()
212
  status = f"{'CLIP Mirror ON' if auto_anti_prompt else 'Standard CFG'} | {step_summary} | Total: {total_time:.1f}s"
213
 
214
- return image, collapse_image, seed, status
215
 
216
  finally:
217
  torch.randn_like = _original_randn_like
@@ -263,7 +264,6 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
263
 
264
  with gr.Column(scale=1):
265
  output_image = gr.Image(label="Final Generated Image")
266
- collapse_map = gr.Image(label="Wave Collapse Visualization")
267
 
268
  with gr.Row():
269
  output_seed = gr.Textbox(label="Used Seed", interactive=False)
@@ -276,7 +276,7 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
276
  width, height, guidance_scale, num_inference_steps,
277
  auto_anti_prompt, epsilon
278
  ],
279
- outputs=[output_image, collapse_map, output_seed, status_display],
280
  )
281
 
282
  if __name__ == "__main__":
 
84
  self.snapshot = None
85
 
86
  def callback(self, pipe, step_index, timestep, callback_kwargs):
 
87
  self.timer.step(step_index + 1)
 
 
88
  latents = callback_kwargs["latents"]
89
 
 
90
  if self.prev_latents is not None:
91
  delta = (latents - self.prev_latents).abs().mean(dim=1, keepdim=True)
92
  new_settled = delta < self.epsilon
 
155
  try:
156
  torch.randn_like = _locked_randn_like
157
 
158
+ # We pass output_type="latent" so the pipeline stops before the VAE decode
159
  if auto_anti_prompt and prompt:
160
  pos_hidden, pos_pooled = encode_prompt_sdxl(pipe, prompt, device)
161
  neg_hidden = -pos_hidden
 
163
 
164
  wave_tracker.timer.start()
165
  t_gen_start = time.time()
166
+ pipeline_output = pipe(
167
  prompt_embeds=pos_hidden,
168
  negative_prompt_embeds=neg_hidden,
169
  pooled_prompt_embeds=pos_pooled,
 
173
  width=width,
174
  height=height,
175
  generator=generator,
176
+ output_type="latent",
177
  callback_on_step_end=wave_tracker.callback,
178
+ )
179
  else:
180
  wave_tracker.timer.start()
181
  t_gen_start = time.time()
182
+ pipeline_output = pipe(
183
  prompt=prompt,
184
  negative_prompt=negative_prompt if negative_prompt else None,
185
  guidance_scale=guidance_scale,
 
187
  width=width,
188
  height=height,
189
  generator=generator,
190
+ output_type="latent",
191
  callback_on_step_end=wave_tracker.callback,
192
+ )
193
 
194
  t_gen_end = time.time()
195
+ final_latents = pipeline_output.images
196
 
197
+ # --- Decode the Accumulated Wave Collapse Master Snapshot ---
 
198
  if wave_tracker.snapshot is not None and wave_tracker.cumulative_mask is not None:
199
+ unsettled = (~wave_tracker.cumulative_mask).expand_as(final_latents).to(final_latents.dtype)
200
+ final_snapshot = (wave_tracker.snapshot * (1.0 - unsettled)) + (final_latents * unsettled)
201
+ else:
202
+ final_snapshot = final_latents
203
 
204
+ with torch.no_grad():
205
+ # Upcast to prevent VAE black-screen bug
206
+ pipe.vae.to(dtype=torch.float32)
207
+ final_snapshot_fp32 = (final_snapshot / pipe.vae.config.scaling_factor).to(torch.float32)
208
+ collapse_tensor = pipe.vae.decode(final_snapshot_fp32, return_dict=False)[0]
209
+ final_image = pipe.image_processor.postprocess(collapse_tensor, output_type="pil")[0]
210
 
211
  total_time = t_gen_end - t_start
212
  step_summary = timer.summary()
213
  status = f"{'CLIP Mirror ON' if auto_anti_prompt else 'Standard CFG'} | {step_summary} | Total: {total_time:.1f}s"
214
 
215
+ return final_image, seed, status
216
 
217
  finally:
218
  torch.randn_like = _original_randn_like
 
264
 
265
  with gr.Column(scale=1):
266
  output_image = gr.Image(label="Final Generated Image")
 
267
 
268
  with gr.Row():
269
  output_seed = gr.Textbox(label="Used Seed", interactive=False)
 
276
  width, height, guidance_scale, num_inference_steps,
277
  auto_anti_prompt, epsilon
278
  ],
279
+ outputs=[output_image, output_seed, status_display],
280
  )
281
 
282
  if __name__ == "__main__":