Krishnakanth1993 commited on
Commit
6baaa4a
·
verified ·
1 Parent(s): d3998d8

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +109 -36
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
  Multi-Style Image Generator with Ice Crystal Effects
3
- Hugging Face Spaces App
4
  """
5
 
6
  import torch
@@ -10,6 +10,8 @@ from PIL import Image
10
  from pathlib import Path
11
  from tqdm.auto import tqdm
12
  import gradio as gr
 
 
13
 
14
  from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
15
  from transformers import CLIPTextModel, CLIPTokenizer
@@ -120,7 +122,40 @@ def load_models():
120
  raise RuntimeError(f"Failed to load models: {e}")
121
 
122
 
123
- def generate_with_style(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  style_file,
125
  prompt,
126
  seed=42,
@@ -131,13 +166,20 @@ def generate_with_style(
131
  use_ice_crystal_guidance=False,
132
  ice_crystal_loss_scale=50,
133
  guidance_frequency=10,
134
- progress=None
135
  ):
136
- """Generate an image using a style embedding with optional ice crystal guidance."""
 
 
 
 
137
  global vae, tokenizer, text_encoder, unet, scheduler, device
138
 
139
  load_models()
140
 
 
 
 
141
  generator = torch.Generator(device=device).manual_seed(seed)
142
  learned_embeds_dict = torch.load(style_file, map_location=device, weights_only=True)
143
 
@@ -194,10 +236,7 @@ def generate_with_style(
194
  scheduler.set_timesteps(num_inference_steps)
195
  latents = latents * scheduler.init_noise_sigma
196
 
197
- for i, t in enumerate(tqdm(scheduler.timesteps, desc="Generating")):
198
- if progress:
199
- progress((i + 1) / num_inference_steps, f"Step {i + 1}/{num_inference_steps}")
200
-
201
  latent_model_input = torch.cat([latents] * 2)
202
  latent_model_input = scheduler.scale_model_input(latent_model_input, t)
203
 
@@ -231,21 +270,37 @@ def generate_with_style(
231
  torch.cuda.empty_cache()
232
 
233
  latents = scheduler.step(noise_pred, t, latents).prev_sample
234
-
235
- latents = 1 / 0.18215 * latents
236
-
237
- with torch.no_grad():
238
- image = vae.decode(latents).sample
239
-
240
- image = (image / 2 + 0.5).clamp(0, 1)
241
- image = image.cpu().permute(0, 2, 3, 1).numpy()
242
- image = (image[0] * 255).astype(np.uint8)
243
- image = Image.fromarray(image)
244
-
245
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
 
248
- def generate_image(
249
  prompt,
250
  style_choice,
251
  custom_embedding,
@@ -253,9 +308,9 @@ def generate_image(
253
  guidance_scale,
254
  use_ice_crystal,
255
  ice_crystal_intensity,
256
- progress=gr.Progress()
257
  ):
258
- """Main generation function for Gradio interface."""
259
 
260
  if custom_embedding is not None:
261
  style_file = custom_embedding
@@ -268,16 +323,18 @@ def generate_image(
268
  raise gr.Error(f"Style embedding file not found: {style_file}")
269
 
270
  try:
271
- image = generate_with_style(
272
  style_file=style_file,
273
  prompt=prompt,
274
  seed=int(seed),
275
  guidance_scale=guidance_scale,
276
  use_ice_crystal_guidance=use_ice_crystal,
277
  ice_crystal_loss_scale=ice_crystal_intensity,
278
- progress=progress
279
- )
280
- return image
 
 
281
  except Exception as e:
282
  raise gr.Error(f"Generation failed: {str(e)}")
283
 
@@ -294,12 +351,13 @@ with gr.Blocks(
294
  # Multi-Style Image Generator with Ice Crystal Effects
295
 
296
  Generate images using textual inversion style embeddings with optional ice crystal overlay effects.
 
297
 
298
  **Instructions:**
299
  1. Enter a prompt using `<style>` as placeholder (e.g., "A cat in the style of <style>")
300
  2. Select a predefined style OR upload your own `.bin` embedding file
301
  3. Optionally enable ice crystal effect for a crystalline overlay
302
- 4. Click Generate!
303
  """)
