Text-to-Image
Diffusers
Safetensors
recoilme commited on
Commit
ff63018
·
1 Parent(s): dd3b192
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.ipynb filter=lfs diff=lfs merge=lfs -text
39
+ *.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Jupyter Notebook
2
+ __pycache__/
3
+ *.pyc
4
+ .ipynb_checkpoints/
5
+ *.ipynb_checkpoints/*
6
+ .ipynb_checkpoints/*
7
+ src/samples
8
+ # cache
9
+ cache
10
+ datasets
11
+ test
12
+ wandb
13
+ nohup.out
model_index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7262c65619c10d3d77607525c101c525b6b6a9a3af89f503f35b42e91dd88e2
3
+ size 417
pipeline_sdxs.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import List, Union, Optional, Tuple
5
+ from dataclasses import dataclass
6
+
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.utils import BaseOutput
9
+ from tqdm import tqdm
10
+
11
+ @dataclass
12
+ class SdxsPipelineOutput(BaseOutput):
13
+ images: Union[List[Image.Image], np.ndarray]
14
+
15
+ class SdxsPipeline(DiffusionPipeline):
16
+ def __init__(self, vae, text_encoder, tokenizer, unet, scheduler):
17
+ super().__init__()
18
+ self.register_modules(
19
+ vae=vae,
20
+ text_encoder=text_encoder,
21
+ tokenizer=tokenizer,
22
+ unet=unet,
23
+ scheduler=scheduler
24
+ )
25
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
26
+
27
+ def encode_prompt(self, prompt, negative_prompt, device, dtype):
28
+ """
29
+ Полное соответствие функции encode_texts и get_negative_embedding из трейна.
30
+ """
31
+ def get_single_encode(texts, is_negative=False):
32
+ if texts is None or texts == "":
33
+ # Логика get_negative_embedding из трейна
34
+ hidden_dim = self.text_encoder.config.hidden_size
35
+
36
+ shape = (1, self.text_encoder.config.max_position_embeddings, hidden_dim)
37
+ # В трейне для негатива: zeros для эмбеддингов и ones для маски
38
+ emb = torch.zeros(shape, dtype=dtype, device=device)
39
+ mask = torch.ones((1, self.text_encoder.config.max_position_embeddings), dtype=torch.int64, device=device)
40
+ return emb, mask
41
+
42
+ if isinstance(texts, str):
43
+ texts = [texts]
44
+
45
+ with torch.no_grad():
46
+ toks = self.tokenizer(
47
+ texts,
48
+ padding="max_length",
49
+ max_length=self.text_encoder.config.max_position_embeddings,
50
+ truncation=True,
51
+ return_tensors="pt"
52
+ ).to(device)
53
+
54
+ outputs = self.text_encoder(
55
+ input_ids=toks.input_ids,
56
+ attention_mask=toks.attention_mask,
57
+ output_hidden_states=True
58
+ )
59
+
60
+ # 1. Выбираем нужный слой.
61
+ # -1 — это последний блок трансформера
62
+ # -2 — это предпоследний (стандарт для большинства современных моделей)
63
+ layer_index = -2
64
+ prompt_embeds = outputs.hidden_states[layer_index]
65
+
66
+ # 2. ДОБАВЛЯЕМ ФИНАЛЬНУЮ НОРМАЛИЗАЦИЮ
67
+ # В CLIP после всех блоков стоит слой LayerNorm.
68
+ # Если мы берем скрытые состояния напрямую, мы "проскакиваем" его.
69
+ # Нужно применить его вручную:
70
+ final_layer_norm = self.text_encoder.text_model.final_layer_norm
71
+ prompt_embeds = final_layer_norm(prompt_embeds)
72
+
73
+ return prompt_embeds, toks.attention_mask
74
+
75
+ # Получаем эмбеддинги
76
+ pos_embeds, pos_mask = get_single_encode(prompt)
77
+ neg_embeds, neg_mask = get_single_encode(negative_prompt, is_negative=True)
78
+
79
+ # Выравнивание батча
80
+ batch_size = pos_embeds.shape[0]
81
+ if neg_embeds.shape[0] != batch_size:
82
+ neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
83
+ neg_mask = neg_mask.repeat(batch_size, 1)
84
+
85
+ # Конкатенация для CFG: [Negative, Positive]
86
+ text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
87
+ final_mask = torch.cat([neg_mask, pos_mask], dim=0)
88
+
89
+ return text_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64)
90
+
91
+ @torch.no_grad()
92
+ def __call__(
93
+ self,
94
+ prompt: Union[str, List[str]],
95
+ negative_prompt: Optional[Union[str, List[str]]] = None,
96
+ height: int = 1024,
97
+ width: int = 1024,
98
+ num_inference_steps: int = 40, # Как в трейне n_diffusion_steps
99
+ guidance_scale: float = 4.0, # Как в трейне
100
+ generator: Optional[torch.Generator] = None,
101
+ output_type: str = "pil",
102
+ return_dict: bool = True,
103
+ **kwargs,
104
+ ):
105
+ device = self.device
106
+ # Убеждаемся, что VAE в правильном режиме
107
+ self.vae.to(device)
108
+
109
+ # Настройка VAE из трейна
110
+ # В трейне: scaling_factor = 1.0, shift_factor = 0.0
111
+ vae_scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
112
+ vae_shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
113
+
114
+ # Если вы в трейне жестко ставили 1.0, раскомм��нтируйте:
115
+ # vae_scaling_factor = 1.0
116
+ # vae_shift_factor = 0.0
117
+
118
+ # 1. Encode Prompt
119
+ dtype = self.text_encoder.dtype
120
+ text_embeddings, attention_mask = self.encode_prompt(
121
+ prompt, negative_prompt, device, dtype
122
+ )
123
+
124
+ # 2. Prepare Latents
125
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
126
+ latent_channels = self.unet.config.in_channels
127
+
128
+ latents = torch.randn(
129
+ (batch_size, latent_channels, height // self.vae_scale_factor, width // self.vae_scale_factor),
130
+ generator=generator,
131
+ device=device,
132
+ dtype=dtype
133
+ )
134
+
135
+ # 3. Настройка Flow Matching шедулера
136
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
137
+ timesteps = self.scheduler.timesteps
138
+
139
+ # 4. Denoising Loop
140
+ for t in tqdm(timesteps, desc="Sampling"):
141
+ # CFG input
142
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
143
+
144
+ # Flow Matching обычно не требует scale_model_input,
145
+ # но оставим для совместимости с интерфейсом шедулера
146
+ if hasattr(self.scheduler, "scale_model_input"):
147
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
148
+
149
+ # Predict
150
+ model_out = self.unet(
151
+ latent_model_input,
152
+ t,
153
+ encoder_hidden_states=text_embeddings,
154
+ encoder_attention_mask=attention_mask,
155
+ return_dict=False,
156
+ )[0]
157
+
158
+ # CFG Logic
159
+ if guidance_scale > 1:
160
+ flow_uncond, flow_cond = model_out.chunk(2)
161
+ # Формула из вашего трейна:
162
+ # flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
163
+ model_out = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
164
+
165
+ # Step (Flow Matching Euler)
166
+ latents = self.scheduler.step(model_out, t, latents, return_dict=False)[0]
167
+
168
+ # 5. Decode
169
+ if output_type == "latent":
170
+ return SdxsPipelineOutput(images=latents)
171
+
172
+ # КРИТИЧЕСКОЕ ИСПРАВЛЕНИЕ: Декодирование как в трейне
173
+ # latent_for_vae = current_latents.detach() * scaling_factor + shift_factor
174
+ latents = latents * vae_scaling_factor + vae_shift_factor
175
+
176
+ image = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
177
+
178
+ # Пост-процессинг (аналог из трейна)
179
+ image = (image / 2 + 0.5).clamp(0, 1)
180
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
181
+
182
+ if output_type == "pil":
183
+ image = (image * 255).round().astype("uint8")
184
+ image = [Image.fromarray(img) for img in image]
185
+
186
+ if not return_dict:
187
+ return image
188
+
189
+ return SdxsPipelineOutput(images=image)
r.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers>=0.32.2
2
+ accelerate>=1.5.2
3
+ datasets>=3.5.0
4
+ matplotlib>=3.10.1
5
+ wandb>=0.19.8
6
+ huggingface_hub>=0.29.3
7
+ bitsandbytes>=0.45.4
8
+ transformers
9
+ hf_transfer
10
+ comet_ml
samples/unet_192x320_0.jpg ADDED

Git LFS Details

  • SHA256: 10657f980e3f583e8dfbe88961a3dbee15d82955bead5062880e5c3cc929edd4
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
samples/unet_256x320_0.jpg ADDED

Git LFS Details

  • SHA256: 0909229838083e96ca220559d1fd2509f416202bf2bcc55544052c10eca09400
  • Pointer size: 131 Bytes
  • Size of remote file: 143 kB
samples/unet_320x192_0.jpg ADDED

Git LFS Details

  • SHA256: da185847c499bada58b3edf0e8267545907cc535cfeec70840f14899815bfbd7
  • Pointer size: 130 Bytes
  • Size of remote file: 78 kB
samples/unet_320x256_0.jpg ADDED

Git LFS Details

  • SHA256: dcb2295f69e1ace89e7a8600a3112168d1243971b00b96a25915e0d31f79ebfe
  • Pointer size: 131 Bytes
  • Size of remote file: 100 kB
samples/unet_320x320_0.jpg ADDED

Git LFS Details

  • SHA256: d66cd59072615413dd799fa4493cc15884f38789bb7cbce4f96e85a0117621b6
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52c357e163f8971a39a9df21baa338f94e189c8723b59d7507944c55f6f4a781
3
+ size 486
test.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c161ff625c1577dc117863f9bcdbfaef3459e166743e6e43936500896a07dff2
3
+ size 753513
text_encoder/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c06cbeeddf5d93f5c7abc16b17a251b2a9ba6a6f08d7114fd8a269efeab1975
3
+ size 563
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17471843ad4564851df2001ef03466891ff356fafa4230f0578b605d2f6c0f6c
3
+ size 492790480
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cdb3b8331a60c92fc1e55a13e9fd61fd2293c5a51275fdcccd62b780052530e
3
+ size 588
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ae623b1013846a4edb4e2206a09ea3f7a4a92b3215f6f840f3706ccfedcef2d
3
+ size 737
tokenizer/vocab.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e089ad92ba36837a0d31433e555c8f45fe601ab5c221d4f607ded32d9f7a4349
3
+ size 1059962
train.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from comet_ml import Experiment
2
+ import os
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from torch.utils.data import DataLoader, Sampler
8
+ from torch.utils.data.distributed import DistributedSampler
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ from collections import defaultdict
11
+ from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2,AsymmetricAutoencoderKL,FlowMatchEulerDiscreteScheduler
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image, ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+ from transformers import AutoTokenizer, AutoModel
28
+
29
+ # --------------------------- Параметры ---------------------------
30
+ ds_path = "/workspace/sdxs-08b/datasets/d123_simplevae_sd15"
31
+ project = "unet"
32
+ batch_size = 72
33
+ base_learning_rate = 4e-5 #2.7e-5
34
+ min_learning_rate = 2e-5 #2.7e-5
35
+ num_epochs = 100
36
+ sample_interval_share = 5
37
+ cfg_dropout = 0.15
38
+ max_length = 248 #192
39
+ use_wandb = False
40
+ use_comet_ml = True
41
+ save_model = True
42
+ use_decay = True
43
+ fbp = False
44
+ optimizer_type = "adam8bit"
45
+ torch_compile = False
46
+ unet_gradient = True
47
+ loss_normalize = False
48
+ fixed_seed = False
49
+ shuffle = True
50
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
51
+ comet_ml_workspace = "recoilme"
52
+ torch.backends.cuda.matmul.allow_tf32 = True
53
+ torch.backends.cudnn.allow_tf32 = True
54
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
55
+ dtype = torch.float32
56
+ save_barrier = 1.01
57
+ warmup_percent = 0.01
58
+ percentile_clipping = 95 #96 #97
59
+ betta2 = 0.995
60
+ eps = 1e-7
61
+ clip_grad_norm = 1.0
62
+ limit = 0
63
+ checkpoints_folder = ""
64
+ mixed_precision = "no"
65
+ gradient_accumulation_steps = 1
66
+
67
+ accelerator = Accelerator(
68
+ mixed_precision=mixed_precision,
69
+ gradient_accumulation_steps=gradient_accumulation_steps
70
+ )
71
+ device = accelerator.device
72
+
73
+ # Параметры для диффузии
74
+ n_diffusion_steps = 40
75
+ samples_to_generate = 12
76
+ guidance_scale = 4
77
+
78
+ # Папки для сохранения результатов
79
+ generated_folder = "samples"
80
+ os.makedirs(generated_folder, exist_ok=True)
81
+
82
+ # Настройка seed
83
+ current_date = datetime.now()
84
+ seed = int(current_date.strftime("%Y%m%d"))
85
+ if fixed_seed:
86
+ torch.manual_seed(seed)
87
+ np.random.seed(seed)
88
+ random.seed(seed)
89
+ if torch.cuda.is_available():
90
+ torch.cuda.manual_seed_all(seed)
91
+
92
+ # --------------------------- Параметры LoRA ---------------------------
93
+ lora_name = ""
94
+ lora_rank = 32
95
+ lora_alpha = 64
96
+
97
+ print("init")
98
+
99
+ loss_ratios = {
100
+ "mse": 1.25,
101
+ "mae": 0.25,
102
+ }
103
+ median_coeff_steps = 256
104
+
105
+ # Нормализация лоссов по медианам: считаем КОЭФФИЦИЕНТЫ
106
+ class MedianLossNormalizer:
107
+ def __init__(self, desired_ratios: dict, window_steps: int):
108
+ # нормируем доли на случай, если сумма != 1
109
+ #s = sum(desired_ratios.values())
110
+ #self.ratios = {k: (v / s) for k, v in desired_ratios.items()}
111
+ self.ratios = {k: float(v) for k, v in desired_ratios.items()}
112
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
113
+ self.window = window_steps
114
+
115
+ def update_and_total(self, losses: dict):
116
+ """
117
+ losses: dict ключ->тензор (значения лоссов)
118
+ Поведение:
119
+ - буферим ABS(l) только для активных (ratio>0) лоссов
120
+ - coeff = ratio / median(abs(loss))
121
+ - total = sum(coeff * loss) по активным лоссам
122
+ CHANGED: буферим abs() — чтобы медиана была положительной и не ломала деление.
123
+ """
124
+ # буферим только активные лоссы
125
+ for k, v in losses.items():
126
+ if k in self.buffers and self.ratios.get(k, 0) > 0:
127
+ val = v.detach().abs().mean().cpu().item() # .item() лучше float() для тензоров
128
+ self.buffers[k].append(val)
129
+ #self.buffers[k].append(float(v.detach().abs().cpu()))
130
+
131
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
132
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
133
+
134
+ # суммируем только по активным (ratio>0)
135
+ total = sum(coeffs[k] * losses[k] for k in coeffs if self.ratios.get(k, 0) > 0)
136
+ return total, coeffs, meds
137
+
138
+ # создаём normalizer после определения loss_ratios
139
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
140
+
141
+ # --------------------------- Инициализаци�� WandB ---------------------------
142
+ if accelerator.is_main_process:
143
+ if use_wandb:
144
+ wandb.init(project=project+lora_name, config={
145
+ "batch_size": batch_size,
146
+ "base_learning_rate": base_learning_rate,
147
+ "num_epochs": num_epochs,
148
+ "optimizer_type": optimizer_type,
149
+ })
150
+ if use_comet_ml:
151
+ from comet_ml import Experiment
152
+ comet_experiment = Experiment(
153
+ api_key=comet_ml_api_key,
154
+ project_name=project,
155
+ workspace=comet_ml_workspace
156
+ )
157
+ hyper_params = {
158
+ "batch_size": batch_size,
159
+ "base_learning_rate": base_learning_rate,
160
+ "num_epochs": num_epochs,
161
+ }
162
+ comet_experiment.log_parameters(hyper_params)
163
+
164
+ # Включение Flash Attention 2/SDPA
165
+ torch.backends.cuda.enable_flash_sdp(True)
166
+
167
+ # --------------------------- Загрузка моделей ---------------------------
168
+ #vae = AutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to("cpu").eval()
169
+ #vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
170
+ vae = AsymmetricAutoencoderKL.from_pretrained("vae",torch_dtype=dtype).to(device).eval()
171
+ tokenizer = AutoTokenizer.from_pretrained("tokenizer")
172
+ text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
173
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
174
+
175
+ def encode_texts(texts, max_length=77): # Для SD 1.5 лучше жестко 77
176
+ if texts is None:
177
+ texts = [""]
178
+
179
+ if isinstance(texts, str):
180
+ texts = [texts]
181
+
182
+ with torch.no_grad():
183
+ # 1. Основная токенизация
184
+ toks = tokenizer(
185
+ texts,
186
+ padding="max_length",
187
+ max_length=max_length,
188
+ truncation=True,
189
+ return_tensors="pt"
190
+ ).to(device)
191
+
192
+ text_input_ids = toks.input_ids
193
+
194
+ # 2. Проверка на обрезку текста (логика для вывода предупреждений)
195
+ # Получаем айдишники без обрезки по длине
196
+ # untruncated_ids = tokenizer(texts, padding="longest", return_tensors="pt").input_ids.to(device)
197
+
198
+ # Исправляем проверку: сравниваем тензоры, а не объект BatchEncoding
199
+ #if untruncated_ids.shape[-1] > text_input_ids.shape[-1] and not torch.equal(
200
+ # text_input_ids, untruncated_ids[:, :max_length]
201
+ #):
202
+ # Заменяем self.tokenizer на tokenizer
203
+ # removed_text = tokenizer.batch_decode(
204
+ # untruncated_ids[:, max_length - 1 : -1]
205
+ # )
206
+ #print(f"Warning: Text truncated. Removed part: {removed_text}")
207
+
208
+ # 3. Маска внимания
209
+ #if hasattr(text_model.config, "use_attention_mask") and text_model.config.use_attention_mask:
210
+ # attention_mask = toks.attention_mask
211
+ #else:
212
+ #attention_mask = None
213
+ attention_mask = toks.attention_mask
214
+
215
+ # 4. Прогон через модель
216
+ # Правильный вызов: передаем конкретные тензоры или распаковываем словарь **toks
217
+ outputs = text_model(
218
+ input_ids=text_input_ids,
219
+ attention_mask=attention_mask,
220
+ output_hidden_states=True # Часто нужно для SD 1.5 (слой -2)
221
+ )
222
+
223
+ layer_index = -2
224
+ prompt_embeds = outputs.hidden_states[layer_index]
225
+
226
+ # 2. ДОБАВЛЯЕМ ФИНАЛЬНУЮ НОРМАЛИЗАЦИЮ
227
+ # В CLIP после всех блоков стоит слой LayerNorm.
228
+ final_layer_norm = text_model.text_model.final_layer_norm
229
+ prompt_embeds = final_layer_norm(prompt_embeds)
230
+
231
+ return prompt_embeds, attention_mask
232
+
233
+ # --- [UPDATED] Функция кодирования текста (с маской и пулингом) ---
234
+ def encode_texts22(texts, max_length=max_length):
235
+ # Если тексты пустые (для unconditional), создаем заглушки
236
+ if texts is None:
237
+ # В случае None возвращаем нули (логика для get_negative_embedding)
238
+ # Но здесь мы обычно ожидаем список строк.
239
+ pass
240
+
241
+ with torch.no_grad():
242
+ if isinstance(texts, str):
243
+ texts = [texts]
244
+
245
+ #for i, prompt_item in enumerate(texts):
246
+ # messages = [
247
+ # {"role": "user", "content": prompt_item},
248
+ # ]
249
+ # prompt_item = tokenizer.apply_chat_template(
250
+ # messages,
251
+ # tokenize=False,
252
+ # add_generation_prompt=True,
253
+ # enable_thinking=True,
254
+ # )
255
+ #print(prompt_item+"\n")
256
+ # texts[i] = prompt_item
257
+
258
+ toks = tokenizer(
259
+ texts,
260
+ return_tensors="pt",
261
+ padding="max_length",
262
+ truncation=True,
263
+ max_length=max_length
264
+ ).to(device)
265
+
266
+ text_input_ids = toks.input_ids
267
+ untruncated_ids = tokenizer(texts, padding="longest", return_tensors="pt").input_ids
268
+
269
+ if untruncated_ids.shape[-1] >= toks.shape[-1] and not torch.equal(
270
+ toks, untruncated_ids
271
+ ):
272
+ removed_text = tokenizer.batch_decode(
273
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
274
+ )
275
+ print(
276
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
277
+ )
278
+
279
+ attention_mask = toks.attention_mask.to(device)
280
+
281
+ prompt_embeds = text_model(toks.to(device), attention_mask=attention_mask)
282
+ prompt_embeds = prompt_embeds[-1]
283
+ return prompt_embeds, attention_mask
284
+ #outs = text_model(**toks, output_hidden_states=True, return_dict=True)
285
+
286
+ # Используем last_hidden_state или hidden_states[-1] (если Qwen, лучше last_hidden_state - прим человека: ХУЙ)
287
+ #hidden = outs.hidden_states[-2]
288
+
289
+ # 2. Маска внимания
290
+ #attention_mask = toks["attention_mask"]
291
+
292
+ # 3. Пулинг-эмбеддинг (Последний токен)
293
+ #sequence_lengths = attention_mask.sum(dim=1) - 1
294
+ #batch_size = hidden.shape[0]
295
+ #pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
296
+
297
+ #return hidden, attention_mask
298
+ # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
299
+ # 1. Расширяем пулинг-вектор до последовательности [B, 1, emb]
300
+ #pooled_expanded = pooled.unsqueeze(1)
301
+
302
+ # 2. Объединяем последовательность токенов и пулинг-вектор
303
+ # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
304
+ # Теперь: [B, 1 + L, emb]. Пулинг стал токеном в НАЧАЛЕ.
305
+ #new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
306
+
307
+ # 3. Обновляем маску внимания для нового токена
308
+ # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
309
+ # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
310
+ #new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
311
+
312
+ #return new_encoder_hidden_states, new_attention_mask
313
+
314
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
315
+ if shift_factor is None: shift_factor = 0.0
316
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
317
+ if scaling_factor is None: scaling_factor = 1.0
318
+ scaling_factor = 1.0
319
+
320
+ class DistributedResolutionBatchSampler(Sampler):
321
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
322
+ self.dataset = dataset
323
+ self.batch_size = max(1, batch_size // num_replicas)
324
+ self.num_replicas = num_replicas
325
+ self.rank = rank
326
+ self.shuffle = shuffle
327
+ self.drop_last = drop_last
328
+ self.epoch = 0
329
+
330
+ try:
331
+ widths = np.array(dataset["width"])
332
+ heights = np.array(dataset["height"])
333
+ except KeyError:
334
+ widths = np.zeros(len(dataset))
335
+ heights = np.zeros(len(dataset))
336
+
337
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
338
+ self.size_groups = {}
339
+ for w, h in self.size_keys:
340
+ mask = (widths == w) & (heights == h)
341
+ self.size_groups[(w, h)] = np.where(mask)[0]
342
+
343
+ self.group_num_batches = {}
344
+ total_batches = 0
345
+ for size, indices in self.size_groups.items():
346
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
347
+ self.group_num_batches[size] = num_full_batches
348
+ total_batches += num_full_batches
349
+
350
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
351
+
352
+ def __iter__(self):
353
+ if torch.cuda.is_available():
354
+ torch.cuda.empty_cache()
355
+ all_batches = []
356
+ rng = np.random.RandomState(self.epoch)
357
+
358
+ for size, indices in self.size_groups.items():
359
+ indices = indices.copy()
360
+ if self.shuffle:
361
+ rng.shuffle(indices)
362
+ num_full_batches = self.group_num_batches[size]
363
+ if num_full_batches == 0:
364
+ continue
365
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
366
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
367
+ start_idx = self.rank * self.batch_size
368
+ end_idx = start_idx + self.batch_size
369
+ gpu_batches = batches[:, start_idx:end_idx]
370
+ all_batches.extend(gpu_batches)
371
+
372
+ if self.shuffle:
373
+ rng.shuffle(all_batches)
374
+ accelerator.wait_for_everyone()
375
+ return iter(all_batches)
376
+
377
+ def __len__(self):
378
+ return self.num_batches
379
+
380
+ def set_epoch(self, epoch):
381
+ self.epoch = epoch
382
+
383
+ # --- [UPDATED] Функция для фиксированных семплов ---
384
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
385
+ size_groups = defaultdict(list)
386
+ try:
387
+ widths = dataset["width"]
388
+ heights = dataset["height"]
389
+ except KeyError:
390
+ widths = [0] * len(dataset)
391
+ heights = [0] * len(dataset)
392
+ for i, (w, h) in enumerate(zip(widths, heights)):
393
+ size = (w, h)
394
+ size_groups[size].append(i)
395
+
396
+ fixed_samples = {}
397
+ for size, indices in size_groups.items():
398
+ n_samples = min(samples_per_group, len(indices))
399
+ if len(size_groups)==1:
400
+ n_samples = samples_to_generate
401
+ if n_samples == 0:
402
+ continue
403
+ sample_indices = random.sample(indices, n_samples)
404
+ samples_data = [dataset[idx] for idx in sample_indices]
405
+
406
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
407
+ texts = [item["text"] for item in samples_data]
408
+
409
+ # Кодируем тексты на лету, чтобы получить маски и пулинг
410
+ embeddings, masks = encode_texts(texts)
411
+
412
+ fixed_samples[size] = (latents, embeddings, masks, texts)
413
+
414
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
415
+ return fixed_samples
416
+
417
+ if limit > 0:
418
+ dataset = load_from_disk(ds_path).select(range(limit))
419
+ else:
420
+ dataset = load_from_disk(ds_path)
421
+
422
+ dataset = dataset.filter(
423
+ lambda x: [not (path.startswith("/workspace/ds/animesfw") or path.startswith("/workspace/ds/d4/animesfw")) for path in x["image_path"]],
424
+ batched=True,
425
+ batch_size=10000, # обрабатываем по 10к строк за раз
426
+ num_proc=8
427
+ )
428
+ print(f"Осталось примеров после фильтрации: {len(dataset)}")
429
+
430
+ # --- [UPDATED] Collate Function ---
431
+ def collate_fn_simple(batch):
432
+ # 1. Латенты (VAE)
433
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype)
434
+
435
+ # 2. Текст берем сырой из датасета
436
+ raw_texts = [item["text"] for item in batch]
437
+ texts = [
438
+ "" if t.lower().startswith("zero")
439
+ else "" if random.random() < cfg_dropout
440
+ else t[1:].lstrip() if t.startswith(".")
441
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
442
+ for t in raw_texts
443
+ ]
444
+
445
+ # 3. Кодируем на лету
446
+ # Возвращает: hidden (B, L, D), mask (B, L)
447
+ embeddings, attention_mask = encode_texts(texts)
448
+
449
+ # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
450
+ attention_mask = attention_mask.to(dtype=torch.int64)
451
+
452
+ return latents, embeddings, attention_mask
453
+
454
+ batch_sampler = DistributedResolutionBatchSampler(
455
+ dataset=dataset,
456
+ batch_size=batch_size,
457
+ num_replicas=accelerator.num_processes,
458
+ rank=accelerator.process_index,
459
+ shuffle=shuffle
460
+ )
461
+
462
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
463
+ if accelerator.is_main_process:
464
+ print("Total samples", len(dataloader))
465
+ dataloader = accelerator.prepare(dataloader)
466
+
467
+ start_epoch = 0
468
+ global_step = 0
469
+ total_training_steps = (len(dataloader) * num_epochs)
470
+ world_size = accelerator.state.num_processes
471
+
472
+ # Загрузка UNet
473
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
474
+ if os.path.isdir(latest_checkpoint):
475
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
476
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
477
+ if unet_gradient:
478
+ unet.enable_gradient_checkpointing()
479
+ unet.set_use_memory_efficient_attention_xformers(False)
480
+ try:
481
+ unet.set_attn_processor(AttnProcessor2_0())
482
+ except Exception as e:
483
+ print(f"Ошибка при включении SDPA: {e}")
484
+ unet.set_use_memory_efficient_attention_xformers(True)
485
+ else:
486
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
487
+
488
+ if lora_name:
489
+ # ... (Код LoRA без изменений, опущен для краткости, если не используется, иначе раскомментируйте оригинальный блок) ...
490
+ pass
491
+
492
+ # Оптимизатор
493
+ if lora_name:
494
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
495
+ else:
496
+ if fbp:
497
+ trainable_params = list(unet.parameters())
498
+
499
+
500
+ def create_optimizer(name, params):
501
+ if name == "adam8bit":
502
+ return bnb.optim.AdamW8bit(
503
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
504
+ percentile_clipping=percentile_clipping
505
+ )
506
+ elif name == "adam":
507
+ return torch.optim.AdamW(
508
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
509
+ )
510
+ elif name == "muon":
511
+ from muon import MuonWithAuxAdam
512
+ trainable_params = [p for p in params if p.requires_grad]
513
+ hidden_weights = [p for p in trainable_params if p.ndim >= 2]
514
+ hidden_gains_biases = [p for p in trainable_params if p.ndim < 2]
515
+
516
+ param_groups = [
517
+ dict(params=hidden_weights, use_muon=True,
518
+ lr=1e-3, weight_decay=1e-4),
519
+ dict(params=hidden_gains_biases, use_muon=False,
520
+ lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-4),
521
+ ]
522
+ optimizer = MuonWithAuxAdam(param_groups)
523
+ from snooc import SnooC
524
+ return SnooC(optimizer)
525
+ else:
526
+ raise ValueError(f"Unknown optimizer: {name}")
527
+
528
+ if fbp:
529
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
530
+ def optimizer_hook(param):
531
+ optimizer_dict[param].step()
532
+ optimizer_dict[param].zero_grad(set_to_none=True)
533
+ for param in trainable_params:
534
+ param.register_post_accumulate_grad_hook(optimizer_hook)
535
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
536
+ else:
537
+ # 1. Сначала замораживаем ВСЕ параметры UNet
538
+ #unet.requires_grad_(False)
539
+
540
+ # 2. Размораживаем только нужные
541
+ #trainable_params_names = ["conv_in.weight", "conv_in.bias", "conv_out.weight", "conv_out.bias"]
542
+ #train_params = []
543
+
544
+ #for name, param in unet.named_parameters():
545
+ # if any(target in name for target in trainable_params_names):
546
+ # param.requires_grad = True
547
+ # train_params.append(param)
548
+ # print(f"Обучаемый слой: {name}")
549
+
550
+ # 3. Передаем в оптимизатор ТОЛЬКО обучаемые параметры
551
+ #optimizer = create_optimizer(optimizer_type, train_params)
552
+
553
+ unet.requires_grad_(True)
554
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
555
+
556
+ def lr_schedule(step):
557
+ x = step / (total_training_steps * world_size)
558
+ warmup = warmup_percent
559
+ if not use_decay:
560
+ return base_learning_rate
561
+ if x < warmup:
562
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
563
+ decay_ratio = (x - warmup) / (1 - warmup)
564
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
565
+ (1 + math.cos(math.pi * decay_ratio))
566
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
567
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
568
+
569
+ if torch_compile:
570
+ print("compiling")
571
+ unet = torch.compile(unet)
572
+ print("compiling - ok")
573
+
574
+ # Фиксированные семплы
575
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
576
+
577
+ # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
578
+ def get_negative_embedding(neg_prompt="", batch_size=1):
579
+ if not neg_prompt:
580
+ hidden_dim = 2048
581
+ seq_len = max_length
582
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
583
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
584
+ return empty_emb, empty_mask
585
+
586
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
587
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
588
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
589
+
590
+ return uncond_emb, uncond_mask
591
+
592
+ # Получаем негативные (пустые) условия для валидации
593
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
594
+
595
+ # --- Функция генерации семплов ---
596
+ @torch.compiler.disable()
597
+ @torch.no_grad()
598
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
599
+ uncond_emb, uncond_mask = uncond_data
600
+
601
+ original_model = None
602
+ try:
603
+ if not torch_compile:
604
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
605
+ else:
606
+ original_model = unet.eval()
607
+
608
+ vae.to(device=device).eval()
609
+
610
+ all_generated_images = []
611
+ all_captions = []
612
+
613
+ # Распаковываем 5 элементов (добавились mask)
614
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
615
+ width, height = size
616
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
617
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
618
+ sample_mask = sample_mask.to(device=device)
619
+
620
+ latents = torch.randn(
621
+ sample_latents.shape,
622
+ device=device,
623
+ dtype=sample_latents.dtype,
624
+ generator=torch.Generator(device=device).manual_seed(seed)
625
+ )
626
+
627
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
628
+
629
+ for t in scheduler.timesteps:
630
+ if guidance_scale != 1:
631
+ latent_model_input = torch.cat([latents, latents], dim=0)
632
+
633
+ # Подготовка батчей для CFG (Negative + Positive)
634
+ # 1. Embeddings
635
+ curr_batch_size = sample_text_embeddings.shape[0]
636
+ seq_len = sample_text_embeddings.shape[1]
637
+ hidden_dim = sample_text_embeddings.shape[2]
638
+
639
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
640
+ text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
641
+
642
+ # 2. Masks
643
+ neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
644
+ attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
645
+
646
+ else:
647
+ latent_model_input = latents
648
+ text_embeddings_batch = sample_text_embeddings
649
+ attention_mask_batch = sample_mask
650
+
651
+ # Предсказание с передачей всех условий
652
+ model_out = original_model(
653
+ latent_model_input,
654
+ t,
655
+ encoder_hidden_states=text_embeddings_batch,
656
+ encoder_attention_mask=attention_mask_batch,
657
+ )
658
+ flow = getattr(model_out, "sample", model_out)
659
+
660
+ if guidance_scale != 1:
661
+ flow_uncond, flow_cond = flow.chunk(2)
662
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
663
+
664
+ latents = scheduler.step(flow, t, latents).prev_sample
665
+
666
+ current_latents = latents
667
+ if step==0:
668
+ current_latents = sample_latents
669
+
670
+ latent_for_vae = current_latents.detach() * scaling_factor + shift_factor
671
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
672
+ decoded_fp32 = decoded.to(torch.float32)
673
+
674
+ for img_idx, img_tensor in enumerate(decoded_fp32):
675
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
676
+ img = img.transpose(1, 2, 0)
677
+
678
+ if np.isnan(img).any():
679
+ print("NaNs found, saving stopped! Step:", step)
680
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
681
+
682
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
683
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
684
+ max_w_overall = max(255, max_w_overall)
685
+ max_h_overall = max(255, max_h_overall)
686
+
687
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
688
+ all_generated_images.append(padded_img)
689
+
690
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
691
+ all_captions.append(caption_text)
692
+
693
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
694
+ pil_img.save(sample_path, "JPEG", quality=96)
695
+
696
+ if use_wandb and accelerator.is_main_process:
697
+ wandb_images = [
698
+ wandb.Image(img, caption=f"{all_captions[i]}")
699
+ for i, img in enumerate(all_generated_images)
700
+ ]
701
+ wandb.log({"generated_images": wandb_images})
702
+ if use_comet_ml and accelerator.is_main_process:
703
+ for i, img in enumerate(all_generated_images):
704
+ comet_experiment.log_image(
705
+ image_data=img,
706
+ name=f"step_{step}_img_{i}",
707
+ step=step,
708
+ metadata={"caption": all_captions[i]}
709
+ )
710
+ finally:
711
+ vae.to("cpu")
712
+ torch.cuda.empty_cache()
713
+ gc.collect()
714
+
715
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
716
+ if accelerator.is_main_process:
717
+ if save_model:
718
+ print("Генерация сэмплов до старта обучения...")
719
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
720
+ accelerator.wait_for_everyone()
721
+
722
+ def save_checkpoint(unet, variant=""):
723
+ if accelerator.is_main_process:
724
+ if lora_name:
725
+ save_lora_checkpoint(unet)
726
+ else:
727
+ model_to_save = None
728
+ if not torch_compile:
729
+ model_to_save = accelerator.unwrap_model(unet)
730
+ else:
731
+ model_to_save = unet
732
+
733
+ if variant != "":
734
+ model_to_save.to(dtype=torch.float16).save_pretrained(
735
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
736
+ )
737
+ else:
738
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
739
+
740
+ unet = unet.to(dtype=dtype)
741
+
742
+ # --------------------------- Тренировочный цикл ---------------------------
743
+ if accelerator.is_main_process:
744
+ print(f"Total steps per GPU: {total_training_steps}")
745
+
746
+ epoch_loss_points = []
747
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
748
+
749
+ steps_per_epoch = len(dataloader)
750
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
751
+ min_loss = 4.
752
+
753
+ for epoch in range(start_epoch, start_epoch + num_epochs):
754
+ batch_losses = []
755
+ batch_grads = []
756
+ batch_sampler.set_epoch(epoch)
757
+ accelerator.wait_for_everyone()
758
+ unet.train()
759
+
760
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
761
+ with accelerator.accumulate(unet):
762
+ if save_model == False and epoch == 0 and step == 5 :
763
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
764
+ print(f"Шаг {step}: {used_gb:.2f} GB")
765
+
766
+ # шум
767
+ noise = torch.randn_like(latents, dtype=latents.dtype)
768
+
769
+ # 3. Время t (сэмплим, как и раньше, но чуть сжимаем края)
770
+ u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
771
+ t = u * (1 - 2 * 1e-5) + 1e-5 # Теперь t строго в (0.00001 ... 0.99999)
772
+ # интерполяция между x0 и шумом
773
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
774
+ # делаем integer timesteps для UNet
775
+ timesteps = t.to(torch.float32).mul(999.0)
776
+ timesteps = timesteps.clamp(0, scheduler.config.num_train_timesteps - 1)
777
+
778
+ # --- Вызов UNet с маской ---
779
+ model_pred = unet(
780
+ noisy_latents,
781
+ timesteps,
782
+ encoder_hidden_states=embeddings,
783
+ encoder_attention_mask=attention_mask
784
+ ).sample
785
+
786
+ target = noise - latents
787
+
788
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
789
+ mae_loss = F.l1_loss(model_pred.float(), target.float())
790
+
791
+ batch_losses.append(mse_loss.detach().item())
792
+
793
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
794
+ accelerator.wait_for_everyone()
795
+
796
+ losses_dict = {}
797
+ losses_dict["mse"] = mse_loss
798
+ losses_dict["mae"] = mae_loss
799
+
800
+ # === Нормализация всех лоссов ===
801
+ abs_for_norm = {k: losses_dict.get(k, torch.tensor(0.0, device=device)) for k in normalizer.ratios.keys()}
802
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_for_norm)
803
+
804
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
805
+ accelerator.wait_for_everyone()
806
+
807
+ if loss_normalize:
808
+ accelerator.backward(total_loss)
809
+ else:
810
+ accelerator.backward(mse_loss)
811
+
812
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
813
+ accelerator.wait_for_everyone()
814
+
815
+ grad = 0.0
816
+ if not fbp:
817
+ if accelerator.sync_gradients:
818
+ #with torch.amp.autocast('cuda', enabled=False):
819
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
820
+ grad = float(grad_val)
821
+ optimizer.step()
822
+ lr_scheduler.step()
823
+ optimizer.zero_grad(set_to_none=True)
824
+
825
+ if accelerator.sync_gradients:
826
+ global_step += 1
827
+ progress_bar.update(1)
828
+ if accelerator.is_main_process:
829
+ if fbp:
830
+ current_lr = base_learning_rate
831
+ else:
832
+ current_lr = lr_scheduler.get_last_lr()[0]
833
+ batch_grads.append(grad)
834
+
835
+ log_data = {}
836
+ log_data["loss_mse"] = mse_loss.detach().item()
837
+ log_data["loss_mae"] = mae_loss.detach().item()
838
+ log_data["lr"] = current_lr
839
+ log_data["grad"] = grad
840
+ log_data["loss_norm"] = float(total_loss.item())
841
+ for k, c in coeffs.items():
842
+ log_data[f"coeff_{k}"] = float(c)
843
+ if accelerator.sync_gradients:
844
+ if use_wandb:
845
+ wandb.log(log_data, step=global_step)
846
+ if use_comet_ml:
847
+ comet_experiment.log_metrics(log_data, step=global_step)
848
+
849
+ if global_step % sample_interval == 0:
850
+ # Передаем tuple (emb, mask) для негатива
851
+ if save_model:
852
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
853
+ elif epoch % 10 == 0:
854
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
855
+ last_n = sample_interval
856
+
857
+ if save_model:
858
+ has_losses = len(batch_losses) > 0
859
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
860
+ last_loss = batch_losses[-1] if has_losses else 0.0
861
+ max_loss = max(avg_sample_loss, last_loss)
862
+ should_save = max_loss < min_loss * save_barrier
863
+ print(
864
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
865
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
866
+ )
867
+ # 6. Сохранение и обновление
868
+ if should_save:
869
+ min_loss = max_loss
870
+ save_checkpoint(unet)
871
+
872
+ if accelerator.is_main_process:
873
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
874
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
875
+
876
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
877
+ log_data_ep = {
878
+ "epoch_loss": avg_epoch_loss,
879
+ "epoch_grad": avg_epoch_grad,
880
+ "epoch": epoch + 1,
881
+ }
882
+ if use_wandb:
883
+ wandb.log(log_data_ep)
884
+ if use_comet_ml:
885
+ comet_experiment.log_metrics(log_data_ep)
886
+
887
+ if accelerator.is_main_process:
888
+ print("Обучение завершено! Сохраняем финальную модель...")
889
+ #if save_model:
890
+ save_checkpoint(unet,"fp16")
891
+ if use_comet_ml:
892
+ comet_experiment.end()
893
+ accelerator.free_memory()
894
+ if torch.distributed.is_initialized():
895
+ torch.distributed.destroy_process_group()
896
+
897
+ print("Готово!")
unet/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a14cb265318c2fc495111e70c6f22f9364fe38511040f7fe97530379027afc52
3
+ size 1812
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:608813f0c133cac3751231f49866dfdb1ac3623c842672d7685dc4ff69c74260
3
+ size 3438444088
vae/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:794f861f46cdc204bf5ebc53357612c1c6af20c0f45b06fc29c05c9e3e5262f9
3
+ size 852
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffde397a3e78a779adff8ba78297f66d01af5e397512f6ed6d500df30e9833a1
3
+ size 382598708