Spaces:
Runtime error
Runtime error
Download model locally
Browse files
app.py
CHANGED
|
@@ -8,13 +8,12 @@ from PIL import Image, ImageDraw
|
|
| 8 |
from diffusers import DDIMScheduler
|
| 9 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 10 |
from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
|
| 11 |
-
from injection_utils import register_attention_editor_diffusers
|
| 12 |
from bounded_attention import BoundedAttention
|
| 13 |
from pytorch_lightning import seed_everything
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 18 |
RESOLUTION = 256
|
| 19 |
MIN_SIZE = 0.01
|
| 20 |
WHITE = 255
|
|
@@ -114,7 +113,6 @@ FOOTNOTE = """
|
|
| 114 |
|
| 115 |
|
| 116 |
def inference(
|
| 117 |
-
model,
|
| 118 |
boxes,
|
| 119 |
prompts,
|
| 120 |
subject_token_indices,
|
|
@@ -135,7 +133,10 @@ def inference(
|
|
| 135 |
raise gr.Error("cuda is not available")
|
| 136 |
|
| 137 |
device = torch.device("cuda")
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
seed_everything(seed)
|
| 141 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
|
@@ -160,10 +161,7 @@ def inference(
|
|
| 160 |
)
|
| 161 |
|
| 162 |
register_attention_editor_diffusers(model, editor)
|
| 163 |
-
|
| 164 |
-
unregister_attention_editor_diffusers(model)
|
| 165 |
-
model.double().to(torch.device("cpu"))
|
| 166 |
-
return images
|
| 167 |
|
| 168 |
|
| 169 |
@spaces.GPU(duration=300)
|
|
@@ -198,7 +196,7 @@ def generate(
|
|
| 198 |
prompts = [prompt.strip(".").strip(",").strip()] * batch_size
|
| 199 |
|
| 200 |
images = inference(
|
| 201 |
-
|
| 202 |
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
|
| 203 |
num_iterations, loss_threshold, num_guidance_steps, seed)
|
| 204 |
|
|
@@ -255,10 +253,9 @@ def clear(batch_size):
|
|
| 255 |
def main():
|
| 256 |
nltk.download("averaged_perceptron_tagger")
|
| 257 |
|
| 258 |
-
|
| 259 |
-
model
|
| 260 |
-
model
|
| 261 |
-
model.enable_sequential_cpu_offload()
|
| 262 |
|
| 263 |
with gr.Blocks(
|
| 264 |
css=CSS,
|
|
@@ -330,7 +327,7 @@ def main():
|
|
| 330 |
)
|
| 331 |
|
| 332 |
generate_image_button.click(
|
| 333 |
-
fn=
|
| 334 |
inputs=[
|
| 335 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
| 336 |
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
|
|
|
| 8 |
from diffusers import DDIMScheduler
|
| 9 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 10 |
from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
|
| 11 |
+
from injection_utils import register_attention_editor_diffusers
|
| 12 |
from bounded_attention import BoundedAttention
|
| 13 |
from pytorch_lightning import seed_everything
|
| 14 |
|
| 15 |
+
REMOTE_MODEL_PATH = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 16 |
+
LOCAL_MODEL_PATH = "./model"
|
|
|
|
| 17 |
RESOLUTION = 256
|
| 18 |
MIN_SIZE = 0.01
|
| 19 |
WHITE = 255
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
def inference(
|
|
|
|
| 116 |
boxes,
|
| 117 |
prompts,
|
| 118 |
subject_token_indices,
|
|
|
|
| 133 |
raise gr.Error("cuda is not available")
|
| 134 |
|
| 135 |
device = torch.device("cuda")
|
| 136 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
| 137 |
+
model = StableDiffusionXLPipeline.from_pretrained(LOCAL_MODEL_PATH, scheduler=scheduler, torch_dtype=torch.float16).to(device)
|
| 138 |
+
model.unet.set_attn_processor(AttnProcessor2_0())
|
| 139 |
+
model.enable_sequential_cpu_offload()
|
| 140 |
|
| 141 |
seed_everything(seed)
|
| 142 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
|
|
|
| 161 |
)
|
| 162 |
|
| 163 |
register_attention_editor_diffusers(model, editor)
|
| 164 |
+
return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
|
| 167 |
@spaces.GPU(duration=300)
|
|
|
|
| 196 |
prompts = [prompt.strip(".").strip(",").strip()] * batch_size
|
| 197 |
|
| 198 |
images = inference(
|
| 199 |
+
boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
|
| 200 |
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
|
| 201 |
num_iterations, loss_threshold, num_guidance_steps, seed)
|
| 202 |
|
|
|
|
| 253 |
def main():
|
| 254 |
nltk.download("averaged_perceptron_tagger")
|
| 255 |
|
| 256 |
+
model = StableDiffusionXLPipeline.from_pretrained(REMOTE_MODEL_PATH, scheduler=scheduler)
|
| 257 |
+
model.save_pretrained(LOCAL_MODEL_PATH)
|
| 258 |
+
del model
|
|
|
|
| 259 |
|
| 260 |
with gr.Blocks(
|
| 261 |
css=CSS,
|
|
|
|
| 327 |
)
|
| 328 |
|
| 329 |
generate_image_button.click(
|
| 330 |
+
fn=generate,
|
| 331 |
inputs=[
|
| 332 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
| 333 |
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|