#!/usr/bin/env python3 """ Utility functions for the application Author: Shilpaj Bhalerao Date: Feb 26, 2025 """ import torch import gc import os from PIL import Image, ImageDraw, ImageFont from diffusers import StableDiffusionPipeline from transformers import CLIPTokenizer, CLIPTextModel # Disable HF transfer to avoid download issues os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" def load_models(device="cuda"): """ Load the necessary models for stable diffusion Args: device (str): Device to load models on ('cuda', 'mps', or 'cpu') Returns: tuple: (vae, tokenizer, text_encoder, unet, scheduler, pipe) """ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel # Set device if device == "cuda" and not torch.cuda.is_available(): device = "mps" if torch.backends.mps.is_available() else "cpu" if device == "mps": os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" print(f"Loading models on {device}...") # Load the autoencoder model which will be used to decode the latents into image space vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=False) # Load the tokenizer and text encoder to tokenize and encode the text tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") # The UNet model for generating the latents unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=False) # The noise scheduler scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) # Load the full pipeline for concept loading pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_safetensors=False ) # Move models to device vae = vae.to(device) text_encoder = text_encoder.to(device) unet = unet.to(device) pipe = pipe.to(device) return vae, tokenizer, text_encoder, unet, scheduler, pipe def clear_gpu_memory(): """Clear GPU memory cache""" torch.cuda.empty_cache() gc.collect() torch.cuda.empty_cache() def set_timesteps(scheduler, num_inference_steps): """Set timesteps for the scheduler with MPS compatibility fix""" scheduler.set_timesteps(num_inference_steps) scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility def pil_to_latent(input_im, vae, device): """ Convert the image to latents Args: input_im: Input PIL image vae: VAE model device: Device to run on Returns: Latents from VAE's encoder """ from torchvision import transforms as tfms # Single image -> single latent in a batch (so size 1, 4, 64, 64) with torch.no_grad(): latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling return 0.18215 * latent.latent_dist.sample() def latents_to_pil(latents, vae): """ Convert the latents to images Args: latents: Latent tensor vae: VAE model Returns: list: PIL images """ # batch of latents -> list of images latents = (1 / 0.18215) * latents with torch.no_grad(): image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images def image_grid(imgs, rows, cols, labels=None): """ Create a grid of images with optional labels. Args: imgs (list): List of PIL images to be arranged in a grid rows (int): Number of rows in the grid cols (int): Number of columns in the grid labels (list, optional): List of label strings for each image Returns: PIL.Image: A single image with all input images arranged in a grid and labeled """ assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})" w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h + 30 if labels else rows*h)) # Add padding at the bottom for labels if they exist label_height = 30 if labels else 0 # Paste images for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) # Add labels if provided if labels: assert len(labels) == len(imgs), "Number of labels must match number of images" draw = ImageDraw.Draw(grid) # Try to use a standard font, fall back to default if not available try: font = ImageFont.truetype("arial.ttf", 14) except IOError: font = ImageFont.load_default() for i, label in enumerate(labels): # Position text under the image x = (i % cols) * w + 10 y = (i // cols + 1) * h - 5 # Draw black text with white outline for visibility # White outline (draw text in each direction) for offset in [(1,1), (-1,-1), (1,-1), (-1,1)]: draw.text((x+offset[0], y+offset[1]), label, fill=(255,255,255), font=font) # Main text (black) draw.text((x, y), label, fill=(0,0,0), font=font) return grid def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]): """ Creates a strong vignette effect (dark corners) and color shift. Args: images: Batch of images from VAE decoder (range 0-1) vignette_strength: How strong the darkening effect is (higher = more dramatic) color_shift: RGB color to shift the center toward [r, g, b] Returns: torch.Tensor: Loss value """ batch_size, channels, height, width = images.shape # Create coordinate grid centered at 0 with range [-1, 1] y = torch.linspace(-1, 1, height).view(-1, 1).repeat(1, width).to(images.device) x = torch.linspace(-1, 1, width).view(1, -1).repeat(height, 1).to(images.device) # Calculate radius from center (normalized [0,1]) radius = torch.sqrt(x.pow(2) + y.pow(2)) / 1.414 # Vignette mask: dark at edges, bright in center vignette = torch.exp(-vignette_strength * radius) # Color shift target: shift center toward specified color color_tensor = torch.tensor(color_shift, dtype=torch.float32).view(1, 3, 1, 1).to(images.device) center_mask = 1.0 - radius.unsqueeze(0).unsqueeze(0) center_mask = torch.pow(center_mask, 2.0) # Make the transition more dramatic # Target image with vignette and color shift target = images.clone() # Apply vignette (multiply all channels by vignette mask) for c in range(channels): target[:, c] = target[:, c] * vignette # Apply color shift in center for c in range(channels): # Shift toward target color more in center, less at edges color_offset = (color_tensor[:, c] - images[:, c]) * center_mask target[:, c] = target[:, c] + color_offset.squeeze(1) # Calculate loss - how different current image is from our target return torch.pow(images - target, 2).mean() def get_concept_embedding(concept_text, tokenizer, text_encoder, device): """ Generate CLIP embedding for a concept described in text Args: concept_text (str): Text description of the concept (e.g., "sketch painting") tokenizer: CLIP tokenizer text_encoder: CLIP text encoder device: Device to run on Returns: torch.Tensor: CLIP embedding for the concept """ # Tokenize the concept text concept_tokens = tokenizer( concept_text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" ).input_ids.to(device) # Generate the embedding using the text encoder with torch.no_grad(): concept_embedding = text_encoder(concept_tokens)[0] return concept_embedding def load_concept_library(pipe): """ Load textual inversion concepts from the SD concept library Args: pipe: StableDiffusionPipeline Returns: dict: Dictionary of token to embedding mappings """ # Load textual inversion embeddings pipe.load_textual_inversion("sd-concepts-library/dreams") pipe.load_textual_inversion("sd-concepts-library/midjourney-style") pipe.load_textual_inversion("sd-concepts-library/moebius") pipe.load_textual_inversion("sd-concepts-library/style-of-marc-allante") pipe.load_textual_inversion("sd-concepts-library/wlop-style") # Extract the embeddings from the pipeline tokens = ['', '', '', '', ''] token_ids = pipe.tokenizer.convert_tokens_to_ids(tokens) embeddings = pipe.text_encoder.get_input_embeddings().weight[token_ids].detach().cpu() # Create a dictionary with the embeddings learned_embeds = {} for i, token in enumerate(tokens): learned_embeds[token] = embeddings[i] # Save the embeddings for future use torch.save(learned_embeds, "learned_embeds.bin") print(f"Saved embeddings for tokens: {', '.join(tokens)}") return learned_embeds, tokens