Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +23 -20
- requirements.txt +6 -6
app.py
CHANGED
|
@@ -31,13 +31,16 @@ def load_model():
|
|
| 31 |
"runwayml/stable-diffusion-v1-5",
|
| 32 |
torch_dtype=torch.float16,
|
| 33 |
safety_checker=None
|
| 34 |
-
)
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
# Load concept library
|
| 40 |
-
concept_embeds, concept_tokens = load_concept_library(
|
| 41 |
|
| 42 |
# Define art style concepts
|
| 43 |
art_concepts = {
|
|
@@ -84,13 +87,13 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
|
|
| 84 |
elif concept_style in art_concepts:
|
| 85 |
# Generate concept embedding from text description
|
| 86 |
concept_text = art_concepts[concept_style]
|
| 87 |
-
concept_embedding = get_concept_embedding(concept_text,
|
| 88 |
|
| 89 |
# Prep text
|
| 90 |
-
text_input =
|
| 91 |
truncation=True, return_tensors="pt")
|
| 92 |
with torch.inference_mode():
|
| 93 |
-
text_embeddings =
|
| 94 |
|
| 95 |
# Apply concept embedding influence if provided
|
| 96 |
if concept_embedding is not None and concept_strength > 0:
|
|
@@ -104,34 +107,34 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
|
|
| 104 |
|
| 105 |
# Unconditional embedding for classifier-free guidance
|
| 106 |
max_length = text_input.input_ids.shape[-1]
|
| 107 |
-
uncond_input =
|
| 108 |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
| 109 |
)
|
| 110 |
with torch.inference_mode():
|
| 111 |
-
uncond_embeddings =
|
| 112 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 113 |
|
| 114 |
# Prep Scheduler
|
| 115 |
-
set_timesteps(
|
| 116 |
|
| 117 |
# Prep latents
|
| 118 |
latents = torch.randn(
|
| 119 |
-
(batch_size,
|
| 120 |
generator=generator,
|
| 121 |
)
|
| 122 |
latents = latents.to(device)
|
| 123 |
-
latents = latents *
|
| 124 |
|
| 125 |
# Loop through diffusion process
|
| 126 |
-
for i, t in tqdm(enumerate(
|
| 127 |
# Expand latents for classifier-free guidance
|
| 128 |
latent_model_input = torch.cat([latents] * 2)
|
| 129 |
-
sigma =
|
| 130 |
-
latent_model_input =
|
| 131 |
|
| 132 |
# Predict the noise residual
|
| 133 |
with torch.inference_mode():
|
| 134 |
-
noise_pred =
|
| 135 |
|
| 136 |
# Perform classifier-free guidance
|
| 137 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
@@ -146,7 +149,7 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
|
|
| 146 |
latents_x0 = latents - sigma * noise_pred
|
| 147 |
|
| 148 |
# Decode to image space
|
| 149 |
-
denoised_images =
|
| 150 |
|
| 151 |
# Calculate loss
|
| 152 |
loss = vignette_loss(denoised_images) * vignette_loss_scale
|
|
@@ -158,7 +161,7 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
|
|
| 158 |
latents = latents.detach() - cond_grad * sigma**2
|
| 159 |
|
| 160 |
# Step with scheduler
|
| 161 |
-
latents =
|
| 162 |
|
| 163 |
return latents
|
| 164 |
|
|
@@ -201,7 +204,7 @@ def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
|
|
| 201 |
)
|
| 202 |
|
| 203 |
# Convert latents to image
|
| 204 |
-
images = latents_to_pil(latents,
|
| 205 |
|
| 206 |
return images[0]
|
| 207 |
|
|
@@ -241,7 +244,7 @@ def generate_style_grid(prompt, seed=42, num_inference_steps=30, guidance_scale=
|
|
| 241 |
)
|
| 242 |
|
| 243 |
# Convert latents to image
|
| 244 |
-
style_images = latents_to_pil(latents,
|
| 245 |
images.append(style_images[0])
|
| 246 |
labels.append(style)
|
| 247 |
|
|
|
|
| 31 |
"runwayml/stable-diffusion-v1-5",
|
| 32 |
torch_dtype=torch.float16,
|
| 33 |
safety_checker=None
|
| 34 |
+
)
|
| 35 |
|
| 36 |
+
@spaces.GPU
|
| 37 |
+
@gr.Cache()
|
| 38 |
+
def get_pipeline():
|
| 39 |
+
pipe = load_model()
|
| 40 |
+
return pipe.to("cuda")
|
| 41 |
|
| 42 |
# Load concept library
|
| 43 |
+
concept_embeds, concept_tokens = load_concept_library(get_pipeline())
|
| 44 |
|
| 45 |
# Define art style concepts
|
| 46 |
art_concepts = {
|
|
|
|
| 87 |
elif concept_style in art_concepts:
|
| 88 |
# Generate concept embedding from text description
|
| 89 |
concept_text = art_concepts[concept_style]
|
| 90 |
+
concept_embedding = get_concept_embedding(concept_text, get_pipeline().tokenizer, get_pipeline().text_encoder, device)
|
| 91 |
|
| 92 |
# Prep text
|
| 93 |
+
text_input = get_pipeline().tokenizer([prompt], padding="max_length", max_length=get_pipeline().tokenizer.model_max_length,
|
| 94 |
truncation=True, return_tensors="pt")
|
| 95 |
with torch.inference_mode():
|
| 96 |
+
text_embeddings = get_pipeline().text_encoder(text_input.input_ids.to(device))[0]
|
| 97 |
|
| 98 |
# Apply concept embedding influence if provided
|
| 99 |
if concept_embedding is not None and concept_strength > 0:
|
|
|
|
| 107 |
|
| 108 |
# Unconditional embedding for classifier-free guidance
|
| 109 |
max_length = text_input.input_ids.shape[-1]
|
| 110 |
+
uncond_input = get_pipeline().tokenizer(
|
| 111 |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
|
| 112 |
)
|
| 113 |
with torch.inference_mode():
|
| 114 |
+
uncond_embeddings = get_pipeline().text_encoder(uncond_input.input_ids.to(device))[0]
|
| 115 |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 116 |
|
| 117 |
# Prep Scheduler
|
| 118 |
+
set_timesteps(get_pipeline().scheduler, num_inference_steps)
|
| 119 |
|
| 120 |
# Prep latents
|
| 121 |
latents = torch.randn(
|
| 122 |
+
(batch_size, get_pipeline().unet.in_channels, height // 8, width // 8),
|
| 123 |
generator=generator,
|
| 124 |
)
|
| 125 |
latents = latents.to(device)
|
| 126 |
+
latents = latents * get_pipeline().scheduler.init_noise_sigma
|
| 127 |
|
| 128 |
# Loop through diffusion process
|
| 129 |
+
for i, t in tqdm(enumerate(get_pipeline().scheduler.timesteps), total=len(get_pipeline().scheduler.timesteps)):
|
| 130 |
# Expand latents for classifier-free guidance
|
| 131 |
latent_model_input = torch.cat([latents] * 2)
|
| 132 |
+
sigma = get_pipeline().scheduler.sigmas[i]
|
| 133 |
+
latent_model_input = get_pipeline().scheduler.scale_model_input(latent_model_input, t)
|
| 134 |
|
| 135 |
# Predict the noise residual
|
| 136 |
with torch.inference_mode():
|
| 137 |
+
noise_pred = get_pipeline().unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
| 138 |
|
| 139 |
# Perform classifier-free guidance
|
| 140 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
|
|
| 149 |
latents_x0 = latents - sigma * noise_pred
|
| 150 |
|
| 151 |
# Decode to image space
|
| 152 |
+
denoised_images = get_pipeline().vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
|
| 153 |
|
| 154 |
# Calculate loss
|
| 155 |
loss = vignette_loss(denoised_images) * vignette_loss_scale
|
|
|
|
| 161 |
latents = latents.detach() - cond_grad * sigma**2
|
| 162 |
|
| 163 |
# Step with scheduler
|
| 164 |
+
latents = get_pipeline().scheduler.step(noise_pred, t, latents).prev_sample
|
| 165 |
|
| 166 |
return latents
|
| 167 |
|
|
|
|
| 204 |
)
|
| 205 |
|
| 206 |
# Convert latents to image
|
| 207 |
+
images = latents_to_pil(latents, get_pipeline().vae)
|
| 208 |
|
| 209 |
return images[0]
|
| 210 |
|
|
|
|
| 244 |
)
|
| 245 |
|
| 246 |
# Convert latents to image
|
| 247 |
+
style_images = latents_to_pil(latents, get_pipeline().vae)
|
| 248 |
images.append(style_images[0])
|
| 249 |
labels.append(style)
|
| 250 |
|
requirements.txt
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
# Core dependencies with pinned versions for compatibility
|
| 2 |
-
torch
|
| 3 |
-
torchvision
|
| 4 |
-
diffusers
|
| 5 |
transformers>=4.38.0
|
| 6 |
-
accelerate
|
| 7 |
ftfy>=6.1.1
|
| 8 |
-
gradio
|
| 9 |
numpy>=1.22.0
|
| 10 |
Pillow>=10.0.0
|
| 11 |
tqdm>=4.64.0
|
|
@@ -13,7 +13,7 @@ huggingface-hub>=0.22.2
|
|
| 13 |
|
| 14 |
# HF Spaces specific
|
| 15 |
gradio-client>=0.15.0
|
| 16 |
-
spaces
|
| 17 |
|
| 18 |
# Optional dependencies for better performance
|
| 19 |
scipy>=1.9.0
|
|
|
|
| 1 |
# Core dependencies with pinned versions for compatibility
|
| 2 |
+
torch==2.2.1
|
| 3 |
+
torchvision==0.17.1
|
| 4 |
+
diffusers==0.28.0
|
| 5 |
transformers>=4.38.0
|
| 6 |
+
accelerate==0.28.0
|
| 7 |
ftfy>=6.1.1
|
| 8 |
+
gradio==4.25.0
|
| 9 |
numpy>=1.22.0
|
| 10 |
Pillow>=10.0.0
|
| 11 |
tqdm>=4.64.0
|
|
|
|
| 13 |
|
| 14 |
# HF Spaces specific
|
| 15 |
gradio-client>=0.15.0
|
| 16 |
+
spaces==0.32.0
|
| 17 |
|
| 18 |
# Optional dependencies for better performance
|
| 19 |
scipy>=1.9.0
|