rahul7star commited on
Commit
cbb491e
·
verified ·
1 Parent(s): 66aa19e

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +70 -100
app_quant_latent.py CHANGED
@@ -250,107 +250,77 @@ log_system_stats("AFTER PIPELINE BUILD")
250
 
251
  @spaces.GPU
252
  def generate_image(prompt, height, width, steps, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- try:
255
- # -----------------------------
256
- # 1) SEED + LATENT INIT
257
- # -----------------------------
258
- generator = torch.Generator("cuda").manual_seed(seed)
259
-
260
- # Unet input size = (B, C, H/8, W/8)
261
- latent_shape = (
262
- 1,
263
- pipe.unet.config.in_channels,
264
- height // 8,
265
- width // 8
266
- )
267
-
268
- latents = torch.randn(latent_shape, generator=generator, device="cuda")
269
- latents = latents * pipe.scheduler.init_noise_sigma
270
-
271
- latent_history = []
272
- log(f"Latent shape: {latent_shape}")
273
-
274
- # -----------------------------
275
- # 2) Text Embeddings
276
- # -----------------------------
277
- text_inputs = pipe.tokenizer(
278
- prompt,
279
- return_tensors="pt",
280
- padding="max_length",
281
- truncation=True,
282
- max_length=pipe.tokenizer.model_max_length,
283
- ).to("cuda")
284
-
285
- text_embeddings = pipe.text_encoder(text_inputs.input_ids)[0]
286
-
287
- # -----------------------------
288
- # 3) Scheduler timesteps
289
- # -----------------------------
290
- pipe.scheduler.set_timesteps(steps, device="cuda")
291
- timesteps = pipe.scheduler.timesteps
292
-
293
- # -----------------------------
294
- # 4) MANUAL DIFFUSION LOOP
295
- # -----------------------------
296
- for i, t in enumerate(timesteps):
297
- with torch.no_grad():
298
-
299
- # Forward UNET
300
- noise_pred = pipe.unet(
301
- latents,
302
- t,
303
- encoder_hidden_states=text_embeddings
304
- ).sample
305
-
306
- # Save latent copy
307
- latent_history.append(
308
- latents.detach().clone().to("cpu")
309
- )
310
-
311
- # Log GPU
312
- gpu_gb = torch.cuda.memory_allocated() / 1e9
313
- log(f"Step {i+1}/{steps} | t={int(t)} | GPU={gpu_gb:.2f} GB")
314
-
315
- # Scheduler update
316
- latents = pipe.scheduler.step(
317
- noise_pred,
318
- t,
319
- latents
320
- ).prev_sample
321
-
322
- # -----------------------------
323
- # 5) FINAL DECODE (VAE)
324
- # -----------------------------
325
- with torch.no_grad():
326
- latents_final = latents / pipe.vae.config.scaling_factor
327
- image = pipe.vae.decode(latents_final).sample[0]
328
-
329
- # Convert to PIL
330
- final_image = pipe.image_processor.postprocess(
331
- image.unsqueeze(0),
332
- output_type="pil"
333
- )[0]
334
-
335
- log("✅ Inference finished.")
336
- log_system_stats("AFTER INFERENCE")
337
-
338
- # -----------------------------
339
- # Convert latent_history to images for gallery
340
- # -----------------------------
341
- latent_imgs = []
342
- for lat in latent_history:
343
- # Normalize each latent step into a displayable grayscale image
344
- lat_img = lat[0, 0].cpu().numpy()
345
- lat_img = (lat_img - lat_img.min()) / (lat_img.max() - lat_img.min() + 1e-8)
346
- lat_img = (lat_img * 255).astype("uint8")
347
- latent_imgs.append(Image.fromarray(lat_img))
348
-
349
- return final_image, latent_imgs, LOGS
350
-
351
- except Exception as e:
352
- log(f"❌ Inference error: {e}")
353
- return None, None, LOGS
354
 
355
 
356
 
 
250
 
251
  @spaces.GPU
252
  def generate_image(prompt, height, width, steps, seed):
253
+
254
+ try:
255
+ generator = torch.Generator(device).manual_seed(int(seed))
256
+ latent_history = []
257
+
258
+ # callback signature expected by ZImagePipeline:
259
+ # callback_on_step_end(self_pipeline, step_index, timestep, callback_kwargs_dict)
260
+ def save_latents(self_pipeline, step_idx, timestep, callback_kwargs):
261
+ # callback_kwargs contains tensor inputs specified by
262
+ # callback_on_step_end_tensor_inputs (defaults to ["latents"])
263
+ try:
264
+ lat = callback_kwargs.get("latents", None)
265
+ if lat is not None:
266
+ # store CPU copy to avoid holding GPU memory
267
+ latent_history.append(lat.detach().clone().cpu())
268
+ # we must return a dict (may include overrides), here no overrides:
269
+ return {}
270
+ except Exception as e:
271
+ log(f"⚠️ save_latents error: {e}")
272
+ return {}
273
+
274
+ # Run pipeline once, using the pipeline's callback mechanism
275
+ out = pipe(
276
+ prompt=prompt,
277
+ height=height,
278
+ width=width,
279
+ num_inference_steps=steps,
280
+ guidance_scale=0.0,
281
+ generator=generator,
282
+ callback_on_step_end=save_latents,
283
+ callback_on_step_end_tensor_inputs=["latents"], # ensure latents passed to callback
284
+ )
285
+
286
+ # out is a ZImagePipelineOutput; pipeline already postprocessed images
287
+ final_image = out.images[0] if hasattr(out, "images") and len(out.images) > 0 else out
288
+
289
+ # Convert saved latents into displayable images (use same postprocessing as pipeline)
290
+ latent_images = []
291
+ try:
292
+ # Determine decode device and dtype
293
+ vae = pipe.vae
294
+ img_proc = pipe.image_processor
295
+ vae_device = vae.device if hasattr(vae, "device") else device
296
+
297
+ for i, lat_cpu in enumerate(latent_history):
298
+ try:
299
+ # move to vae device and dtype
300
+ lat = lat_cpu.to(vae_device).to(vae.dtype)
301
+
302
+ # pipeline used this transform before decoding:
303
+ lat = (lat / vae.config.scaling_factor) + getattr(vae.config, "shift_factor", 0.0)
304
+
305
+ # decode: vae.decode returns (batch, C, H, W)
306
+ img_tensor = vae.decode(lat, return_dict=False)[0]
307
+
308
+ # postprocess with pipeline's image processor to PIL
309
+ pil = img_proc.postprocess(img_tensor.unsqueeze(0), output_type="pil")[0]
310
+ latent_images.append(pil)
311
+ except Exception as e:
312
+ log(f"⚠️ Failed to decode latent step {i}: {e}")
313
+ except Exception as e:
314
+ log(f"⚠️ Error while converting latents: {e}")
315
+
316
+ log("✅ Inference finished.")
317
+ log_system_stats("AFTER INFERENCE")
318
+
319
+ return final_image, latent_images, LOGS
320
 
321
+ except Exception as e:
322
+ log(f"❌ Inference error: {e}")
323
+ return None, [], LOGS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
 
326