304
 
305
  with gr.Row():
@@ -353,27 +411,42 @@ with gr.Blocks(
353
  info="Higher = stronger crystal effect"
354
  )
355
 
 
 
 
 
 
 
 
 
 
 
356
  generate_btn = gr.Button("Generate", variant="primary", size="lg")
 
357
 
358
  with gr.Column(scale=1):
359
  output_image = gr.Image(
360
- label="Generated Image",
361
  type="pil"
362
  )
 
 
 
 
363
 
364
  gr.Examples(
365
  examples=[
366
- ["A cat in the style of <style>", "8bit", None, 42, 7.5, False, 50],
367
- ["A mystical forest in the style of <style>", "dr_strange", None, 123, 7.5, False, 50],
368
- ["A portrait in the style of <style>", "max_naylor", None, 456, 7.5, True, 60],
369
  ],
370
- inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity],
371
  )
372
 
373
  generate_btn.click(
374
- fn=generate_image,
375
- inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity],
376
- outputs=output_image
377
  )
378
 
379
  if __name__ == "__main__":
 
1
  """
2
  Multi-Style Image Generator with Ice Crystal Effects
3
+ Hugging Face Spaces App - With Diffusion Progress Streaming
4
  """
5
 
6
  import torch
 
10
  from pathlib import Path
11
  from tqdm.auto import tqdm
12
  import gradio as gr
13
+ import io
14
+ import tempfile
15
 
16
  from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
17
  from transformers import CLIPTextModel, CLIPTokenizer
 
122
  raise RuntimeError(f"Failed to load models: {e}")
123
 
124
 
125
+ def decode_latents_to_image(latents_to_decode):
126
+ """Decode latents to PIL Image."""
127
+ global vae, device
128
+
129
+ with torch.no_grad():
130
+ latents_scaled = 1 / 0.18215 * latents_to_decode
131
+ image = vae.decode(latents_scaled).sample
132
+
133
+ image = (image / 2 + 0.5).clamp(0, 1)
134
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
135
+ image = (image[0] * 255).astype(np.uint8)
136
+ return Image.fromarray(image)
137
+
138
+
139
+ def create_gif_from_frames(frames, output_path=None, duration=200):
140
+ """Create an animated GIF from a list of PIL Images."""
141
+ if not frames:
142
+ return None
143
+
144
+ if output_path is None:
145
+ output_path = tempfile.mktemp(suffix='.gif')
146
+
147
+ # Save as GIF
148
+ frames[0].save(
149
+ output_path,
150
+ save_all=True,
151
+ append_images=frames[1:],
152
+ duration=duration,
153
+ loop=0
154
+ )
155
+ return output_path
156
+
157
+
158
+ def generate_with_style_streaming(
159
  style_file,
160
  prompt,
161
  seed=42,
 
166
  use_ice_crystal_guidance=False,
167
  ice_crystal_loss_scale=50,
168
  guidance_frequency=10,
169
+ preview_frequency=5
170
  ):
171
+ """
172
+ Generate an image with streaming updates.
173
+ Yields intermediate images during generation.
174
+ Returns final image and GIF path at the end.
175
+ """
176
  global vae, tokenizer, text_encoder, unet, scheduler, device
177
 
178
  load_models()
179
 
180
+ # Collect frames for GIF
181
+ frames = []
182
+
183
  generator = torch.Generator(device=device).manual_seed(seed)
184
  learned_embeds_dict = torch.load(style_file, map_location=device, weights_only=True)
185
 
 
236
  scheduler.set_timesteps(num_inference_steps)
237
  latents = latents * scheduler.init_noise_sigma
238
 
239
+ for i, t in enumerate(scheduler.timesteps):
 
 
 
240
  latent_model_input = torch.cat([latents] * 2)
241
  latent_model_input = scheduler.scale_model_input(latent_model_input, t)
242
 
 
270
  torch.cuda.empty_cache()
271
 
272
  latents = scheduler.step(noise_pred, t, latents).prev_sample
