Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Gradio Application for Stable Diffusion | |
| Author: Shilpaj Bhalerao | |
| Date: Feb 26, 2025 | |
| """ | |
| import gc | |
| import os | |
| import torch | |
| import gradio as gr | |
| # import spaces | |
| from tqdm.auto import tqdm | |
| from PIL import Image | |
| from utils import ( | |
| load_models, clear_gpu_memory, set_timesteps, latents_to_pil, | |
| vignette_loss, get_concept_embedding, image_grid | |
| ) | |
| # Remove this import to avoid the cached_download error | |
| # from diffusers import StableDiffusionPipeline | |
| def generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width): | |
| """ | |
| Function to generate latents from the UNet | |
| :param seed_number: Seed | |
| :param prompt: Text prompt | |
| :param concept: Concept to influence generation (optional) | |
| :param concept_strength: How strongly to apply the concept (0.0-1.0) | |
| :return: Latents of the UNet. This will be passed to the VAE to generate the image | |
| """ | |
| global art_concepts | |
| # Batch size | |
| batch_size = 1 | |
| # Set the seed | |
| generator = torch.manual_seed(seed) | |
| # Prep text | |
| text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") | |
| with torch.no_grad(): | |
| text_embeddings = text_encoder(text_input.input_ids.to(device))[0] | |
| # Get the concept embedding | |
| concept_embedding = art_concepts[concept] | |
| # Apply concept embedding influence if provided | |
| if concept_embedding is not None and concept_strength > 0: | |
| # Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed | |
| if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3: | |
| # Add batch dimension to concept_embedding to match text_embeddings | |
| concept_embedding = concept_embedding.unsqueeze(0) | |
| # Create weighted blend between original text embedding and concept | |
| if text_embeddings.shape == concept_embedding.shape: | |
| # Interpolate between text embeddings and concept | |
| text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding | |
| print(f"Successfully applied concept with strength {concept_strength}") | |
| else: | |
| print(f"Warning: Shapes still incompatible after adjustment. Concept: {concept_embedding.shape}, Text: {text_embeddings.shape}") | |
| # And the uncond. input as before: | |
| max_length = text_input.input_ids.shape[-1] | |
| uncond_input = tokenizer( | |
| [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| # Prep Scheduler | |
| set_timesteps(scheduler, num_inference_steps) | |
| # Prep latents | |
| latents = torch.randn( | |
| (batch_size, unet.in_channels, height // 8, width // 8), | |
| generator=generator, | |
| ) | |
| latents = latents.to(device) | |
| latents = latents * scheduler.init_noise_sigma | |
| # Loop | |
| for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): | |
| # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. | |
| latent_model_input = torch.cat([latents] * 2) | |
| sigma = scheduler.sigmas[i] | |
| latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| with torch.no_grad(): | |
| noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | |
| # perform CFG | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| #### ADDITIONAL GUIDANCE ### | |
| if i%5 == 0: | |
| # Requires grad on the latents | |
| latents = latents.detach().requires_grad_() | |
| # Get the predicted x0: | |
| latents_x0 = latents - sigma * noise_pred | |
| # latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample | |
| # Decode to image space | |
| denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1) | |
| # Calculate loss | |
| loss = vignette_loss(denoised_images) * vignette_loss_scale | |
| # Occasionally print it out | |
| if i%10==0: | |
| print(i, 'loss:', loss.item()) | |
| # Get gradient | |
| cond_grad = torch.autograd.grad(loss, latents)[0] | |
| # Modify the latents based on this gradient | |
| latents = latents.detach() - cond_grad * sigma**2 | |
| # Now step with scheduler | |
| latents = scheduler.step(noise_pred, t, latents).prev_sample | |
| return latents | |
| def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5, | |
| vignette_loss_scale=0.0, concept="none", concept_strength=0.5, height=512, width=512): | |
| """ | |
| Generate a single image | |
| """ | |
| global vae | |
| latents = generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width) | |
| generated_image = latents_to_pil(latents, vae) | |
| return image_grid(generated_image, 1, 1, None) | |
| def generate_style_images(prompt, num_inference_steps=30, guidance_scale=7.5, | |
| vignette_loss_scale=0.0, concept_strength=0.5, height=512, width=512): | |
| """ | |
| Function to generate images of all the styles | |
| """ | |
| global art_concepts, vae | |
| seed_list = [2000, 1000, 500, 600, 100] | |
| latents_collect = [] | |
| concept_labels = [] | |
| # Load and remove the "none" element | |
| concepts_list = list(art_concepts.keys()) | |
| concepts_list.remove("none") | |
| for seed_no, concept in zip(seed_list, concepts_list): | |
| # Clear the CUDA cache | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print(f"Generating image with concept '{concept}' at strength {concept_strength}") | |
| # Generate latents using the concept embedding | |
| latents = generate_latents(prompt, seed_no, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width) | |
| latents_collect.append(latents) | |
| concept_labels.append(f"{concept} ({concept_strength})") | |
| # Show results | |
| latents_collect = torch.vstack(latents_collect) | |
| images = latents_to_pil(latents_collect, vae) | |
| return image_grid(images, 1, len(seed_list), concept_labels) | |
| # Define Gradio interface | |
| # @spaces.GPU(enable_queue=False) | |
| def create_demo(): | |
| with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo: | |
| gr.Markdown("# Guided Stable Diffusion with Styles") | |
| with gr.Tab("Single Image Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| all_styles = ["none"] + list(art_concepts.keys()) | |
| all_styles.remove("none") # Remove "none" to avoid duplication | |
| all_styles = ["none"] + all_styles # Add it back at the beginning | |
| prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair") | |
| seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=1000) | |
| concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none") | |
| concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) | |
| num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| height = gr.Slider(minimum=256, maximum=1024, step=1, label="Height", value=512) | |
| width = gr.Slider(minimum=256, maximum=1024, step=1, label="Width", value=512) | |
| guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0) | |
| vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0) | |
| generate_btn = gr.Button("Generate Image") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image", type="pil") | |
| with gr.Tab("Style Grid"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park") | |
| grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0) | |
| grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0) | |
| grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) | |
| grid_generate_btn = gr.Button("Generate Style Grid") | |
| with gr.Column(): | |
| output_grid = gr.Image(label="Style Grid", type="pil") | |
| # Set up event handlers | |
| generate_btn.click( | |
| generate_image, | |
| inputs=[prompt, seed, num_inference_steps, guidance_scale, | |
| vignette_loss_scale, concept_style, concept_strength, height, width], | |
| outputs=output_image | |
| ) | |
| grid_generate_btn.click( | |
| generate_style_images, | |
| inputs=[grid_prompt, grid_num_inference_steps, | |
| grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength], | |
| outputs=output_grid | |
| ) | |
| return demo | |
| # Launch the app | |
| if __name__ == "__main__": | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
| if device == "mps": | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" | |
| # Load models | |
| vae, tokenizer, text_encoder, unet, scheduler, pipe = load_models(device=device) | |
| # Define art style concepts | |
| art_concepts = { | |
| "sketch_painting": get_concept_embedding("a sketch painting, pencil drawing, hand-drawn illustration", tokenizer, text_encoder, device), | |
| "oil_painting": get_concept_embedding("an oil painting, textured canvas, painterly technique", tokenizer, text_encoder, device), | |
| "watercolor": get_concept_embedding("a watercolor painting, fluid, soft edges", tokenizer, text_encoder, device), | |
| "digital_art": get_concept_embedding("digital art, computer generated, precise details", tokenizer, text_encoder, device), | |
| "comic_book": get_concept_embedding("comic book style, ink outlines, cel shading", tokenizer, text_encoder, device), | |
| "none": None | |
| } | |
| demo = create_demo() | |
| demo.launch(debug=True) | |