Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Utility functions for the application | |
| Author: Shilpaj Bhalerao | |
| Date: Feb 26, 2025 | |
| """ | |
| import torch | |
| import gc | |
| import os | |
| from PIL import Image, ImageDraw, ImageFont | |
| from diffusers import StableDiffusionPipeline | |
| from transformers import CLIPTokenizer, CLIPTextModel | |
| # Disable HF transfer to avoid download issues | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" | |
| def load_models(device="cuda"): | |
| """ | |
| Load the necessary models for stable diffusion | |
| Args: | |
| device (str): Device to load models on ('cuda', 'mps', or 'cpu') | |
| Returns: | |
| tuple: (vae, tokenizer, text_encoder, unet, scheduler, pipe) | |
| """ | |
| from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel | |
| # Set device | |
| if device == "cuda" and not torch.cuda.is_available(): | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| if device == "mps": | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" | |
| print(f"Loading models on {device}...") | |
| # Load the autoencoder model which will be used to decode the latents into image space | |
| vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=False) | |
| # Load the tokenizer and text encoder to tokenize and encode the text | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
| text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | |
| # The UNet model for generating the latents | |
| unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=False) | |
| # The noise scheduler | |
| scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) | |
| # Load the full pipeline for concept loading | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| use_safetensors=False | |
| ) | |
| # Move models to device | |
| vae = vae.to(device) | |
| text_encoder = text_encoder.to(device) | |
| unet = unet.to(device) | |
| pipe = pipe.to(device) | |
| return vae, tokenizer, text_encoder, unet, scheduler, pipe | |
| def clear_gpu_memory(): | |
| """Clear GPU memory cache""" | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def set_timesteps(scheduler, num_inference_steps): | |
| """Set timesteps for the scheduler with MPS compatibility fix""" | |
| scheduler.set_timesteps(num_inference_steps) | |
| scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility | |
| def pil_to_latent(input_im, vae, device): | |
| """ | |
| Convert the image to latents | |
| Args: | |
| input_im: Input PIL image | |
| vae: VAE model | |
| device: Device to run on | |
| Returns: | |
| Latents from VAE's encoder | |
| """ | |
| from torchvision import transforms as tfms | |
| # Single image -> single latent in a batch (so size 1, 4, 64, 64) | |
| with torch.no_grad(): | |
| latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling | |
| return 0.18215 * latent.latent_dist.sample() | |
| def latents_to_pil(latents, vae): | |
| """ | |
| Convert the latents to images | |
| Args: | |
| latents: Latent tensor | |
| vae: VAE model | |
| Returns: | |
| list: PIL images | |
| """ | |
| # batch of latents -> list of images | |
| latents = (1 / 0.18215) * latents | |
| with torch.no_grad(): | |
| image = vae.decode(latents).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
| images = (image * 255).round().astype("uint8") | |
| pil_images = [Image.fromarray(image) for image in images] | |
| return pil_images | |
| def image_grid(imgs, rows, cols, labels=None): | |
| """ | |
| Create a grid of images with optional labels. | |
| Args: | |
| imgs (list): List of PIL images to be arranged in a grid | |
| rows (int): Number of rows in the grid | |
| cols (int): Number of columns in the grid | |
| labels (list, optional): List of label strings for each image | |
| Returns: | |
| PIL.Image: A single image with all input images arranged in a grid and labeled | |
| """ | |
| assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})" | |
| w, h = imgs[0].size | |
| grid = Image.new('RGB', size=(cols*w, rows*h + 30 if labels else rows*h)) | |
| # Add padding at the bottom for labels if they exist | |
| label_height = 30 if labels else 0 | |
| # Paste images | |
| for i, img in enumerate(imgs): | |
| grid.paste(img, box=(i%cols*w, i//cols*h)) | |
| # Add labels if provided | |
| if labels: | |
| assert len(labels) == len(imgs), "Number of labels must match number of images" | |
| draw = ImageDraw.Draw(grid) | |
| # Try to use a standard font, fall back to default if not available | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 14) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| for i, label in enumerate(labels): | |
| # Position text under the image | |
| x = (i % cols) * w + 10 | |
| y = (i // cols + 1) * h - 5 | |
| # Draw black text with white outline for visibility | |
| # White outline (draw text in each direction) | |
| for offset in [(1,1), (-1,-1), (1,-1), (-1,1)]: | |
| draw.text((x+offset[0], y+offset[1]), label, fill=(255,255,255), font=font) | |
| # Main text (black) | |
| draw.text((x, y), label, fill=(0,0,0), font=font) | |
| return grid | |
| def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]): | |
| """ | |
| Creates a strong vignette effect (dark corners) and color shift. | |
| Args: | |
| images: Batch of images from VAE decoder (range 0-1) | |
| vignette_strength: How strong the darkening effect is (higher = more dramatic) | |
| color_shift: RGB color to shift the center toward [r, g, b] | |
| Returns: | |
| torch.Tensor: Loss value | |
| """ | |
| batch_size, channels, height, width = images.shape | |
| # Create coordinate grid centered at 0 with range [-1, 1] | |
| y = torch.linspace(-1, 1, height).view(-1, 1).repeat(1, width).to(images.device) | |
| x = torch.linspace(-1, 1, width).view(1, -1).repeat(height, 1).to(images.device) | |
| # Calculate radius from center (normalized [0,1]) | |
| radius = torch.sqrt(x.pow(2) + y.pow(2)) / 1.414 | |
| # Vignette mask: dark at edges, bright in center | |
| vignette = torch.exp(-vignette_strength * radius) | |
| # Color shift target: shift center toward specified color | |
| color_tensor = torch.tensor(color_shift, dtype=torch.float32).view(1, 3, 1, 1).to(images.device) | |
| center_mask = 1.0 - radius.unsqueeze(0).unsqueeze(0) | |
| center_mask = torch.pow(center_mask, 2.0) # Make the transition more dramatic | |
| # Target image with vignette and color shift | |
| target = images.clone() | |
| # Apply vignette (multiply all channels by vignette mask) | |
| for c in range(channels): | |
| target[:, c] = target[:, c] * vignette | |
| # Apply color shift in center | |
| for c in range(channels): | |
| # Shift toward target color more in center, less at edges | |
| color_offset = (color_tensor[:, c] - images[:, c]) * center_mask | |
| target[:, c] = target[:, c] + color_offset.squeeze(1) | |
| # Calculate loss - how different current image is from our target | |
| return torch.pow(images - target, 2).mean() | |
| def get_concept_embedding(concept_text, tokenizer, text_encoder, device): | |
| """ | |
| Generate CLIP embedding for a concept described in text | |
| Args: | |
| concept_text (str): Text description of the concept (e.g., "sketch painting") | |
| tokenizer: CLIP tokenizer | |
| text_encoder: CLIP text encoder | |
| device: Device to run on | |
| Returns: | |
| torch.Tensor: CLIP embedding for the concept | |
| """ | |
| # Tokenize the concept text | |
| concept_tokens = tokenizer( | |
| concept_text, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).input_ids.to(device) | |
| # Generate the embedding using the text encoder | |
| with torch.no_grad(): | |
| concept_embedding = text_encoder(concept_tokens)[0] | |
| return concept_embedding | |
| def load_concept_library(pipe): | |
| """ | |
| Load textual inversion concepts from the SD concept library | |
| Args: | |
| pipe: StableDiffusionPipeline | |
| Returns: | |
| dict: Dictionary of token to embedding mappings | |
| """ | |
| # Load textual inversion embeddings | |
| pipe.load_textual_inversion("sd-concepts-library/dreams") | |
| pipe.load_textual_inversion("sd-concepts-library/midjourney-style") | |
| pipe.load_textual_inversion("sd-concepts-library/moebius") | |
| pipe.load_textual_inversion("sd-concepts-library/style-of-marc-allante") | |
| pipe.load_textual_inversion("sd-concepts-library/wlop-style") | |
| # Extract the embeddings from the pipeline | |
| tokens = ['<meeg>', '<midjourney-style>', '<moebius>', '<Marc_Allante>', '<wlop-style>'] | |
| token_ids = pipe.tokenizer.convert_tokens_to_ids(tokens) | |
| embeddings = pipe.text_encoder.get_input_embeddings().weight[token_ids].detach().cpu() | |
| # Create a dictionary with the embeddings | |
| learned_embeds = {} | |
| for i, token in enumerate(tokens): | |
| learned_embeds[token] = embeddings[i] | |
| # Save the embeddings for future use | |
| torch.save(learned_embeds, "learned_embeds.bin") | |
| print(f"Saved embeddings for tokens: {', '.join(tokens)}") | |
| return learned_embeds, tokens |