2512
Browse files- media/result_grid.jpg +2 -2
- pipeline_sdxs.py +17 -3
- samples/unet_320x640_0.jpg +2 -2
- samples/unet_384x640_0.jpg +2 -2
- samples/unet_448x640_0.jpg +2 -2
- samples/unet_512x640_0.jpg +2 -2
- samples/unet_576x640_0.jpg +2 -2
- samples/unet_640x320_0.jpg +2 -2
- samples/unet_640x384_0.jpg +2 -2
- samples/unet_640x448_0.jpg +2 -2
- samples/unet_640x512_0.jpg +2 -2
- samples/unet_640x576_0.jpg +2 -2
- samples/unet_640x640_0.jpg +2 -2
- src/sdxs_create.ipynb +2 -2
- test.ipynb +2 -2
- train.py +19 -4
- unet/config.json +2 -2
- unet/diffusion_pytorch_model.safetensors +2 -2
- {unet_old → unet_very_old}/config.json +0 -0
- {unet_old → unet_very_old}/diffusion_pytorch_model.safetensors +0 -0
media/result_grid.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
pipeline_sdxs.py
CHANGED
|
@@ -78,8 +78,22 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 78 |
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 79 |
batch_size = hidden.shape[0]
|
| 80 |
pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
# Кодируем позитивные и негативные промпты
|
| 85 |
# ИСПРАВЛЕНИЕ: Теперь возвращаем (None, None, None), чтобы избежать UnboundLocalError
|
|
@@ -190,7 +204,7 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 190 |
t,
|
| 191 |
encoder_hidden_states=text_embeddings,
|
| 192 |
encoder_attention_mask=unet_attention_mask,
|
| 193 |
-
added_cond_kwargs={'text_embeds': unet_pooled_embeddings}
|
| 194 |
).sample
|
| 195 |
|
| 196 |
if guidance_scale > 1:
|
|
|
|
| 78 |
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 79 |
batch_size = hidden.shape[0]
|
| 80 |
pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
|
| 81 |
+
|
| 82 |
+
# --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
|
| 83 |
+
# 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
|
| 84 |
+
pooled_expanded = pooled.unsqueeze(1)
|
| 85 |
+
|
| 86 |
+
# 2. Объединяем последовательность токенов и пулинг-вектор
|
| 87 |
+
# !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
|
| 88 |
+
# Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ.
|
| 89 |
+
new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
|
| 90 |
+
|
| 91 |
+
# 3. Обновляем маску внимания для нового токена
|
| 92 |
+
# Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
|
| 93 |
+
# torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
|
| 94 |
+
new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
|
| 95 |
+
|
| 96 |
+
return new_encoder_hidden_states, new_attention_mask, pooled
|
| 97 |
|
| 98 |
# Кодируем позитивные и негативные промпты
|
| 99 |
# ИСПРАВЛЕНИЕ: Теперь возвращаем (None, None, None), чтобы избежать UnboundLocalError
|
|
|
|
| 204 |
t,
|
| 205 |
encoder_hidden_states=text_embeddings,
|
| 206 |
encoder_attention_mask=unet_attention_mask,
|
| 207 |
+
#added_cond_kwargs={'text_embeds': unet_pooled_embeddings}
|
| 208 |
).sample
|
| 209 |
|
| 210 |
if guidance_scale > 1:
|
samples/unet_320x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_384x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_448x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_512x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_576x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_640x320_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_640x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_640x448_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_640x512_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_640x576_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_640x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
src/sdxs_create.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e204cfa450a5fed8f3651be4a44f5ba8c86108bf4e51c9c61f6bee8d6a4be98f
|
| 3 |
+
size 8034
|
test.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2b00b404e66b2ee215280298d3e670b5c6d7ff7d70075052011f4f7719973f5
|
| 3 |
+
size 4598308
|
train.py
CHANGED
|
@@ -32,7 +32,7 @@ batch_size = 256
|
|
| 32 |
base_learning_rate = 3e-5
|
| 33 |
min_learning_rate = 2.5e-5
|
| 34 |
num_epochs = 10
|
| 35 |
-
sample_interval_share =
|
| 36 |
max_length = 192
|
| 37 |
use_wandb = True
|
| 38 |
use_comet_ml = False
|
|
@@ -170,7 +170,22 @@ def encode_texts(texts, max_length=max_length):
|
|
| 170 |
batch_size = hidden.shape[0]
|
| 171 |
pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
|
| 172 |
|
| 173 |
-
return hidden, attention_mask, pooled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
shift_factor = getattr(vae.config, "shift_factor", 0.0)
|
| 176 |
if shift_factor is None: shift_factor = 0.0
|
|
@@ -482,7 +497,7 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
|
|
| 482 |
t,
|
| 483 |
encoder_hidden_states=text_embeddings_batch,
|
| 484 |
encoder_attention_mask=attention_mask_batch,
|
| 485 |
-
added_cond_kwargs={"text_embeds": pooled_batch} # <--- ПУЛИНГ ЗДЕСЬ
|
| 486 |
)
|
| 487 |
flow = getattr(model_out, "sample", model_out)
|
| 488 |
|
|
@@ -606,7 +621,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
| 606 |
timesteps,
|
| 607 |
encoder_hidden_states=embeddings,
|
| 608 |
encoder_attention_mask=attention_mask,
|
| 609 |
-
added_cond_kwargs={"text_embeds": pooled_embeddings} # <--- Передача пулинга
|
| 610 |
).sample
|
| 611 |
|
| 612 |
target = noise - latents
|
|
|
|
| 32 |
base_learning_rate = 3e-5
|
| 33 |
min_learning_rate = 2.5e-5
|
| 34 |
num_epochs = 10
|
| 35 |
+
sample_interval_share = 40
|
| 36 |
max_length = 192
|
| 37 |
use_wandb = True
|
| 38 |
use_comet_ml = False
|
|
|
|
| 170 |
batch_size = hidden.shape[0]
|
| 171 |
pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
|
| 172 |
|
| 173 |
+
#return hidden, attention_mask, pooled
|
| 174 |
+
# --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
|
| 175 |
+
# 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
|
| 176 |
+
pooled_expanded = pooled.unsqueeze(1)
|
| 177 |
+
|
| 178 |
+
# 2. Объединяем последовательность токенов и пулинг-вектор
|
| 179 |
+
# !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
|
| 180 |
+
# Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ.
|
| 181 |
+
new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
|
| 182 |
+
|
| 183 |
+
# 3. Обновляем маску внимания для нового токена
|
| 184 |
+
# Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
|
| 185 |
+
# torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
|
| 186 |
+
new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
|
| 187 |
+
|
| 188 |
+
return new_encoder_hidden_states, new_attention_mask, pooled
|
| 189 |
|
| 190 |
shift_factor = getattr(vae.config, "shift_factor", 0.0)
|
| 191 |
if shift_factor is None: shift_factor = 0.0
|
|
|
|
| 497 |
t,
|
| 498 |
encoder_hidden_states=text_embeddings_batch,
|
| 499 |
encoder_attention_mask=attention_mask_batch,
|
| 500 |
+
#added_cond_kwargs={"text_embeds": pooled_batch} # <--- ПУЛИНГ ЗДЕСЬ
|
| 501 |
)
|
| 502 |
flow = getattr(model_out, "sample", model_out)
|
| 503 |
|
|
|
|
| 621 |
timesteps,
|
| 622 |
encoder_hidden_states=embeddings,
|
| 623 |
encoder_attention_mask=attention_mask,
|
| 624 |
+
#added_cond_kwargs={"text_embeds": pooled_embeddings} # <--- Передача пулинга
|
| 625 |
).sample
|
| 626 |
|
| 627 |
target = noise - latents
|
unet/config.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ab7222cdd538ff5178adc870a764d22ab24a185f0a7b63852ea728b3b09fcff
|
| 3 |
+
size 1876
|
unet/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:402b5747329ecd5573887b46e35e687b8f3d6c79ccba522c06e58082d0eace87
|
| 3 |
+
size 6604736640
|
{unet_old → unet_very_old}/config.json
RENAMED
|
File without changes
|
{unet_old → unet_very_old}/diffusion_pytorch_model.safetensors
RENAMED
|
File without changes
|