|
|
""" |
|
|
Multi-Style Image Generator with Ice Crystal Effects |
|
|
Hugging Face Spaces App - With Diffusion Progress Streaming |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from pathlib import Path |
|
|
from tqdm.auto import tqdm |
|
|
import gradio as gr |
|
|
import io |
|
|
import tempfile |
|
|
|
|
|
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler |
|
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
|
|
|
vae = None |
|
|
tokenizer = None |
|
|
text_encoder = None |
|
|
unet = None |
|
|
scheduler = None |
|
|
device = None |
|
|
|
|
|
|
|
|
PREDEFINED_STYLES = { |
|
|
"8bit": "styles/8bit_learned_embeds.bin", |
|
|
"ahx_beta": "styles/ahx_beta_learned_embeds.bin", |
|
|
"dr_strange": "styles/dr_strangelearned_embeds.bin", |
|
|
"max_naylor": "styles/max_naylorlearned_embeds.bin", |
|
|
"smiling_friend": "styles/smiling-friend-style_learned_embeds.bin" |
|
|
} |
|
|
|
|
|
|
|
|
def ice_crystal_loss(images): |
|
|
""" |
|
|
Calculate loss to encourage TRANSPARENT ice crystal patterns as an overlay. |
|
|
""" |
|
|
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], |
|
|
dtype=images.dtype, device=images.device).view(1, 1, 3, 3) |
|
|
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], |
|
|
dtype=images.dtype, device=images.device).view(1, 1, 3, 3) |
|
|
|
|
|
edges_x = F.conv2d(images, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) |
|
|
edges_y = F.conv2d(images, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) |
|
|
edge_magnitude = torch.sqrt(edges_x**2 + edges_y**2) |
|
|
|
|
|
edge_threshold = 0.1 |
|
|
strong_edges = torch.relu(edge_magnitude - edge_threshold) |
|
|
edge_loss = -strong_edges.mean() |
|
|
|
|
|
edge_mask = (edge_magnitude > edge_threshold).float() |
|
|
brightness = images.mean(dim=1, keepdim=True) |
|
|
selective_brightness = brightness * edge_mask |
|
|
brightness_loss = -selective_brightness.mean() * 0.3 |
|
|
|
|
|
laplacian_kernel = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], |
|
|
dtype=images.dtype, device=images.device).view(1, 1, 3, 3) |
|
|
high_freq = F.conv2d(images, laplacian_kernel.repeat(3, 1, 1, 1), padding=1, groups=3) |
|
|
high_freq_loss = -torch.abs(high_freq).mean() * 0.5 |
|
|
|
|
|
r, g, b = images[:, 0], images[:, 1], images[:, 2] |
|
|
bright_mask = (brightness.squeeze(1) > 0.5).float() |
|
|
cool_tone_loss = (r * bright_mask).mean() - ((b * bright_mask).mean() + (g * bright_mask).mean()) / 2 |
|
|
cool_tone_loss = cool_tone_loss * 0.2 |
|
|
|
|
|
kernel_size = 3 |
|
|
local_mean = F.avg_pool2d(images, kernel_size, stride=1, padding=kernel_size//2) |
|
|
local_variance = F.avg_pool2d((images - local_mean)**2, kernel_size, stride=1, padding=kernel_size//2) |
|
|
texture_in_edges = local_variance * edge_mask.unsqueeze(1) |
|
|
texture_loss = -texture_in_edges.mean() * 0.5 |
|
|
|
|
|
total_loss = ( |
|
|
3.0 * edge_loss + |
|
|
0.5 * brightness_loss + |
|
|
0.8 * high_freq_loss + |
|
|
0.2 * cool_tone_loss + |
|
|
1.0 * texture_loss |
|
|
) |
|
|
|
|
|
return total_loss |
|
|
|
|
|
|
|
|
def load_models(): |
|
|
"""Load all models once and cache them globally.""" |
|
|
global vae, tokenizer, text_encoder, unet, scheduler, device |
|
|
|
|
|
|
|
|
if vae is not None and scheduler is not None: |
|
|
return |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
model_id = "CompVis/stable-diffusion-v1-4" |
|
|
|
|
|
try: |
|
|
print("Loading models... (this may take a few minutes on CPU)") |
|
|
|
|
|
|
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype).to(device) |
|
|
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") |
|
|
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype).to(device) |
|
|
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype).to(device) |
|
|
|
|
|
|
|
|
scheduler = LMSDiscreteScheduler( |
|
|
beta_start=0.00085, |
|
|
beta_end=0.012, |
|
|
beta_schedule="scaled_linear", |
|
|
num_train_timesteps=1000 |
|
|
) |
|
|
|
|
|
print("Models loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading models: {e}") |
|
|
raise RuntimeError(f"Failed to load models: {e}") |
|
|
|
|
|
|
|
|
def decode_latents_to_image(latents_to_decode): |
|
|
"""Decode latents to PIL Image.""" |
|
|
global vae, device |
|
|
|
|
|
with torch.no_grad(): |
|
|
latents_scaled = 1 / 0.18215 * latents_to_decode |
|
|
image = vae.decode(latents_scaled).sample |
|
|
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
|
image = (image[0] * 255).astype(np.uint8) |
|
|
return Image.fromarray(image) |
|
|
|
|
|
|
|
|
def create_gif_from_frames(frames, output_path=None, duration=200): |
|
|
"""Create an animated GIF from a list of PIL Images.""" |
|
|
if not frames: |
|
|
return None |
|
|
|
|
|
if output_path is None: |
|
|
output_path = tempfile.mktemp(suffix='.gif') |
|
|
|
|
|
|
|
|
frames[0].save( |
|
|
output_path, |
|
|
save_all=True, |
|
|
append_images=frames[1:], |
|
|
duration=duration, |
|
|
loop=0 |
|
|
) |
|
|
return output_path |
|
|
|
|
|
|
|
|
def generate_with_style_streaming( |
|
|
style_file, |
|
|
prompt, |
|
|
seed=42, |
|
|
num_inference_steps=50, |
|
|
guidance_scale=7.5, |
|
|
height=512, |
|
|
width=512, |
|
|
use_ice_crystal_guidance=False, |
|
|
ice_crystal_loss_scale=50, |
|
|
guidance_frequency=10, |
|
|
preview_frequency=5 |
|
|
): |
|
|
""" |
|
|
Generate an image with streaming updates. |
|
|
Yields intermediate images during generation. |
|
|
Returns final image and GIF path at the end. |
|
|
""" |
|
|
global vae, tokenizer, text_encoder, unet, scheduler, device |
|
|
|
|
|
load_models() |
|
|
|
|
|
|
|
|
frames = [] |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
learned_embeds_dict = torch.load(style_file, map_location=device, weights_only=True) |
|
|
|
|
|
style_token = list(learned_embeds_dict.keys())[0] |
|
|
style_embedding = learned_embeds_dict[style_token].to(device) |
|
|
|
|
|
expected_dim = text_encoder.get_input_embeddings().weight.shape[1] |
|
|
|
|
|
if style_embedding.shape[0] != expected_dim: |
|
|
if style_embedding.shape[0] == 1024 and expected_dim == 768: |
|
|
style_embedding = style_embedding[:768] |
|
|
else: |
|
|
raise ValueError(f"Cannot handle embedding dimension {style_embedding.shape[0]} -> {expected_dim}") |
|
|
|
|
|
if style_token not in tokenizer.get_vocab(): |
|
|
tokenizer.add_tokens([style_token]) |
|
|
text_encoder.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
token_id = tokenizer.convert_tokens_to_ids(style_token) |
|
|
with torch.no_grad(): |
|
|
text_encoder.get_input_embeddings().weight[token_id] = style_embedding |
|
|
|
|
|
final_prompt = prompt.replace("<style>", style_token) |
|
|
|
|
|
text_input = tokenizer( |
|
|
final_prompt, |
|
|
padding="max_length", |
|
|
max_length=tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
text_embeddings = text_encoder(text_input.input_ids.to(device))[0] |
|
|
|
|
|
uncond_input = tokenizer( |
|
|
[""], |
|
|
padding="max_length", |
|
|
max_length=tokenizer.model_max_length, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] |
|
|
|
|
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
|
|
|
|
|
latents = torch.randn( |
|
|
(1, unet.config.in_channels, height // 8, width // 8), |
|
|
generator=generator, |
|
|
device=device |
|
|
) |
|
|
|
|
|
scheduler.set_timesteps(num_inference_steps) |
|
|
latents = latents * scheduler.init_noise_sigma |
|
|
|
|
|
for i, t in enumerate(scheduler.timesteps): |
|
|
latent_model_input = torch.cat([latents] * 2) |
|
|
latent_model_input = scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
with torch.no_grad(): |
|
|
noise_pred = unet( |
|
|
latent_model_input, |
|
|
t, |
|
|
encoder_hidden_states=text_embeddings |
|
|
).sample |
|
|
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
if use_ice_crystal_guidance and i % guidance_frequency == 0: |
|
|
if device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
latents = latents.detach().requires_grad_() |
|
|
sigma = scheduler.sigmas[i] |
|
|
latents_x0 = latents - sigma * noise_pred |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=False): |
|
|
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 |
|
|
|
|
|
loss = ice_crystal_loss(denoised_images) * ice_crystal_loss_scale |
|
|
cond_grad = torch.autograd.grad(loss, latents)[0] |
|
|
latents = latents.detach() - cond_grad * sigma**2 |
|
|
|
|
|
del denoised_images, loss, cond_grad |
|
|
if device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
latents = scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
|
|
|
if i % preview_frequency == 0 or i == num_inference_steps - 1: |
|
|
preview_image = decode_latents_to_image(latents) |
|
|
frames.append(preview_image) |
|
|
|
|
|
|
|
|
yield { |
|
|
"step": i + 1, |
|
|
"total": num_inference_steps, |
|
|
"image": preview_image, |
|
|
"gif": None |
|
|
} |
|
|
|
|
|
|
|
|
final_image = decode_latents_to_image(latents) |
|
|
frames.append(final_image) |
|
|
|
|
|
|
|
|
gif_path = create_gif_from_frames(frames, duration=300) |
|
|
|
|
|
|
|
|
yield { |
|
|
"step": num_inference_steps, |
|
|
"total": num_inference_steps, |
|
|
"image": final_image, |
|
|
"gif": gif_path |
|
|
} |
|
|
|
|
|
|
|
|
def generate_image_streaming( |
|
|
prompt, |
|
|
style_choice, |
|
|
custom_embedding, |
|
|
seed, |
|
|
guidance_scale, |
|
|
use_ice_crystal, |
|
|
ice_crystal_intensity, |
|
|
preview_frequency |
|
|
): |
|
|
"""Streaming generation function for Gradio interface.""" |
|
|
|
|
|
if custom_embedding is not None: |
|
|
style_file = custom_embedding |
|
|
else: |
|
|
if style_choice not in PREDEFINED_STYLES: |
|
|
raise gr.Error("Please select a style or upload a custom embedding file.") |
|
|
style_file = PREDEFINED_STYLES[style_choice] |
|
|
|
|
|
if not Path(style_file).exists(): |
|
|
raise gr.Error(f"Style embedding file not found: {style_file}") |
|
|
|
|
|
try: |
|
|
for update in generate_with_style_streaming( |
|
|
style_file=style_file, |
|
|
prompt=prompt, |
|
|
seed=int(seed), |
|
|
guidance_scale=guidance_scale, |
|
|
use_ice_crystal_guidance=use_ice_crystal, |
|
|
ice_crystal_loss_scale=ice_crystal_intensity, |
|
|
preview_frequency=int(preview_frequency) |
|
|
): |
|
|
status = f"Step {update['step']}/{update['total']}" |
|
|
yield update["image"], update["gif"], status |
|
|
|
|
|
except Exception as e: |
|
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="Multi-Style Image Generator", |
|
|
theme=gr.themes.Soft( |
|
|
primary_hue="indigo", |
|
|
secondary_hue="cyan" |
|
|
) |
|
|
) as demo: |
|
|
gr.Markdown(""" |
|
|
# Multi-Style Image Generator with Ice Crystal Effects |
|
|
|
|
|
Generate images using textual inversion style embeddings with optional ice crystal overlay effects. |
|
|
**Now with live diffusion progress streaming!** |
|
|
|
|
|
**Instructions:** |
|
|
1. Enter a prompt using `<style>` as placeholder (e.g., "A cat in the style of <style>") |
|
|
2. Select a predefined style OR upload your own `.bin` embedding file |
|
|
3. Optionally enable ice crystal effect for a crystalline overlay |
|
|
4. Click Generate and watch the image evolve! |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="A mouse in the style of <style>", |
|
|
value="A mouse in the style of <style>", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
style_choice = gr.Dropdown( |
|
|
choices=list(PREDEFINED_STYLES.keys()), |
|
|
value="8bit", |
|
|
label="Predefined Style", |
|
|
info="Select a bundled style embedding" |
|
|
) |
|
|
|
|
|
custom_embedding = gr.File( |
|
|
label="Custom Embedding (Optional)", |
|
|
file_types=[".bin"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
seed = gr.Number( |
|
|
label="Seed", |
|
|
value=42, |
|
|
precision=0 |
|
|
) |
|
|
guidance_scale = gr.Slider( |
|
|
label="Guidance Scale", |
|
|
minimum=1.0, |
|
|
maximum=20.0, |
|
|
value=7.5, |
|
|
step=0.5 |
|
|
) |
|
|
|
|
|
with gr.Accordion("Ice Crystal Effect", open=False): |
|
|
use_ice_crystal = gr.Checkbox( |
|
|
label="Enable Ice Crystal Effect", |
|
|
value=False, |
|
|
info="Add crystalline overlay to the image" |
|
|
) |
|
|
ice_crystal_intensity = gr.Slider( |
|
|
label="Ice Crystal Intensity", |
|
|
minimum=30, |
|
|
maximum=100, |
|
|
value=50, |
|
|
step=5, |
|
|
info="Higher = stronger crystal effect" |
|
|
) |
|
|
|
|
|
with gr.Accordion("Streaming Settings", open=True): |
|
|
preview_frequency = gr.Slider( |
|
|
label="Preview Frequency", |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
value=5, |
|
|
step=1, |
|
|
info="Show preview every N steps (lower = more updates, slower)" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("Generate", variant="primary", size="lg") |
|
|
status_text = gr.Textbox(label="Status", interactive=False, value="Ready") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_image = gr.Image( |
|
|
label="Live Preview / Final Image", |
|
|
type="pil" |
|
|
) |
|
|
output_gif = gr.File( |
|
|
label="Diffusion Progress GIF (available after generation)", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["A cat in the style of <style>", "8bit", None, 42, 7.5, False, 50, 5], |
|
|
["A mystical forest in the style of <style>", "dr_strange", None, 123, 7.5, False, 50, 5], |
|
|
["A portrait in the style of <style>", "max_naylor", None, 456, 7.5, True, 60, 5], |
|
|
], |
|
|
inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity, preview_frequency], |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_image_streaming, |
|
|
inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity, preview_frequency], |
|
|
outputs=[output_image, output_gif, status_text] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|