Krishnakanth1993's picture
Upload app.py with huggingface_hub
6baaa4a verified
"""
Multi-Style Image Generator with Ice Crystal Effects
Hugging Face Spaces App - With Diffusion Progress Streaming
"""
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.auto import tqdm
import gradio as gr
import io
import tempfile
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer
# Global variables for models (will be loaded once)
vae = None
tokenizer = None
text_encoder = None
unet = None
scheduler = None
device = None
# Predefined styles mapping
PREDEFINED_STYLES = {
"8bit": "styles/8bit_learned_embeds.bin",
"ahx_beta": "styles/ahx_beta_learned_embeds.bin",
"dr_strange": "styles/dr_strangelearned_embeds.bin",
"max_naylor": "styles/max_naylorlearned_embeds.bin",
"smiling_friend": "styles/smiling-friend-style_learned_embeds.bin"
}
def ice_crystal_loss(images):
"""
Calculate loss to encourage TRANSPARENT ice crystal patterns as an overlay.
"""
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=images.dtype, device=images.device).view(1, 1, 3, 3)
sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
dtype=images.dtype, device=images.device).view(1, 1, 3, 3)
edges_x = F.conv2d(images, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3)
edges_y = F.conv2d(images, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3)
edge_magnitude = torch.sqrt(edges_x**2 + edges_y**2)
edge_threshold = 0.1
strong_edges = torch.relu(edge_magnitude - edge_threshold)
edge_loss = -strong_edges.mean()
edge_mask = (edge_magnitude > edge_threshold).float()
brightness = images.mean(dim=1, keepdim=True)
selective_brightness = brightness * edge_mask
brightness_loss = -selective_brightness.mean() * 0.3
laplacian_kernel = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]],
dtype=images.dtype, device=images.device).view(1, 1, 3, 3)
high_freq = F.conv2d(images, laplacian_kernel.repeat(3, 1, 1, 1), padding=1, groups=3)
high_freq_loss = -torch.abs(high_freq).mean() * 0.5
r, g, b = images[:, 0], images[:, 1], images[:, 2]
bright_mask = (brightness.squeeze(1) > 0.5).float()
cool_tone_loss = (r * bright_mask).mean() - ((b * bright_mask).mean() + (g * bright_mask).mean()) / 2
cool_tone_loss = cool_tone_loss * 0.2
kernel_size = 3
local_mean = F.avg_pool2d(images, kernel_size, stride=1, padding=kernel_size//2)
local_variance = F.avg_pool2d((images - local_mean)**2, kernel_size, stride=1, padding=kernel_size//2)
texture_in_edges = local_variance * edge_mask.unsqueeze(1)
texture_loss = -texture_in_edges.mean() * 0.5
total_loss = (
3.0 * edge_loss +
0.5 * brightness_loss +
0.8 * high_freq_loss +
0.2 * cool_tone_loss +
1.0 * texture_loss
)
return total_loss
def load_models():
"""Load all models once and cache them globally."""
global vae, tokenizer, text_encoder, unet, scheduler, device
# Check if already loaded
if vae is not None and scheduler is not None:
return
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model_id = "CompVis/stable-diffusion-v1-4"
try:
print("Loading models... (this may take a few minutes on CPU)")
# Load with float16 on GPU, float32 on CPU
dtype = torch.float16 if device == "cuda" else torch.float32
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype).to(device)
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype).to(device)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype).to(device)
# Initialize scheduler
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000
)
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
raise RuntimeError(f"Failed to load models: {e}")
def decode_latents_to_image(latents_to_decode):
"""Decode latents to PIL Image."""
global vae, device
with torch.no_grad():
latents_scaled = 1 / 0.18215 * latents_to_decode
image = vae.decode(latents_scaled).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image[0] * 255).astype(np.uint8)
return Image.fromarray(image)
def create_gif_from_frames(frames, output_path=None, duration=200):
"""Create an animated GIF from a list of PIL Images."""
if not frames:
return None
if output_path is None:
output_path = tempfile.mktemp(suffix='.gif')
# Save as GIF
frames[0].save(
output_path,
save_all=True,
append_images=frames[1:],
duration=duration,
loop=0
)
return output_path
def generate_with_style_streaming(
style_file,
prompt,
seed=42,
num_inference_steps=50,
guidance_scale=7.5,
height=512,
width=512,
use_ice_crystal_guidance=False,
ice_crystal_loss_scale=50,
guidance_frequency=10,
preview_frequency=5
):
"""
Generate an image with streaming updates.
Yields intermediate images during generation.
Returns final image and GIF path at the end.
"""
global vae, tokenizer, text_encoder, unet, scheduler, device
load_models()
# Collect frames for GIF
frames = []
generator = torch.Generator(device=device).manual_seed(seed)
learned_embeds_dict = torch.load(style_file, map_location=device, weights_only=True)
style_token = list(learned_embeds_dict.keys())[0]
style_embedding = learned_embeds_dict[style_token].to(device)
expected_dim = text_encoder.get_input_embeddings().weight.shape[1]
if style_embedding.shape[0] != expected_dim:
if style_embedding.shape[0] == 1024 and expected_dim == 768:
style_embedding = style_embedding[:768]
else:
raise ValueError(f"Cannot handle embedding dimension {style_embedding.shape[0]} -> {expected_dim}")
if style_token not in tokenizer.get_vocab():
tokenizer.add_tokens([style_token])
text_encoder.resize_token_embeddings(len(tokenizer))
token_id = tokenizer.convert_tokens_to_ids(style_token)
with torch.no_grad():
text_encoder.get_input_embeddings().weight[token_id] = style_embedding
final_prompt = prompt.replace("<style>", style_token)
text_input = tokenizer(
final_prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
)
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
uncond_input = tokenizer(
[""],
padding="max_length",
max_length=tokenizer.model_max_length,
return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
latents = torch.randn(
(1, unet.config.in_channels, height // 8, width // 8),
generator=generator,
device=device
)
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma
for i, t in enumerate(scheduler.timesteps):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
with torch.no_grad():
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if use_ice_crystal_guidance and i % guidance_frequency == 0:
if device == "cuda":
torch.cuda.empty_cache()
latents = latents.detach().requires_grad_()
sigma = scheduler.sigmas[i]
latents_x0 = latents - sigma * noise_pred
with torch.cuda.amp.autocast(enabled=False):
denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
loss = ice_crystal_loss(denoised_images) * ice_crystal_loss_scale
cond_grad = torch.autograd.grad(loss, latents)[0]
latents = latents.detach() - cond_grad * sigma**2
del denoised_images, loss, cond_grad
if device == "cuda":
torch.cuda.empty_cache()
latents = scheduler.step(noise_pred, t, latents).prev_sample
# Decode and yield intermediate preview every N steps
if i % preview_frequency == 0 or i == num_inference_steps - 1:
preview_image = decode_latents_to_image(latents)
frames.append(preview_image)
# Yield progress update: (step, total, current_image, gif_path)
yield {
"step": i + 1,
"total": num_inference_steps,
"image": preview_image,
"gif": None # GIF not ready yet
}
# Final decode
final_image = decode_latents_to_image(latents)
frames.append(final_image)
# Create GIF from all frames
gif_path = create_gif_from_frames(frames, duration=300)
# Yield final result
yield {
"step": num_inference_steps,
"total": num_inference_steps,
"image": final_image,
"gif": gif_path
}
def generate_image_streaming(
prompt,
style_choice,
custom_embedding,
seed,
guidance_scale,
use_ice_crystal,
ice_crystal_intensity,
preview_frequency
):
"""Streaming generation function for Gradio interface."""
if custom_embedding is not None:
style_file = custom_embedding
else:
if style_choice not in PREDEFINED_STYLES:
raise gr.Error("Please select a style or upload a custom embedding file.")
style_file = PREDEFINED_STYLES[style_choice]
if not Path(style_file).exists():
raise gr.Error(f"Style embedding file not found: {style_file}")
try:
for update in generate_with_style_streaming(
style_file=style_file,
prompt=prompt,
seed=int(seed),
guidance_scale=guidance_scale,
use_ice_crystal_guidance=use_ice_crystal,
ice_crystal_loss_scale=ice_crystal_intensity,
preview_frequency=int(preview_frequency)
):
status = f"Step {update['step']}/{update['total']}"
yield update["image"], update["gif"], status
except Exception as e:
raise gr.Error(f"Generation failed: {str(e)}")
# Build the Gradio interface
with gr.Blocks(
title="Multi-Style Image Generator",
theme=gr.themes.Soft(
primary_hue="indigo",
secondary_hue="cyan"
)
) as demo:
gr.Markdown("""
# Multi-Style Image Generator with Ice Crystal Effects
Generate images using textual inversion style embeddings with optional ice crystal overlay effects.
**Now with live diffusion progress streaming!**
**Instructions:**
1. Enter a prompt using `<style>` as placeholder (e.g., "A cat in the style of <style>")
2. Select a predefined style OR upload your own `.bin` embedding file
3. Optionally enable ice crystal effect for a crystalline overlay
4. Click Generate and watch the image evolve!
""")
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt",
placeholder="A mouse in the style of <style>",
value="A mouse in the style of <style>",
lines=2
)
style_choice = gr.Dropdown(
choices=list(PREDEFINED_STYLES.keys()),
value="8bit",
label="Predefined Style",
info="Select a bundled style embedding"
)
custom_embedding = gr.File(
label="Custom Embedding (Optional)",
file_types=[".bin"],
type="filepath"
)
with gr.Row():
seed = gr.Number(
label="Seed",
value=42,
precision=0
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=20.0,
value=7.5,
step=0.5
)
with gr.Accordion("Ice Crystal Effect", open=False):
use_ice_crystal = gr.Checkbox(
label="Enable Ice Crystal Effect",
value=False,
info="Add crystalline overlay to the image"
)
ice_crystal_intensity = gr.Slider(
label="Ice Crystal Intensity",
minimum=30,
maximum=100,
value=50,
step=5,
info="Higher = stronger crystal effect"
)
with gr.Accordion("Streaming Settings", open=True):
preview_frequency = gr.Slider(
label="Preview Frequency",
minimum=1,
maximum=10,
value=5,
step=1,
info="Show preview every N steps (lower = more updates, slower)"
)
generate_btn = gr.Button("Generate", variant="primary", size="lg")
status_text = gr.Textbox(label="Status", interactive=False, value="Ready")
with gr.Column(scale=1):
output_image = gr.Image(
label="Live Preview / Final Image",
type="pil"
)
output_gif = gr.File(
label="Diffusion Progress GIF (available after generation)",
type="filepath"
)
gr.Examples(
examples=[
["A cat in the style of <style>", "8bit", None, 42, 7.5, False, 50, 5],
["A mystical forest in the style of <style>", "dr_strange", None, 123, 7.5, False, 50, 5],
["A portrait in the style of <style>", "max_naylor", None, 456, 7.5, True, 60, 5],
],
inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity, preview_frequency],
)
generate_btn.click(
fn=generate_image_streaming,
inputs=[prompt, style_choice, custom_embedding, seed, guidance_scale, use_ice_crystal, ice_crystal_intensity, preview_frequency],
outputs=[output_image, output_gif, status_text]
)
if __name__ == "__main__":
demo.launch()