Krishnakanth1993 commited on
Commit
d9b5109
·
verified ·
1 Parent(s): 5bcc41c

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +369 -0
app.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Style Image Generator with Ice Crystal Effects
3
+ Hugging Face Spaces App
4
+ """
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from PIL import Image
10
+ from pathlib import Path
11
+ from tqdm.auto import tqdm
12
+ import gradio as gr
13
+
14
+ from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
15
+ from transformers import CLIPTextModel, CLIPTokenizer
16
+
17
+ # Global variables for models (will be loaded once)
18
+ vae = None
19
+ tokenizer = None
20
+ text_encoder = None
21
+ unet = None
22
+ scheduler = None
23
+ device = None
24
+
25
+ # Predefined styles mapping
26
+ PREDEFINED_STYLES = {
27
+ "8bit": "styles/8bit_learned_embeds.bin",
28
+ "ahx_beta": "styles/ahx_beta_learned_embeds.bin",
29
+ "dr_strange": "styles/dr_strangelearned_embeds.bin",
30
+ "max_naylor": "styles/max_naylorlearned_embeds.bin",
31
+ "smiling_friend": "styles/smiling-friend-style_learned_embeds.bin"
32
+ }
33
+
34
+
35
+ def ice_crystal_loss(images):
36
+ """
37
+ Calculate loss to encourage TRANSPARENT ice crystal patterns as an overlay.
38
+ """
39
+ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
40
+ dtype=images.dtype, device=images.device).view(1, 1, 3, 3)
41
+ sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
42
+ dtype=images.dtype, device=images.device).view(1, 1, 3, 3)
43
+
44
+ edges_x = F.conv2d(images, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3)
45
+ edges_y = F.conv2d(images, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3)
46
+ edge_magnitude = torch.sqrt(edges_x**2 + edges_y**2)
47
+
48
+ edge_threshold = 0.1
49
+ strong_edges = torch.relu(edge_magnitude - edge_threshold)
50
+ edge_loss = -strong_edges.mean()
51
+
52
+ edge_mask = (edge_magnitude > edge_threshold).float()
53
+ brightness = images.mean(dim=1, keepdim=True)
54
+ selective_brightness = brightness * edge_mask
55
+ brightness_loss = -selective_brightness.mean() * 0.3
56
+
57
+ laplacian_kernel = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]],
58
+ dtype=images.dtype, device=images.device).view(1, 1, 3, 3)
59
+ high_freq = F.conv2d(images, laplacian_kernel.repeat(3, 1, 1, 1), padding=1, groups=3)
60
+ high_freq_loss = -torch.abs(high_freq).mean() * 0.5
61
+
62
+ r, g, b = images[:, 0], images[:, 1], images[:, 2]
63
+ bright_mask = (brightness.squeeze(1) > 0.5).float()
64
+ cool_tone_loss = (r * bright_mask).mean() - ((b * bright_mask).mean() + (g * bright_mask).mean()) / 2
65
+ cool_tone_loss = cool_tone_loss * 0.2
66
+
67
+ kernel_size = 3
68
+ local_mean = F.avg_pool2d(images, kernel_size, stride=1, padding=kernel_size//2)
69
+ local_variance = F.avg_pool2d((images - local_mean)**2, kernel_size, stride=1, padding=kernel_size//2)
70
+ texture_in_edges = local_variance * edge_mask.unsqueeze(1)
71
+ texture_loss = -texture_in_edges.mean() * 0.5
72
+
73
+ total_loss = (
74
+ 3.0 * edge_loss +
75
+ 0.5 * brightness_loss +
76
+ 0.8 * high_freq_loss +
77
+ 0.2 * cool_tone_loss +
78
+ 1.0 * texture_loss
79
+ )
80
+
81
+ return total_loss
82
+
83
+
84
+ def load_models():
85
+ """Load all models once and cache them globally."""
86
+ global vae, tokenizer, text_encoder, unet, scheduler, device
87
+
88
+ if vae is not None:
89
+ return
90
+
91
+ device = "cuda" if torch.cuda.is_available() else "cpu"
92
+ print(f"Using device: {device}")
93
+
94
+ model_id = "CompVis/stable-diffusion-v1-4"
95
+
96
+ print("Loading models...")
97
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
98
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
99
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
100
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(device)
101
+
102
+ scheduler = LMSDiscreteScheduler(
103
+ beta_start=0.00085,
104
+ beta_end=0.012,
105
+ beta_schedule="scaled_linear",
106
+ num_train_timesteps=1000
107
+ )
108
+
109
+ print("Models loaded successfully!")
110
+
111
+
112
+ def generate_with_style(
113
+ style_file,
114
+ prompt,
115
+ seed=42,
116
+ num_inference_steps=50,
117
+ guidance_scale=7.5,
118
+ height=512,
119
+ width=512,
120
+ use_ice_crystal_guidance=False,
121
+ ice_crystal_loss_scale=50,
122
+ guidance_frequency=10,
123
+ progress=None
124
+ ):
125
+ """Generate an image using a style embedding with optional ice crystal guidance."""
126
+ global vae, tokenizer, text_encoder, unet, scheduler, device
127
+
128
+ load_models()
129
+
130
+ generator = torch.Generator(device=device).manual_seed(seed)
131
+ learned_embeds_dict = torch.load(style_file, map_location=device, weights_only=True)
132
+
133
+ style_token = list(learned_embeds_dict.keys())[0]
134
+ style_embedding = learned_embeds_dict[style_token].to(device)
135
+
136
+ expected_dim = text_encoder.get_input_embeddings().weight.shape[1]
137
+
138
+ if style_embedding.shape[0] != expected_dim:
139
+ if style_embedding.shape[0] == 1024 and expected_dim == 768:
140
+ style_embedding = style_embedding[:768]
141
+ else:
142
+ raise ValueError(f"Cannot handle embedding dimension {style_embedding.shape[0]} -> {expected_dim}")
143
+
144
+ if style_token not in tokenizer.get_vocab():
145
+ tokenizer.add_tokens([style_token])
146
+ text_encoder.resize_token_embeddings(len(tokenizer))
147
+
148
+ token_id = tokenizer.convert_tokens_to_ids(style_token)
149
+ with torch.no_grad():
150
+ text_encoder.get_input_embeddings().weight[token_id] = style_embedding
151
+
152
+ final_prompt = prompt.replace("<style>", style_token)
153
+
154
+ text_input = tokenizer(
155
+ final_prompt,
156
+ padding="max_length",
157
+ max_length=tokenizer.model_max_length,
158
+ truncation=True,
159
+ return_tensors="pt"
160
+ )
161
+
162
+ with torch.no_grad():
163
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
164
+
165
+ uncond_input = tokenizer(
166
+ [""],
167
+ padding="max_length",
168
+ max_length=tokenizer.model_max_length,
169
+ return_tensors="pt"
170
+ )
171
+
172
+ with torch.no_grad():
173
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
174
+
175
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
176
+
177
+ latents = torch.randn(
178
+ (1, unet.config.in_channels, height // 8, width // 8),
179
+ generator=generator,
180
+ device=device
181
+ )
182
+
183
+ scheduler.set_timesteps(num_inference_steps)
184
+ latents = latents * scheduler.init_noise_sigma
185
+
186
+ for i, t in enumerate(tqdm(scheduler.timesteps, desc="Generating")):
187
+ if progress:
188
+ progress((i + 1) / num_inference_steps, f"Step {i + 1}/{num_inference_steps}")
189
+
190
+ latent_model_input = torch.cat([latents] * 2)
191
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
192
+
193
+ with torch.no_grad():
194
+ noise_pred = unet(
195
+ latent_model_input,
196
+ t,
197
+ encoder_hidden_states=text_embeddings
198
+ ).sample
199
+
200
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
201
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
202
+
203
+ if use_ice_crystal_guidance and i % guidance_frequency == 0:
204
+ if device == "cuda":
205
+ torch.cuda.empty_cache()
206
+
207
+ latents = latents.detach().requires_grad_()
208
+ sigma = scheduler.sigmas[i]
209
+ latents_x0 = latents - sigma * noise_pred
210
+
211
+ with torch.cuda.amp.autocast(enabled=False):
212
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
213
+
214
+ loss = ice_crystal_loss(denoised_images) * ice_crystal_loss_scale
215
+ cond_grad = torch.autograd.grad(loss, latents)[0]
216
+ latents = latents.detach() - cond_grad * sigma**2
217
+
218
+ del denoised_images, loss, cond_grad
219
+ if device == "cuda":
220
+ torch.cuda.empty_cache()
221
+
222
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
223
+
224
+ latents = 1 / 0.18215 * latents
225
+
226
+ with torch.no_grad():
227
+ image = vae.decode(latents).sample
228
+
229
+ image = (image / 2 + 0.5).clamp(0, 1)
230
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
231
+ image = (image[0] * 255).astype(np.uint8)
232
+ image = Image.fromarray(image)
233
+
234
+ return image
235
+
236
+
237
+ def generate_image(
238
+ prompt,
239
+ style_choice,
240
+ custom_embedding,
241
+ seed,
242
+ guidance_scale,
243
+ use_ice_crystal,
244
+ ice_crystal_intensity,
245
+ progress=gr.Progress()
246
+ ):
247
+ """Main generation function for Gradio interface."""
248
+
249
+ if custom_embedding is not None:
250
+ style_file = custom_embedding
251
+ else:
252
+ if style_choice not in PREDEFINED_STYLES:
253
+ raise gr.Error("Please select a style or upload a custom embedding file.")
254
+ style_file = PREDEFINED_STYLES[style_choice]
255
+
256
+ if not Path(style_file).exists():
257
+ raise gr.Error(f"Style embedding file not found: {style_file}")
258
+
259
+ try:
260
+ image = generate_with_style(
261
+ style_file=style_file,
262
+ prompt=prompt,
263
+ seed=int(seed),
264
+ guidance_scale=guidance_scale,
265
+ use_ice_crystal_guidance=use_ice_crystal,
266
+ ice_crystal_loss_scale=ice_crystal_intensity,
267
+ progress=progress
268
+ )
269
+ return image
270
+ except Exception as e:
271
+ raise gr.Error(f"Generation failed: {str(e)}")
272
+
273
+
274
+ # Build the Gradio interface
275
+ with gr.Blocks(
276
+ title="Multi-Style Image Generator",
277
+ theme=gr.themes.Soft(
278
+ primary_hue="indigo",
279
+ secondary_hue="cyan"
280
+ )
281
+ ) as demo:
282
+ gr.Markdown("""
283
+ # Multi-Style Image Generator with Ice Crystal Effects
284
+
285
+ Generate images using textual inversion style embeddings with optional ice crystal overlay effects.
286
+
287
+ **Instructions:**
288
+ 1. Enter a prompt using `<style>` as placeholder (e.g., "A cat in the style of <style>")
289
+ 2. Select a predefined style OR upload your own `.bin` embedding file
290
+ 3. Optionally enable ice crystal effect for a crystalline overlay
291
+ 4. Click Generate!
292
+ """)
293
+
294
+ with gr.Row():
295
+ with gr.Column(scale=1):
296
+ prompt = gr.Textbox(
297
+ label="Prompt",
298
+ placeholder="A mouse in the style of <style>",
299
+ value="A mouse in the style of <style>",
300
+ lines=2
301
+ )
302
+
303
+ style_choice = gr.Dropdown(
304
+ choices=list(PREDEFINED_STYLES.keys()),
305
+ value="8bit",
306
+ label="Predefined Style",
307
+ info="Select a bundled style embedding"
308
+ )
309
+
310
+ custom_embedding = gr.File(
311
+ label="Custom Embedding (Optional)",
312
+ file_types=[".bin"],
313
+ type="filepath"
314
+ )
315
+
316
+ with gr.Row():
317
+ seed = gr.Number(
318
+ label="Seed",
319
+ value=42,
320
+ precision=0
321
+ )
322
+ guidance_scale = gr.Slider(
323
+ label="Guidance Scale",
324
+ minimum=1.0,
325
+ maximum=20.0,
326
+ value=7.5,
327
+ step=0.5
328
+ )
329
+
330
+ with gr.Accordion("Ice Crystal Effect", open=False):
331
+ use_ice_crystal = gr.Checkbox(
332
+ label="Enable Ice Crystal Effect",
333
+ value=False,
334
+ info="Add crystalline overlay to the image"
335
+ )
336
+ ice_crystal_intensity = gr.Slider(
337
+ label="Ice Crystal Intensity",
338
+ minimum=30,
339
+ maximum=100,
340
+ value=50,
341
+ step=5,
342
+ info="Higher = stronger crystal effect"
343
+ )
344
+
345
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
346
+
347
+ with gr.Column(scale=1):
348
+ output_image = gr.Image(
349
+ label="Generated Image",
350
+ type="pil"
351
+ )
352
+
353
+ gr.Examples(
354
+ examples=[
355
+ ["A cat in the style of <style>", "8bit", None, 42, 7.5, False, 50],
356
+ ["A mystical forest in the style of <style>", "dr_strange", None, 123, 7.5, False, 50],
357
+ ["A portrait in the style of <style>", "max_naylor", None, 456, 7.5, True, 60],
358
+ ],
359
+ inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity],
360
+ )
361
+
362
+ generate_btn.click(
363
+ fn=generate_image,
364
+ inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity],
365
+ outputs=output_image
366
+ )
367
+
368
+ if __name__ == "__main__":
369
+ demo.launch()