Instructions to use babkasotona/vae2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use babkasotona/vae2 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("babkasotona/vae2", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
Upload folder using huggingface_hub
Browse files- train_sdxs_vae.py +36 -50
train_sdxs_vae.py
CHANGED
|
@@ -26,10 +26,11 @@ import wandb
|
|
| 26 |
import lpips # pip install lpips
|
| 27 |
from FDL_pytorch import FDL_loss # pip install fdl-pytorch
|
| 28 |
from collections import deque
|
|
|
|
| 29 |
|
| 30 |
# --------------------------- Параметры ---------------------------
|
| 31 |
ds_path = "/workspace/d23"
|
| 32 |
-
project = "
|
| 33 |
batch_size = 1
|
| 34 |
base_learning_rate = 6e-6
|
| 35 |
min_learning_rate = 7e-7
|
|
@@ -52,7 +53,7 @@ clip_grad_norm = 1.0
|
|
| 52 |
mixed_precision = "no"
|
| 53 |
gradient_accumulation_steps = 1
|
| 54 |
generated_folder = "samples"
|
| 55 |
-
save_as = "
|
| 56 |
num_workers = 0
|
| 57 |
device = None
|
| 58 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
@@ -73,9 +74,8 @@ kl_ratio = 0.0
|
|
| 73 |
loss_ratios = {
|
| 74 |
"lpips": 0.70,#0.50,
|
| 75 |
"fdl" : 0.10,#0.25,
|
| 76 |
-
"edge": 0.05,
|
| 77 |
"mse": 0.10,
|
| 78 |
-
"mae": 0.
|
| 79 |
"kl": 0.00,
|
| 80 |
}
|
| 81 |
median_coeff_steps = 250
|
|
@@ -195,33 +195,48 @@ else:
|
|
| 195 |
|
| 196 |
|
| 197 |
print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
|
| 198 |
-
for nm in unfrozen_param_names[:
|
| 199 |
print(" ", nm)
|
| 200 |
|
| 201 |
# --------------------------- Датасет ---------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
class PngFolderDataset(Dataset):
|
| 203 |
-
def __init__(self, root_dir, min_exts=('.png',),
|
| 204 |
-
self.root_dir = root_dir
|
| 205 |
self.resolution = resolution
|
| 206 |
self.paths = []
|
|
|
|
| 207 |
for root, _, files in os.walk(root_dir):
|
| 208 |
-
for
|
| 209 |
-
if
|
| 210 |
-
self.paths.append(os.path.join(root,
|
|
|
|
| 211 |
if limit:
|
| 212 |
self.paths = self.paths[:limit]
|
|
|
|
|
|
|
| 213 |
valid = []
|
| 214 |
for p in self.paths:
|
| 215 |
try:
|
| 216 |
-
with Image.open(p) as
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
| 218 |
valid.append(p)
|
| 219 |
except (OSError, UnidentifiedImageError):
|
| 220 |
continue
|
|
|
|
| 221 |
self.paths = valid
|
| 222 |
-
if
|
| 223 |
-
raise RuntimeError(
|
|
|
|
| 224 |
random.shuffle(self.paths)
|
|
|
|
| 225 |
|
| 226 |
def __len__(self):
|
| 227 |
return len(self.paths)
|
|
@@ -230,21 +245,10 @@ class PngFolderDataset(Dataset):
|
|
| 230 |
p = self.paths[idx % len(self.paths)]
|
| 231 |
with Image.open(p) as img:
|
| 232 |
img = img.convert("RGB")
|
| 233 |
-
|
| 234 |
-
return img
|
| 235 |
-
w, h = img.size
|
| 236 |
-
long = max(w, h)
|
| 237 |
-
if long <= resize_long_side:
|
| 238 |
-
return img
|
| 239 |
-
scale = resize_long_side / float(long)
|
| 240 |
-
new_w = int(round(w * scale))
|
| 241 |
-
new_h = int(round(h * scale))
|
| 242 |
-
return img.resize((new_w, new_h), Image.BICUBIC)
|
| 243 |
|
| 244 |
def random_crop(img, sz):
|
| 245 |
w, h = img.size
|
| 246 |
-
if w < sz or h < sz:
|
| 247 |
-
img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
|
| 248 |
x = random.randint(0, max(1, img.width - sz))
|
| 249 |
y = random.randint(0, max(1, img.height - sz))
|
| 250 |
return img.crop((x, y, x + sz, y + sz))
|
|
@@ -254,11 +258,6 @@ tfm = transforms.Compose([
|
|
| 254 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
| 255 |
])
|
| 256 |
|
| 257 |
-
dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
|
| 258 |
-
print("len(dataset)",len(dataset))
|
| 259 |
-
if len(dataset) < batch_size:
|
| 260 |
-
raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
|
| 261 |
-
|
| 262 |
def collate_fn(batch):
|
| 263 |
imgs = []
|
| 264 |
for img in batch:
|
|
@@ -266,15 +265,12 @@ def collate_fn(batch):
|
|
| 266 |
imgs.append(tfm(img))
|
| 267 |
return torch.stack(imgs)
|
| 268 |
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
pin_memory=True,
|
| 276 |
-
drop_last=True
|
| 277 |
-
)
|
| 278 |
|
| 279 |
# --------------------------- Оптимизатор ---------------------------
|
| 280 |
def get_param_groups(module, weight_decay=0.001):
|
|
@@ -350,15 +346,6 @@ def _get_lpips():
|
|
| 350 |
_lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
|
| 351 |
return _lpips_net
|
| 352 |
|
| 353 |
-
_sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
|
| 354 |
-
_sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
|
| 355 |
-
def sobel_edges(x: torch.Tensor) -> torch.Tensor:
|
| 356 |
-
C = x.shape[1]
|
| 357 |
-
kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
|
| 358 |
-
ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
|
| 359 |
-
gx = F.conv2d(x, kx, padding=1, groups=C)
|
| 360 |
-
gy = F.conv2d(x, ky, padding=1, groups=C)
|
| 361 |
-
return torch.sqrt(gx * gx + gy * gy + 1e-12)
|
| 362 |
|
| 363 |
class MedianLossNormalizer:
|
| 364 |
def __init__(self, desired_ratios: dict, window_steps: int):
|
|
@@ -532,7 +519,6 @@ for epoch in range(num_epochs):
|
|
| 532 |
"mse": F.mse_loss(rec_f32, imgs_f32),
|
| 533 |
"lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
|
| 534 |
"fdl": fdl_loss(rec_f32, imgs_f32),
|
| 535 |
-
"edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
|
| 536 |
}
|
| 537 |
|
| 538 |
if full_training and not train_decoder_only:
|
|
|
|
| 26 |
import lpips # pip install lpips
|
| 27 |
from FDL_pytorch import FDL_loss # pip install fdl-pytorch
|
| 28 |
from collections import deque
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
|
| 31 |
# --------------------------- Параметры ---------------------------
|
| 32 |
ds_path = "/workspace/d23"
|
| 33 |
+
project = "vae2"
|
| 34 |
batch_size = 1
|
| 35 |
base_learning_rate = 6e-6
|
| 36 |
min_learning_rate = 7e-7
|
|
|
|
| 53 |
mixed_precision = "no"
|
| 54 |
gradient_accumulation_steps = 1
|
| 55 |
generated_folder = "samples"
|
| 56 |
+
save_as = "vae3"
|
| 57 |
num_workers = 0
|
| 58 |
device = None
|
| 59 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
| 74 |
loss_ratios = {
|
| 75 |
"lpips": 0.70,#0.50,
|
| 76 |
"fdl" : 0.10,#0.25,
|
|
|
|
| 77 |
"mse": 0.10,
|
| 78 |
+
"mae": 0.10,
|
| 79 |
"kl": 0.00,
|
| 80 |
}
|
| 81 |
median_coeff_steps = 250
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
|
| 198 |
+
for nm in unfrozen_param_names[:10]:
|
| 199 |
print(" ", nm)
|
| 200 |
|
| 201 |
# --------------------------- Датасет ---------------------------
|
| 202 |
+
from torch.utils.data import Dataset
|
| 203 |
+
from PIL import Image, UnidentifiedImageError
|
| 204 |
+
import random
|
| 205 |
+
import torchvision.transforms as transforms
|
| 206 |
+
import os
|
| 207 |
+
|
| 208 |
class PngFolderDataset(Dataset):
|
| 209 |
+
def __init__(self, root_dir, resolution=1024, min_exts=('.png',), limit=0):
|
|
|
|
| 210 |
self.resolution = resolution
|
| 211 |
self.paths = []
|
| 212 |
+
|
| 213 |
for root, _, files in os.walk(root_dir):
|
| 214 |
+
for f in files:
|
| 215 |
+
if f.lower().endswith(tuple(ext.lower() for ext in min_exts)):
|
| 216 |
+
self.paths.append(os.path.join(root, f))
|
| 217 |
+
|
| 218 |
if limit:
|
| 219 |
self.paths = self.paths[:limit]
|
| 220 |
+
|
| 221 |
+
# фильтруем недопустимые картинки
|
| 222 |
valid = []
|
| 223 |
for p in self.paths:
|
| 224 |
try:
|
| 225 |
+
with Image.open(p) as img:
|
| 226 |
+
img.verify() # только метаданные
|
| 227 |
+
w, h = img.size
|
| 228 |
+
if w < resolution or h < resolution:
|
| 229 |
+
continue
|
| 230 |
valid.append(p)
|
| 231 |
except (OSError, UnidentifiedImageError):
|
| 232 |
continue
|
| 233 |
+
|
| 234 |
self.paths = valid
|
| 235 |
+
if not self.paths:
|
| 236 |
+
raise RuntimeError("No valid images found")
|
| 237 |
+
|
| 238 |
random.shuffle(self.paths)
|
| 239 |
+
self.transform = transforms.ToTensor() # конвертирует сразу [0,1] float32
|
| 240 |
|
| 241 |
def __len__(self):
|
| 242 |
return len(self.paths)
|
|
|
|
| 245 |
p = self.paths[idx % len(self.paths)]
|
| 246 |
with Image.open(p) as img:
|
| 247 |
img = img.convert("RGB")
|
| 248 |
+
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
def random_crop(img, sz):
|
| 251 |
w, h = img.size
|
|
|
|
|
|
|
| 252 |
x = random.randint(0, max(1, img.width - sz))
|
| 253 |
y = random.randint(0, max(1, img.height - sz))
|
| 254 |
return img.crop((x, y, x + sz, y + sz))
|
|
|
|
| 258 |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
| 259 |
])
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
def collate_fn(batch):
|
| 262 |
imgs = []
|
| 263 |
for img in batch:
|
|
|
|
| 265 |
imgs.append(tfm(img))
|
| 266 |
return torch.stack(imgs)
|
| 267 |
|
| 268 |
+
dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
|
| 269 |
+
print("len(dataset)",len(dataset))
|
| 270 |
+
if len(dataset) < batch_size:
|
| 271 |
+
raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
|
| 272 |
+
|
| 273 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, drop_last=True)
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
# --------------------------- Оптимизатор ---------------------------
|
| 276 |
def get_param_groups(module, weight_decay=0.001):
|
|
|
|
| 346 |
_lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
|
| 347 |
return _lpips_net
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
class MedianLossNormalizer:
|
| 351 |
def __init__(self, desired_ratios: dict, window_steps: int):
|
|
|
|
| 519 |
"mse": F.mse_loss(rec_f32, imgs_f32),
|
| 520 |
"lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
|
| 521 |
"fdl": fdl_loss(rec_f32, imgs_f32),
|
|
|
|
| 522 |
}
|
| 523 |
|
| 524 |
if full_training and not train_decoder_only:
|