Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -84,13 +84,9 @@ class WaveCollapseTracker:
|
|
| 84 |
self.snapshot = None
|
| 85 |
|
| 86 |
def callback(self, pipe, step_index, timestep, callback_kwargs):
|
| 87 |
-
# 1. Trigger the step timer
|
| 88 |
self.timer.step(step_index + 1)
|
| 89 |
-
|
| 90 |
-
# 2. Extract current latents
|
| 91 |
latents = callback_kwargs["latents"]
|
| 92 |
|
| 93 |
-
# 3. Wave Collapse Math
|
| 94 |
if self.prev_latents is not None:
|
| 95 |
delta = (latents - self.prev_latents).abs().mean(dim=1, keepdim=True)
|
| 96 |
new_settled = delta < self.epsilon
|
|
@@ -159,6 +155,7 @@ def infer(
|
|
| 159 |
try:
|
| 160 |
torch.randn_like = _locked_randn_like
|
| 161 |
|
|
|
|
| 162 |
if auto_anti_prompt and prompt:
|
| 163 |
pos_hidden, pos_pooled = encode_prompt_sdxl(pipe, prompt, device)
|
| 164 |
neg_hidden = -pos_hidden
|
|
@@ -166,7 +163,7 @@ def infer(
|
|
| 166 |
|
| 167 |
wave_tracker.timer.start()
|
| 168 |
t_gen_start = time.time()
|
| 169 |
-
|
| 170 |
prompt_embeds=pos_hidden,
|
| 171 |
negative_prompt_embeds=neg_hidden,
|
| 172 |
pooled_prompt_embeds=pos_pooled,
|
|
@@ -176,12 +173,13 @@ def infer(
|
|
| 176 |
width=width,
|
| 177 |
height=height,
|
| 178 |
generator=generator,
|
|
|
|
| 179 |
callback_on_step_end=wave_tracker.callback,
|
| 180 |
-
)
|
| 181 |
else:
|
| 182 |
wave_tracker.timer.start()
|
| 183 |
t_gen_start = time.time()
|
| 184 |
-
|
| 185 |
prompt=prompt,
|
| 186 |
negative_prompt=negative_prompt if negative_prompt else None,
|
| 187 |
guidance_scale=guidance_scale,
|
|
@@ -189,29 +187,32 @@ def infer(
|
|
| 189 |
width=width,
|
| 190 |
height=height,
|
| 191 |
generator=generator,
|
|
|
|
| 192 |
callback_on_step_end=wave_tracker.callback,
|
| 193 |
-
)
|
| 194 |
|
| 195 |
t_gen_end = time.time()
|
|
|
|
| 196 |
|
| 197 |
-
# --- Decode the Wave Collapse
|
| 198 |
-
collapse_image = None
|
| 199 |
if wave_tracker.snapshot is not None and wave_tracker.cumulative_mask is not None:
|
| 200 |
-
unsettled = (~wave_tracker.cumulative_mask).expand_as(
|
| 201 |
-
final_snapshot = (wave_tracker.snapshot * (1.0 - unsettled)) + (
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
|
| 210 |
total_time = t_gen_end - t_start
|
| 211 |
step_summary = timer.summary()
|
| 212 |
status = f"{'CLIP Mirror ON' if auto_anti_prompt else 'Standard CFG'} | {step_summary} | Total: {total_time:.1f}s"
|
| 213 |
|
| 214 |
-
return
|
| 215 |
|
| 216 |
finally:
|
| 217 |
torch.randn_like = _original_randn_like
|
|
@@ -263,7 +264,6 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
| 263 |
|
| 264 |
with gr.Column(scale=1):
|
| 265 |
output_image = gr.Image(label="Final Generated Image")
|
| 266 |
-
collapse_map = gr.Image(label="Wave Collapse Visualization")
|
| 267 |
|
| 268 |
with gr.Row():
|
| 269 |
output_seed = gr.Textbox(label="Used Seed", interactive=False)
|
|
@@ -276,7 +276,7 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
| 276 |
width, height, guidance_scale, num_inference_steps,
|
| 277 |
auto_anti_prompt, epsilon
|
| 278 |
],
|
| 279 |
-
outputs=[output_image,
|
| 280 |
)
|
| 281 |
|
| 282 |
if __name__ == "__main__":
|
|
|
|
| 84 |
self.snapshot = None
|
| 85 |
|
| 86 |
def callback(self, pipe, step_index, timestep, callback_kwargs):
|
|
|
|
| 87 |
self.timer.step(step_index + 1)
|
|
|
|
|
|
|
| 88 |
latents = callback_kwargs["latents"]
|
| 89 |
|
|
|
|
| 90 |
if self.prev_latents is not None:
|
| 91 |
delta = (latents - self.prev_latents).abs().mean(dim=1, keepdim=True)
|
| 92 |
new_settled = delta < self.epsilon
|
|
|
|
| 155 |
try:
|
| 156 |
torch.randn_like = _locked_randn_like
|
| 157 |
|
| 158 |
+
# We pass output_type="latent" so the pipeline stops before the VAE decode
|
| 159 |
if auto_anti_prompt and prompt:
|
| 160 |
pos_hidden, pos_pooled = encode_prompt_sdxl(pipe, prompt, device)
|
| 161 |
neg_hidden = -pos_hidden
|
|
|
|
| 163 |
|
| 164 |
wave_tracker.timer.start()
|
| 165 |
t_gen_start = time.time()
|
| 166 |
+
pipeline_output = pipe(
|
| 167 |
prompt_embeds=pos_hidden,
|
| 168 |
negative_prompt_embeds=neg_hidden,
|
| 169 |
pooled_prompt_embeds=pos_pooled,
|
|
|
|
| 173 |
width=width,
|
| 174 |
height=height,
|
| 175 |
generator=generator,
|
| 176 |
+
output_type="latent",
|
| 177 |
callback_on_step_end=wave_tracker.callback,
|
| 178 |
+
)
|
| 179 |
else:
|
| 180 |
wave_tracker.timer.start()
|
| 181 |
t_gen_start = time.time()
|
| 182 |
+
pipeline_output = pipe(
|
| 183 |
prompt=prompt,
|
| 184 |
negative_prompt=negative_prompt if negative_prompt else None,
|
| 185 |
guidance_scale=guidance_scale,
|
|
|
|
| 187 |
width=width,
|
| 188 |
height=height,
|
| 189 |
generator=generator,
|
| 190 |
+
output_type="latent",
|
| 191 |
callback_on_step_end=wave_tracker.callback,
|
| 192 |
+
)
|
| 193 |
|
| 194 |
t_gen_end = time.time()
|
| 195 |
+
final_latents = pipeline_output.images
|
| 196 |
|
| 197 |
+
# --- Decode the Accumulated Wave Collapse Master Snapshot ---
|
|
|
|
| 198 |
if wave_tracker.snapshot is not None and wave_tracker.cumulative_mask is not None:
|
| 199 |
+
unsettled = (~wave_tracker.cumulative_mask).expand_as(final_latents).to(final_latents.dtype)
|
| 200 |
+
final_snapshot = (wave_tracker.snapshot * (1.0 - unsettled)) + (final_latents * unsettled)
|
| 201 |
+
else:
|
| 202 |
+
final_snapshot = final_latents
|
| 203 |
|
| 204 |
+
with torch.no_grad():
|
| 205 |
+
# Upcast to prevent VAE black-screen bug
|
| 206 |
+
pipe.vae.to(dtype=torch.float32)
|
| 207 |
+
final_snapshot_fp32 = (final_snapshot / pipe.vae.config.scaling_factor).to(torch.float32)
|
| 208 |
+
collapse_tensor = pipe.vae.decode(final_snapshot_fp32, return_dict=False)[0]
|
| 209 |
+
final_image = pipe.image_processor.postprocess(collapse_tensor, output_type="pil")[0]
|
| 210 |
|
| 211 |
total_time = t_gen_end - t_start
|
| 212 |
step_summary = timer.summary()
|
| 213 |
status = f"{'CLIP Mirror ON' if auto_anti_prompt else 'Standard CFG'} | {step_summary} | Total: {total_time:.1f}s"
|
| 214 |
|
| 215 |
+
return final_image, seed, status
|
| 216 |
|
| 217 |
finally:
|
| 218 |
torch.randn_like = _original_randn_like
|
|
|
|
| 264 |
|
| 265 |
with gr.Column(scale=1):
|
| 266 |
output_image = gr.Image(label="Final Generated Image")
|
|
|
|
| 267 |
|
| 268 |
with gr.Row():
|
| 269 |
output_seed = gr.Textbox(label="Used Seed", interactive=False)
|
|
|
|
| 276 |
width, height, guidance_scale, num_inference_steps,
|
| 277 |
auto_anti_prompt, epsilon
|
| 278 |
],
|
| 279 |
+
outputs=[output_image, output_seed, status_display],
|
| 280 |
)
|
| 281 |
|
| 282 |
if __name__ == "__main__":
|