1.3b
Browse files- README.md +2 -2
- unet1.3b.ipynb → media/girl.jpg +2 -2
- media/result_grid.jpg +2 -2
- media/result_grid_base.jpg +0 -3
- pipeline_sdxs-Copy1.py +0 -335
- pipeline_sdxs.py +1 -1
- samples/unet_1.3b_320x640_0.jpg +2 -2
- samples/unet_1.3b_352x640_0.jpg +2 -2
- samples/unet_1.3b_384x640_0.jpg +2 -2
- samples/unet_1.3b_416x640_0.jpg +2 -2
- samples/unet_1.3b_448x640_0.jpg +2 -2
- samples/unet_1.3b_480x640_0.jpg +2 -2
- samples/unet_1.3b_512x640_0.jpg +2 -2
- samples/unet_1.3b_544x640_0.jpg +2 -2
- samples/unet_1.3b_576x640_0.jpg +2 -2
- samples/unet_1.3b_608x640_0.jpg +2 -2
- samples/unet_1.3b_640x320_0.jpg +2 -2
- samples/unet_1.3b_640x352_0.jpg +2 -2
- samples/unet_1.3b_640x384_0.jpg +2 -2
- samples/unet_1.3b_640x416_0.jpg +2 -2
- samples/unet_1.3b_640x448_0.jpg +2 -2
- samples/unet_1.3b_640x480_0.jpg +2 -2
- samples/unet_1.3b_640x512_0.jpg +2 -2
- samples/unet_1.3b_640x544_0.jpg +2 -2
- samples/unet_1.3b_640x576_0.jpg +2 -2
- samples/unet_1.3b_640x608_0.jpg +2 -2
- samples/unet_1.3b_640x640_0.jpg +2 -2
- test.ipynb +2 -2
- unet/config.json +2 -2
- unet/diffusion_pytorch_model.safetensors +2 -2
- unet_1.3b/config.json +0 -3
- unet_1.3b/diffusion_pytorch_model.safetensors +0 -3
- vae/.gitattributes +35 -0
- vae/config.json +2 -2
- vae/diffusion_pytorch_model.safetensors +2 -2
- vae/train_vae_fdl.py +624 -0
README.md
CHANGED
|
@@ -8,11 +8,11 @@ datasets:
|
|
| 8 |
# Simple Diffusion XS
|
| 9 |
|
| 10 |
*XS Size, Excess Quality*
|
| 11 |
-

|
| 17 |
- VAE: 16ch16x(8x-enc/16x-dec)
|
| 18 |
- Speed: Sampling 100%|██████████| 40/40 [00:01<00:00, 30.72it/s] (1024x1280)
|
|
|
|
| 8 |
# Simple Diffusion XS
|
| 9 |
|
| 10 |
*XS Size, Excess Quality*
|
| 11 |
+

