Text-to-Image
Diffusers
Safetensors
recoilme commited on
Commit
0be0b27
·
1 Parent(s): e6e8c46
README.md CHANGED
@@ -8,11 +8,11 @@ datasets:
8
  # Simple Diffusion XS
9
 
10
  *XS Size, Excess Quality*
11
- ![promo](media/result_grid_base.jpg)
12
 
13
  At AiArtLab, we strive to create a free, compact and fast model that can be trained on consumer graphics cards.
14
 
15
- - Unet: 1.16b parameters
16
  - Clip: [LongCLIP with 248 tokens](https://huggingface.co/zer0int/CLIP-KO-LITE-TypoAttack-Attn-Dropout-ViT-L-14)
17
  - VAE: 16ch16x(8x-enc/16x-dec)
18
  - Speed: Sampling 100%|██████████| 40/40 [00:01<00:00, 30.72it/s] (1024x1280)
 
8
  # Simple Diffusion XS
9
 
10
  *XS Size, Excess Quality*
11
+ ![promo](media/girl.jpg)
12
 
13
  At AiArtLab, we strive to create a free, compact and fast model that can be trained on consumer graphics cards.
14
 
15
+ - Unet: 1.3b parameters
16
  - Clip: [LongCLIP with 248 tokens](https://huggingface.co/zer0int/CLIP-KO-LITE-TypoAttack-Attn-Dropout-ViT-L-14)
17
  - VAE: 16ch16x(8x-enc/16x-dec)
18
  - Speed: Sampling 100%|██████████| 40/40 [00:01<00:00, 30.72it/s] (1024x1280)
unet1.3b.ipynb → media/girl.jpg RENAMED
File without changes
media/result_grid.jpg CHANGED

Git LFS Details

  • SHA256: 5476fefcc3b778a967365e090500e349ac1f79b4fe76e80f864f4a1ca059bc21
  • Pointer size: 132 Bytes
  • Size of remote file: 4.32 MB

Git LFS Details

  • SHA256: 7258028c9041d5853da2dedd8bd7fff472e76dc349c805fc99406937692b1321
  • Pointer size: 132 Bytes
  • Size of remote file: 3.7 MB
media/result_grid_base.jpg DELETED

Git LFS Details

  • SHA256: 2e550d11a47bd98b5deba7510e80eb43f964cbacaa78872352c921e41cb580e1
  • Pointer size: 131 Bytes
  • Size of remote file: 994 kB
pipeline_sdxs-Copy1.py DELETED
@@ -1,335 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from PIL import Image
4
- from typing import List, Union, Optional, Tuple
5
- from dataclasses import dataclass
6
-
7
- from diffusers import DiffusionPipeline
8
- from diffusers.utils import BaseOutput
9
- from tqdm import tqdm
10
-
11
- @dataclass
12
- class SdxsPipelineOutput(BaseOutput):
13
- images: Union[List[Image.Image], np.ndarray]
14
-
15
- class SdxsPipeline(DiffusionPipeline):
16
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler):
17
- super().__init__()
18
- self.register_modules(
19
- vae=vae,
20
- text_encoder=text_encoder,
21
- tokenizer=tokenizer,
22
- unet=unet,
23
- scheduler=scheduler
24
- )
25
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
26
-
27
- def create_frequency_soft_cutoff_mask(self, height: int, width: int, cutoff_radius: float,
28
- transition_width: float = 5.0, device: torch.device = None) -> torch.Tensor:
29
- """Создает плавную маску частотного среза для сохранения структуры."""
30
- if device is None:
31
- device = torch.device('cpu')
32
-
33
- u = torch.arange(height, device=device)
34
- v = torch.arange(width, device=device)
35
- u, v = torch.meshgrid(u, v, indexing='ij')
36
-
37
- center_u, center_v = height // 2, width // 2
38
- frequency_radius = torch.sqrt((u - center_u)**2 + (v - center_v)**2)
39
-
40
- mask = torch.exp(-(frequency_radius - cutoff_radius)**2 / (2 * transition_width**2))
41
- mask = torch.where(frequency_radius <= cutoff_radius, torch.ones_like(mask), mask)
42
-
43
- return mask
44
-
45
- def generate_structured_noise(
46
- self,
47
- image_latents: torch.Tensor,
48
- cutoff_radius: Optional[float] = None,
49
- transition_width: float = 2.0,
50
- noise_std: float = 1.0,
51
- ) -> torch.Tensor:
52
- """
53
- Генерирует структурированный шум для латентов с сохранением низкочастотной структуры.
54
-
55
- Args:
56
- image_latents: Чистые латенты изображения [B, C, H, W]
57
- cutoff_radius: Радиус среза частот (None = авто-расчет на основе coef)
58
- transition_width: Ширина плавного перехода
59
- noise_std: Стандартное отклонение шума
60
-
61
- Returns:
62
- Структурированный шум с той же размерностью
63
- """
64
- batch_size, channels, height, width = image_latents.shape
65
- device = image_latents.device
66
- dtype = image_latents.dtype
67
-
68
- # Автоматический расчет cutoff_radius если не задан
69
- if cutoff_radius is None:
70
- # Сохраняем больше низких частот для лучшей структуры
71
- max_radius = min(height, width) / 2
72
- cutoff_radius = max_radius * 0.7 # Сохраняем 70% низких частот
73
-
74
- # Создаем частотную маску
75
- freq_mask = self.create_frequency_soft_cutoff_mask(
76
- height, width, cutoff_radius, transition_width, device
77
- )
78
- freq_mask = freq_mask.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
79
-
80
- # Преобразуем латенты в частотную область
81
- fft_image = torch.fft.fft2(image_latents, dim=(-2, -1))
82
- fft_shifted = torch.fft.fftshift(fft_image, dim=(-2, -1))
83
-
84
- # Извлекаем фазу изображения
85
- image_phase = torch.angle(fft_shifted)
86
-
87
- # Генерируем гауссовский шум
88
- noise = torch.randn_like(image_latents) * noise_std
89
-
90
- # Преобразуем шум в частотную область
91
- fft_noise = torch.fft.fft2(noise, dim=(-2, -1))
92
- fft_noise_shifted = torch.fft.fftshift(fft_noise, dim=(-2, -1))
93
-
94
- # Извлекаем амплитуду шума
95
- noise_magnitude = torch.abs(fft_noise_shifted)
96
- noise_phase = torch.angle(fft_noise_shifted)
97
-
98
- # Смешиваем фазы: низкие частоты - фаза изображения, высокие - фаза шума
99
- mixed_phase = freq_mask * image_phase + (1 - freq_mask) * noise_phase
100
-
101
- # Собираем обратно: амплитуда шума + смешанная фаза
102
- fft_combined = noise_magnitude * torch.exp(1j * mixed_phase)
103
- fft_unshifted = torch.fft.ifftshift(fft_combined, dim=(-2, -1))
104
-
105
- # Обратное преобразование
106
- structured_noise = torch.fft.ifft2(fft_unshifted, dim=(-2, -1))
107
- structured_noise = torch.real(structured_noise)
108
-
109
- # Нормализуе�� для сохранения статистики гауссовского шума
110
- current_std = torch.std(structured_noise)
111
- if current_std > 0:
112
- structured_noise = structured_noise / current_std * noise_std
113
-
114
- return structured_noise.to(dtype)
115
-
116
- def preprocess_image(self, image: Image.Image, width: int, height: int):
117
- """Ресайз и центрированный кроп изображения для асимметричного VAE."""
118
- # Для энкодера с масштабом 8
119
- target_height = ((height // self.vae_scale_factor) * self.vae_scale_factor)//2
120
- target_width = ((width // self.vae_scale_factor) * self.vae_scale_factor)//2
121
-
122
- w, h = image.size
123
- aspect_ratio = target_width / target_height
124
-
125
- if w / h > aspect_ratio:
126
- new_w = int(h * aspect_ratio)
127
- left = (w - new_w) // 2
128
- image = image.crop((left, 0, left + new_w, h))
129
- else:
130
- new_h = int(w / aspect_ratio)
131
- top = (h - new_h) // 2
132
- image = image.crop((0, top, w, top + new_h))
133
-
134
- image = image.resize((target_width, target_height), resample=Image.LANCZOS)
135
- image = np.array(image).astype(np.float32) / 255.0
136
- image = image[None].transpose(0, 3, 1, 2) # [1, C, H, W]
137
- image = torch.from_numpy(image)
138
- return 2.0 * image - 1.0 # [-1, 1]
139
-
140
- def encode_prompt(self, prompt, negative_prompt, device, dtype):
141
- def get_single_encode(texts, is_negative=False):
142
- if texts is None or texts == "":
143
- hidden_dim = self.text_encoder.config.hidden_size
144
- shape = (1, self.text_encoder.config.max_position_embeddings, hidden_dim)
145
- emb = torch.zeros(shape, dtype=dtype, device=device)
146
- mask = torch.ones((1, self.text_encoder.config.max_position_embeddings), dtype=torch.int64, device=device)
147
- return emb, mask
148
-
149
- if isinstance(texts, str):
150
- texts = [texts]
151
-
152
- with torch.no_grad():
153
- toks = self.tokenizer(
154
- texts,
155
- padding="max_length",
156
- max_length=self.text_encoder.config.max_position_embeddings,
157
- truncation=True,
158
- return_tensors="pt"
159
- ).to(device)
160
-
161
- outputs = self.text_encoder(
162
- input_ids=toks.input_ids,
163
- attention_mask=toks.attention_mask,
164
- output_hidden_states=True
165
- )
166
-
167
- layer_index = -2
168
- prompt_embeds = outputs.hidden_states[layer_index]
169
- final_layer_norm = self.text_encoder.text_model.final_layer_norm
170
- prompt_embeds = final_layer_norm(prompt_embeds)
171
-
172
- return prompt_embeds, toks.attention_mask
173
-
174
- pos_embeds, pos_mask = get_single_encode(prompt)
175
- neg_embeds, neg_mask = get_single_encode(negative_prompt, is_negative=True)
176
-
177
- batch_size = pos_embeds.shape[0]
178
- if neg_embeds.shape[0] != batch_size:
179
- neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
180
- neg_mask = neg_mask.repeat(batch_size, 1)
181
-
182
- text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
183
- final_mask = torch.cat([neg_mask, pos_mask], dim=0)
184
-
185
- return text_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64)
186
-
187
- @torch.no_grad()
188
- def __call__(
189
- self,
190
- prompt: Union[str, List[str]],
191
- image: Optional[Union[Image.Image, List[Image.Image]]] = None,
192
- coef: float = 0.5, # strength: 1.0 - полный шум, 0.0 - оригинал
193
- negative_prompt: Optional[Union[str, List[str]]] = None,
194
- height: int = 1024,
195
- width: int = 1024,
196
- num_inference_steps: int = 40,
197
- guidance_scale: float = 4.0,
198
- generator: Optional[torch.Generator] = None,
199
- seed: Optional[int] = None,
200
- output_type: str = "pil",
201
- return_dict: bool = True,
202
- structure_preservation: float = 0.2, # Новый параметр: сохранение структуры 0-1
203
- **kwargs,
204
- ):
205
- device = self.device
206
- dtype = self.unet.dtype
207
-
208
- if generator is None and seed is not None:
209
- if torch.cuda.is_available():
210
- generator = torch.Generator(device=device)
211
- else:
212
- generator = torch.Generator()
213
- generator.manual_seed(seed)
214
-
215
- # 1. Encode Prompt
216
- text_embeddings, attention_mask = self.encode_prompt(
217
- prompt, negative_prompt, device, dtype
218
- )
219
- batch_size = 1 if isinstance(prompt, str) else len(prompt)
220
-
221
- # 2. Настройка таймстепов - ВСЕГДА используем ВСЕ шаги
222
- self.scheduler.set_timesteps(num_inference_steps, device=device)
223
- timesteps = self.scheduler.timesteps # Все 40 шагов
224
-
225
- #print(f"Используем ВСЕ шаги: {len(timesteps)}")
226
- #print(f"Диапазон таймстепов: [{timesteps[0].item():.3f}, {timesteps[-1].item():.3f}]")
227
- #print(f"Коэффициент смешивания (coef): {coef}")
228
-
229
- # 3. Обработка img2img с структурированным шумом
230
- if image is not None:
231
- # Подготовка изображения
232
- if isinstance(image, Image.Image):
233
- image_tensor = self.preprocess_image(image, width, height).to(
234
- device=device, dtype=self.vae.dtype
235
- )
236
- else:
237
- image_tensor = self.preprocess_image(image[0], width, height).to(
238
- device=device, dtype=self.vae.dtype
239
- )
240
-
241
- # Кодируем в латенты
242
- latents_clean = self.vae.encode(image_tensor).latent_dist.sample(generator=generator)
243
- vae_scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
244
- vae_shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
245
- latents_clean = (latents_clean - vae_shift_factor) / vae_scaling_factor
246
- latents_clean = latents_clean.to(dtype=dtype)
247
-
248
- # Автоматический расчет cutoff_radius на основе structure_preservation
249
- latent_height, latent_width = latents_clean.shape[-2], latents_clean.shape[-1]
250
- max_radius = min(latent_height, latent_width) / 2
251
- cutoff_radius = max_radius * structure_preservation
252
-
253
- # Генерируем структурированный шум
254
- structured_noise = self.generate_structured_noise(
255
- image_latents=latents_clean,
256
- cutoff_radius=cutoff_radius,
257
- transition_width=2.0,
258
- noise_std=1.0
259
- )
260
-
261
- # Нормализуем шум
262
- current_std = torch.std(structured_noise)
263
- if current_std > 0:
264
- structured_noise = structured_noise / current_std
265
-
266
- # КЛЮЧЕВОЕ ИЗМЕНЕНИЕ: Простое линейное смешивание
267
- # coef=0.0 -> 100% оригинал, 0% шум
268
- # coef=1.0 -> 0% оригинал, 100% шум
269
- print(f"Смешивание: {100*(1-coef):.1f}% оригинал + {100*coef:.1f}% шум")
270
-
271
- # Важно: инвертируем coef, если хотим, чтобы coef=0.1 давал слабое изменение
272
- # coef=0.1 -> 90% оригинал + 10% шум
273
- # coef=0.9 -> 10% оригинал + 90% шум
274
- latents = (1.0 - coef) * latents_clean + coef * structured_noise
275
-
276
- else:
277
- # TXT2IMG: начинаем с чистого шума (coef=1.0 эквивалентно)
278
- vae_scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
279
- vae_shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
280
-
281
- latent_height = height // self.vae_scale_factor
282
- latent_width = width // self.vae_scale_factor
283
-
284
- latents = torch.randn(
285
- (batch_size, self.unet.config.in_channels, latent_height, latent_width),
286
- generator=generator,
287
- device=device,
288
- dtype=dtype
289
- )
290
-
291
- # 4. Denoising Loop
292
- for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
293
- # CFG preparation
294
- latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
295
-
296
- # Predict flow
297
- model_out = self.unet(
298
- latent_model_input,
299
- t,
300
- encoder_hidden_states=text_embeddings,
301
- encoder_attention_mask=attention_mask,
302
- return_dict=False,
303
- )[0]
304
-
305
- # CFG
306
- if guidance_scale > 1:
307
- flow_uncond, flow_cond = model_out.chunk(2)
308
- model_out = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
309
-
310
- # Euler step для flow matching
311
- latents = self.scheduler.step(model_out, t, latents, return_dict=False)[0]
312
-
313
- # 5. Decode
314
- if output_type == "latent":
315
- return SdxsPipelineOutput(images=latents)
316
-
317
- # Масштабируем обратно для VAE
318
- latents = latents * vae_scaling_factor + vae_shift_factor
319
- image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
320
-
321
-
322
- # Нормализуем к [0, 1] для PIL
323
- image_output = (image_output.clamp(-1, 1) + 1) / 2
324
- image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
325
-
326
- if output_type == "pil":
327
- image_np = (image_np * 255).round().astype("uint8")
328
- images = [Image.fromarray(img) for img in image_np]
329
- else:
330
- images = image_np
331
-
332
- if not return_dict:
333
- return images
334
-
335
- return SdxsPipelineOutput(images=images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline_sdxs.py CHANGED
@@ -100,7 +100,7 @@ class SdxsPipeline(DiffusionPipeline):
100
  self,
101
  prompt: Union[str, List[str]],
102
  image: Optional[Union[Image.Image, List[Image.Image]]] = None,
103
- coef: float = 0.5, # ← strength (0.0 = оригинал, 1.0 = полный шум)
104
  negative_prompt: Optional[Union[str, List[str]]] = None,
105
  height: int = 1024,
106
  width: int = 1024,
 
100
  self,
101
  prompt: Union[str, List[str]],
102
  image: Optional[Union[Image.Image, List[Image.Image]]] = None,
103
+ coef: float = 0.97, # ← strength (0.0 = оригинал, 1.0 = полный шум)
104
  negative_prompt: Optional[Union[str, List[str]]] = None,
105
  height: int = 1024,
106
  width: int = 1024,
samples/unet_1.3b_320x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 07922452d06ae4d5df6fa749ee6449122c8dceb39442aace81e9a059e4e73792
  • Pointer size: 131 Bytes
  • Size of remote file: 365 kB

Git LFS Details

  • SHA256: 659e43cffc99f9b878d5bcd1a9c32753e15cd50c16e4497be4a083869870bf78
  • Pointer size: 131 Bytes
  • Size of remote file: 372 kB
samples/unet_1.3b_352x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 5018ea71663135ed24699cb312f0d2202461887439dc7f7cfd55f87fd7febdbd
  • Pointer size: 131 Bytes
  • Size of remote file: 475 kB

Git LFS Details

  • SHA256: b71316f15ad24605f5f8cb8bee5a71f72d4e63433514ed03215680e5e59e03d0
  • Pointer size: 131 Bytes
  • Size of remote file: 463 kB
samples/unet_1.3b_384x640_0.jpg CHANGED

Git LFS Details

  • SHA256: ce31d1b72c15c5dec26bd9742450c9d4b82251e1537ca4e82da9118fb2b84b89
  • Pointer size: 130 Bytes
  • Size of remote file: 49.4 kB

Git LFS Details

  • SHA256: 3ae9686567a0f789db23e2289061639cbe03695f7ec3baead9f0c87c5ee564bd
  • Pointer size: 130 Bytes
  • Size of remote file: 42.9 kB
samples/unet_1.3b_416x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 2981cb631fe07fec4d34d65cd9743f8ef775a9be455e766f8216a8e8d77a9f46
  • Pointer size: 131 Bytes
  • Size of remote file: 395 kB

Git LFS Details

  • SHA256: 9f897f50cae214e09b971999833077c85995d779847ae0a1cd265a2a5620461f
  • Pointer size: 131 Bytes
  • Size of remote file: 410 kB
samples/unet_1.3b_448x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 6f6be66ba7b15077e03892832494a28d33c07f250e6d6f3e79b597e37922691b
  • Pointer size: 130 Bytes
  • Size of remote file: 57.1 kB

Git LFS Details

  • SHA256: 6e92ce24c77a31d6c1fd5c8ad7f36bb67f3be7cb16362a7a8f522e34670158f4
  • Pointer size: 130 Bytes
  • Size of remote file: 56.6 kB
samples/unet_1.3b_480x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 4b0412a2d715977208259d31f87e7b0544652b8f922db80c943f48c77c352efc
  • Pointer size: 131 Bytes
  • Size of remote file: 377 kB

Git LFS Details

  • SHA256: dc9554c7b4de99af58003f8fb82eda695e11db4dd8bf7d90b4146524d50665ed
  • Pointer size: 131 Bytes
  • Size of remote file: 387 kB
samples/unet_1.3b_512x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 530fb1a49ea89c5f38a00632b035788d4ef2daacf00b0c47cce9d45555fa780f
  • Pointer size: 131 Bytes
  • Size of remote file: 378 kB

Git LFS Details

  • SHA256: d0e722dd9e67543e5254927a8708737b43c41963d90fbb5f2deecb8ade370b67
  • Pointer size: 131 Bytes
  • Size of remote file: 359 kB
samples/unet_1.3b_544x640_0.jpg CHANGED

Git LFS Details

  • SHA256: aefeb62a8532d968d4cc5e13cda909d7416ac78373049cdb0a321da487ae7d6e
  • Pointer size: 131 Bytes
  • Size of remote file: 145 kB

Git LFS Details

  • SHA256: fd84c7451640166f445e766e8e501fd03f01dba90b8142e1698bfb368736537c
  • Pointer size: 131 Bytes
  • Size of remote file: 135 kB
samples/unet_1.3b_576x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 79494968e0e818aeff528a5e69423b7b34b784fe0ba16e18ee7b81cbdb1eeeed
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB

Git LFS Details

  • SHA256: 61e6a190f48d4c251a444f3c43114036347db7ea1f8ba4437eb624221505ea73
  • Pointer size: 131 Bytes
  • Size of remote file: 187 kB
samples/unet_1.3b_608x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 49f38527a7289f799a7520150382c88b08ea62cf48d4d20360ca54b0a8fc533a
  • Pointer size: 130 Bytes
  • Size of remote file: 68.9 kB

Git LFS Details

  • SHA256: d59184854278fbe08f1671ab4a7d33c60ef9836904d1d4b10452e180138e613a
  • Pointer size: 130 Bytes
  • Size of remote file: 68.9 kB
samples/unet_1.3b_640x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 19169f84c9991a249e279403b9a6cd6e0b576b42d8b6baedc73fd5629ad39016
  • Pointer size: 131 Bytes
  • Size of remote file: 260 kB

Git LFS Details

  • SHA256: 51677340dcd99d10c6c84417627a505be5c8f5fd4a53b0d71549f99fc945d26c
  • Pointer size: 131 Bytes
  • Size of remote file: 256 kB
samples/unet_1.3b_640x352_0.jpg CHANGED

Git LFS Details

  • SHA256: 26ef58ca23838de2ef382736edccf55e65e2f1071f01cdc8a8a40517aba355bd
  • Pointer size: 131 Bytes
  • Size of remote file: 688 kB

Git LFS Details

  • SHA256: 049205877afbfc05294dad3ddb141574466a800c57cf97abeb30e8a2af9d5131
  • Pointer size: 131 Bytes
  • Size of remote file: 698 kB
samples/unet_1.3b_640x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 047dbcf6f75784068ed23efd904c4e0da9f5fc78ae9ba282c3f223d0807574eb
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB

Git LFS Details

  • SHA256: 667123de67ae1fefb3d7a880f84482cd05f0f5d4a0f309b6a8afc4a14af40f0b
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
samples/unet_1.3b_640x416_0.jpg CHANGED

Git LFS Details

  • SHA256: 7bb9f61f1964340bbdd437bca164f6e1f792bc2909ddc46c99e21bbad9f06efc
  • Pointer size: 131 Bytes
  • Size of remote file: 253 kB

Git LFS Details

  • SHA256: 7f551eb2e78e8922769af08a1dcfff718e98b805387b8e7eda513772052cd91e
  • Pointer size: 131 Bytes
  • Size of remote file: 249 kB
samples/unet_1.3b_640x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 02e35cf855964fa36815537419d193135f0eabe81c7a2f4f50b85d0c017f970b
  • Pointer size: 131 Bytes
  • Size of remote file: 677 kB

Git LFS Details

  • SHA256: e93eda862756b206310f6f94bc39786c5ecf9a49d96e2e8b360962162187b26a
  • Pointer size: 131 Bytes
  • Size of remote file: 654 kB
samples/unet_1.3b_640x480_0.jpg CHANGED

Git LFS Details

  • SHA256: 3290923e437b7446ef12c6fc6a22be3eaff6e0ad7831a0d217bc60dc290afde2
  • Pointer size: 131 Bytes
  • Size of remote file: 266 kB

Git LFS Details

  • SHA256: 5c9f533250df54fb5c6b1fb2225c8fc382cea66aca3999379c231e89c2aa76dd
  • Pointer size: 131 Bytes
  • Size of remote file: 245 kB
samples/unet_1.3b_640x512_0.jpg CHANGED

Git LFS Details

  • SHA256: 109118912ffdcfd57bdedbddf261ea4c7b3a01f2425cef9e398246f0f67cede2
  • Pointer size: 131 Bytes
  • Size of remote file: 860 kB

Git LFS Details

  • SHA256: 0437ed62213f8b2a86d0cc4de83d101406887c25104772be6040ada6bb41bc1b
  • Pointer size: 131 Bytes
  • Size of remote file: 847 kB
samples/unet_1.3b_640x544_0.jpg CHANGED

Git LFS Details

  • SHA256: 1408b7978a31d7d771bafdcf652cb03a10e846589c542978759ce9b3a4909618
  • Pointer size: 131 Bytes
  • Size of remote file: 459 kB

Git LFS Details

  • SHA256: a0faf7c5ed5fd226c26e97334dfcd40c518224eba924d1c00c243d7b23e06acd
  • Pointer size: 131 Bytes
  • Size of remote file: 436 kB
samples/unet_1.3b_640x576_0.jpg CHANGED

Git LFS Details

  • SHA256: c2efea377d0a45508609cb0f5719f8584a19859335b4d993b2e1d69ae0a5fa52
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB

Git LFS Details

  • SHA256: 67f3809caaa1cd45f7dc70d7ada93cab1d0c68ea0b128868be8fdf1b6013ccd7
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
samples/unet_1.3b_640x608_0.jpg CHANGED

Git LFS Details

  • SHA256: 312add8b7469110aac5224b878b510dc4c5368ab47e49435e029150b8bd432d0
  • Pointer size: 131 Bytes
  • Size of remote file: 227 kB

Git LFS Details

  • SHA256: 19fa2ace6c58a9f7e2902005d080530999deef35c9d04d4194e634454879b478
  • Pointer size: 131 Bytes
  • Size of remote file: 223 kB
samples/unet_1.3b_640x640_0.jpg CHANGED

Git LFS Details

  • SHA256: e778921a9ebb2674a7139ac4417f3c238c673df469cbbe7a82c4f37bfda5ca89
  • Pointer size: 131 Bytes
  • Size of remote file: 514 kB

Git LFS Details

  • SHA256: ab410b747944754174039e1b8148b6fd38f9a553a8fb0fa74d1b95d07e885ca7
  • Pointer size: 131 Bytes
  • Size of remote file: 520 kB
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:28095b7ec189c24ef448aaf1f3e44fdf347533348d04d187f46b9599ff83a06c
3
- size 5721667
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:632553d43b6346999891e3f671bbf197a6f4e1bca39362387e48e3eb6f357630
3
+ size 5305844
unet/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a93a03ab94ebcdad5427326dae7fabbc74a7f46dca4a3804c3e5c11e667ff7e
3
- size 1848
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18ea6e42455fff8208c1a900f1e343224198079dc0814f90d4ff283209c3924a
3
+ size 1843
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d91907509f145bade2cea2809952d4bdfcbef609efb23ec606e5b9bb10239d2f
3
- size 5163635688
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e342af4061c1f6cce4214d23962c1033cf61f2cfee942c42661f72dc29b179a3
3
+ size 2581889304
unet_1.3b/config.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a93a03ab94ebcdad5427326dae7fabbc74a7f46dca4a3804c3e5c11e667ff7e
3
- size 1848
 
 
 
 
unet_1.3b/diffusion_pytorch_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0acc47c6590d2347959a9338ef56fefbd59b8b09b44bdf8db019d700e4cb3bef
3
- size 5163635688
 
 
 
 
vae/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
vae/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:aca6cfc9bfab2f27ae086c52df8bc1d7b4ebede07c3d6d1b29c47e39cca1b753
3
- size 752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12d252abeac321629cb81b908a6b49d1bf8d7f60247e00b6be83fd03b0f98b39
3
+ size 852
vae/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b6342c154910ff1065c25915b107ab1326f146c44055f3a4119ffb95f5159a4f
3
- size 382598708
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19bd4d341b7cc8d20893e2f257760c8b964bad447a243263191c4be1c89c1aaf
3
+ size 427466716
vae/train_vae_fdl.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
+ from collections import deque
29
+
30
+ # --------------------------- Параметры ---------------------------
31
+ ds_path = "/workspace/d23"
32
+ project = "vae10"
33
+ batch_size = 1
34
+ base_learning_rate = 6e-6
35
+ min_learning_rate = 7e-7
36
+ num_epochs = 2
37
+ sample_interval_share = 25
38
+ use_wandb = True
39
+ save_model = True
40
+ use_decay = True
41
+ optimizer_type = "adam8bit"
42
+ dtype = torch.float32
43
+
44
+ model_resolution = 512 #288
45
+ high_resolution = 1024 #576
46
+ limit = 0
47
+ save_barrier = 1.3
48
+ warmup_percent = 0.005
49
+ percentile_clipping = 99
50
+ beta2 = 0.997
51
+ eps = 1e-8
52
+ clip_grad_norm = 1.0
53
+ mixed_precision = "no"
54
+ gradient_accumulation_steps = 1
55
+ generated_folder = "samples"
56
+ save_as = "vae10"
57
+ num_workers = 0
58
+ device = None
59
+ torch.backends.cuda.matmul.allow_tf32 = True
60
+ torch.backends.cudnn.allow_tf32 = True
61
+ # Включение Flash Attention 2/SDPA #MAX_JOBS=4 pip install flash-attn --no-build-isolation
62
+ torch.backends.cuda.enable_flash_sdp(True)
63
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
64
+ torch.backends.cuda.enable_math_sdp(False)
65
+
66
+ # --- Режимы обучения ---
67
+ # QWEN: учим только декодер
68
+ train_decoder_only = True
69
+ train_up_only = False
70
+ full_training = False # если True — учим весь VAE и добавляем KL (ниже)
71
+ kl_ratio = 0.00
72
+
73
+ # Доли лоссов
74
+ loss_ratios = {
75
+ "lpips": 0.70,#0.50,
76
+ "fdl" : 0.10,#0.25,
77
+ "edge": 0.05,
78
+ "mse": 0.10,
79
+ "mae": 0.05,
80
+ "kl": 0.00, # активируем при full_training=True
81
+ }
82
+ median_coeff_steps = 250
83
+
84
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
85
+
86
+ # QWEN: конфиг загрузки модели
87
+ vae_kind = "kl" # "qwen" или "kl" (обычный)
88
+
89
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
90
+
91
+ accelerator = Accelerator(
92
+ mixed_precision=mixed_precision,
93
+ gradient_accumulation_steps=gradient_accumulation_steps
94
+ )
95
+ device = accelerator.device
96
+
97
+ # reproducibility
98
+ seed = int(datetime.now().strftime("%Y%m%d")) + 13
99
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
100
+ torch.backends.cudnn.benchmark = False
101
+
102
+ # --------------------------- WandB ---------------------------
103
+ if use_wandb and accelerator.is_main_process:
104
+ wandb.init(project=project, config={
105
+ "batch_size": batch_size,
106
+ "base_learning_rate": base_learning_rate,
107
+ "num_epochs": num_epochs,
108
+ "optimizer_type": optimizer_type,
109
+ "model_resolution": model_resolution,
110
+ "high_resolution": high_resolution,
111
+ "gradient_accumulation_steps": gradient_accumulation_steps,
112
+ "train_decoder_only": train_decoder_only,
113
+ "full_training": full_training,
114
+ "kl_ratio": kl_ratio,
115
+ "vae_kind": vae_kind,
116
+ })
117
+
118
+ # --------------------------- VAE ---------------------------
119
+ def get_core_model(model):
120
+ m = model
121
+ # если модель уже обёрнута torch.compile
122
+ if hasattr(m, "_orig_mod"):
123
+ m = m._orig_mod
124
+ return m
125
+
126
+ def is_video_vae(model) -> bool:
127
+ # WAN/Qwen — это видео-VAEs
128
+ if vae_kind in ("wan", "qwen"):
129
+ return True
130
+ # fallback по структуре (если понадобится)
131
+ try:
132
+ core = get_core_model(model)
133
+ enc = getattr(core, "encoder", None)
134
+ conv_in = getattr(enc, "conv_in", None)
135
+ w = getattr(conv_in, "weight", None)
136
+ if isinstance(w, torch.nn.Parameter):
137
+ return w.ndim == 5
138
+ except Exception:
139
+ pass
140
+ return False
141
+
142
+ # загрузка
143
+ if vae_kind == "qwen":
144
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
145
+ else:
146
+ if vae_kind == "wan":
147
+ vae = AutoencoderKLWan.from_pretrained(project)
148
+ else:
149
+ # старое поведение (пример)
150
+ if model_resolution==high_resolution:
151
+ vae = AutoencoderKL.from_pretrained(project)
152
+ else:
153
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
154
+
155
+ vae = vae.to(dtype)
156
+
157
+ # torch.compile (опцион��льно)
158
+ if hasattr(torch, "compile"):
159
+ try:
160
+ vae = torch.compile(vae)
161
+ except Exception as e:
162
+ print(f"[WARN] torch.compile failed: {e}")
163
+
164
+ # --------------------------- Freeze/Unfreeze ---------------------------
165
+ core = get_core_model(vae)
166
+
167
+ for p in core.parameters():
168
+ p.requires_grad = False
169
+
170
+ unfrozen_param_names = []
171
+
172
+ if full_training and not train_decoder_only:
173
+ for name, p in core.named_parameters():
174
+ p.requires_grad = True
175
+ unfrozen_param_names.append(name)
176
+ loss_ratios["kl"] = float(kl_ratio)
177
+ trainable_module = core
178
+ else:
179
+ # учим только 0-й блок декодера + post_quant_conv
180
+ if hasattr(core, "decoder"):
181
+ if train_up_only:#hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
182
+ # --- только 0-й up_block ---
183
+ for name, p in core.decoder.up_blocks[0].named_parameters():
184
+ p.requires_grad = True
185
+ unfrozen_param_names.append(f"{name}")
186
+ else:
187
+ print("Decoder — fallback to full decoder")
188
+ for name, p in core.decoder.named_parameters():
189
+ p.requires_grad = True
190
+ unfrozen_param_names.append(f"decoder.{name}")
191
+ if hasattr(core, "post_quant_conv"):
192
+ for name, p in core.post_quant_conv.named_parameters():
193
+ p.requires_grad = True
194
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
195
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
196
+
197
+
198
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
199
+ for nm in unfrozen_param_names[:200]:
200
+ print(" ", nm)
201
+
202
+ # --------------------------- Датасет ---------------------------
203
+ class PngFolderDataset(Dataset):
204
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
205
+ self.root_dir = root_dir
206
+ self.resolution = resolution
207
+ self.paths = []
208
+ for root, _, files in os.walk(root_dir):
209
+ for fname in files:
210
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
211
+ self.paths.append(os.path.join(root, fname))
212
+ if limit:
213
+ self.paths = self.paths[:limit]
214
+ valid = []
215
+ for p in self.paths:
216
+ try:
217
+ with Image.open(p) as im:
218
+ im.verify()
219
+ valid.append(p)
220
+ except (OSError, UnidentifiedImageError):
221
+ continue
222
+ self.paths = valid
223
+ if len(self.paths) == 0:
224
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
225
+ random.shuffle(self.paths)
226
+
227
+ def __len__(self):
228
+ return len(self.paths)
229
+
230
+ def __getitem__(self, idx):
231
+ p = self.paths[idx % len(self.paths)]
232
+ with Image.open(p) as img:
233
+ img = img.convert("RGB")
234
+ if not resize_long_side or resize_long_side <= 0:
235
+ return img
236
+ w, h = img.size
237
+ long = max(w, h)
238
+ if long <= resize_long_side:
239
+ return img
240
+ scale = resize_long_side / float(long)
241
+ new_w = int(round(w * scale))
242
+ new_h = int(round(h * scale))
243
+ return img.resize((new_w, new_h), Image.BICUBIC)
244
+
245
+ def random_crop(img, sz):
246
+ w, h = img.size
247
+ if w < sz or h < sz:
248
+ img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
249
+ x = random.randint(0, max(1, img.width - sz))
250
+ y = random.randint(0, max(1, img.height - sz))
251
+ return img.crop((x, y, x + sz, y + sz))
252
+
253
+ tfm = transforms.Compose([
254
+ transforms.ToTensor(),
255
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
256
+ ])
257
+
258
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
259
+ print("len(dataset)",len(dataset))
260
+ if len(dataset) < batch_size:
261
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
262
+
263
+ def collate_fn(batch):
264
+ imgs = []
265
+ for img in batch:
266
+ img = random_crop(img, high_resolution)
267
+ imgs.append(tfm(img))
268
+ return torch.stack(imgs)
269
+
270
+ dataloader = DataLoader(
271
+ dataset,
272
+ batch_size=batch_size,
273
+ shuffle=True,
274
+ collate_fn=collate_fn,
275
+ num_workers=num_workers,
276
+ pin_memory=True,
277
+ drop_last=True
278
+ )
279
+
280
+ # --------------------------- Оптимизатор ---------------------------
281
+ def get_param_groups(module, weight_decay=0.001):
282
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
283
+ decay_params, no_decay_params = [], []
284
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
285
+ if not p.requires_grad:
286
+ continue
287
+ if any(nd in n for nd in no_decay):
288
+ no_decay_params.append(p)
289
+ else:
290
+ decay_params.append(p)
291
+ return [
292
+ {"params": decay_params, "weight_decay": weight_decay},
293
+ {"params": no_decay_params, "weight_decay": 0.0},
294
+ ]
295
+
296
+ def get_param_groups(module, weight_decay=0.001):
297
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
298
+ decay_params, no_decay_params = [], []
299
+ for n, p in module.named_parameters():
300
+ if not p.requires_grad:
301
+ continue
302
+ n_l = n.lower()
303
+ if any(t in n_l for t in no_decay_tokens):
304
+ no_decay_params.append(p)
305
+ else:
306
+ decay_params.append(p)
307
+ return [
308
+ {"params": decay_params, "weight_decay": weight_decay},
309
+ {"params": no_decay_params, "weight_decay": 0.0},
310
+ ]
311
+
312
+ def create_optimizer(name, param_groups):
313
+ if name == "adam8bit":
314
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
315
+ raise ValueError(name)
316
+
317
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
318
+ optimizer = create_optimizer(optimizer_type, param_groups)
319
+
320
+ # --------------------------- LR schedule ---------------------------
321
+ batches_per_epoch = len(dataloader)
322
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
323
+ total_steps = steps_per_epoch * num_epochs
324
+
325
+ def lr_lambda(step):
326
+ if not use_decay:
327
+ return 1.0
328
+ x = float(step) / float(max(1, total_steps))
329
+ warmup = float(warmup_percent)
330
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
331
+ if x < warmup:
332
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
333
+ decay_ratio = (x - warmup) / (1.0 - warmup)
334
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
335
+
336
+ scheduler = LambdaLR(optimizer, lr_lambda)
337
+
338
+ # Подготовка
339
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
340
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
341
+
342
+ # fdl
343
+ fdl_loss = FDL_loss()
344
+ fdl_loss = fdl_loss.to(accelerator.device)
345
+
346
+ # --------------------------- LPIPS и вспомогательные ---------------------------
347
+ _lpips_net = None
348
+ def _get_lpips():
349
+ global _lpips_net
350
+ if _lpips_net is None:
351
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
352
+ return _lpips_net
353
+
354
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
355
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
356
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
357
+ C = x.shape[1]
358
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
359
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
360
+ gx = F.conv2d(x, kx, padding=1, groups=C)
361
+ gy = F.conv2d(x, ky, padding=1, groups=C)
362
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
363
+
364
+ class MedianLossNormalizer:
365
+ def __init__(self, desired_ratios: dict, window_steps: int):
366
+ s = sum(desired_ratios.values())
367
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
368
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
369
+ self.window = window_steps
370
+
371
+ def update_and_total(self, abs_losses: dict):
372
+ for k, v in abs_losses.items():
373
+ if k in self.buffers:
374
+ self.buffers[k].append(float(v.detach().abs().cpu()))
375
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
376
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
377
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
378
+ return total, coeffs, meds
379
+
380
+ if full_training and not train_decoder_only:
381
+ loss_ratios["kl"] = float(kl_ratio)
382
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
383
+
384
+ # --------------------------- Сэмплы ---------------------------
385
+ @torch.no_grad()
386
+ def get_fixed_samples(n=3):
387
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
388
+ pil_imgs = [dataset[i] for i in idx]
389
+ tensors = []
390
+ for img in pil_imgs:
391
+ img = random_crop(img, high_resolution)
392
+ tensors.append(tfm(img))
393
+ return torch.stack(tensors).to(accelerator.device, dtype)
394
+
395
+ fixed_samples = get_fixed_samples()
396
+
397
+ @torch.no_grad()
398
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
399
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
400
+ return Image.fromarray(arr)
401
+
402
+
403
+ @torch.no_grad()
404
+ def generate_and_save_samples(step=None):
405
+ try:
406
+ #temp_vae = accelerator.unwrap_model(vae).eval()
407
+ if hasattr(vae, "module"):
408
+ # Если это DDP или DistributedDataParallel
409
+ unwrapped_vae = vae.module
410
+ else:
411
+ unwrapped_vae = vae
412
+
413
+ # Если использовался torch.compile, достаем оригинал
414
+ if hasattr(unwrapped_vae, "_orig_mod"):
415
+ temp_vae = unwrapped_vae._orig_mod
416
+ else:
417
+ temp_vae = unwrapped_vae
418
+
419
+ temp_vae = temp_vae.eval()
420
+ lpips_net = _get_lpips()
421
+ with torch.no_grad():
422
+ orig_high = fixed_samples
423
+ orig_low = F.interpolate(
424
+ orig_high,
425
+ size=(model_resolution, model_resolution),
426
+ mode="bilinear",
427
+ align_corners=False
428
+ )
429
+ model_dtype = next(temp_vae.parameters()).dtype
430
+ orig_low = orig_low.to(dtype=model_dtype)
431
+
432
+ # Encode/decode с учётом видео-режима
433
+ if is_video_vae(temp_vae):
434
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
435
+ enc = temp_vae.encode(x_in)
436
+ latents_mean = enc.latent_dist.mean
437
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
438
+ rec = dec.squeeze(2) # [B,3,H,W]
439
+ else:
440
+ enc = temp_vae.encode(orig_low)
441
+ latents_mean = enc.latent_dist.mean
442
+ rec = temp_vae.decode(latents_mean).sample
443
+
444
+ # Подгон размеров, если надо
445
+ #if rec.shape[-2:] != orig_high.shape[-2:]:
446
+ # rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
447
+
448
+ # Сохраняем все real/decoded
449
+ for i in range(rec.shape[0]):
450
+ real_img = _to_pil_uint8(orig_high[i])
451
+ dec_img = _to_pil_uint8(rec[i])
452
+ real_img.save(f"{generated_folder}/sample_real_{i}.png")
453
+ dec_img.save(f"{generated_folder}/sample_decoded_{i}.png")
454
+
455
+ # LPIPS
456
+ lpips_scores = []
457
+ for i in range(rec.shape[0]):
458
+ orig_full = orig_high[i:i+1].to(torch.float32)
459
+ rec_full = rec[i:i+1].to(torch.float32)
460
+ #if rec_full.shape[-2:] != orig_full.shape[-2:]:
461
+ # rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
462
+ lpips_val = lpips_net(orig_full, rec_full).item()
463
+ lpips_scores.append(lpips_val)
464
+ avg_lpips = float(np.mean(lpips_scores))
465
+
466
+ # W&B логирование
467
+ if use_wandb and accelerator.is_main_process:
468
+ log_data = {"lpips_mean": avg_lpips}
469
+ for i in range(rec.shape[0]):
470
+ log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.png", caption=f"real_{i}")
471
+ log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.png", caption=f"decoded_{i}")
472
+ wandb.log(log_data, step=step)
473
+
474
+ finally:
475
+ gc.collect()
476
+ torch.cuda.empty_cache()
477
+
478
+
479
+ if accelerator.is_main_process and save_model:
480
+ print("Генерация сэмплов до старта обучения...")
481
+ generate_and_save_samples(0)
482
+
483
+ accelerator.wait_for_everyone()
484
+
485
+ # --------------------------- Тренировка ---------------------------
486
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
487
+ global_step = 0
488
+ min_loss = float("inf")
489
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
490
+
491
+ for epoch in range(num_epochs):
492
+ vae.train()
493
+ batch_losses, batch_grads = [], []
494
+ track_losses = {k: [] for k in loss_ratios.keys()}
495
+
496
+ for imgs in dataloader:
497
+ with accelerator.accumulate(vae):
498
+ imgs = imgs.to(accelerator.device)
499
+
500
+ if high_resolution != model_resolution:
501
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution),mode="area") # mode="bilinear", align_corners=False)
502
+ else:
503
+ imgs_low = imgs
504
+
505
+ model_dtype = next(vae.parameters()).dtype
506
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
507
+
508
+ # Вместо: current_vae = accelerator.unwrap_model(vae)
509
+ unwrapped = vae.module if hasattr(vae, "module") else vae
510
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
511
+
512
+
513
+ # QWEN: encode/decode с T=1
514
+ if is_video_vae(current_vae):
515
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
516
+ enc = current_vae.encode(x_in)
517
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
518
+ dec = current_vae.decode(latents).sample # [B,3,1,H,W]
519
+ rec = dec.squeeze(2) # [B,3,H,W]
520
+ else:
521
+ enc = current_vae.encode(imgs_low_model)
522
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
523
+ rec = current_vae.decode(latents).sample
524
+
525
+ #if rec.shape[-2:] != imgs.shape[-2:]:
526
+ # rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
527
+
528
+ rec_f32 = rec.to(torch.float32)
529
+ imgs_f32 = imgs.to(torch.float32)
530
+
531
+ abs_losses = {
532
+ "mae": F.l1_loss(rec_f32, imgs_f32),
533
+ "mse": F.mse_loss(rec_f32, imgs_f32),
534
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
535
+ "fdl": fdl_loss(rec_f32, imgs_f32),
536
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
537
+ }
538
+
539
+ if full_training and not train_decoder_only:
540
+ mean = enc.latent_dist.mean
541
+ logvar = enc.latent_dist.logvar
542
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
543
+ abs_losses["kl"] = kl
544
+ else:
545
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
546
+
547
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
548
+
549
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
550
+ raise RuntimeError("NaN/Inf loss")
551
+
552
+ accelerator.backward(total_loss)
553
+
554
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
555
+ if accelerator.sync_gradients:
556
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
557
+ optimizer.step()
558
+ scheduler.step()
559
+ optimizer.zero_grad(set_to_none=True)
560
+ global_step += 1
561
+ progress.update(1)
562
+
563
+ if accelerator.is_main_process:
564
+ try:
565
+ current_lr = optimizer.param_groups[0]["lr"]
566
+ except Exception:
567
+ current_lr = scheduler.get_last_lr()[0]
568
+
569
+ batch_losses.append(total_loss.detach().item())
570
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
571
+ for k, v in abs_losses.items():
572
+ track_losses[k].append(float(v.detach().item()))
573
+
574
+ if use_wandb and accelerator.sync_gradients:
575
+ log_dict = {
576
+ "total_loss": float(total_loss.detach().item()),
577
+ "learning_rate": current_lr,
578
+ "epoch": epoch,
579
+ "grad_norm": batch_grads[-1],
580
+ }
581
+ for k, v in abs_losses.items():
582
+ log_dict[f"loss_{k}"] = float(v.detach().item())
583
+ for k in coeffs:
584
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
585
+ log_dict[f"median_{k}"] = float(meds[k])
586
+ wandb.log(log_dict, step=global_step)
587
+
588
+ if global_step > 0 and global_step % sample_interval == 0:
589
+ if accelerator.is_main_process:
590
+ generate_and_save_samples(global_step)
591
+ accelerator.wait_for_everyone()
592
+
593
+ n_micro = sample_interval * gradient_accumulation_steps
594
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
595
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
596
+
597
+ if accelerator.is_main_process:
598
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
599
+ if save_model and avg_loss < min_loss * save_barrier:
600
+ min_loss = avg_loss
601
+ unwrapped = vae.module if hasattr(vae, "module") else vae
602
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
603
+ current_vae.save_pretrained(save_as)
604
+ if use_wandb:
605
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
606
+
607
+ if accelerator.is_main_process:
608
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
609
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
610
+ if use_wandb:
611
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
612
+
613
+ # --------------------------- Финальное сохранение ---------------------------
614
+ if accelerator.is_main_process:
615
+ print("Training finished – saving final model")
616
+ if save_model:
617
+ unwrapped = vae.module if hasattr(vae, "module") else vae
618
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
619
+ current_vae.save_pretrained(save_as)
620
+
621
+ accelerator.free_memory()
622
+ if torch.distributed.is_initialized():
623
+ torch.distributed.destroy_process_group()
624
+ print("Готово!")