Spaces:
Sleeping
Sleeping
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from torch import autocast | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| from pathlib import Path | |
| import traceback | |
| import glob | |
| from PIL import Image | |
| # Reuse the same load_learned_embed_in_clip and Distance_loss functions | |
| def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token=None): | |
| loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu") | |
| trained_token = list(loaded_learned_embeds.keys())[0] | |
| embeds = loaded_learned_embeds[trained_token] | |
| # Get the expected dimension from the text encoder | |
| expected_dim = text_encoder.get_input_embeddings().weight.shape[1] | |
| current_dim = embeds.shape[0] | |
| # Resize embeddings if dimensions don't match | |
| if current_dim != expected_dim: | |
| print(f"Resizing embedding from {current_dim} to {expected_dim}") | |
| # Option 1: Truncate or pad with zeros | |
| if current_dim > expected_dim: | |
| embeds = embeds[:expected_dim] | |
| else: | |
| embeds = torch.cat([embeds, torch.zeros(expected_dim - current_dim)], dim=0) | |
| # Reshape to match expected dimensions | |
| embeds = embeds.unsqueeze(0) # Add batch dimension | |
| # Cast to dtype of text_encoder | |
| dtype = text_encoder.get_input_embeddings().weight.dtype | |
| embeds = embeds.to(dtype) | |
| # Add the token in tokenizer | |
| token = token if token is not None else trained_token | |
| num_added_tokens = tokenizer.add_tokens(token) | |
| # Resize the token embeddings | |
| text_encoder.resize_token_embeddings(len(tokenizer)) | |
| # Get the id for the token and assign the embeds | |
| token_id = tokenizer.convert_tokens_to_ids(token) | |
| text_encoder.get_input_embeddings().weight.data[token_id] = embeds[0] | |
| return token | |
| def Distance_loss(images): | |
| # Ensure we're working with gradients | |
| if not images.requires_grad: | |
| images = images.detach().requires_grad_(True) | |
| # Convert to float32 and normalize | |
| images = images.float() / 2 + 0.5 | |
| # Get RGB channels | |
| red = images[:,0:1] | |
| green = images[:,1:2] | |
| blue = images[:,2:3] | |
| # Calculate color distances using L2 norm | |
| rg_distance = ((red - green) ** 2).mean() | |
| rb_distance = ((red - blue) ** 2).mean() | |
| gb_distance = ((green - blue) ** 2).mean() | |
| return (rg_distance + rb_distance + gb_distance) * 100 # Scale up the loss | |
| class StyleGenerator: | |
| _instance = None | |
| def get_instance(cls): | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| def __init__(self): | |
| self.pipe = None | |
| self.style_tokens = [] | |
| self.styles = [ | |
| "ronaldo", | |
| "canna-lily-flowers102", | |
| "threestooges", | |
| "pop_art", | |
| "bird_style" | |
| ] | |
| self.style_names = [ | |
| "Ronaldo", | |
| "Canna Lily", | |
| "Three Stooges", | |
| "Pop Art", | |
| "Bird Style" | |
| ] | |
| self.is_initialized = False | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if self.device == "cpu": | |
| print("NVIDIA GPU not found. Running on CPU (this will be slower)") | |
| def initialize_model(self): | |
| if self.is_initialized: | |
| return | |
| try: | |
| print("Initializing Stable Diffusion model...") | |
| model_id = "runwayml/stable-diffusion-v1-5" | |
| self.pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| safety_checker=None | |
| ) | |
| self.pipe = self.pipe.to(self.device) | |
| # Load style embeddings from current directory | |
| current_dir = Path(__file__).parent | |
| for style, style_name in zip(self.styles, self.style_names): | |
| style_path = current_dir / f"{style}.bin" | |
| if not style_path.exists(): | |
| raise FileNotFoundError(f"Style embedding not found: {style_path}") | |
| print(f"Loading style: {style_name}") | |
| token = load_learned_embed_in_clip(str(style_path), self.pipe.text_encoder, self.pipe.tokenizer) | |
| self.style_tokens.append(token) | |
| print(f"β Loaded style: {style_name}") | |
| self.is_initialized = True | |
| print(f"Model initialization complete! Using device: {self.device}") | |
| except Exception as e: | |
| print(f"Error during initialization: {str(e)}") | |
| print(traceback.format_exc()) | |
| raise | |
| def generate_single_style(self, prompt, selected_style): | |
| try: | |
| # Find the index of the selected style | |
| style_idx = self.style_names.index(self.style_names[selected_style]) | |
| # Generate single image with selected style | |
| styled_prompt = f"{prompt}, {self.style_tokens[style_idx]}" | |
| # Set seed for reproducibility | |
| generator_seed = 42 | |
| torch.manual_seed(generator_seed) | |
| if self.device == "cuda": | |
| torch.cuda.manual_seed(generator_seed) | |
| # Generate base image | |
| with autocast(self.device): | |
| base_image = self.pipe( | |
| styled_prompt, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| generator=torch.Generator(self.device).manual_seed(generator_seed) | |
| ).images[0] | |
| # Generate same image with loss | |
| with autocast(self.device): | |
| loss_image = self.pipe( | |
| styled_prompt, | |
| num_inference_steps=50, | |
| guidance_scale=7.5, | |
| callback=self.callback_fn, | |
| callback_steps=5, | |
| generator=torch.Generator(self.device).manual_seed(generator_seed) | |
| ).images[0] | |
| return base_image, loss_image | |
| except Exception as e: | |
| print(f"Error in generate_single_style: {e}") | |
| raise | |
| def callback_fn(self, i, t, latents): | |
| if i % 5 == 0: # Apply loss every 5 steps | |
| try: | |
| # Create a copy that requires gradients | |
| latents_copy = latents.detach().clone() | |
| latents_copy.requires_grad_(True) | |
| # Compute loss | |
| loss = Distance_loss(latents_copy) | |
| # Compute gradients | |
| if loss.requires_grad: | |
| grads = torch.autograd.grad( | |
| outputs=loss, | |
| inputs=latents_copy, | |
| allow_unused=True, | |
| retain_graph=False | |
| )[0] | |
| if grads is not None: | |
| # Apply gradients to original latents | |
| return latents - 0.1 * grads.detach() | |
| except Exception as e: | |
| print(f"Error in callback: {e}") | |
| return latents | |
| def generate_single_style(prompt, selected_style): | |
| try: | |
| generator = StyleGenerator.get_instance() | |
| if not generator.is_initialized: | |
| generator.initialize_model() | |
| base_image, loss_image = generator.generate_single_style(prompt, selected_style) | |
| return [ | |
| gr.update(visible=False), # error_message | |
| base_image, # original_image | |
| loss_image # loss_image | |
| ] | |
| except Exception as e: | |
| print(f"Error in generate_single_style: {e}") | |
| return [ | |
| gr.update(value=f"Error: {str(e)}", visible=True), # error_message | |
| None, # original_image | |
| None # loss_image | |
| ] | |
| # Add at the start of your script | |
| def debug_image_paths(): | |
| output_dir = Path("Outputs") | |
| enhanced_dir = output_dir / "Color_Enhanced" | |
| print(f"\nChecking image paths:") | |
| print(f"Current working directory: {Path.cwd()}") | |
| print(f"Looking for images in: {enhanced_dir.absolute()}") | |
| if enhanced_dir.exists(): | |
| print("\nFound files:") | |
| for file in enhanced_dir.glob("*.webp"): | |
| print(f"- {file.name}") | |
| else: | |
| print("\nDirectory not found!") | |
| # Call this function before creating the interface | |
| debug_image_paths() | |
| # Create a more beautiful interface with custom styling | |
| with gr.Blocks(css=""" | |
| .gradio-container { | |
| background-color: #1f2937 !important; | |
| } | |
| .dark-theme { | |
| background-color: #111827; | |
| border-radius: 10px; | |
| padding: 20px; | |
| margin: 10px; | |
| border: 1px solid #374151; | |
| color: #f3f4f6; | |
| } | |
| /* Enhanced Tab Styling */ | |
| .tabs.svelte-710i53 { | |
| margin-bottom: 0 !important; | |
| } | |
| .tab-nav.svelte-710i53 { | |
| background: transparent !important; | |
| border: none !important; | |
| padding: 12px 24px !important; | |
| margin: 0 2px !important; | |
| color: #9CA3AF !important; | |
| font-weight: 500 !important; | |
| transition: all 0.2s ease !important; | |
| border-bottom: 2px solid transparent !important; | |
| } | |
| .tab-nav.svelte-710i53.selected { | |
| background: transparent !important; | |
| color: #F3F4F6 !important; | |
| border-bottom: 2px solid #6366F1 !important; | |
| } | |
| .tab-nav.svelte-710i53:hover { | |
| color: #F3F4F6 !important; | |
| border-bottom: 2px solid #4F46E5 !important; | |
| } | |
| """) as iface: | |
| # Header section | |
| gr.Markdown( | |
| """ | |
| <div class="dark-theme" style="text-align: center;"> | |
| # π¨ AI Style Transfer Studio | |
| ### Transform your ideas into artistic masterpieces | |
| </div> | |
| """ | |
| ) | |
| # Controls section | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## π― Controls") | |
| prompt = gr.Textbox( | |
| label="What would you like to create?", | |
| placeholder="e.g., a soccer player celebrating a goal", | |
| lines=3 | |
| ) | |
| style_radio = gr.Radio( | |
| choices=[ | |
| "Ronaldo Style", | |
| "Canna Lily", | |
| "Three Stooges", | |
| "Pop Art", | |
| "Bird Style" | |
| ], | |
| label="Choose Your Style", | |
| value="Ronaldo Style", | |
| type="index" | |
| ) | |
| generate_btn = gr.Button( | |
| "π Generate Artwork", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| error_message = gr.Markdown(visible=False) | |
| style_description = gr.Markdown() | |
| # Generated Images | |
| with gr.Row(): | |
| with gr.Column(): | |
| original_image = gr.Image( | |
| label="Original Style", | |
| show_label=True, | |
| height=300 | |
| ) | |
| with gr.Column(): | |
| loss_image = gr.Image( | |
| label="Color Enhanced", | |
| show_label=True, | |
| height=300 | |
| ) | |
| # Example Gallery | |
| gr.Markdown( | |
| """ | |
| <div class="dark-theme"> | |
| ## π Example Gallery | |
| Compare original and enhanced versions for each style: | |
| </div> | |
| """ | |
| ) | |
| # Example Images | |
| with gr.Row(): | |
| try: | |
| output_dir = Path("Outputs") | |
| original_dir = output_dir | |
| enhanced_dir = output_dir / "Color_Enhanced" | |
| if enhanced_dir.exists(): | |
| original_images = { | |
| Path(f).stem.split('_example')[0]: f | |
| for f in original_dir.glob("*.webp") | |
| if '_example' in f.name | |
| } | |
| enhanced_images = { | |
| Path(f).stem.split('_example')[0]: f | |
| for f in enhanced_dir.glob("*.webp") | |
| if '_example' in f.name | |
| } | |
| styles = [ | |
| ("ronaldo", "Ronaldo Style"), | |
| ("canna_lily", "Canna Lily"), | |
| ("three_stooges", "Three Stooges"), | |
| ("pop_art", "Pop Art"), | |
| ("bird_style", "Bird Style") | |
| ] | |
| # Create a grid of all styles | |
| for style_key, style_name in styles: | |
| if style_key in original_images and style_key in enhanced_images: | |
| with gr.Row(): | |
| gr.Markdown(f"### {style_name}") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Image( | |
| value=str(original_images[style_key]), | |
| label="Original", | |
| show_label=True, | |
| height=180 | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Image( | |
| value=str(enhanced_images[style_key]), | |
| label="Color Enhanced", | |
| show_label=True, | |
| height=180 | |
| ) | |
| # Add a small spacing between styles | |
| gr.Markdown("<div style='margin: 10px 0;'></div>") | |
| except Exception as e: | |
| print(f"Error in example gallery: {e}") | |
| gr.Markdown(f"Error loading example gallery: {str(e)}") | |
| # Info section | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| <div class="dark-theme"> | |
| ## π¨ Style Guide | |
| | Style | Best For | | |
| |-------|----------| | |
| | **Ronaldo Style** | Dynamic sports scenes, action shots, celebrations | | |
| | **Canna Lily** | Natural scenes, floral compositions, garden imagery | | |
| | **Three Stooges** | Comedy, humor, expressive character portraits | | |
| | **Pop Art** | Vibrant artwork, bold colors, stylized designs | | |
| | **Bird Style** | Wildlife, nature scenes, peaceful landscapes | | |
| *Choose the style that best matches your creative vision* | |
| </div> | |
| """ | |
| ) | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| <div class="dark-theme"> | |
| ## π Color Enhancement Technology | |
| Our advanced color processing uses distance loss to enhance your images: | |
| ### π Color Dynamics | |
| - **Vibrancy**: Intensifies colors naturally | |
| - **Contrast**: Improves depth and definition | |
| - **Balance**: Optimizes color relationships | |
| ### π¨ Technical Features | |
| - **Channel Separation**: RGB optimization | |
| - **Loss Function**: Mathematical color enhancement | |
| - **Real-time Processing**: Dynamic adjustments | |
| ### β¨ Benefits | |
| - Richer, more vivid colors | |
| - Clearer color boundaries | |
| - Reduced color muddiness | |
| - Enhanced artistic impact | |
| <small>*Our color distance loss technology mathematically optimizes RGB channel relationships*</small> | |
| </div> | |
| """ | |
| ) | |
| # Update style description on change | |
| def update_style_description(style_idx): | |
| descriptions = [ | |
| "Perfect for capturing dynamic sports moments and celebrations", | |
| "Ideal for creating beautiful natural and floral compositions", | |
| "Great for adding humor and expressiveness to your scenes", | |
| "Transform your ideas into vibrant pop art masterpieces", | |
| "Specialized in capturing the beauty of nature and wildlife" | |
| ] | |
| styles = ["Ronaldo Style", "Canna Lily", "Three Stooges", "Pop Art", "Bird Style"] | |
| return f"### Selected Style: {styles[style_idx]}\n{descriptions[style_idx]}" | |
| style_radio.change( | |
| fn=update_style_description, | |
| inputs=style_radio, | |
| outputs=style_description | |
| ) | |
| generate_btn.click( | |
| fn=generate_single_style, | |
| inputs=[prompt, style_radio], | |
| outputs=[error_message, original_image, loss_image] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| iface.launch( | |
| share=True, | |
| show_error=True | |
| ) |