yashwantram commited on
Commit
4ce72cc
·
verified ·
1 Parent(s): 83e6a77

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +374 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
6
+ from transformers import CLIPTextModel, CLIPTokenizer
7
+ from tqdm.auto import tqdm
8
+ import os
9
+
10
+ # Set device
11
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Load models
14
+ print("Loading models...")
15
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
16
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
17
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
18
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
19
+
20
+ vae = vae.to(torch_device)
21
+ text_encoder = text_encoder.to(torch_device)
22
+ unet = unet.to(torch_device)
23
+
24
+ # Scheduler
25
+ scheduler = LMSDiscreteScheduler(
26
+ beta_start=0.00085,
27
+ beta_end=0.012,
28
+ beta_schedule="scaled_linear",
29
+ num_train_timesteps=1000
30
+ )
31
+
32
+ # Style embeddings mapping (only 768-dimensional embeddings compatible with SD 1.4)
33
+ STYLE_EMBEDDINGS = {
34
+ "Bird Style": ("learned_embeds/bird-learned_embeds.bin", "<birb-style>"),
35
+ "Shigure UI Art": ("learned_embeds/shigure-ui-learned_embeds.bin", "<shigure-ui>"),
36
+ "Takuji Kawano Art": ("learned_embeds/takuji-kawano-learned_embeds.bin", "<takuji-kawano>"),
37
+ }
38
+
39
+ # Track which embeddings have been loaded
40
+ loaded_tokens = set()
41
+
42
+ def load_learned_embed_in_clip(learned_embeds_path, text_encoder, tokenizer, token):
43
+ """Load learned embedding into the text encoder (only once per token)"""
44
+ global loaded_tokens
45
+
46
+ # Skip if already loaded
47
+ if token in loaded_tokens:
48
+ return token
49
+
50
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
51
+
52
+ # Get the embedding
53
+ if isinstance(loaded_learned_embeds, dict):
54
+ if token in loaded_learned_embeds:
55
+ trained_token = loaded_learned_embeds[token]
56
+ else:
57
+ # Take the first embedding
58
+ trained_token = list(loaded_learned_embeds.values())[0]
59
+ else:
60
+ trained_token = loaded_learned_embeds
61
+
62
+ # Verify dimensions match (768 for SD 1.4)
63
+ if trained_token.shape[0] != text_encoder.get_input_embeddings().weight.shape[1]:
64
+ raise ValueError(
65
+ f"Embedding dimension mismatch: {trained_token.shape[0]} vs "
66
+ f"{text_encoder.get_input_embeddings().weight.shape[1]}. "
67
+ f"This embedding is not compatible with SD 1.4."
68
+ )
69
+
70
+ # Add token to tokenizer
71
+ num_added_tokens = tokenizer.add_tokens(token)
72
+
73
+ # Resize token embeddings if we added a new token
74
+ if num_added_tokens > 0:
75
+ text_encoder.resize_token_embeddings(len(tokenizer))
76
+
77
+ # Get token id
78
+ token_id = tokenizer.convert_tokens_to_ids(token)
79
+
80
+ # Set the embedding
81
+ text_encoder.get_input_embeddings().weight.data[token_id] = trained_token
82
+
83
+ # Mark as loaded
84
+ loaded_tokens.add(token)
85
+
86
+ return token
87
+
88
+ def neon_cyberpunk_loss(img):
89
+ """
90
+ Custom loss to guide generation toward neon cyberpunk aesthetic:
91
+ - Vibrant neon colors (cyan, magenta, purple, pink)
92
+ - High saturation and contrast
93
+ - Dark backgrounds with bright highlights
94
+ - Futuristic vibe
95
+ """
96
+ # Extract RGB channels
97
+ r = img[:, 0]
98
+ g = img[:, 1]
99
+ b = img[:, 2]
100
+
101
+ # 1. Boost Neon Colors (Cyan, Magenta, Purple tones)
102
+ # Cyan: high G and B, low R
103
+ cyan_score = (g + b - r).clamp(0, 1).mean()
104
+ # Magenta: high R and B, low G
105
+ magenta_score = (r + b - g).clamp(0, 1).mean()
106
+ # Purple/Pink: high R and B
107
+ purple_score = (r * b).mean()
108
+
109
+ # Maximize neon color presence
110
+ neon_color_loss = -(cyan_score + magenta_score + purple_score) / 3
111
+
112
+ # 2. Increase Saturation (difference between channels)
113
+ saturation = torch.stack([r, g, b], dim=1).std(dim=1).mean()
114
+ saturation_loss = -saturation # maximize saturation
115
+
116
+ # 3. High Contrast (bright highlights on dark backgrounds)
117
+ contrast = img.std()
118
+ contrast_loss = -contrast # maximize contrast
119
+
120
+ # 4. Boost brightness of bright areas (neon glow effect)
121
+ brightness_mask = (img.mean(dim=1, keepdim=True) > 0.5).float()
122
+ bright_areas = (img * brightness_mask).mean()
123
+ brightness_loss = -bright_areas # maximize brightness in bright areas
124
+
125
+ # 5. Darken dark areas (cyberpunk has dark backgrounds)
126
+ dark_mask = (img.mean(dim=1, keepdim=True) < 0.5).float()
127
+ dark_areas = (img * dark_mask).mean()
128
+ darkness_loss = dark_areas # minimize brightness in dark areas
129
+
130
+ # Weighted combination for maximum visual impact
131
+ total = (
132
+ 2.0 * neon_color_loss + # Strong emphasis on neon colors
133
+ 1.5 * saturation_loss + # High saturation
134
+ 1.0 * contrast_loss + # Strong contrast
135
+ 0.8 * brightness_loss + # Bright neon highlights
136
+ 0.5 * darkness_loss # Dark backgrounds
137
+ )
138
+
139
+ return total
140
+
141
+ def generate_image(
142
+ prompt,
143
+ style_name,
144
+ seed,
145
+ apply_loss=False,
146
+ loss_scale=200,
147
+ height=512,
148
+ width=512,
149
+ num_inference_steps=50,
150
+ guidance_scale=8
151
+ ):
152
+ """Generate image with optional neon cyberpunk loss"""
153
+
154
+ # Load the style embedding
155
+ if style_name in STYLE_EMBEDDINGS:
156
+ embed_path, token_name = STYLE_EMBEDDINGS[style_name]
157
+ if os.path.exists(embed_path):
158
+ token = load_learned_embed_in_clip(embed_path, text_encoder, tokenizer, token=token_name)
159
+ # Add token to prompt
160
+ prompt = f"{prompt} in the style of {token}"
161
+
162
+ # Set seed
163
+ generator = torch.manual_seed(seed)
164
+
165
+ # Prepare text embeddings
166
+ text_input = tokenizer(
167
+ [prompt],
168
+ padding="max_length",
169
+ max_length=tokenizer.model_max_length,
170
+ truncation=True,
171
+ return_tensors="pt"
172
+ )
173
+
174
+ with torch.no_grad():
175
+ text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
176
+
177
+ # Unconditional embeddings for classifier-free guidance
178
+ max_length = text_input.input_ids.shape[-1]
179
+ uncond_input = tokenizer(
180
+ [""],
181
+ padding="max_length",
182
+ max_length=max_length,
183
+ return_tensors="pt"
184
+ )
185
+ with torch.no_grad():
186
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
187
+
188
+ # Concatenate for classifier-free guidance
189
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
190
+
191
+ # Prepare latents
192
+ latents = torch.randn(
193
+ (1, unet.config.in_channels, height // 8, width // 8),
194
+ generator=generator,
195
+ ).to(torch_device)
196
+
197
+ # Set scheduler
198
+ scheduler.set_timesteps(num_inference_steps)
199
+ latents = latents * scheduler.init_noise_sigma
200
+
201
+ # Denoising loop
202
+ for i, t in enumerate(tqdm(scheduler.timesteps)):
203
+ # Expand latents for classifier-free guidance
204
+ latent_model_input = torch.cat([latents] * 2)
205
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
206
+
207
+ # Predict noise residual
208
+ with torch.no_grad():
209
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
210
+
211
+ # Perform guidance
212
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
213
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
214
+
215
+ # Apply loss every 5 steps if enabled
216
+ if apply_loss and i % 5 == 0:
217
+ # Compute what the image would look like (need gradients for loss)
218
+ latents_x0 = latents - (scheduler.sigmas[i] * noise_pred)
219
+ latents_x0 = latents_x0.detach().requires_grad_(True)
220
+
221
+ # Decode to image space (without no_grad so we can backprop)
222
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
223
+
224
+ # Calculate loss
225
+ loss = neon_cyberpunk_loss(denoised_images) * loss_scale
226
+
227
+ # Get gradients
228
+ cond_grad = torch.autograd.grad(loss, latents_x0)[0]
229
+
230
+ # Modify noise prediction
231
+ noise_pred = noise_pred - (scheduler.sigmas[i] * cond_grad)
232
+
233
+ # Compute previous noisy sample
234
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
235
+
236
+ # Decode latents to image
237
+ with torch.no_grad():
238
+ latents = 1 / 0.18215 * latents
239
+ image = vae.decode(latents).sample
240
+
241
+ # Convert to PIL
242
+ image = (image / 2 + 0.5).clamp(0, 1)
243
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
244
+ image = (image * 255).round().astype("uint8")
245
+ pil_image = Image.fromarray(image[0])
246
+
247
+ return pil_image
248
+
249
+ def generate_comparison(prompt, style_name, seed):
250
+ """Generate comparison with and without neon cyberpunk loss"""
251
+
252
+ # Generate without loss
253
+ img_without = generate_image(
254
+ prompt=prompt,
255
+ style_name=style_name,
256
+ seed=seed,
257
+ apply_loss=False
258
+ )
259
+
260
+ # Generate with neon cyberpunk loss
261
+ img_with = generate_image(
262
+ prompt=prompt,
263
+ style_name=style_name,
264
+ seed=seed,
265
+ apply_loss=True,
266
+ loss_scale=200
267
+ )
268
+
269
+ return img_without, img_with
270
+
271
+ def generate_all_styles(prompt, seed1, seed2, seed3):
272
+ """Generate images for all 3 styles with comparison"""
273
+
274
+ styles = list(STYLE_EMBEDDINGS.keys())
275
+ seeds = [seed1, seed2, seed3]
276
+
277
+ results = []
278
+
279
+ for style, seed in zip(styles, seeds):
280
+ img_without, img_with = generate_comparison(prompt, style, seed)
281
+ results.extend([img_without, img_with])
282
+
283
+ return results
284
+
285
+ # Create Gradio interface
286
+ with gr.Blocks(title="Stable Diffusion with Neon Cyberpunk Loss", theme=gr.themes.Soft()) as demo:
287
+ gr.Markdown(
288
+ """
289
+ # 🌆 Stable Diffusion with Neon Cyberpunk Loss
290
+
291
+ This app demonstrates textual inversion with 3 different learned styles and applies a custom **Neon Cyberpunk Loss**
292
+ that transforms images into vibrant cyberpunk scenes with neon colors (cyan, magenta, purple), high saturation,
293
+ and dramatic contrast between dark backgrounds and bright neon highlights.
294
+
295
+ ## Features:
296
+ - **3 Different Styles**: Bird Style, Shigure UI Art, Takuji Kawano Art
297
+ - **Custom Neon Cyberpunk Loss**: Creates futuristic neon aesthetic with vibrant colors
298
+ - **Seed Control**: Different seeds for reproducible results
299
+
300
+ ⏱️ **Note**: This process can take up to 10 minutes to run. Perfect time to grab a coffee! ☕
301
+ """
302
+ )
303
+
304
+ with gr.Row():
305
+ with gr.Column():
306
+ prompt_input = gr.Textbox(
307
+ label="Prompt",
308
+ placeholder="Enter your prompt here...",
309
+ value="A beautiful landscape with mountains"
310
+ )
311
+
312
+ with gr.Row():
313
+ seed1 = gr.Number(label="Seed for Style 1 (Bird Style)", value=42, precision=0)
314
+ seed2 = gr.Number(label="Seed for Style 2 (Shigure UI)", value=123, precision=0)
315
+ seed3 = gr.Number(label="Seed for Style 3 (Takuji Kawano)", value=456, precision=0)
316
+
317
+ generate_btn = gr.Button("🎨 Generate All Comparisons", variant="primary", size="lg")
318
+
319
+ gr.Markdown("### Results: Left = Original | Right = With Neon Cyberpunk Loss")
320
+
321
+ with gr.Row():
322
+ gr.Markdown("#### Style 1: Bird Style")
323
+ with gr.Row():
324
+ out1_without = gr.Image(label="Original")
325
+ out1_with = gr.Image(label="Neon Cyberpunk")
326
+
327
+ with gr.Row():
328
+ gr.Markdown("#### Style 2: Shigure UI Art")
329
+ with gr.Row():
330
+ out2_without = gr.Image(label="Original")
331
+ out2_with = gr.Image(label="Neon Cyberpunk")
332
+
333
+ with gr.Row():
334
+ gr.Markdown("#### Style 3: Takuji Kawano Art")
335
+ with gr.Row():
336
+ out3_without = gr.Image(label="Original")
337
+ out3_with = gr.Image(label="Neon Cyberpunk")
338
+
339
+ # Connect the button
340
+ generate_btn.click(
341
+ fn=generate_all_styles,
342
+ inputs=[prompt_input, seed1, seed2, seed3],
343
+ outputs=[
344
+ out1_without, out1_with,
345
+ out2_without, out2_with,
346
+ out3_without, out3_with
347
+ ]
348
+ )
349
+
350
+ gr.Markdown(
351
+ """
352
+ ---
353
+ ### About the Neon Cyberpunk Loss
354
+
355
+ The **Neon Cyberpunk Loss** is a creative guidance technique that transforms images into futuristic cyberpunk scenes:
356
+ - **Neon Colors**: Maximizes cyan, magenta, and purple tones for that distinctive neon glow
357
+ - **High Saturation**: Boosts color vibrancy to create electric, vivid scenes
358
+ - **Dramatic Contrast**: Creates dark backgrounds with bright neon highlights
359
+ - **Glow Effect**: Enhances brightness in highlight areas while darkening shadows
360
+
361
+ This demonstrates how custom loss functions can dramatically alter the aesthetic and mood of generated images,
362
+ going far beyond simple color adjustments to create an entirely different visual style.
363
+
364
+ **Seeds Used**: Different seeds ensure variety across the three styles while maintaining reproducibility.
365
+
366
+ ### Assignment Info
367
+ - **Task**: Demonstrate 3 different styles with creative custom loss (not standard RGB)
368
+ - **Implementation**: Uses textual inversion embeddings + custom neon cyberpunk loss during inference
369
+ """
370
+ )
371
+
372
+ if __name__ == "__main__":
373
+ torch.manual_seed(1)
374
+ demo.launch(share=False, server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ diffusers>=0.21.0
4
+ transformers>=4.30.0
5
+ accelerate>=0.20.0
6
+ safetensors>=0.3.1
7
+ Pillow>=9.5.0
8
+ numpy>=1.24.0
9
+ tqdm>=4.65.0