Spaces:
Runtime error
Runtime error
Allow repeated inference
Browse files
app.py
CHANGED
|
@@ -10,6 +10,7 @@ from pipeline_stable_diffusion_xl_opt import StableDiffusionXLPipeline
|
|
| 10 |
from injection_utils import regiter_attention_editor_diffusers
|
| 11 |
from bounded_attention import BoundedAttention
|
| 12 |
from pytorch_lightning import seed_everything
|
|
|
|
| 13 |
|
| 14 |
from functools import partial
|
| 15 |
|
|
@@ -40,26 +41,45 @@ def inference(
|
|
| 40 |
):
|
| 41 |
seed_everything(seed)
|
| 42 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
max_guidance_iter=num_guidance_steps
|
| 54 |
-
max_guidance_iter_per_step=num_iterations
|
| 55 |
-
start_step_size=init_step_size
|
| 56 |
-
|
| 57 |
-
loss_stopping_value=loss_threshold
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
|
| 64 |
|
| 65 |
|
|
|
|
| 10 |
from injection_utils import regiter_attention_editor_diffusers
|
| 11 |
from bounded_attention import BoundedAttention
|
| 12 |
from pytorch_lightning import seed_everything
|
| 13 |
+
from torch_kmeans import KMeans
|
| 14 |
|
| 15 |
from functools import partial
|
| 16 |
|
|
|
|
| 41 |
):
|
| 42 |
seed_everything(seed)
|
| 43 |
start_code = torch.randn([len(prompts), 4, 128, 128], device=device)
|
| 44 |
+
eos_token_index = num_tokens + 1
|
| 45 |
+
|
| 46 |
+
if hasattr(model, 'editor'):
|
| 47 |
+
editor.boxes = boxes
|
| 48 |
+
editor.prompts = prompts
|
| 49 |
+
editor.subject_token_indices = subject_token_indices
|
| 50 |
+
editor.filter_token_indices = filter_token_indices
|
| 51 |
+
editor.eos_token_index = eos_token_index
|
| 52 |
+
editor.cross_loss_coef = cross_loss_scale
|
| 53 |
+
editor.self_loss_coef = self_loss_scale
|
| 54 |
+
editor.max_guidance_iter = num_guidance_steps
|
| 55 |
+
editor.max_guidance_iter_per_step = num_iterations
|
| 56 |
+
editor.start_step_size = init_step_size
|
| 57 |
+
self.step_size_coef = (final_step_size - init_step_size) / num_guidance_steps
|
| 58 |
+
editor.loss_stopping_value = loss_threshold
|
| 59 |
+
num_clusters = len(boxes) * num_clusters_per_subject
|
| 60 |
+
self.clustering = KMeans(n_clusters=num_clusters, num_init=100)
|
| 61 |
+
|
| 62 |
+
else:
|
| 63 |
+
editor = BoundedAttention(
|
| 64 |
+
boxes,
|
| 65 |
+
prompts,
|
| 66 |
+
subject_token_indices,
|
| 67 |
+
list(range(70, 82)),
|
| 68 |
+
list(range(70, 82)),
|
| 69 |
+
filter_token_indices=filter_token_indices,
|
| 70 |
+
eos_token_index=eos_token_index,
|
| 71 |
+
cross_loss_coef=cross_loss_scale,
|
| 72 |
+
self_loss_coef=self_loss_scale,
|
| 73 |
+
max_guidance_iter=num_guidance_steps,
|
| 74 |
+
max_guidance_iter_per_step=num_iterations,
|
| 75 |
+
start_step_size=init_step_size,
|
| 76 |
+
end_step_size=final_step_size,
|
| 77 |
+
loss_stopping_value=loss_threshold,
|
| 78 |
+
num_clusters_per_box=num_clusters_per_subject,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
regiter_attention_editor_diffusers(model, editor)
|
| 82 |
+
|
| 83 |
return model(prompts, latents=start_code, guidance_scale=classifier_free_guidance_scale).images
|
| 84 |
|
| 85 |
|