Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
·
209d166
1
Parent(s):
0002379
provide empty negative prompt when training
Browse files- StableDiffuser.py +3 -1
- train.py +0 -2
StableDiffuser.py
CHANGED
|
@@ -114,9 +114,11 @@ class StableDiffuser(torch.nn.Module):
|
|
| 114 |
latents = noise * self.scheduler.init_noise_sigma
|
| 115 |
return latents
|
| 116 |
|
| 117 |
-
def get_text_embeddings(self, prompts, negative_prompts, n_imgs):
|
| 118 |
text_tokens = self.text_tokenize(prompts)
|
| 119 |
text_embeddings = self.text_encode(text_tokens)
|
|
|
|
|
|
|
| 120 |
unconditional_tokens = self.text_tokenize(negative_prompts)
|
| 121 |
unconditional_embeddings = self.text_encode(unconditional_tokens)
|
| 122 |
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
|
|
|
|
| 114 |
latents = noise * self.scheduler.init_noise_sigma
|
| 115 |
return latents
|
| 116 |
|
| 117 |
+
def get_text_embeddings(self, prompts, negative_prompts=None, n_imgs=1):
|
| 118 |
text_tokens = self.text_tokenize(prompts)
|
| 119 |
text_embeddings = self.text_encode(text_tokens)
|
| 120 |
+
if negative_prompts is None:
|
| 121 |
+
negative_prompts = [""] * len(prompts)
|
| 122 |
unconditional_tokens = self.text_tokenize(negative_prompts)
|
| 123 |
unconditional_embeddings = self.text_encode(unconditional_tokens)
|
| 124 |
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
|
train.py
CHANGED
|
@@ -36,11 +36,9 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
| 36 |
optimizer.zero_grad()
|
| 37 |
|
| 38 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
| 39 |
-
|
| 40 |
latents = diffuser.get_initial_latents(1, img_size, 1)
|
| 41 |
|
| 42 |
with finetuner:
|
| 43 |
-
|
| 44 |
latents_steps, _ = diffuser.diffusion(
|
| 45 |
latents,
|
| 46 |
positive_text_embeddings,
|
|
|
|
| 36 |
optimizer.zero_grad()
|
| 37 |
|
| 38 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
|
|
|
| 39 |
latents = diffuser.get_initial_latents(1, img_size, 1)
|
| 40 |
|
| 41 |
with finetuner:
|
|
|
|
| 42 |
latents_steps, _ = diffuser.diffusion(
|
| 43 |
latents,
|
| 44 |
positive_text_embeddings,
|