2512
Browse files- samples/sdxs_1b_384x768_0.jpg +2 -2
- samples/sdxs_1b_416x768_0.jpg +2 -2
- samples/sdxs_1b_448x768_0.jpg +2 -2
- samples/sdxs_1b_480x768_0.jpg +2 -2
- samples/sdxs_1b_512x768_0.jpg +2 -2
- samples/sdxs_1b_544x768_0.jpg +2 -2
- samples/sdxs_1b_576x768_0.jpg +2 -2
- samples/sdxs_1b_608x768_0.jpg +2 -2
- samples/sdxs_1b_640x768_0.jpg +2 -2
- samples/sdxs_1b_672x768_0.jpg +2 -2
- samples/sdxs_1b_704x768_0.jpg +2 -2
- samples/sdxs_1b_736x768_0.jpg +2 -2
- samples/sdxs_1b_768x384_0.jpg +2 -2
- samples/sdxs_1b_768x416_0.jpg +2 -2
- samples/sdxs_1b_768x448_0.jpg +2 -2
- samples/sdxs_1b_768x480_0.jpg +2 -2
- samples/sdxs_1b_768x512_0.jpg +2 -2
- samples/sdxs_1b_768x544_0.jpg +2 -2
- samples/sdxs_1b_768x576_0.jpg +2 -2
- samples/sdxs_1b_768x608_0.jpg +2 -2
- samples/sdxs_1b_768x640_0.jpg +2 -2
- samples/sdxs_1b_768x672_0.jpg +2 -2
- samples/sdxs_1b_768x704_0.jpg +2 -2
- samples/sdxs_1b_768x736_0.jpg +2 -2
- samples/sdxs_1b_768x768_0.jpg +2 -2
- sdxs_1b/diffusion_pytorch_model.safetensors +1 -1
- train.py +13 -9
samples/sdxs_1b_384x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_416x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_448x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_480x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_512x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_544x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_576x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_608x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_640x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_672x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_704x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_736x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x416_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x448_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x480_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x512_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x544_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x576_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x608_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x672_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x704_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x736_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/sdxs_1b_768x768_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
sdxs_1b/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 4463672488
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3773b931314cc4356c69abfae833a9e948b3092b58d436a73d1cff8a36a74376
|
| 3 |
size 4463672488
|
train.py
CHANGED
|
@@ -34,6 +34,7 @@ base_learning_rate = 4e-5 #2.7e-5
|
|
| 34 |
min_learning_rate = 9e-6 #2.7e-5
|
| 35 |
num_epochs = 10
|
| 36 |
sample_interval_share = 20
|
|
|
|
| 37 |
max_length = 192
|
| 38 |
use_wandb = True
|
| 39 |
use_comet_ml = False
|
|
@@ -95,8 +96,8 @@ lora_alpha = 64
|
|
| 95 |
print("init")
|
| 96 |
|
| 97 |
loss_ratios = {
|
| 98 |
-
"mse": 0.
|
| 99 |
-
"mae": 0.
|
| 100 |
}
|
| 101 |
median_coeff_steps = 256
|
| 102 |
|
|
@@ -104,8 +105,9 @@ median_coeff_steps = 256
|
|
| 104 |
class MedianLossNormalizer:
|
| 105 |
def __init__(self, desired_ratios: dict, window_steps: int):
|
| 106 |
# нормируем доли на случай, если сумма != 1
|
| 107 |
-
s = sum(desired_ratios.values())
|
| 108 |
-
self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
|
|
|
|
| 109 |
self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
|
| 110 |
self.window = window_steps
|
| 111 |
|
|
@@ -358,7 +360,7 @@ def collate_fn_simple(batch):
|
|
| 358 |
raw_texts = [item["text"] for item in batch]
|
| 359 |
texts = [
|
| 360 |
"" if t.lower().startswith("zero")
|
| 361 |
-
else "" if random.random() <
|
| 362 |
else t[1:].lstrip() if t.startswith(".")
|
| 363 |
else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
|
| 364 |
for t in raw_texts
|
|
@@ -480,7 +482,7 @@ fixed_samples = get_fixed_samples_by_resolution(dataset)
|
|
| 480 |
# --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
|
| 481 |
def get_negative_embedding(neg_prompt="", batch_size=1):
|
| 482 |
if not neg_prompt:
|
| 483 |
-
hidden_dim =
|
| 484 |
seq_len = max_length
|
| 485 |
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
|
| 486 |
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
|
|
@@ -567,6 +569,8 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
|
|
| 567 |
latents = scheduler.step(flow, t, latents).prev_sample
|
| 568 |
|
| 569 |
current_latents = latents
|
|
|
|
|
|
|
| 570 |
|
| 571 |
latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
|
| 572 |
decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
|
|
@@ -667,9 +671,9 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
|
|
| 667 |
# шум
|
| 668 |
noise = torch.randn_like(latents, dtype=latents.dtype)
|
| 669 |
# берём t из [0, 1]
|
| 670 |
-
|
| 671 |
-
u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
|
| 672 |
-
t = torch.sigmoid(torch.randn_like(u))
|
| 673 |
|
| 674 |
# интерполяция между x0 и шумом
|
| 675 |
noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
|
|
|
|
| 34 |
min_learning_rate = 9e-6 #2.7e-5
|
| 35 |
num_epochs = 10
|
| 36 |
sample_interval_share = 20
|
| 37 |
+
cfg_dropout = 0.5
|
| 38 |
max_length = 192
|
| 39 |
use_wandb = True
|
| 40 |
use_comet_ml = False
|
|
|
|
| 96 |
print("init")
|
| 97 |
|
| 98 |
loss_ratios = {
|
| 99 |
+
"mse": 0.8,
|
| 100 |
+
"mae": 0.2,
|
| 101 |
}
|
| 102 |
median_coeff_steps = 256
|
| 103 |
|
|
|
|
| 105 |
class MedianLossNormalizer:
|
| 106 |
def __init__(self, desired_ratios: dict, window_steps: int):
|
| 107 |
# нормируем доли на случай, если сумма != 1
|
| 108 |
+
#s = sum(desired_ratios.values())
|
| 109 |
+
#self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
|
| 110 |
+
self.ratios = {k: float(v) for k, v in desired_ratios.items()}
|
| 111 |
self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
|
| 112 |
self.window = window_steps
|
| 113 |
|
|
|
|
| 360 |
raw_texts = [item["text"] for item in batch]
|
| 361 |
texts = [
|
| 362 |
"" if t.lower().startswith("zero")
|
| 363 |
+
else "" if random.random() < cfg_dropout
|
| 364 |
else t[1:].lstrip() if t.startswith(".")
|
| 365 |
else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
|
| 366 |
for t in raw_texts
|
|
|
|
| 482 |
# --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
|
| 483 |
def get_negative_embedding(neg_prompt="", batch_size=1):
|
| 484 |
if not neg_prompt:
|
| 485 |
+
hidden_dim = 1024
|
| 486 |
seq_len = max_length
|
| 487 |
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
|
| 488 |
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
|
|
|
|
| 569 |
latents = scheduler.step(flow, t, latents).prev_sample
|
| 570 |
|
| 571 |
current_latents = latents
|
| 572 |
+
if step==0:
|
| 573 |
+
current_latents = sample_latents
|
| 574 |
|
| 575 |
latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
|
| 576 |
decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
|
|
|
|
| 671 |
# шум
|
| 672 |
noise = torch.randn_like(latents, dtype=latents.dtype)
|
| 673 |
# берём t из [0, 1]
|
| 674 |
+
t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
|
| 675 |
+
#u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
|
| 676 |
+
#t = torch.sigmoid(torch.randn_like(u))
|
| 677 |
|
| 678 |
# интерполяция между x0 и шумом
|
| 679 |
noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
|