273
+
274
+ # Decode and yield intermediate preview every N steps
275
+ if i % preview_frequency == 0 or i == num_inference_steps - 1:
276
+ preview_image = decode_latents_to_image(latents)
277
+ frames.append(preview_image)
278
+
279
+ # Yield progress update: (step, total, current_image, gif_path)
280
+ yield {
281
+ "step": i + 1,
282
+ "total": num_inference_steps,
283
+ "image": preview_image,
284
+ "gif": None # GIF not ready yet
285
+ }
286
+
287
+ # Final decode
288
+ final_image = decode_latents_to_image(latents)
289
+ frames.append(final_image)
290
+
291
+ # Create GIF from all frames
292
+ gif_path = create_gif_from_frames(frames, duration=300)
293
+
294
+ # Yield final result
295
+ yield {
296
+ "step": num_inference_steps,
297
+ "total": num_inference_steps,
298
+ "image": final_image,
299
+ "gif": gif_path
300
+ }
301
 
302
 
303
+ def generate_image_streaming(
304
  prompt,
305
  style_choice,
306
  custom_embedding,
 
308
  guidance_scale,
309
  use_ice_crystal,
310
  ice_crystal_intensity,
311
+ preview_frequency
312
  ):
313
+ """Streaming generation function for Gradio interface."""
314
 
315
  if custom_embedding is not None:
316
  style_file = custom_embedding
 
323
  raise gr.Error(f"Style embedding file not found: {style_file}")
324
 
325
  try:
326
+ for update in generate_with_style_streaming(
327
  style_file=style_file,
328
  prompt=prompt,
329
  seed=int(seed),
330
  guidance_scale=guidance_scale,
331
  use_ice_crystal_guidance=use_ice_crystal,
332
  ice_crystal_loss_scale=ice_crystal_intensity,
333
+ preview_frequency=int(preview_frequency)
334
+ ):
335
+ status = f"Step {update['step']}/{update['total']}"
336
+ yield update["image"], update["gif"], status
337
+
338
  except Exception as e:
339
  raise gr.Error(f"Generation failed: {str(e)}")
340
 
 
351
  # Multi-Style Image Generator with Ice Crystal Effects
352
 
353
  Generate images using textual inversion style embeddings with optional ice crystal overlay effects.
354
+ **Now with live diffusion progress streaming!**
355
 
356
  **Instructions:**
357
  1. Enter a prompt using `<style>` as placeholder (e.g., "A cat in the style of <style>")
358
  2. Select a predefined style OR upload your own `.bin` embedding file
359
  3. Optionally enable ice crystal effect for a crystalline overlay
360
+ 4. Click Generate and watch the image evolve!
361
  """)
362
 
363
  with gr.Row():
 
411
  info="Higher = stronger crystal effect"
412
  )
413
 
414
+ with gr.Accordion("Streaming Settings", open=True):
415
+ preview_frequency = gr.Slider(
416
+ label="Preview Frequency",
417
+ minimum=1,
418
+ maximum=10,
419
+ value=5,
420
+ step=1,
421
+ info="Show preview every N steps (lower = more updates, slower)"
422
+ )
423
+
424
  generate_btn = gr.Button("Generate", variant="primary", size="lg")
425
+ status_text = gr.Textbox(label="Status", interactive=False, value="Ready")
426
 
427
  with gr.Column(scale=1):
428
  output_image = gr.Image(
429
+ label="Live Preview / Final Image",
430
  type="pil"
431
  )
432
+ output_gif = gr.File(
433
+ label="Diffusion Progress GIF (available after generation)",
434
+ type="filepath"
435
+ )
436
 
437
  gr.Examples(
438
  examples=[
439
+ ["A cat in the style of <style>", "8bit", None, 42, 7.5, False, 50, 5],
440
+ ["A mystical forest in the style of <style>", "dr_strange", None, 123, 7.5, False, 50, 5],
441
+ ["A portrait in the style of <style>", "max_naylor", None, 456, 7.5, True, 60, 5],
442
  ],
443
+ inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity, preview_frequency],
444
  )
445
 
446
  generate_btn.click(
447
+ fn=generate_image_streaming,
448
+ inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity, preview_frequency],
449
+ outputs=[output_image, output_gif, status_text]
450
  )
451
 
452
  if __name__ == "__main__":