Spaces:
Runtime error
Runtime error
File size: 9,607 Bytes
74fa5e8 483825e 74fa5e8 483825e 74fa5e8 483825e 74fa5e8 483825e 74fa5e8 483825e 74fa5e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
#!/usr/bin/env python3
"""
Utility functions for the application
Author: Shilpaj Bhalerao
Date: Feb 26, 2025
"""
import torch
import gc
import os
from PIL import Image, ImageDraw, ImageFont
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer, CLIPTextModel
# Disable HF transfer to avoid download issues
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
def load_models(device="cuda"):
"""
Load the necessary models for stable diffusion
Args:
device (str): Device to load models on ('cuda', 'mps', or 'cpu')
Returns:
tuple: (vae, tokenizer, text_encoder, unet, scheduler, pipe)
"""
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
# Set device
if device == "cuda" and not torch.cuda.is_available():
device = "mps" if torch.backends.mps.is_available() else "cpu"
if device == "mps":
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
print(f"Loading models on {device}...")
# Load the autoencoder model which will be used to decode the latents into image space
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=False)
# Load the tokenizer and text encoder to tokenize and encode the text
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
# The UNet model for generating the latents
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=False)
# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
# Load the full pipeline for concept loading
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=False
)
# Move models to device
vae = vae.to(device)
text_encoder = text_encoder.to(device)
unet = unet.to(device)
pipe = pipe.to(device)
return vae, tokenizer, text_encoder, unet, scheduler, pipe
def clear_gpu_memory():
"""Clear GPU memory cache"""
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
def set_timesteps(scheduler, num_inference_steps):
"""Set timesteps for the scheduler with MPS compatibility fix"""
scheduler.set_timesteps(num_inference_steps)
scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility
def pil_to_latent(input_im, vae, device):
"""
Convert the image to latents
Args:
input_im: Input PIL image
vae: VAE model
device: Device to run on
Returns:
Latents from VAE's encoder
"""
from torchvision import transforms as tfms
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
with torch.no_grad():
latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling
return 0.18215 * latent.latent_dist.sample()
def latents_to_pil(latents, vae):
"""
Convert the latents to images
Args:
latents: Latent tensor
vae: VAE model
Returns:
list: PIL images
"""
# batch of latents -> list of images
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def image_grid(imgs, rows, cols, labels=None):
"""
Create a grid of images with optional labels.
Args:
imgs (list): List of PIL images to be arranged in a grid
rows (int): Number of rows in the grid
cols (int): Number of columns in the grid
labels (list, optional): List of label strings for each image
Returns:
PIL.Image: A single image with all input images arranged in a grid and labeled
"""
assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})"
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h + 30 if labels else rows*h))
# Add padding at the bottom for labels if they exist
label_height = 30 if labels else 0
# Paste images
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
# Add labels if provided
if labels:
assert len(labels) == len(imgs), "Number of labels must match number of images"
draw = ImageDraw.Draw(grid)
# Try to use a standard font, fall back to default if not available
try:
font = ImageFont.truetype("arial.ttf", 14)
except IOError:
font = ImageFont.load_default()
for i, label in enumerate(labels):
# Position text under the image
x = (i % cols) * w + 10
y = (i // cols + 1) * h - 5
# Draw black text with white outline for visibility
# White outline (draw text in each direction)
for offset in [(1,1), (-1,-1), (1,-1), (-1,1)]:
draw.text((x+offset[0], y+offset[1]), label, fill=(255,255,255), font=font)
# Main text (black)
draw.text((x, y), label, fill=(0,0,0), font=font)
return grid
def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]):
"""
Creates a strong vignette effect (dark corners) and color shift.
Args:
images: Batch of images from VAE decoder (range 0-1)
vignette_strength: How strong the darkening effect is (higher = more dramatic)
color_shift: RGB color to shift the center toward [r, g, b]
Returns:
torch.Tensor: Loss value
"""
batch_size, channels, height, width = images.shape
# Create coordinate grid centered at 0 with range [-1, 1]
y = torch.linspace(-1, 1, height).view(-1, 1).repeat(1, width).to(images.device)
x = torch.linspace(-1, 1, width).view(1, -1).repeat(height, 1).to(images.device)
# Calculate radius from center (normalized [0,1])
radius = torch.sqrt(x.pow(2) + y.pow(2)) / 1.414
# Vignette mask: dark at edges, bright in center
vignette = torch.exp(-vignette_strength * radius)
# Color shift target: shift center toward specified color
color_tensor = torch.tensor(color_shift, dtype=torch.float32).view(1, 3, 1, 1).to(images.device)
center_mask = 1.0 - radius.unsqueeze(0).unsqueeze(0)
center_mask = torch.pow(center_mask, 2.0) # Make the transition more dramatic
# Target image with vignette and color shift
target = images.clone()
# Apply vignette (multiply all channels by vignette mask)
for c in range(channels):
target[:, c] = target[:, c] * vignette
# Apply color shift in center
for c in range(channels):
# Shift toward target color more in center, less at edges
color_offset = (color_tensor[:, c] - images[:, c]) * center_mask
target[:, c] = target[:, c] + color_offset.squeeze(1)
# Calculate loss - how different current image is from our target
return torch.pow(images - target, 2).mean()
def get_concept_embedding(concept_text, tokenizer, text_encoder, device):
"""
Generate CLIP embedding for a concept described in text
Args:
concept_text (str): Text description of the concept (e.g., "sketch painting")
tokenizer: CLIP tokenizer
text_encoder: CLIP text encoder
device: Device to run on
Returns:
torch.Tensor: CLIP embedding for the concept
"""
# Tokenize the concept text
concept_tokens = tokenizer(
concept_text,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).input_ids.to(device)
# Generate the embedding using the text encoder
with torch.no_grad():
concept_embedding = text_encoder(concept_tokens)[0]
return concept_embedding
def load_concept_library(pipe):
"""
Load textual inversion concepts from the SD concept library
Args:
pipe: StableDiffusionPipeline
Returns:
dict: Dictionary of token to embedding mappings
"""
# Load textual inversion embeddings
pipe.load_textual_inversion("sd-concepts-library/dreams")
pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
pipe.load_textual_inversion("sd-concepts-library/moebius")
pipe.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
pipe.load_textual_inversion("sd-concepts-library/wlop-style")
# Extract the embeddings from the pipeline
tokens = ['<meeg>', '<midjourney-style>', '<moebius>', '<Marc_Allante>', '<wlop-style>']
token_ids = pipe.tokenizer.convert_tokens_to_ids(tokens)
embeddings = pipe.text_encoder.get_input_embeddings().weight[token_ids].detach().cpu()
# Create a dictionary with the embeddings
learned_embeds = {}
for i, token in enumerate(tokens):
learned_embeds[token] = embeddings[i]
# Save the embeddings for future use
torch.save(learned_embeds, "learned_embeds.bin")
print(f"Saved embeddings for tokens: {', '.join(tokens)}")
return learned_embeds, tokens |