Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from tqdm.auto import tqdm | |
| import os | |
| # Set device | |
| torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load models | |
| print("Loading models...") | |
| vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | |
| unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") | |
| vae = vae.to(torch_device) | |
| text_encoder = text_encoder.to(torch_device) | |
| unet = unet.to(torch_device) | |
| # Scheduler | |
| scheduler = LMSDiscreteScheduler( | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| num_train_timesteps=1000 | |
| ) | |
| # Style embeddings mapping (only 768-dimensional embeddings compatible with SD 1.4) | |
| STYLE_EMBEDDINGS = { | |
| "Bird Style": ("learned_embeds/bird-learned_embeds.bin", "<birb-style>"), | |
| "Shigure UI Art": ("learned_embeds/shigure-ui-learned_embeds.bin", "<shigure-ui>"), | |
| "Takuji Kawano Art": ("learned_embeds/takuji-kawano-learned_embeds.bin", "<takuji-kawano>"), | |
| } | |
| # Track which embeddings have been loaded | |
| loaded_tokens = set() | |
| def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token): | |
| """Load learned embedding into the text encoder (only once per token)""" | |
| global loaded_tokens | |
| # Skip if already loaded | |
| if token in loaded_tokens: | |
| return token | |
| loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
| # Get the embedding | |
| if isinstance(loaded_learned_embeds, dict): | |
| if token in loaded_learned_embeds: | |
| trained_token = loaded_learned_embeds[token] | |
| else: | |
| # Take the first embedding | |
| trained_token = list(loaded_learned_embeds.values())[0] | |
| else: | |
| trained_token = loaded_learned_embeds | |
| # Verify dimensions match (768 for SD 1.4) | |
| if trained_token.shape[0] != text_encoder.get_input_embeddings().weight.shape[1]: | |
| raise ValueError( | |
| f"Embedding dimension mismatch: {trained_token.shape[0]} vs " | |
| f"{text_encoder.get_input_embeddings().weight.shape[1]}. " | |
| f"This embedding is not compatible with SD 1.4." | |
| ) | |
| # Add token to tokenizer | |
| num_added_tokens = tokenizer.add_tokens(token) | |
| # Resize token embeddings if we added a new token | |
| if num_added_tokens > 0: | |
| text_encoder.resize_token_embeddings(len(tokenizer)) | |
| # Get token id | |
| token_id = tokenizer.convert_tokens_to_ids(token) | |
| # Set the embedding | |
| text_encoder.get_input_embeddings().weight.data[token_id] = trained_token | |
| # Mark as loaded | |
| loaded_tokens.add(token) | |
| return token | |
| def neon_cyberpunk_loss(img): | |
| """ | |
| Custom loss to guide generation toward neon cyberpunk aesthetic: | |
| - Vibrant neon colors (cyan, magenta, purple, pink) | |
| - High saturation and contrast | |
| - Dark backgrounds with bright highlights | |
| - Futuristic vibe | |
| """ | |
| # Extract RGB channels | |
| r = img[:, 0] | |
| g = img[:, 1] | |
| b = img[:, 2] | |
| # 1. Boost Neon Colors (Cyan, Magenta, Purple tones) | |
| # Cyan: high G and B, low R | |
| cyan_score = (g + b - r).clamp(0, 1).mean() | |
| # Magenta: high R and B, low G | |
| magenta_score = (r + b - g).clamp(0, 1).mean() | |
| # Purple/Pink: high R and B | |
| purple_score = (r * b).mean() | |
| # Maximize neon color presence | |
| neon_color_loss = -(cyan_score + magenta_score + purple_score) / 3 | |
| # 2. Increase Saturation (difference between channels) | |
| saturation = torch.stack([r, g, b], dim=1).std(dim=1).mean() | |
| saturation_loss = -saturation # maximize saturation | |
| # 3. High Contrast (bright highlights on dark backgrounds) | |
| contrast = img.std() | |
| contrast_loss = -contrast # maximize contrast | |
| # 4. Boost brightness of bright areas (neon glow effect) | |
| brightness_mask = (img.mean(dim=1, keepdim=True) > 0.5).float() | |
| bright_areas = (img * brightness_mask).mean() | |
| brightness_loss = -bright_areas # maximize brightness in bright areas | |
| # 5. Darken dark areas (cyberpunk has dark backgrounds) | |
| dark_mask = (img.mean(dim=1, keepdim=True) < 0.5).float() | |
| dark_areas = (img * dark_mask).mean() | |
| darkness_loss = dark_areas # minimize brightness in dark areas | |
| # Weighted combination for maximum visual impact | |
| total = ( | |
| 2.0 * neon_color_loss + # Strong emphasis on neon colors | |
| 1.5 * saturation_loss + # High saturation | |
| 1.0 * contrast_loss + # Strong contrast | |
| 0.8 * brightness_loss + # Bright neon highlights | |
| 0.5 * darkness_loss # Dark backgrounds | |
| ) | |
| return total | |
| def generate_image( | |
| prompt, | |
| style_name, | |
| seed, | |
| apply_loss=False, | |
| loss_scale=200, | |
| height=512, | |
| width=512, | |
| num_inference_steps=50, | |
| guidance_scale=8 | |
| ): | |
| """Generate image with optional neon cyberpunk loss""" | |
| # Load the style embedding | |
| if style_name in STYLE_EMBEDDINGS: | |
| embed_path, token_name = STYLE_EMBEDDINGS[style_name] | |
| if os.path.exists(embed_path): | |
| token = load_learned_embed_in_clip(embed_path, text_encoder, tokenizer, token=token_name) | |
| # Add token to prompt | |
| prompt = f"{prompt} in the style of {token}" | |
| # Set seed | |
| generator = torch.manual_seed(seed) | |
| # Prepare text embeddings | |
| text_input = tokenizer( | |
| [prompt], | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] | |
| # Unconditional embeddings for classifier-free guidance | |
| max_length = text_input.input_ids.shape[-1] | |
| uncond_input = tokenizer( | |
| [""], | |
| padding="max_length", | |
| max_length=max_length, | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] | |
| # Concatenate for classifier-free guidance | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| # Prepare latents | |
| latents = torch.randn( | |
| (1, unet.config.in_channels, height // 8, width // 8), | |
| generator=generator, | |
| ).to(torch_device) | |
| # Set scheduler | |
| scheduler.set_timesteps(num_inference_steps) | |
| latents = latents * scheduler.init_noise_sigma | |
| # Denoising loop | |
| for i, t in enumerate(tqdm(scheduler.timesteps)): | |
| # Expand latents for classifier-free guidance | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = scheduler.scale_model_input(latent_model_input, t) | |
| # Predict noise residual | |
| with torch.no_grad(): | |
| noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | |
| # Perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # Apply loss every 5 steps if enabled | |
| if apply_loss and i % 5 == 0: | |
| # Compute what the image would look like (need gradients for loss) | |
| latents_x0 = latents - (scheduler.sigmas[i] * noise_pred) | |
| latents_x0 = latents_x0.detach().requires_grad_(True) | |
| # Decode to image space (without no_grad so we can backprop) | |
| denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 | |
| # Calculate loss | |
| loss = neon_cyberpunk_loss(denoised_images) * loss_scale | |
| # Get gradients | |
| cond_grad = torch.autograd.grad(loss, latents_x0)[0] | |
| # Modify noise prediction | |
| noise_pred = noise_pred - (scheduler.sigmas[i] * cond_grad) | |
| # Compute previous noisy sample | |
| latents = scheduler.step(noise_pred, t, latents).prev_sample | |
| # Decode latents to image | |
| with torch.no_grad(): | |
| latents = 1 / 0.18215 * latents | |
| image = vae.decode(latents).sample | |
| # Convert to PIL | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
| image = (image * 255).round().astype("uint8") | |
| pil_image = Image.fromarray(image[0]) | |
| return pil_image | |
| def generate_comparison(prompt, style_name, seed): | |
| """Generate comparison with and without neon cyberpunk loss""" | |
| # Generate without loss | |
| img_without = generate_image( | |
| prompt=prompt, | |
| style_name=style_name, | |
| seed=seed, | |
| apply_loss=False | |
| ) | |
| # Generate with neon cyberpunk loss | |
| img_with = generate_image( | |
| prompt=prompt, | |
| style_name=style_name, | |
| seed=seed, | |
| apply_loss=True, | |
| loss_scale=200 | |
| ) | |
| return img_without, img_with | |
| def generate_all_styles(prompt, seed1, seed2, seed3): | |
| """Generate images for all 3 styles with comparison""" | |
| styles = list(STYLE_EMBEDDINGS.keys()) | |
| seeds = [seed1, seed2, seed3] | |
| results = [] | |
| for style, seed in zip(styles, seeds): | |
| img_without, img_with = generate_comparison(prompt, style, seed) | |
| results.extend([img_without, img_with]) | |
| return results | |
| # Create Gradio interface | |
| with gr.Blocks(title="Stable Diffusion with Neon Cyberpunk Loss", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌆 Stable Diffusion with Neon Cyberpunk Loss | |
| This app demonstrates textual inversion with 3 different learned styles and applies a custom **Neon Cyberpunk Loss** | |
| that transforms images into vibrant cyberpunk scenes with neon colors (cyan, magenta, purple), high saturation, | |
| and dramatic contrast between dark backgrounds and bright neon highlights. | |
| ## Features: | |
| - **3 Different Styles**: Bird Style, Shigure UI Art, Takuji Kawano Art | |
| - **Custom Neon Cyberpunk Loss**: Creates futuristic neon aesthetic with vibrant colors | |
| - **Seed Control**: Different seeds for reproducible results | |
| ⏱️ **Note**: This process can take up to 10 minutes to run. Perfect time to grab a coffee! ☕ | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter your prompt here...", | |
| value="A beautiful landscape with mountains" | |
| ) | |
| with gr.Row(): | |
| seed1 = gr.Number(label="Seed for Style 1 (Bird Style)", value=42, precision=0) | |
| seed2 = gr.Number(label="Seed for Style 2 (Shigure UI)", value=123, precision=0) | |
| seed3 = gr.Number(label="Seed for Style 3 (Takuji Kawano)", value=456, precision=0) | |
| generate_btn = gr.Button("🎨 Generate All Comparisons", variant="primary", size="lg") | |
| gr.Markdown("### Results: Left = Original | Right = With Neon Cyberpunk Loss") | |
| with gr.Row(): | |
| gr.Markdown("#### Style 1: Bird Style") | |
| with gr.Row(): | |
| out1_without = gr.Image(label="Original") | |
| out1_with = gr.Image(label="Neon Cyberpunk") | |
| with gr.Row(): | |
| gr.Markdown("#### Style 2: Shigure UI Art") | |
| with gr.Row(): | |
| out2_without = gr.Image(label="Original") | |
| out2_with = gr.Image(label="Neon Cyberpunk") | |
| with gr.Row(): | |
| gr.Markdown("#### Style 3: Takuji Kawano Art") | |
| with gr.Row(): | |
| out3_without = gr.Image(label="Original") | |
| out3_with = gr.Image(label="Neon Cyberpunk") | |
| # Connect the button | |
| generate_btn.click( | |
| fn=generate_all_styles, | |
| inputs=[prompt_input, seed1, seed2, seed3], | |
| outputs=[ | |
| out1_without, out1_with, | |
| out2_without, out2_with, | |
| out3_without, out3_with | |
| ] | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### About the Neon Cyberpunk Loss | |
| The **Neon Cyberpunk Loss** is a creative guidance technique that transforms images into futuristic cyberpunk scenes: | |
| - **Neon Colors**: Maximizes cyan, magenta, and purple tones for that distinctive neon glow | |
| - **High Saturation**: Boosts color vibrancy to create electric, vivid scenes | |
| - **Dramatic Contrast**: Creates dark backgrounds with bright neon highlights | |
| - **Glow Effect**: Enhances brightness in highlight areas while darkening shadows | |
| This demonstrates how custom loss functions can dramatically alter the aesthetic and mood of generated images, | |
| going far beyond simple color adjustments to create an entirely different visual style. | |
| **Seeds Used**: Different seeds ensure variety across the three styles while maintaining reproducibility. | |
| ### Assignment Info | |
| - **Task**: Demonstrate 3 different styles with creative custom loss (not standard RGB) | |
| - **Implementation**: Uses textual inversion embeddings + custom neon cyberpunk loss during inference | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| torch.manual_seed(1) | |
| demo.launch(share=False, server_name="0.0.0.0", server_port=7860) | |