""" Multi-Style Image Generator with Ice Crystal Effects Hugging Face Spaces App - With Diffusion Progress Streaming """ import torch import torch.nn.functional as F import numpy as np from PIL import Image from pathlib import Path from tqdm.auto import tqdm import gradio as gr import io import tempfile from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler from transformers import CLIPTextModel, CLIPTokenizer # Global variables for models (will be loaded once) vae = None tokenizer = None text_encoder = None unet = None scheduler = None device = None # Predefined styles mapping PREDEFINED_STYLES = { "8bit": "styles/8bit_learned_embeds.bin", "ahx_beta": "styles/ahx_beta_learned_embeds.bin", "dr_strange": "styles/dr_strangelearned_embeds.bin", "max_naylor": "styles/max_naylorlearned_embeds.bin", "smiling_friend": "styles/smiling-friend-style_learned_embeds.bin" } def ice_crystal_loss(images): """ Calculate loss to encourage TRANSPARENT ice crystal patterns as an overlay. """ sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=images.dtype, device=images.device).view(1, 1, 3, 3) sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=images.dtype, device=images.device).view(1, 1, 3, 3) edges_x = F.conv2d(images, sobel_x.repeat(3, 1, 1, 1), padding=1, groups=3) edges_y = F.conv2d(images, sobel_y.repeat(3, 1, 1, 1), padding=1, groups=3) edge_magnitude = torch.sqrt(edges_x**2 + edges_y**2) edge_threshold = 0.1 strong_edges = torch.relu(edge_magnitude - edge_threshold) edge_loss = -strong_edges.mean() edge_mask = (edge_magnitude > edge_threshold).float() brightness = images.mean(dim=1, keepdim=True) selective_brightness = brightness * edge_mask brightness_loss = -selective_brightness.mean() * 0.3 laplacian_kernel = torch.tensor([[0, -1, 0], [-1, 4, -1], [0, -1, 0]], dtype=images.dtype, device=images.device).view(1, 1, 3, 3) high_freq = F.conv2d(images, laplacian_kernel.repeat(3, 1, 1, 1), padding=1, groups=3) high_freq_loss = -torch.abs(high_freq).mean() * 0.5 r, g, b = images[:, 0], images[:, 1], images[:, 2] bright_mask = (brightness.squeeze(1) > 0.5).float() cool_tone_loss = (r * bright_mask).mean() - ((b * bright_mask).mean() + (g * bright_mask).mean()) / 2 cool_tone_loss = cool_tone_loss * 0.2 kernel_size = 3 local_mean = F.avg_pool2d(images, kernel_size, stride=1, padding=kernel_size//2) local_variance = F.avg_pool2d((images - local_mean)**2, kernel_size, stride=1, padding=kernel_size//2) texture_in_edges = local_variance * edge_mask.unsqueeze(1) texture_loss = -texture_in_edges.mean() * 0.5 total_loss = ( 3.0 * edge_loss + 0.5 * brightness_loss + 0.8 * high_freq_loss + 0.2 * cool_tone_loss + 1.0 * texture_loss ) return total_loss def load_models(): """Load all models once and cache them globally.""" global vae, tokenizer, text_encoder, unet, scheduler, device # Check if already loaded if vae is not None and scheduler is not None: return device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model_id = "CompVis/stable-diffusion-v1-4" try: print("Loading models... (this may take a few minutes on CPU)") # Load with float16 on GPU, float32 on CPU dtype = torch.float16 if device == "cuda" else torch.float32 vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=dtype).to(device) tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=dtype).to(device) unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", torch_dtype=dtype).to(device) # Initialize scheduler scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) print("Models loaded successfully!") except Exception as e: print(f"Error loading models: {e}") raise RuntimeError(f"Failed to load models: {e}") def decode_latents_to_image(latents_to_decode): """Decode latents to PIL Image.""" global vae, device with torch.no_grad(): latents_scaled = 1 / 0.18215 * latents_to_decode image = vae.decode(latents_scaled).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() image = (image[0] * 255).astype(np.uint8) return Image.fromarray(image) def create_gif_from_frames(frames, output_path=None, duration=200): """Create an animated GIF from a list of PIL Images.""" if not frames: return None if output_path is None: output_path = tempfile.mktemp(suffix='.gif') # Save as GIF frames[0].save( output_path, save_all=True, append_images=frames[1:], duration=duration, loop=0 ) return output_path def generate_with_style_streaming( style_file, prompt, seed=42, num_inference_steps=50, guidance_scale=7.5, height=512, width=512, use_ice_crystal_guidance=False, ice_crystal_loss_scale=50, guidance_frequency=10, preview_frequency=5 ): """ Generate an image with streaming updates. Yields intermediate images during generation. Returns final image and GIF path at the end. """ global vae, tokenizer, text_encoder, unet, scheduler, device load_models() # Collect frames for GIF frames = [] generator = torch.Generator(device=device).manual_seed(seed) learned_embeds_dict = torch.load(style_file, map_location=device, weights_only=True) style_token = list(learned_embeds_dict.keys())[0] style_embedding = learned_embeds_dict[style_token].to(device) expected_dim = text_encoder.get_input_embeddings().weight.shape[1] if style_embedding.shape[0] != expected_dim: if style_embedding.shape[0] == 1024 and expected_dim == 768: style_embedding = style_embedding[:768] else: raise ValueError(f"Cannot handle embedding dimension {style_embedding.shape[0]} -> {expected_dim}") if style_token not in tokenizer.get_vocab(): tokenizer.add_tokens([style_token]) text_encoder.resize_token_embeddings(len(tokenizer)) token_id = tokenizer.convert_tokens_to_ids(style_token) with torch.no_grad(): text_encoder.get_input_embeddings().weight[token_id] = style_embedding final_prompt = prompt.replace("