|
| 12 |
|
| 13 |
At AiArtLab, we strive to create a free, compact and fast model that can be trained on consumer graphics cards.
|
| 14 |
|
| 15 |
+
- Unet: 1.3b parameters
|
| 16 |
- Clip: [LongCLIP with 248 tokens](https://huggingface.co/zer0int/CLIP-KO-LITE-TypoAttack-Attn-Dropout-ViT-L-14)
|
| 17 |
- VAE: 16ch16x(8x-enc/16x-dec)
|
| 18 |
- Speed: Sampling 100%|██████████| 40/40 [00:01<00:00, 30.72it/s] (1024x1280)
|
unet1.3b.ipynb → media/girl.jpg
RENAMED
|
File without changes
|
media/result_grid.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
media/result_grid_base.jpg
DELETED
Git LFS Details
|
pipeline_sdxs-Copy1.py
DELETED
|
@@ -1,335 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import numpy as np
|
| 3 |
-
from PIL import Image
|
| 4 |
-
from typing import List, Union, Optional, Tuple
|
| 5 |
-
from dataclasses import dataclass
|
| 6 |
-
|
| 7 |
-
from diffusers import DiffusionPipeline
|
| 8 |
-
from diffusers.utils import BaseOutput
|
| 9 |
-
from tqdm import tqdm
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
-
class SdxsPipelineOutput(BaseOutput):
|
| 13 |
-
images: Union[List[Image.Image], np.ndarray]
|
| 14 |
-
|
| 15 |
-
class SdxsPipeline(DiffusionPipeline):
|
| 16 |
-
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler):
|
| 17 |
-
super().__init__()
|
| 18 |
-
self.register_modules(
|
| 19 |
-
vae=vae,
|
| 20 |
-
text_encoder=text_encoder,
|
| 21 |
-
tokenizer=tokenizer,
|
| 22 |
-
unet=unet,
|
| 23 |
-
scheduler=scheduler
|
| 24 |
-
)
|
| 25 |
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 26 |
-
|
| 27 |
-
def create_frequency_soft_cutoff_mask(self, height: int, width: int, cutoff_radius: float,
|
| 28 |
-
transition_width: float = 5.0, device: torch.device = None) -> torch.Tensor:
|
| 29 |
-
"""Создает плавную маску частотного среза для сохранения структуры."""
|
| 30 |
-
if device is None:
|
| 31 |
-
device = torch.device('cpu')
|
| 32 |
-
|
| 33 |
-
u = torch.arange(height, device=device)
|
| 34 |
-
v = torch.arange(width, device=device)
|
| 35 |
-
u, v = torch.meshgrid(u, v, indexing='ij')
|
| 36 |
-
|
| 37 |
-
center_u, center_v = height // 2, width // 2
|
| 38 |
-
frequency_radius = torch.sqrt((u - center_u)**2 + (v - center_v)**2)
|
| 39 |
-
|
| 40 |
-
mask = torch.exp(-(frequency_radius - cutoff_radius)**2 / (2 * transition_width**2))
|
| 41 |
-
mask = torch.where(frequency_radius <= cutoff_radius, torch.ones_like(mask), mask)
|
| 42 |
-
|
| 43 |
-
return mask
|
| 44 |
-
|
| 45 |
-
def generate_structured_noise(
|
| 46 |
-
self,
|
| 47 |
-
image_latents: torch.Tensor,
|
| 48 |
-
cutoff_radius: Optional[float] = None,
|
| 49 |
-
transition_width: float = 2.0,
|
| 50 |
-
noise_std: float = 1.0,
|
| 51 |
-
) -> torch.Tensor:
|
| 52 |
-
"""
|
| 53 |
-
Генерирует структурированный шум для латентов с сохранением низкочастотной структуры.
|
| 54 |
-
|
| 55 |
-
Args:
|
| 56 |
-
image_latents: Чистые латенты изображения [B, C, H, W]
|
| 57 |
-
cutoff_radius: Радиус среза частот (None = авто-расчет на основе coef)
|
| 58 |
-
transition_width: Ширина плавного перехода
|
| 59 |
-
noise_std: Стандартное отклонение шума
|
| 60 |
-
|
| 61 |
-
Returns:
|
| 62 |
-
Структурированный шум с той же размерностью
|
| 63 |
-
"""
|
| 64 |
-
batch_size, channels, height, width = image_latents.shape
|
| 65 |
-
device = image_latents.device
|
| 66 |
-
dtype = image_latents.dtype
|
| 67 |
-
|
| 68 |
-
# Автоматический расчет cutoff_radius если не задан
|
| 69 |
-
if cutoff_radius is None:
|
| 70 |
-
# Сохраняем больше низких частот для лучшей структуры
|
| 71 |
-
max_radius = min(height, width) / 2
|
| 72 |
-
cutoff_radius = max_radius * 0.7 # Сохраняем 70% низких частот
|
| 73 |
-
|
| 74 |
-
# Создаем частотную маску
|
| 75 |
-
freq_mask = self.create_frequency_soft_cutoff_mask(
|
| 76 |
-
height, width, cutoff_radius, transition_width, device
|
| 77 |
-
)
|
| 78 |
-
freq_mask = freq_mask.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
|
| 79 |
-
|
| 80 |
-
# Преобразуем латенты в частотную область
|
| 81 |
-
fft_image = torch.fft.fft2(image_latents, dim=(-2, -1))
|
| 82 |
-
fft_shifted = torch.fft.fftshift(fft_image, dim=(-2, -1))
|
| 83 |
-
|
| 84 |
-
# Извлекаем фазу изображения
|
| 85 |
-
image_phase = torch.angle(fft_shifted)
|
| 86 |
-
|
| 87 |
-
# Генерируем гауссовский шум
|
| 88 |
-
noise = torch.randn_like(image_latents) * noise_std
|
| 89 |
-
|
| 90 |
-
# Преобразуем шум в частотную область
|
| 91 |
-
fft_noise = torch.fft.fft2(noise, dim=(-2, -1))
|
| 92 |
-
fft_noise_shifted = torch.fft.fftshift(fft_noise, dim=(-2, -1))
|
| 93 |
-
|
| 94 |
-
# Извлекаем амплитуду шума
|
| 95 |
-
noise_magnitude = torch.abs(fft_noise_shifted)
|
| 96 |
-
noise_phase = torch.angle(fft_noise_shifted)
|
| 97 |
-
|
| 98 |
-
# Смешиваем фазы: низкие частоты - фаза изображения, высокие - фаза шума
|
| 99 |
-
mixed_phase = freq_mask * image_phase + (1 - freq_mask) * noise_phase
|
| 100 |
-
|
| 101 |
-
# Собираем обратно: амплитуда шума + смешанная фаза
|
| 102 |
-
fft_combined = noise_magnitude * torch.exp(1j * mixed_phase)
|
| 103 |
-
fft_unshifted = torch.fft.ifftshift(fft_combined, dim=(-2, -1))
|
| 104 |
-
|
| 105 |
-
# Обратное преобразование
|
| 106 |
-
structured_noise = torch.fft.ifft2(fft_unshifted, dim=(-2, -1))
|
| 107 |
-
structured_noise = torch.real(structured_noise)
|
| 108 |
-
|
| 109 |
-
# Нормализуе�� для сохранения статистики гауссовского шума
|
| 110 |
-
current_std = torch.std(structured_noise)
|
| 111 |
-
if current_std > 0:
|
| 112 |
-
structured_noise = structured_noise / current_std * noise_std
|
| 113 |
-
|
| 114 |
-
return structured_noise.to(dtype)
|
| 115 |
-
|
| 116 |
-
def preprocess_image(self, image: Image.Image, width: int, height: int):
|
| 117 |
-
"""Ресайз и центрированный кроп изображения для асимметричного VAE."""
|
| 118 |
-
# Для энкодера с масштабом 8
|
| 119 |
-
target_height = ((height // self.vae_scale_factor) * self.vae_scale_factor)//2
|
| 120 |
-
target_width = ((width // self.vae_scale_factor) * self.vae_scale_factor)//2
|
| 121 |
-
|
| 122 |
-
w, h = image.size
|
| 123 |
-
aspect_ratio = target_width / target_height
|
| 124 |
-
|
| 125 |
-
if w / h > aspect_ratio:
|
| 126 |
-
new_w = int(h * aspect_ratio)
|
| 127 |
-
left = (w - new_w) // 2
|
| 128 |
-
image = image.crop((left, 0, left + new_w, h))
|
| 129 |
-
else:
|
| 130 |
-
new_h = int(w / aspect_ratio)
|
| 131 |
-
top = (h - new_h) // 2
|
| 132 |
-
image = image.crop((0, top, w, top + new_h))
|
| 133 |
-
|
| 134 |
-
image = image.resize((target_width, target_height), resample=Image.LANCZOS)
|
| 135 |
-
image = np.array(image).astype(np.float32) / 255.0
|
| 136 |
-
image = image[None].transpose(0, 3, 1, 2) # [1, C, H, W]
|
| 137 |
-
image = torch.from_numpy(image)
|
| 138 |
-
return 2.0 * image - 1.0 # [-1, 1]
|
| 139 |
-
|
| 140 |
-
def encode_prompt(self, prompt, negative_prompt, device, dtype):
|
| 141 |
-
def get_single_encode(texts, is_negative=False):
|
| 142 |
-
if texts is None or texts == "":
|
| 143 |
-
hidden_dim = self.text_encoder.config.hidden_size
|
| 144 |
-
shape = (1, self.text_encoder.config.max_position_embeddings, hidden_dim)
|
| 145 |
-
emb = torch.zeros(shape, dtype=dtype, device=device)
|
| 146 |
-
mask = torch.ones((1, self.text_encoder.config.max_position_embeddings), dtype=torch.int64, device=device)
|
| 147 |
-
return emb, mask
|
| 148 |
-
|
| 149 |
-
if isinstance(texts, str):
|
| 150 |
-
texts = [texts]
|
| 151 |
-
|
| 152 |
-
with torch.no_grad():
|
| 153 |
-
toks = self.tokenizer(
|
| 154 |
-
texts,
|
| 155 |
-
padding="max_length",
|
| 156 |
-
max_length=self.text_encoder.config.max_position_embeddings,
|
| 157 |
-
truncation=True,
|
| 158 |
-
return_tensors="pt"
|
| 159 |
-
).to(device)
|
| 160 |
-
|
| 161 |
-
outputs = self.text_encoder(
|
| 162 |
-
input_ids=toks.input_ids,
|
| 163 |
-
attention_mask=toks.attention_mask,
|
| 164 |
-
output_hidden_states=True
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
layer_index = -2
|
| 168 |
-
prompt_embeds = outputs.hidden_states[layer_index]
|
| 169 |
-
final_layer_norm = self.text_encoder.text_model.final_layer_norm
|
| 170 |
-
prompt_embeds = final_layer_norm(prompt_embeds)
|
| 171 |
-
|
| 172 |
-
return prompt_embeds, toks.attention_mask
|
| 173 |
-
|
| 174 |
-
pos_embeds, pos_mask = get_single_encode(prompt)
|
| 175 |
-
neg_embeds, neg_mask = get_single_encode(negative_prompt, is_negative=True)
|
| 176 |
-
|
| 177 |
-
batch_size = pos_embeds.shape[0]
|
| 178 |
-
if neg_embeds.shape[0] != batch_size:
|
| 179 |
-
neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
|
| 180 |
-
neg_mask = neg_mask.repeat(batch_size, 1)
|
| 181 |
-
|
| 182 |
-
text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
|
| 183 |
-
final_mask = torch.cat([neg_mask, pos_mask], dim=0)
|
| 184 |
-
|
| 185 |
-
return text_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64)
|
| 186 |
-
|
| 187 |
-
@torch.no_grad()
|
| 188 |
-
def __call__(
|
| 189 |
-
self,
|
| 190 |
-
prompt: Union[str, List[str]],
|
| 191 |
-
image: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
| 192 |
-
coef: float = 0.5, # strength: 1.0 - полный шум, 0.0 - оригинал
|
| 193 |
-
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 194 |
-
height: int = 1024,
|
| 195 |
-
width: int = 1024,
|
| 196 |
-
num_inference_steps: int = 40,
|
| 197 |
-
guidance_scale: float = 4.0,
|
| 198 |
-
generator: Optional[torch.Generator] = None,
|
| 199 |
-
seed: Optional[int] = None,
|
| 200 |
-
output_type: str = "pil",
|
| 201 |
-
return_dict: bool = True,
|
| 202 |
-
structure_preservation: float = 0.2, # Новый параметр: сохранение структуры 0-1
|
| 203 |
-
**kwargs,
|
| 204 |
-
):
|
| 205 |
-
device = self.device
|
| 206 |
-
dtype = self.unet.dtype
|
| 207 |
-
|
| 208 |
-
if generator is None and seed is not None:
|
| 209 |
-
if torch.cuda.is_available():
|
| 210 |
-
generator = torch.Generator(device=device)
|
| 211 |
-
else:
|
| 212 |
-
generator = torch.Generator()
|
| 213 |
-
generator.manual_seed(seed)
|
| 214 |
-
|
| 215 |
-
# 1. Encode Prompt
|
| 216 |
-
text_embeddings, attention_mask = self.encode_prompt(
|
| 217 |
-
prompt, negative_prompt, device, dtype
|
| 218 |
-
)
|
| 219 |
-
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
| 220 |
-
|
| 221 |
-
# 2. Настройка таймстепов - ВСЕГДА используем ВСЕ шаги
|
| 222 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 223 |
-
timesteps = self.scheduler.timesteps # Все 40 шагов
|
| 224 |
-
|
| 225 |
-
#print(f"Используем ВСЕ шаги: {len(timesteps)}")
|
| 226 |
-
#print(f"Диапазон таймстепов: [{timesteps[0].item():.3f}, {timesteps[-1].item():.3f}]")
|
| 227 |
-
#print(f"Коэффициент смешивания (coef): {coef}")
|
| 228 |
-
|
| 229 |
-
# 3. Обработка img2img с структурированным шумом
|
| 230 |
-
if image is not None:
|
| 231 |
-
# Подготовка изображения
|
| 232 |
-
if isinstance(image, Image.Image):
|
| 233 |
-
image_tensor = self.preprocess_image(image, width, height).to(
|
| 234 |
-
device=device, dtype=self.vae.dtype
|
| 235 |
-
)
|
| 236 |
-
else:
|
| 237 |
-
image_tensor = self.preprocess_image(image[0], width, height).to(
|
| 238 |
-
device=device, dtype=self.vae.dtype
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
# Кодируем в латенты
|
| 242 |
-
latents_clean = self.vae.encode(image_tensor).latent_dist.sample(generator=generator)
|
| 243 |
-
vae_scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
|
| 244 |
-
vae_shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
|
| 245 |
-
latents_clean = (latents_clean - vae_shift_factor) / vae_scaling_factor
|
| 246 |
-
latents_clean = latents_clean.to(dtype=dtype)
|
| 247 |
-
|
| 248 |
-
# Автоматический расчет cutoff_radius на основе structure_preservation
|
| 249 |
-
latent_height, latent_width = latents_clean.shape[-2], latents_clean.shape[-1]
|
| 250 |
-
max_radius = min(latent_height, latent_width) / 2
|
| 251 |
-
cutoff_radius = max_radius * structure_preservation
|
| 252 |
-
|
| 253 |
-
# Генерируем структурированный шум
|
| 254 |
-
structured_noise = self.generate_structured_noise(
|
| 255 |
-
image_latents=latents_clean,
|
| 256 |
-
cutoff_radius=cutoff_radius,
|
| 257 |
-
transition_width=2.0,
|
| 258 |
-
noise_std=1.0
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
# Нормализуем шум
|
| 262 |
-
current_std = torch.std(structured_noise)
|
| 263 |
-
if current_std > 0:
|
| 264 |
-
structured_noise = structured_noise / current_std
|
| 265 |
-
|
| 266 |
-
# КЛЮЧЕВОЕ ИЗМЕНЕНИЕ: Простое линейное смешивание
|
| 267 |
-
# coef=0.0 -> 100% оригинал, 0% шум
|
| 268 |
-
# coef=1.0 -> 0% оригинал, 100% шум
|
| 269 |
-
print(f"Смешивание: {100*(1-coef):.1f}% оригинал + {100*coef:.1f}% шум")
|
| 270 |
-
|
| 271 |
-
# Важно: инвертируем coef, если хотим, чтобы coef=0.1 давал слабое изменение
|
| 272 |
-
# coef=0.1 -> 90% оригинал + 10% шум
|
| 273 |
-
# coef=0.9 -> 10% оригинал + 90% шум
|
| 274 |
-
latents = (1.0 - coef) * latents_clean + coef * structured_noise
|
| 275 |
-
|
| 276 |
-
else:
|
| 277 |
-
# TXT2IMG: начинаем с чистого шума (coef=1.0 эквивалентно)
|
| 278 |
-
vae_scaling_factor = getattr(self.vae.config, "scaling_factor", 1.0)
|
| 279 |
-
vae_shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
|
| 280 |
-
|
| 281 |
-
latent_height = height // self.vae_scale_factor
|
| 282 |
-
latent_width = width // self.vae_scale_factor
|
| 283 |
-
|
| 284 |
-
latents = torch.randn(
|
| 285 |
-
(batch_size, self.unet.config.in_channels, latent_height, latent_width),
|
| 286 |
-
generator=generator,
|
| 287 |
-
device=device,
|
| 288 |
-
dtype=dtype
|
| 289 |
-
)
|
| 290 |
-
|
| 291 |
-
# 4. Denoising Loop
|
| 292 |
-
for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
|
| 293 |
-
# CFG preparation
|
| 294 |
-
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
|
| 295 |
-
|
| 296 |
-
# Predict flow
|
| 297 |
-
model_out = self.unet(
|
| 298 |
-
latent_model_input,
|
| 299 |
-
t,
|
| 300 |
-
encoder_hidden_states=text_embeddings,
|
| 301 |
-
encoder_attention_mask=attention_mask,
|
| 302 |
-
return_dict=False,
|
| 303 |
-
)[0]
|
| 304 |
-
|
| 305 |
-
# CFG
|
| 306 |
-
if guidance_scale > 1:
|
| 307 |
-
flow_uncond, flow_cond = model_out.chunk(2)
|
| 308 |
-
model_out = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
|
| 309 |
-
|
| 310 |
-
# Euler step для flow matching
|
| 311 |
-
latents = self.scheduler.step(model_out, t, latents, return_dict=False)[0]
|
| 312 |
-
|
| 313 |
-
# 5. Decode
|
| 314 |
-
if output_type == "latent":
|
| 315 |
-
return SdxsPipelineOutput(images=latents)
|
| 316 |
-
|
| 317 |
-
# Масштабируем обратно для VAE
|
| 318 |
-
latents = latents * vae_scaling_factor + vae_shift_factor
|
| 319 |
-
image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
# Нормализуем к [0, 1] для PIL
|
| 323 |
-
image_output = (image_output.clamp(-1, 1) + 1) / 2
|
| 324 |
-
image_np = image_output.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 325 |
-
|
| 326 |
-
if output_type == "pil":
|
| 327 |
-
image_np = (image_np * 255).round().astype("uint8")
|
| 328 |
-
images = [Image.fromarray(img) for img in image_np]
|
| 329 |
-
else:
|
| 330 |
-
images = image_np
|
| 331 |
-
|
| 332 |
-
if not return_dict:
|
| 333 |
-
return images
|
| 334 |
-
|
| 335 |
-
return SdxsPipelineOutput(images=images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline_sdxs.py
CHANGED
|
@@ -100,7 +100,7 @@ class SdxsPipeline(DiffusionPipeline):
|
|
| 100 |
self,
|
| 101 |
prompt: Union[str, List[str]],
|
| 102 |
image: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
| 103 |
-
coef: float = 0.
|
| 104 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 105 |
height: int = 1024,
|
| 106 |
width: int = 1024,
|
|
|
|
| 100 |
self,
|
| 101 |
prompt: Union[str, List[str]],
|
| 102 |
image: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
| 103 |
+
coef: float = 0.97, # ← strength (0.0 = оригинал, 1.0 = полный шум)
|
| 104 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 105 |
height: int = 1024,
|
| 106 |
width: int = 1024,
|
samples/unet_1.3b_320x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_352x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_384x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_416x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_448x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_480x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_512x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_544x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_576x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_608x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x320_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x352_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x384_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x416_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x448_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x480_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x512_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x544_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x576_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x608_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
samples/unet_1.3b_640x640_0.jpg
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
test.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:632553d43b6346999891e3f671bbf197a6f4e1bca39362387e48e3eb6f357630
|
| 3 |
+
size 5305844
|
unet/config.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:18ea6e42455fff8208c1a900f1e343224198079dc0814f90d4ff283209c3924a
|
| 3 |
+
size 1843
|
unet/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e342af4061c1f6cce4214d23962c1033cf61f2cfee942c42661f72dc29b179a3
|
| 3 |
+
size 2581889304
|
unet_1.3b/config.json
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:4a93a03ab94ebcdad5427326dae7fabbc74a7f46dca4a3804c3e5c11e667ff7e
|
| 3 |
-
size 1848
|
|
|
|
|
|
|
|
|
|
|
|
unet_1.3b/diffusion_pytorch_model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:0acc47c6590d2347959a9338ef56fefbd59b8b09b44bdf8db019d700e4cb3bef
|
| 3 |
-
size 5163635688
|
|
|
|
|
|
|
|
|
|
|
|
vae/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
vae/config.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12d252abeac321629cb81b908a6b49d1bf8d7f60247e00b6be83fd03b0f98b39
|
| 3 |
+
size 852
|
vae/diffusion_pytorch_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:19bd4d341b7cc8d20893e2f257760c8b964bad447a243263191c4be1c89c1aaf
|
| 3 |
+
size 427466716
|
vae/train_vae_fdl.py
ADDED
|
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
import re
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
import gc
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import torchvision.transforms as transforms
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch.utils.data import DataLoader, Dataset
|
| 15 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 16 |
+
from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
|
| 17 |
+
# QWEN: импорт класса
|
| 18 |
+
from diffusers import AutoencoderKLQwenImage
|
| 19 |
+
from diffusers import AutoencoderKLWan
|
| 20 |
+
|
| 21 |
+
from accelerate import Accelerator
|
| 22 |
+
from PIL import Image, UnidentifiedImageError
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
import bitsandbytes as bnb
|
| 25 |
+
import wandb
|
| 26 |
+
import lpips # pip install lpips
|
| 27 |
+
from FDL_pytorch import FDL_loss # pip install fdl-pytorch
|
| 28 |
+
from collections import deque
|
| 29 |
+
|
| 30 |
+
# --------------------------- Параметры ---------------------------
|
| 31 |
+
ds_path = "/workspace/d23"
|
| 32 |
+
project = "vae10"
|
| 33 |
+
batch_size = 1
|
| 34 |
+
base_learning_rate = 6e-6
|
| 35 |
+
min_learning_rate = 7e-7
|
| 36 |
+
num_epochs = 2
|
| 37 |
+
sample_interval_share = 25
|
| 38 |
+
use_wandb = True
|
| 39 |
+
save_model = True
|
| 40 |
+
use_decay = True
|
| 41 |
+
optimizer_type = "adam8bit"
|
| 42 |
+
dtype = torch.float32
|
| 43 |
+
|
| 44 |
+
model_resolution = 512 #288
|
| 45 |
+
high_resolution = 1024 #576
|
| 46 |
+
limit = 0
|
| 47 |
+
save_barrier = 1.3
|
| 48 |
+
warmup_percent = 0.005
|
| 49 |
+
percentile_clipping = 99
|
| 50 |
+
beta2 = 0.997
|
| 51 |
+
eps = 1e-8
|
| 52 |
+
clip_grad_norm = 1.0
|
| 53 |
+
mixed_precision = "no"
|
| 54 |
+
gradient_accumulation_steps = 1
|
| 55 |
+
generated_folder = "samples"
|
| 56 |
+
save_as = "vae10"
|
| 57 |
+
num_workers = 0
|
| 58 |
+
device = None
|
| 59 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 60 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 61 |
+
# Включение Flash Attention 2/SDPA #MAX_JOBS=4 pip install flash-attn --no-build-isolation
|
| 62 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 63 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 64 |
+
torch.backends.cuda.enable_math_sdp(False)
|
| 65 |
+
|
| 66 |
+
# --- Режимы обучения ---
|
| 67 |
+
# QWEN: учим только декодер
|
| 68 |
+
train_decoder_only = True
|
| 69 |
+
train_up_only = False
|
| 70 |
+
full_training = False # если True — учим весь VAE и добавляем KL (ниже)
|
| 71 |
+
kl_ratio = 0.00
|
| 72 |
+
|
| 73 |
+
# Доли лоссов
|
| 74 |
+
loss_ratios = {
|
| 75 |
+
"lpips": 0.70,#0.50,
|
| 76 |
+
"fdl" : 0.10,#0.25,
|
| 77 |
+
"edge": 0.05,
|
| 78 |
+
"mse": 0.10,
|
| 79 |
+
"mae": 0.05,
|
| 80 |
+
"kl": 0.00, # активируем при full_training=True
|
| 81 |
+
}
|
| 82 |
+
median_coeff_steps = 250
|
| 83 |
+
|
| 84 |
+
resize_long_side = 1280 # ресайз длинной стороны исходных картинок
|
| 85 |
+
|
| 86 |
+
# QWEN: конфиг загрузки модели
|
| 87 |
+
vae_kind = "kl" # "qwen" или "kl" (обычный)
|
| 88 |
+
|
| 89 |
+
Path(generated_folder).mkdir(parents=True, exist_ok=True)
|
| 90 |
+
|
| 91 |
+
accelerator = Accelerator(
|
| 92 |
+
mixed_precision=mixed_precision,
|
| 93 |
+
gradient_accumulation_steps=gradient_accumulation_steps
|
| 94 |
+
)
|
| 95 |
+
device = accelerator.device
|
| 96 |
+
|
| 97 |
+
# reproducibility
|
| 98 |
+
seed = int(datetime.now().strftime("%Y%m%d")) + 13
|
| 99 |
+
torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
|
| 100 |
+
torch.backends.cudnn.benchmark = False
|
| 101 |
+
|
| 102 |
+
# --------------------------- WandB ---------------------------
|
| 103 |
+
if use_wandb and accelerator.is_main_process:
|
| 104 |
+
wandb.init(project=project, config={
|
| 105 |
+
"batch_size": batch_size,
|
| 106 |
+
"base_learning_rate": base_learning_rate,
|
| 107 |
+
"num_epochs": num_epochs,
|
| 108 |
+
"optimizer_type": optimizer_type,
|
| 109 |
+
"model_resolution": model_resolution,
|
| 110 |
+
"high_resolution": high_resolution,
|
| 111 |
+
"gradient_accumulation_steps": gradient_accumulation_steps,
|
| 112 |
+
"train_decoder_only": train_decoder_only,
|
| 113 |
+
"full_training": full_training,
|
| 114 |
+
"kl_ratio": kl_ratio,
|
| 115 |
+
"vae_kind": vae_kind,
|
| 116 |
+
})
|
| 117 |
+
|
| 118 |
+
# --------------------------- VAE ---------------------------
|
| 119 |
+
def get_core_model(model):
|
| 120 |
+
m = model
|
| 121 |
+
# если модель уже обёрнута torch.compile
|
| 122 |
+
if hasattr(m, "_orig_mod"):
|
| 123 |
+
m = m._orig_mod
|
| 124 |
+
return m
|
| 125 |
+
|
| 126 |
+
def is_video_vae(model) -> bool:
|
| 127 |
+
# WAN/Qwen — это видео-VAEs
|
| 128 |
+
if vae_kind in ("wan", "qwen"):
|
| 129 |
+
return True
|
| 130 |
+
# fallback по структуре (если понадобится)
|
| 131 |
+
try:
|
| 132 |
+
core = get_core_model(model)
|
| 133 |
+
enc = getattr(core, "encoder", None)
|
| 134 |
+
conv_in = getattr(enc, "conv_in", None)
|
| 135 |
+
w = getattr(conv_in, "weight", None)
|
| 136 |
+
if isinstance(w, torch.nn.Parameter):
|
| 137 |
+
return w.ndim == 5
|
| 138 |
+
except Exception:
|
| 139 |
+
pass
|
| 140 |
+
return False
|
| 141 |
+
|
| 142 |
+
# загрузка
|
| 143 |
+
if vae_kind == "qwen":
|
| 144 |
+
vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
|
| 145 |
+
else:
|
| 146 |
+
if vae_kind == "wan":
|
| 147 |
+
vae = AutoencoderKLWan.from_pretrained(project)
|
| 148 |
+
else:
|
| 149 |
+
# старое поведение (пример)
|
| 150 |
+
if model_resolution==high_resolution:
|
| 151 |
+
vae = AutoencoderKL.from_pretrained(project)
|
| 152 |
+
else:
|
| 153 |
+
vae = AsymmetricAutoencoderKL.from_pretrained(project)
|
| 154 |
+
|
| 155 |
+
vae = vae.to(dtype)
|
| 156 |
+
|
| 157 |
+
# torch.compile (опцион��льно)
|
| 158 |
+
if hasattr(torch, "compile"):
|
| 159 |
+
try:
|
| 160 |
+
vae = torch.compile(vae)
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print(f"[WARN] torch.compile failed: {e}")
|
| 163 |
+
|
| 164 |
+
# --------------------------- Freeze/Unfreeze ---------------------------
|
| 165 |
+
core = get_core_model(vae)
|
| 166 |
+
|
| 167 |
+
for p in core.parameters():
|
| 168 |
+
p.requires_grad = False
|
| 169 |
+
|
| 170 |
+
unfrozen_param_names = []
|
| 171 |
+
|
| 172 |
+
if full_training and not train_decoder_only:
|
| 173 |
+
for name, p in core.named_parameters():
|
| 174 |
+
p.requires_grad = True
|
| 175 |
+
unfrozen_param_names.append(name)
|
| 176 |
+
loss_ratios["kl"] = float(kl_ratio)
|
| 177 |
+
trainable_module = core
|
| 178 |
+
else:
|
| 179 |
+
# учим только 0-й блок декодера + post_quant_conv
|
| 180 |
+
if hasattr(core, "decoder"):
|
| 181 |
+
if train_up_only:#hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
|
| 182 |
+
# --- только 0-й up_block ---
|
| 183 |
+
for name, p in core.decoder.up_blocks[0].named_parameters():
|
| 184 |
+
p.requires_grad = True
|
| 185 |
+
unfrozen_param_names.append(f"{name}")
|
| 186 |
+
else:
|
| 187 |
+
print("Decoder — fallback to full decoder")
|
| 188 |
+
for name, p in core.decoder.named_parameters():
|
| 189 |
+
p.requires_grad = True
|
| 190 |
+
unfrozen_param_names.append(f"decoder.{name}")
|
| 191 |
+
if hasattr(core, "post_quant_conv"):
|
| 192 |
+
for name, p in core.post_quant_conv.named_parameters():
|
| 193 |
+
p.requires_grad = True
|
| 194 |
+
unfrozen_param_names.append(f"post_quant_conv.{name}")
|
| 195 |
+
trainable_module = core.decoder if hasattr(core, "decoder") else core
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
|
| 199 |
+
for nm in unfrozen_param_names[:200]:
|
| 200 |
+
print(" ", nm)
|
| 201 |
+
|
| 202 |
+
# --------------------------- Датасет ---------------------------
|
| 203 |
+
class PngFolderDataset(Dataset):
|
| 204 |
+
def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
|
| 205 |
+
self.root_dir = root_dir
|
| 206 |
+
self.resolution = resolution
|
| 207 |
+
self.paths = []
|
| 208 |
+
for root, _, files in os.walk(root_dir):
|
| 209 |
+
for fname in files:
|
| 210 |
+
if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
|
| 211 |
+
self.paths.append(os.path.join(root, fname))
|
| 212 |
+
if limit:
|
| 213 |
+
self.paths = self.paths[:limit]
|
| 214 |
+
valid = []
|
| 215 |
+
for p in self.paths:
|
| 216 |
+
try:
|
| 217 |
+
with Image.open(p) as im:
|
| 218 |
+
im.verify()
|
| 219 |
+
valid.append(p)
|
| 220 |
+
except (OSError, UnidentifiedImageError):
|
| 221 |
+
continue
|
| 222 |
+
self.paths = valid
|
| 223 |
+
if len(self.paths) == 0:
|
| 224 |
+
raise RuntimeError(f"No valid PNG images found under {root_dir}")
|
| 225 |
+
random.shuffle(self.paths)
|
| 226 |
+
|
| 227 |
+
def __len__(self):
|
| 228 |
+
return len(self.paths)
|
| 229 |
+
|
| 230 |
+
def __getitem__(self, idx):
|
| 231 |
+
p = self.paths[idx % len(self.paths)]
|
| 232 |
+
with Image.open(p) as img:
|
| 233 |
+
img = img.convert("RGB")
|
| 234 |
+
if not resize_long_side or resize_long_side <= 0:
|
| 235 |
+
return img
|
| 236 |
+
w, h = img.size
|
| 237 |
+
long = max(w, h)
|
| 238 |
+
if long <= resize_long_side:
|
| 239 |
+
return img
|
| 240 |
+
scale = resize_long_side / float(long)
|
| 241 |
+
new_w = int(round(w * scale))
|
| 242 |
+
new_h = int(round(h * scale))
|
| 243 |
+
return img.resize((new_w, new_h), Image.BICUBIC)
|
| 244 |
+
|
| 245 |
+
def random_crop(img, sz):
|
| 246 |
+
w, h = img.size
|
| 247 |
+
if w < sz or h < sz:
|
| 248 |
+
img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
|
| 249 |
+
x = random.randint(0, max(1, img.width - sz))
|
| 250 |
+
y = random.randint(0, max(1, img.height - sz))
|
| 251 |
+
return img.crop((x, y, x + sz, y + sz))
|
| 252 |
+
|
| 253 |
+
tfm = transforms.Compose([
|
| 254 |
+
transforms.ToTensor(),
|
| 255 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
| 256 |
+
])
|
| 257 |
+
|
| 258 |
+
dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
|
| 259 |
+
print("len(dataset)",len(dataset))
|
| 260 |
+
if len(dataset) < batch_size:
|
| 261 |
+
raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
|
| 262 |
+
|
| 263 |
+
def collate_fn(batch):
|
| 264 |
+
imgs = []
|
| 265 |
+
for img in batch:
|
| 266 |
+
img = random_crop(img, high_resolution)
|
| 267 |
+
imgs.append(tfm(img))
|
| 268 |
+
return torch.stack(imgs)
|
| 269 |
+
|
| 270 |
+
dataloader = DataLoader(
|
| 271 |
+
dataset,
|
| 272 |
+
batch_size=batch_size,
|
| 273 |
+
shuffle=True,
|
| 274 |
+
collate_fn=collate_fn,
|
| 275 |
+
num_workers=num_workers,
|
| 276 |
+
pin_memory=True,
|
| 277 |
+
drop_last=True
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# --------------------------- Оптимизатор ---------------------------
|
| 281 |
+
def get_param_groups(module, weight_decay=0.001):
|
| 282 |
+
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
|
| 283 |
+
decay_params, no_decay_params = [], []
|
| 284 |
+
for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
|
| 285 |
+
if not p.requires_grad:
|
| 286 |
+
continue
|
| 287 |
+
if any(nd in n for nd in no_decay):
|
| 288 |
+
no_decay_params.append(p)
|
| 289 |
+
else:
|
| 290 |
+
decay_params.append(p)
|
| 291 |
+
return [
|
| 292 |
+
{"params": decay_params, "weight_decay": weight_decay},
|
| 293 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 294 |
+
]
|
| 295 |
+
|
| 296 |
+
def get_param_groups(module, weight_decay=0.001):
|
| 297 |
+
no_decay_tokens = ("bias", "norm", "rms", "layernorm")
|
| 298 |
+
decay_params, no_decay_params = [], []
|
| 299 |
+
for n, p in module.named_parameters():
|
| 300 |
+
if not p.requires_grad:
|
| 301 |
+
continue
|
| 302 |
+
n_l = n.lower()
|
| 303 |
+
if any(t in n_l for t in no_decay_tokens):
|
| 304 |
+
no_decay_params.append(p)
|
| 305 |
+
else:
|
| 306 |
+
decay_params.append(p)
|
| 307 |
+
return [
|
| 308 |
+
{"params": decay_params, "weight_decay": weight_decay},
|
| 309 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 310 |
+
]
|
| 311 |
+
|
| 312 |
+
def create_optimizer(name, param_groups):
|
| 313 |
+
if name == "adam8bit":
|
| 314 |
+
return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
|
| 315 |
+
raise ValueError(name)
|
| 316 |
+
|
| 317 |
+
param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
|
| 318 |
+
optimizer = create_optimizer(optimizer_type, param_groups)
|
| 319 |
+
|
| 320 |
+
# --------------------------- LR schedule ---------------------------
|
| 321 |
+
batches_per_epoch = len(dataloader)
|
| 322 |
+
steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
|
| 323 |
+
total_steps = steps_per_epoch * num_epochs
|
| 324 |
+
|
| 325 |
+
def lr_lambda(step):
|
| 326 |
+
if not use_decay:
|
| 327 |
+
return 1.0
|
| 328 |
+
x = float(step) / float(max(1, total_steps))
|
| 329 |
+
warmup = float(warmup_percent)
|
| 330 |
+
min_ratio = float(min_learning_rate) / float(base_learning_rate)
|
| 331 |
+
if x < warmup:
|
| 332 |
+
return min_ratio + (1.0 - min_ratio) * (x / warmup)
|
| 333 |
+
decay_ratio = (x - warmup) / (1.0 - warmup)
|
| 334 |
+
return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
|
| 335 |
+
|
| 336 |
+
scheduler = LambdaLR(optimizer, lr_lambda)
|
| 337 |
+
|
| 338 |
+
# Подготовка
|
| 339 |
+
dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
|
| 340 |
+
trainable_params = [p for p in vae.parameters() if p.requires_grad]
|
| 341 |
+
|
| 342 |
+
# fdl
|
| 343 |
+
fdl_loss = FDL_loss()
|
| 344 |
+
fdl_loss = fdl_loss.to(accelerator.device)
|
| 345 |
+
|
| 346 |
+
# --------------------------- LPIPS и вспомогательные ---------------------------
|
| 347 |
+
_lpips_net = None
|
| 348 |
+
def _get_lpips():
|
| 349 |
+
global _lpips_net
|
| 350 |
+
if _lpips_net is None:
|
| 351 |
+
_lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
|
| 352 |
+
return _lpips_net
|
| 353 |
+
|
| 354 |
+
_sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
|
| 355 |
+
_sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
|
| 356 |
+
def sobel_edges(x: torch.Tensor) -> torch.Tensor:
|
| 357 |
+
C = x.shape[1]
|
| 358 |
+
kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
|
| 359 |
+
ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
|
| 360 |
+
gx = F.conv2d(x, kx, padding=1, groups=C)
|
| 361 |
+
gy = F.conv2d(x, ky, padding=1, groups=C)
|
| 362 |
+
return torch.sqrt(gx * gx + gy * gy + 1e-12)
|
| 363 |
+
|
| 364 |
+
class MedianLossNormalizer:
|
| 365 |
+
def __init__(self, desired_ratios: dict, window_steps: int):
|
| 366 |
+
s = sum(desired_ratios.values())
|
| 367 |
+
self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
|
| 368 |
+
self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
|
| 369 |
+
self.window = window_steps
|
| 370 |
+
|
| 371 |
+
def update_and_total(self, abs_losses: dict):
|
| 372 |
+
for k, v in abs_losses.items():
|
| 373 |
+
if k in self.buffers:
|
| 374 |
+
self.buffers[k].append(float(v.detach().abs().cpu()))
|
| 375 |
+
meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
|
| 376 |
+
coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
|
| 377 |
+
total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
|
| 378 |
+
return total, coeffs, meds
|
| 379 |
+
|
| 380 |
+
if full_training and not train_decoder_only:
|
| 381 |
+
loss_ratios["kl"] = float(kl_ratio)
|
| 382 |
+
normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
|
| 383 |
+
|
| 384 |
+
# --------------------------- Сэмплы ---------------------------
|
| 385 |
+
@torch.no_grad()
|
| 386 |
+
def get_fixed_samples(n=3):
|
| 387 |
+
idx = random.sample(range(len(dataset)), min(n, len(dataset)))
|
| 388 |
+
pil_imgs = [dataset[i] for i in idx]
|
| 389 |
+
tensors = []
|
| 390 |
+
for img in pil_imgs:
|
| 391 |
+
img = random_crop(img, high_resolution)
|
| 392 |
+
tensors.append(tfm(img))
|
| 393 |
+
return torch.stack(tensors).to(accelerator.device, dtype)
|
| 394 |
+
|
| 395 |
+
fixed_samples = get_fixed_samples()
|
| 396 |
+
|
| 397 |
+
@torch.no_grad()
|
| 398 |
+
def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
|
| 399 |
+
arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
|
| 400 |
+
return Image.fromarray(arr)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@torch.no_grad()
|
| 404 |
+
def generate_and_save_samples(step=None):
|
| 405 |
+
try:
|
| 406 |
+
#temp_vae = accelerator.unwrap_model(vae).eval()
|
| 407 |
+
if hasattr(vae, "module"):
|
| 408 |
+
# Если это DDP или DistributedDataParallel
|
| 409 |
+
unwrapped_vae = vae.module
|
| 410 |
+
else:
|
| 411 |
+
unwrapped_vae = vae
|
| 412 |
+
|
| 413 |
+
# Если использовался torch.compile, достаем оригинал
|
| 414 |
+
if hasattr(unwrapped_vae, "_orig_mod"):
|
| 415 |
+
temp_vae = unwrapped_vae._orig_mod
|
| 416 |
+
else:
|
| 417 |
+
temp_vae = unwrapped_vae
|
| 418 |
+
|
| 419 |
+
temp_vae = temp_vae.eval()
|
| 420 |
+
lpips_net = _get_lpips()
|
| 421 |
+
with torch.no_grad():
|
| 422 |
+
orig_high = fixed_samples
|
| 423 |
+
orig_low = F.interpolate(
|
| 424 |
+
orig_high,
|
| 425 |
+
size=(model_resolution, model_resolution),
|
| 426 |
+
mode="bilinear",
|
| 427 |
+
align_corners=False
|
| 428 |
+
)
|
| 429 |
+
model_dtype = next(temp_vae.parameters()).dtype
|
| 430 |
+
orig_low = orig_low.to(dtype=model_dtype)
|
| 431 |
+
|
| 432 |
+
# Encode/decode с учётом видео-режима
|
| 433 |
+
if is_video_vae(temp_vae):
|
| 434 |
+
x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
|
| 435 |
+
enc = temp_vae.encode(x_in)
|
| 436 |
+
latents_mean = enc.latent_dist.mean
|
| 437 |
+
dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
|
| 438 |
+
rec = dec.squeeze(2) # [B,3,H,W]
|
| 439 |
+
else:
|
| 440 |
+
enc = temp_vae.encode(orig_low)
|
| 441 |
+
latents_mean = enc.latent_dist.mean
|
| 442 |
+
rec = temp_vae.decode(latents_mean).sample
|
| 443 |
+
|
| 444 |
+
# Подгон размеров, если надо
|
| 445 |
+
#if rec.shape[-2:] != orig_high.shape[-2:]:
|
| 446 |
+
# rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
|
| 447 |
+
|
| 448 |
+
# Сохраняем все real/decoded
|
| 449 |
+
for i in range(rec.shape[0]):
|
| 450 |
+
real_img = _to_pil_uint8(orig_high[i])
|
| 451 |
+
dec_img = _to_pil_uint8(rec[i])
|
| 452 |
+
real_img.save(f"{generated_folder}/sample_real_{i}.png")
|
| 453 |
+
dec_img.save(f"{generated_folder}/sample_decoded_{i}.png")
|
| 454 |
+
|
| 455 |
+
# LPIPS
|
| 456 |
+
lpips_scores = []
|
| 457 |
+
for i in range(rec.shape[0]):
|
| 458 |
+
orig_full = orig_high[i:i+1].to(torch.float32)
|
| 459 |
+
rec_full = rec[i:i+1].to(torch.float32)
|
| 460 |
+
#if rec_full.shape[-2:] != orig_full.shape[-2:]:
|
| 461 |
+
# rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
|
| 462 |
+
lpips_val = lpips_net(orig_full, rec_full).item()
|
| 463 |
+
lpips_scores.append(lpips_val)
|
| 464 |
+
avg_lpips = float(np.mean(lpips_scores))
|
| 465 |
+
|
| 466 |
+
# W&B логирование
|
| 467 |
+
if use_wandb and accelerator.is_main_process:
|
| 468 |
+
log_data = {"lpips_mean": avg_lpips}
|
| 469 |
+
for i in range(rec.shape[0]):
|
| 470 |
+
log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.png", caption=f"real_{i}")
|
| 471 |
+
log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.png", caption=f"decoded_{i}")
|
| 472 |
+
wandb.log(log_data, step=step)
|
| 473 |
+
|
| 474 |
+
finally:
|
| 475 |
+
gc.collect()
|
| 476 |
+
torch.cuda.empty_cache()
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
if accelerator.is_main_process and save_model:
|
| 480 |
+
print("Генерация сэмплов до старта обучения...")
|
| 481 |
+
generate_and_save_samples(0)
|
| 482 |
+
|
| 483 |
+
accelerator.wait_for_everyone()
|
| 484 |
+
|
| 485 |
+
# --------------------------- Тренировка ---------------------------
|
| 486 |
+
progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
|
| 487 |
+
global_step = 0
|
| 488 |
+
min_loss = float("inf")
|
| 489 |
+
sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
|
| 490 |
+
|
| 491 |
+
for epoch in range(num_epochs):
|
| 492 |
+
vae.train()
|
| 493 |
+
batch_losses, batch_grads = [], []
|
| 494 |
+
track_losses = {k: [] for k in loss_ratios.keys()}
|
| 495 |
+
|
| 496 |
+
for imgs in dataloader:
|
| 497 |
+
with accelerator.accumulate(vae):
|
| 498 |
+
imgs = imgs.to(accelerator.device)
|
| 499 |
+
|
| 500 |
+
if high_resolution != model_resolution:
|
| 501 |
+
imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution),mode="area") # mode="bilinear", align_corners=False)
|
| 502 |
+
else:
|
| 503 |
+
imgs_low = imgs
|
| 504 |
+
|
| 505 |
+
model_dtype = next(vae.parameters()).dtype
|
| 506 |
+
imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
|
| 507 |
+
|
| 508 |
+
# Вместо: current_vae = accelerator.unwrap_model(vae)
|
| 509 |
+
unwrapped = vae.module if hasattr(vae, "module") else vae
|
| 510 |
+
current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
# QWEN: encode/decode с T=1
|
| 514 |
+
if is_video_vae(current_vae):
|
| 515 |
+
x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
|
| 516 |
+
enc = current_vae.encode(x_in)
|
| 517 |
+
latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
|
| 518 |
+
dec = current_vae.decode(latents).sample # [B,3,1,H,W]
|
| 519 |
+
rec = dec.squeeze(2) # [B,3,H,W]
|
| 520 |
+
else:
|
| 521 |
+
enc = current_vae.encode(imgs_low_model)
|
| 522 |
+
latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
|
| 523 |
+
rec = current_vae.decode(latents).sample
|
| 524 |
+
|
| 525 |
+
#if rec.shape[-2:] != imgs.shape[-2:]:
|
| 526 |
+
# rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
|
| 527 |
+
|
| 528 |
+
rec_f32 = rec.to(torch.float32)
|
| 529 |
+
imgs_f32 = imgs.to(torch.float32)
|
| 530 |
+
|
| 531 |
+
abs_losses = {
|
| 532 |
+
"mae": F.l1_loss(rec_f32, imgs_f32),
|
| 533 |
+
"mse": F.mse_loss(rec_f32, imgs_f32),
|
| 534 |
+
"lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
|
| 535 |
+
"fdl": fdl_loss(rec_f32, imgs_f32),
|
| 536 |
+
"edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
if full_training and not train_decoder_only:
|
| 540 |
+
mean = enc.latent_dist.mean
|
| 541 |
+
logvar = enc.latent_dist.logvar
|
| 542 |
+
kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
|
| 543 |
+
abs_losses["kl"] = kl
|
| 544 |
+
else:
|
| 545 |
+
abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
|
| 546 |
+
|
| 547 |
+
total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
|
| 548 |
+
|
| 549 |
+
if torch.isnan(total_loss) or torch.isinf(total_loss):
|
| 550 |
+
raise RuntimeError("NaN/Inf loss")
|
| 551 |
+
|
| 552 |
+
accelerator.backward(total_loss)
|
| 553 |
+
|
| 554 |
+
grad_norm = torch.tensor(0.0, device=accelerator.device)
|
| 555 |
+
if accelerator.sync_gradients:
|
| 556 |
+
grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
|
| 557 |
+
optimizer.step()
|
| 558 |
+
scheduler.step()
|
| 559 |
+
optimizer.zero_grad(set_to_none=True)
|
| 560 |
+
global_step += 1
|
| 561 |
+
progress.update(1)
|
| 562 |
+
|
| 563 |
+
if accelerator.is_main_process:
|
| 564 |
+
try:
|
| 565 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 566 |
+
except Exception:
|
| 567 |
+
current_lr = scheduler.get_last_lr()[0]
|
| 568 |
+
|
| 569 |
+
batch_losses.append(total_loss.detach().item())
|
| 570 |
+
batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
|
| 571 |
+
for k, v in abs_losses.items():
|
| 572 |
+
track_losses[k].append(float(v.detach().item()))
|
| 573 |
+
|
| 574 |
+
if use_wandb and accelerator.sync_gradients:
|
| 575 |
+
log_dict = {
|
| 576 |
+
"total_loss": float(total_loss.detach().item()),
|
| 577 |
+
"learning_rate": current_lr,
|
| 578 |
+
"epoch": epoch,
|
| 579 |
+
"grad_norm": batch_grads[-1],
|
| 580 |
+
}
|
| 581 |
+
for k, v in abs_losses.items():
|
| 582 |
+
log_dict[f"loss_{k}"] = float(v.detach().item())
|
| 583 |
+
for k in coeffs:
|
| 584 |
+
log_dict[f"coeff_{k}"] = float(coeffs[k])
|
| 585 |
+
log_dict[f"median_{k}"] = float(meds[k])
|
| 586 |
+
wandb.log(log_dict, step=global_step)
|
| 587 |
+
|
| 588 |
+
if global_step > 0 and global_step % sample_interval == 0:
|
| 589 |
+
if accelerator.is_main_process:
|
| 590 |
+
generate_and_save_samples(global_step)
|
| 591 |
+
accelerator.wait_for_everyone()
|
| 592 |
+
|
| 593 |
+
n_micro = sample_interval * gradient_accumulation_steps
|
| 594 |
+
avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
|
| 595 |
+
avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
|
| 596 |
+
|
| 597 |
+
if accelerator.is_main_process:
|
| 598 |
+
print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
|
| 599 |
+
if save_model and avg_loss < min_loss * save_barrier:
|
| 600 |
+
min_loss = avg_loss
|
| 601 |
+
unwrapped = vae.module if hasattr(vae, "module") else vae
|
| 602 |
+
current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
|
| 603 |
+
current_vae.save_pretrained(save_as)
|
| 604 |
+
if use_wandb:
|
| 605 |
+
wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
|
| 606 |
+
|
| 607 |
+
if accelerator.is_main_process:
|
| 608 |
+
epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
|
| 609 |
+
print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
|
| 610 |
+
if use_wandb:
|
| 611 |
+
wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
|
| 612 |
+
|
| 613 |
+
# --------------------------- Финальное сохранение ---------------------------
|
| 614 |
+
if accelerator.is_main_process:
|
| 615 |
+
print("Training finished – saving final model")
|
| 616 |
+
if save_model:
|
| 617 |
+
unwrapped = vae.module if hasattr(vae, "module") else vae
|
| 618 |
+
current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
|
| 619 |
+
current_vae.save_pretrained(save_as)
|
| 620 |
+
|
| 621 |
+
accelerator.free_memory()
|
| 622 |
+
if torch.distributed.is_initialized():
|
| 623 |
+
torch.distributed.destroy_process_group()
|
| 624 |
+
print("Готово!")
|