rahul7star commited on
Commit
66aa19e
·
verified ·
1 Parent(s): 2e8da0d

Update app_quant_latent.py

Browse files
Files changed (1) hide show
  1. app_quant_latent.py +98 -42
app_quant_latent.py CHANGED
@@ -247,55 +247,111 @@ log_system_stats("AFTER PIPELINE BUILD")
247
 
248
 
249
 
 
250
  @spaces.GPU
251
  def generate_image(prompt, height, width, steps, seed):
252
- global latent_history
253
- latent_history = [] # reset every run
254
-
255
- generator = torch.Generator("cuda").manual_seed(int(seed))
256
-
257
- logs = []
258
- def log(msg):
259
- logs.append(msg)
260
-
261
- # Run pipeline manually step by step
262
- out = pipe(
263
- prompt=prompt,
264
- height=height,
265
- width=width,
266
- num_inference_steps=steps,
267
- generator=generator,
268
- output_type="latent"
269
- )
270
-
271
- latents = out.latents
272
-
273
- # Denoising loop - MANUAL callback
274
- for i, t in enumerate(pipe.scheduler.timesteps):
275
- latents = pipe.unet(latents, t, encoder_hidden_states=out.prompt_embeds).sample
276
 
277
- # Store cloned latent
278
- latent_history.append(latents.detach().cpu().clone())
279
-
280
- # Log GPU memory
281
- gpu = torch.cuda.memory_allocated() / 1e9
282
- log(f"Step {i+1}/{steps} — GPU: {gpu:.2f} GB")
 
 
 
 
 
 
 
283
 
284
- # Step scheduler
285
- latents = pipe.scheduler.step(latents, timestep=t).prev_sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- # Decode final image
288
- final_image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor).sample[0]
289
- final_image = (final_image / 2 + 0.5).clamp(0,1).cpu().permute(1,2,0).numpy()
290
 
291
- # Convert latents to preview images
292
- latent_imgs = []
293
- for l in latent_history:
294
- img = pipe.vae.decode(l / pipe.vae.config.scaling_factor).sample[0]
295
- img = (img / 2 + 0.5).clamp(0,1).cpu().permute(1,2,0).numpy()
296
- latent_imgs.append(img)
297
 
298
- return final_image, latent_imgs, "\n".join(logs)
299
 
300
 
301
  # ============================================================
 
247
 
248
 
249
 
250
+
251
  @spaces.GPU
252
  def generate_image(prompt, height, width, steps, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ try:
255
+ # -----------------------------
256
+ # 1) SEED + LATENT INIT
257
+ # -----------------------------
258
+ generator = torch.Generator("cuda").manual_seed(seed)
259
+
260
+ # Unet input size = (B, C, H/8, W/8)
261
+ latent_shape = (
262
+ 1,
263
+ pipe.unet.config.in_channels,
264
+ height // 8,
265
+ width // 8
266
+ )
267
 
268
+ latents = torch.randn(latent_shape, generator=generator, device="cuda")
269
+ latents = latents * pipe.scheduler.init_noise_sigma
270
+
271
+ latent_history = []
272
+ log(f"Latent shape: {latent_shape}")
273
+
274
+ # -----------------------------
275
+ # 2) Text Embeddings
276
+ # -----------------------------
277
+ text_inputs = pipe.tokenizer(
278
+ prompt,
279
+ return_tensors="pt",
280
+ padding="max_length",
281
+ truncation=True,
282
+ max_length=pipe.tokenizer.model_max_length,
283
+ ).to("cuda")
284
+
285
+ text_embeddings = pipe.text_encoder(text_inputs.input_ids)[0]
286
+
287
+ # -----------------------------
288
+ # 3) Scheduler timesteps
289
+ # -----------------------------
290
+ pipe.scheduler.set_timesteps(steps, device="cuda")
291
+ timesteps = pipe.scheduler.timesteps
292
+
293
+ # -----------------------------
294
+ # 4) MANUAL DIFFUSION LOOP
295
+ # -----------------------------
296
+ for i, t in enumerate(timesteps):
297
+ with torch.no_grad():
298
+
299
+ # Forward UNET
300
+ noise_pred = pipe.unet(
301
+ latents,
302
+ t,
303
+ encoder_hidden_states=text_embeddings
304
+ ).sample
305
+
306
+ # Save latent copy
307
+ latent_history.append(
308
+ latents.detach().clone().to("cpu")
309
+ )
310
+
311
+ # Log GPU
312
+ gpu_gb = torch.cuda.memory_allocated() / 1e9
313
+ log(f"Step {i+1}/{steps} | t={int(t)} | GPU={gpu_gb:.2f} GB")
314
+
315
+ # Scheduler update
316
+ latents = pipe.scheduler.step(
317
+ noise_pred,
318
+ t,
319
+ latents
320
+ ).prev_sample
321
+
322
+ # -----------------------------
323
+ # 5) FINAL DECODE (VAE)
324
+ # -----------------------------
325
+ with torch.no_grad():
326
+ latents_final = latents / pipe.vae.config.scaling_factor
327
+ image = pipe.vae.decode(latents_final).sample[0]
328
+
329
+ # Convert to PIL
330
+ final_image = pipe.image_processor.postprocess(
331
+ image.unsqueeze(0),
332
+ output_type="pil"
333
+ )[0]
334
+
335
+ log("✅ Inference finished.")
336
+ log_system_stats("AFTER INFERENCE")
337
+
338
+ # -----------------------------
339
+ # Convert latent_history to images for gallery
340
+ # -----------------------------
341
+ latent_imgs = []
342
+ for lat in latent_history:
343
+ # Normalize each latent step into a displayable grayscale image
344
+ lat_img = lat[0, 0].cpu().numpy()
345
+ lat_img = (lat_img - lat_img.min()) / (lat_img.max() - lat_img.min() + 1e-8)
346
+ lat_img = (lat_img * 255).astype("uint8")
347
+ latent_imgs.append(Image.fromarray(lat_img))
348
 
349
+ return final_image, latent_imgs, LOGS
 
 
350
 
351
+ except Exception as e:
352
+ log(f"❌ Inference error: {e}")
353
+ return None, None, LOGS
 
 
 
354
 
 
355
 
356
 
357
  # ============================================================