HSinghHuggingFace commited on
Commit
2c843c7
Β·
1 Parent(s): eff7e87

stable diffusion image generator

Browse files
README.md CHANGED
@@ -5,10 +5,94 @@ colorFrom: blue
5
  colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.42.2
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
- short_description: Stable-Diffusion-Image-Generator
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  colorTo: purple
6
  sdk: streamlit
7
  sdk_version: 1.42.2
8
+ app_file: src/app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Transform your ideas into artistic masterpieces using Stable Diffusion with custom style embeddings
12
  ---
13
 
14
+ # 🎨 AI Style Transfer Studio
15
+
16
+ Transform your ideas into artistic masterpieces using Stable Diffusion with custom style embeddings.
17
+
18
+ ## πŸš€ Features
19
+
20
+ - Multiple pre-trained style embeddings (Dhoni, Mickey Mouse, Balloon, Lion King, Rose Flower)
21
+ - Advanced color enhancement technology
22
+ - User-friendly Streamlit interface
23
+ - Real-time image generation
24
+ - Example gallery with style comparisons
25
+
26
+ ## πŸ› οΈ Local Setup
27
+
28
+ 1. Clone the repository:
29
+ ```bash
30
+ git clone https://github.com/yourusername/stable-diffusion-image-generator.git
31
+ cd stable-diffusion-image-generator
32
+ ```
33
+
34
+ 2. Create and activate a virtual environment (recommended):
35
+ ```bash
36
+ python -m venv venv
37
+ # On Windows
38
+ venv\Scripts\activate
39
+ # On Unix or MacOS
40
+ source venv/bin/activate
41
+ ```
42
+
43
+ 3. Install dependencies:
44
+ ```bash
45
+ pip install -r requirements.txt
46
+ ```
47
+
48
+ 4. Run the Streamlit app:
49
+ ```bash
50
+ streamlit run src/app.py
51
+ ```
52
+
53
+ The app will open in your default web browser at `http://localhost:8501`
54
+
55
+ ## 🌐 Deploying to Hugging Face Spaces
56
+
57
+ 1. Create a new Space on Hugging Face:
58
+ - Go to https://huggingface.co/spaces
59
+ - Click "Create new Space"
60
+ - Choose "Streamlit" as the SDK
61
+ - Set the Space name and visibility
62
+
63
+ 2. Push your code to Hugging Face:
64
+ ```bash
65
+ git add .
66
+ git commit -m "Initial commit"
67
+ git remote add space https://huggingface.co/spaces/yourusername/your-space-name
68
+ git push space main
69
+ ```
70
+
71
+ 3. The deployment will start automatically. Monitor the build logs on your Space's page.
72
+
73
+ ## 🎯 Usage
74
+
75
+ 1. Enter your creative prompt in the text area
76
+ 2. Select a style from the available options
77
+ 3. Click "Generate Artwork"
78
+ 4. View both the original and color-enhanced versions of your creation
79
+
80
+ ## πŸ“ Requirements
81
+
82
+ - Python 3.8+
83
+ - CUDA-capable GPU (recommended)
84
+ - 8GB+ RAM
85
+
86
+ ## πŸ”‘ Environment Variables
87
+
88
+ No additional environment variables are required for basic usage.
89
+
90
+ ## πŸ“„ License
91
+
92
+ This project is licensed under the Apache 2.0 License.
93
+
94
+ ## πŸ™ Acknowledgments
95
+
96
+ - [Stable Diffusion](https://github.com/CompVis/stable-diffusion) for the base model
97
+ - [Hugging Face](https://huggingface.co/) for model hosting and Spaces
98
+ - [Streamlit](https://streamlit.io/) for the web interface
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ diffusers>=0.19.0
3
+ transformers>=4.30.0
4
+ accelerator>=0.21.0
5
+ streamlit>=1.24.0
6
+ Pillow>=9.5.0
7
+ numpy>=1.24.0
8
+ pathlib>=1.0.1
9
+ tqdm>=4.65.0
10
+ huggingface-hub>=0.16.0
src/app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from utils.style_generator import StyleTransfer
3
+ from utils.ui_components import (
4
+ setup_page_config,
5
+ apply_custom_css,
6
+ render_header,
7
+ render_controls,
8
+ render_image_columns,
9
+ render_example_gallery,
10
+ render_info_sections
11
+ )
12
+
13
+ # Initialize the application
14
+ setup_page_config()
15
+ apply_custom_css()
16
+ render_header()
17
+
18
+ # Initialize session state
19
+ if 'generator' not in st.session_state:
20
+ st.session_state.generator = StyleTransfer.get_instance()
21
+ if not st.session_state.generator.is_initialized:
22
+ st.session_state.generator.initialize_pipeline()
23
+
24
+ # Render controls and handle user input
25
+ prompt, selected_style = render_controls(st.session_state.generator.style_names)
26
+
27
+ if st.sidebar.button("πŸš€ Generate Artwork", use_container_width=True):
28
+ if prompt:
29
+ try:
30
+ with st.spinner("Generating your artwork..."):
31
+ base_image, enhanced_image = st.session_state.generator.generate_artwork(prompt, selected_style)
32
+
33
+ # Store images in session state
34
+ st.session_state.base_image = base_image
35
+ st.session_state.enhanced_image = enhanced_image
36
+ except Exception as e:
37
+ st.error(f"Error: {str(e)}")
38
+ else:
39
+ st.warning("Please enter a prompt first!")
40
+
41
+ # Display generated images
42
+ render_image_columns(
43
+ base_image=st.session_state.get('base_image'),
44
+ enhanced_image=st.session_state.get('enhanced_image')
45
+ )
46
+
47
+ # Render example gallery and information sections
48
+ render_example_gallery()
49
+ render_info_sections()
src/utils/style_generator.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline
3
+ from torch import autocast
4
+ from pathlib import Path
5
+ import traceback
6
+
7
+ class StyleTransfer:
8
+ _instance = None
9
+
10
+ @classmethod
11
+ def get_instance(cls):
12
+ if cls._instance is None:
13
+ cls._instance = cls()
14
+ return cls._instance
15
+
16
+ def __init__(self):
17
+ self.pipeline = None
18
+ self.style_tokens = []
19
+ self.styles = [
20
+ "dhoni",
21
+ "mickey_mouse",
22
+ "balloon",
23
+ "lion_king",
24
+ "rose_flower"
25
+ ]
26
+ self.style_names = [
27
+ "Dhoni Style",
28
+ "Mickey Mouse Style",
29
+ "Balloon Style",
30
+ "Lion King Style",
31
+ "Rose Flower Style"
32
+ ]
33
+ self.is_initialized = False
34
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
35
+ if self.device == "cpu":
36
+ print("NVIDIA GPU not found. Running on CPU (this will be slower)")
37
+
38
+ def initialize_pipeline(self):
39
+ if self.is_initialized:
40
+ return
41
+
42
+ try:
43
+ print("Initializing Stable Diffusion model...")
44
+ model_id = "runwayml/stable-diffusion-v1-5"
45
+ self.pipeline = StableDiffusionPipeline.from_pretrained(
46
+ model_id,
47
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
48
+ safety_checker=None
49
+ )
50
+ self.pipeline = self.pipeline.to(self.device)
51
+
52
+ # Load style embeddings from current directory
53
+ current_dir = Path(__file__).parent.parent
54
+
55
+ for style, style_name in zip(self.styles, self.style_names):
56
+ style_path = current_dir / f"{style}.bin"
57
+ if not style_path.exists():
58
+ raise FileNotFoundError(f"Style embedding not found: {style_path}")
59
+
60
+ print(f"Loading style: {style_name}")
61
+ token = self._load_style_embedding(str(style_path))
62
+ self.style_tokens.append(token)
63
+ print(f"βœ“ Loaded style: {style_name}")
64
+
65
+ self.is_initialized = True
66
+ print(f"Model initialization complete! Using device: {self.device}")
67
+
68
+ except Exception as e:
69
+ print(f"Error during initialization: {str(e)}")
70
+ print(traceback.format_exc())
71
+ raise
72
+
73
+ def _load_style_embedding(self, embedding_path, token=None):
74
+ loaded_embeds = torch.load(embedding_path, map_location="cpu")
75
+ trained_token = list(loaded_embeds.keys())[0]
76
+ embeds = loaded_embeds[trained_token]
77
+
78
+ # Get the expected dimension from the text encoder
79
+ expected_dim = self.pipeline.text_encoder.get_input_embeddings().weight.shape[1]
80
+ current_dim = embeds.shape[0]
81
+
82
+ # Resize embeddings if dimensions don't match
83
+ if current_dim != expected_dim:
84
+ print(f"Resizing embedding from {current_dim} to {expected_dim}")
85
+ if current_dim > expected_dim:
86
+ embeds = embeds[:expected_dim]
87
+ else:
88
+ embeds = torch.cat([embeds, torch.zeros(expected_dim - current_dim)], dim=0)
89
+
90
+ # Reshape to match expected dimensions
91
+ embeds = embeds.unsqueeze(0) # Add batch dimension
92
+
93
+ # Cast to dtype of text_encoder
94
+ dtype = self.pipeline.text_encoder.get_input_embeddings().weight.dtype
95
+ embeds = embeds.to(dtype)
96
+
97
+ # Add the token in tokenizer
98
+ token = token if token is not None else trained_token
99
+ self.pipeline.tokenizer.add_tokens(token)
100
+
101
+ # Resize the token embeddings
102
+ self.pipeline.text_encoder.resize_token_embeddings(len(self.pipeline.tokenizer))
103
+
104
+ # Get the id for the token and assign the embeds
105
+ token_id = self.pipeline.tokenizer.convert_tokens_to_ids(token)
106
+ self.pipeline.text_encoder.get_input_embeddings().weight.data[token_id] = embeds[0]
107
+ return token
108
+
109
+ def generate_artwork(self, prompt, selected_style):
110
+ try:
111
+ # Find the index of the selected style
112
+ style_idx = self.style_names.index(selected_style)
113
+
114
+ # Generate single image with selected style
115
+ styled_prompt = f"{prompt}, {self.style_tokens[style_idx]}"
116
+
117
+ # Set seed for reproducibility
118
+ generator_seed = 42
119
+ torch.manual_seed(generator_seed)
120
+ if self.device == "cuda":
121
+ torch.cuda.manual_seed(generator_seed)
122
+
123
+ # Generate base image
124
+ with autocast(self.device):
125
+ base_image = self.pipeline(
126
+ styled_prompt,
127
+ num_inference_steps=50,
128
+ guidance_scale=7.5,
129
+ generator=torch.Generator(self.device).manual_seed(generator_seed)
130
+ ).images[0]
131
+
132
+ # Generate same image with color enhancement
133
+ with autocast(self.device):
134
+ enhanced_image = self.pipeline(
135
+ styled_prompt,
136
+ num_inference_steps=50,
137
+ guidance_scale=7.5,
138
+ callback=self._enhance_colors,
139
+ callback_steps=5,
140
+ generator=torch.Generator(self.device).manual_seed(generator_seed)
141
+ ).images[0]
142
+
143
+ return base_image, enhanced_image
144
+
145
+ except Exception as e:
146
+ print(f"Error in generate_artwork: {e}")
147
+ raise
148
+
149
+ def _enhance_colors(self, i, t, latents):
150
+ if i % 5 == 0: # Apply enhancement every 5 steps
151
+ try:
152
+ # Create a copy that requires gradients
153
+ latents_copy = latents.detach().clone()
154
+ latents_copy.requires_grad_(True)
155
+
156
+ # Compute color distance loss
157
+ loss = self._calculate_color_distance(latents_copy)
158
+
159
+ # Compute gradients
160
+ if loss.requires_grad:
161
+ grads = torch.autograd.grad(
162
+ outputs=loss,
163
+ inputs=latents_copy,
164
+ allow_unused=True,
165
+ retain_graph=False
166
+ )[0]
167
+
168
+ if grads is not None:
169
+ # Apply gradients to original latents
170
+ return latents - 0.1 * grads.detach()
171
+
172
+ except Exception as e:
173
+ print(f"Error in color enhancement: {e}")
174
+
175
+ return latents
176
+
177
+ def _calculate_color_distance(self, images):
178
+ # Ensure we're working with gradients
179
+ if not images.requires_grad:
180
+ images = images.detach().requires_grad_(True)
181
+
182
+ # Convert to float32 and normalize
183
+ images = images.float() / 2 + 0.5
184
+
185
+ # Get RGB channels
186
+ red = images[:,0:1]
187
+ green = images[:,1:2]
188
+ blue = images[:,2:3]
189
+
190
+ # Calculate color distances using L2 norm
191
+ rg_distance = ((red - green) ** 2).mean()
192
+ rb_distance = ((red - blue) ** 2).mean()
193
+ gb_distance = ((green - blue) ** 2).mean()
194
+
195
+ return (rg_distance + rb_distance + gb_distance) * 100 # Scale up the loss
src/utils/ui_components.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pathlib import Path
3
+
4
+ def setup_page_config():
5
+ st.set_page_config(
6
+ page_title="AI Style Transfer Studio",
7
+ page_icon="🎨",
8
+ layout="wide"
9
+ )
10
+
11
+ def apply_custom_css():
12
+ st.markdown("""
13
+ <style>
14
+ .stApp {
15
+ background-color: #1f2937;
16
+ }
17
+ .stMarkdown {
18
+ color: #f3f4f6;
19
+ }
20
+ .stButton > button {
21
+ background-color: #6366F1;
22
+ color: white;
23
+ }
24
+ .stButton > button:hover {
25
+ background-color: #4F46E5;
26
+ }
27
+ .dark-theme {
28
+ background-color: #111827;
29
+ border-radius: 10px;
30
+ padding: 20px;
31
+ margin: 10px 0;
32
+ border: 1px solid #374151;
33
+ }
34
+ </style>
35
+ """, unsafe_allow_html=True)
36
+
37
+ def render_header():
38
+ st.markdown("""
39
+ <div class="dark-theme" style="text-align: center;">
40
+ <h1>🎨 AI Style Transfer Studio</h1>
41
+ <h3>Transform your ideas into artistic masterpieces</h3>
42
+ </div>
43
+ """, unsafe_allow_html=True)
44
+
45
+ def render_controls(style_names):
46
+ with st.sidebar:
47
+ st.markdown("## 🎯 Controls")
48
+
49
+ prompt = st.text_area(
50
+ "What would you like to create?",
51
+ placeholder="e.g., a soccer player celebrating a goal",
52
+ height=100
53
+ )
54
+
55
+ selected_style = st.radio(
56
+ "Choose Your Style",
57
+ style_names,
58
+ index=0
59
+ )
60
+
61
+ return prompt, selected_style
62
+
63
+ def render_image_columns(base_image=None, enhanced_image=None):
64
+ col1, col2 = st.columns(2)
65
+
66
+ with col1:
67
+ st.markdown("### Original Style")
68
+ if base_image:
69
+ st.image(base_image, use_column_width=True)
70
+
71
+ with col2:
72
+ st.markdown("### Color Enhanced")
73
+ if enhanced_image:
74
+ st.image(enhanced_image, use_column_width=True)
75
+
76
+ def render_example_gallery():
77
+ st.markdown("""
78
+ <div class="dark-theme">
79
+ <h2>πŸŽ† Example Gallery</h2>
80
+ <p>Compare original and enhanced versions for each style:</p>
81
+ </div>
82
+ """, unsafe_allow_html=True)
83
+
84
+ try:
85
+ output_dir = Path("Outputs")
86
+ original_dir = output_dir
87
+ enhanced_dir = output_dir / "Color_Enhanced"
88
+
89
+ if enhanced_dir.exists():
90
+ original_images = {
91
+ Path(f).stem.split('_example')[0]: f
92
+ for f in original_dir.glob("*.webp")
93
+ if '_example' in f.name
94
+ }
95
+ enhanced_images = {
96
+ Path(f).stem.split('_example')[0]: f
97
+ for f in enhanced_dir.glob("*.webp")
98
+ if '_example' in f.name
99
+ }
100
+
101
+ styles = [
102
+ ("ronaldo", "Ronaldo Style"),
103
+ ("canna_lily", "Canna Lily"),
104
+ ("three_stooges", "Three Stooges"),
105
+ ("pop_art", "Pop Art"),
106
+ ("bird_style", "Bird Style")
107
+ ]
108
+
109
+ for style_key, style_name in styles:
110
+ if style_key in original_images and style_key in enhanced_images:
111
+ st.markdown(f"### {style_name}")
112
+ col1, col2 = st.columns(2)
113
+
114
+ with col1:
115
+ st.image(
116
+ str(original_images[style_key]),
117
+ caption="Original",
118
+ use_column_width=True
119
+ )
120
+ with col2:
121
+ st.image(
122
+ str(enhanced_images[style_key]),
123
+ caption="Color Enhanced",
124
+ use_column_width=True
125
+ )
126
+ st.markdown("<hr>", unsafe_allow_html=True)
127
+
128
+ except Exception as e:
129
+ st.error(f"Error loading example gallery: {str(e)}")
130
+
131
+ def render_info_sections():
132
+ col1, col2 = st.columns(2)
133
+
134
+ with col1:
135
+ st.markdown("""
136
+ <div class="dark-theme">
137
+ <h2>🎨 Style Guide</h2>
138
+ <table>
139
+ <tr>
140
+ <th>Style</th>
141
+ <th>Best For</th>
142
+ </tr>
143
+ <tr>
144
+ <td><strong>Dhoni Style</strong></td>
145
+ <td>Cricket scenes, sports action, victory celebrations</td>
146
+ </tr>
147
+ <tr>
148
+ <td><strong>Mickey Mouse Style</strong></td>
149
+ <td>Cartoon characters, playful scenes, whimsical art</td>
150
+ </tr>
151
+ <tr>
152
+ <td><strong>Balloon Style</strong></td>
153
+ <td>Festive scenes, colorful celebrations, light and airy compositions</td>
154
+ </tr>
155
+ <tr>
156
+ <td><strong>Lion King Style</strong></td>
157
+ <td>Animal portraits, majestic scenes, dramatic landscapes</td>
158
+ </tr>
159
+ <tr>
160
+ <td><strong>Rose Flower Style</strong></td>
161
+ <td>Floral art, romantic scenes, delicate compositions</td>
162
+ </tr>
163
+ </table>
164
+ <em>Choose the style that best matches your creative vision</em>
165
+ </div>
166
+ """, unsafe_allow_html=True)
167
+
168
+ with col2:
169
+ st.markdown("""
170
+ <div class="dark-theme">
171
+ <h2>πŸ” Color Enhancement Technology</h2>
172
+ <p>Our advanced color processing uses distance loss to maximize the distinction between color channels,
173
+ resulting in more vibrant and visually striking images. This technique helps to:</p>
174
+ <ul>
175
+ <li>Enhance color separation</li>
176
+ <li>Improve visual contrast</li>
177
+ <li>Create more dynamic compositions</li>
178
+ <li>Preserve artistic style while boosting vibrancy</li>
179
+ </ul>
180
+ </div>
181
+ """, unsafe_allow_html=True)
style_embeddings/balloon.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5916ba4a9c011cb7f04df4501b20307b05b115c1aafacd538439db055790e6e1
3
+ size 151785628
style_embeddings/dhoni.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb3894eb1e73b4ee7b22806c4bc74dd1177188e3282d8fe7968aa281de8b2119
3
+ size 151785554
style_embeddings/lion_king.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a7a97e656141710692e65655a6992ddfa783c08f3e80584c4ed4933a8a3471b
3
+ size 151785638
style_embeddings/mickey_mouse.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b576e7a808d880786b0c155e249c18473512ae3c16a6fe23419f586247c2406
3
+ size 151785717
style_embeddings/rose_flower.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88020c89e2fb6ee4e1d89eb55f08ee9762850f46c3b7d6f19dc665ba961aad6c
3
+ size 151785712