babkasotona commited on
Commit
21d0f47
·
verified ·
1 Parent(s): d067363

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sdxs_vae.py +36 -50
train_sdxs_vae.py CHANGED
@@ -26,10 +26,11 @@ import wandb
26
  import lpips # pip install lpips
27
  from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
  from collections import deque
 
29
 
30
  # --------------------------- Параметры ---------------------------
31
  ds_path = "/workspace/d23"
32
- project = "vae"
33
  batch_size = 1
34
  base_learning_rate = 6e-6
35
  min_learning_rate = 7e-7
@@ -52,7 +53,7 @@ clip_grad_norm = 1.0
52
  mixed_precision = "no"
53
  gradient_accumulation_steps = 1
54
  generated_folder = "samples"
55
- save_as = "vae2"
56
  num_workers = 0
57
  device = None
58
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -73,9 +74,8 @@ kl_ratio = 0.0
73
  loss_ratios = {
74
  "lpips": 0.70,#0.50,
75
  "fdl" : 0.10,#0.25,
76
- "edge": 0.05,
77
  "mse": 0.10,
78
- "mae": 0.05,
79
  "kl": 0.00,
80
  }
81
  median_coeff_steps = 250
@@ -195,33 +195,48 @@ else:
195
 
196
 
197
  print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
198
- for nm in unfrozen_param_names[:100]:
199
  print(" ", nm)
200
 
201
  # --------------------------- Датасет ---------------------------
 
 
 
 
 
 
202
  class PngFolderDataset(Dataset):
203
- def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
204
- self.root_dir = root_dir
205
  self.resolution = resolution
206
  self.paths = []
 
207
  for root, _, files in os.walk(root_dir):
208
- for fname in files:
209
- if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
210
- self.paths.append(os.path.join(root, fname))
 
211
  if limit:
212
  self.paths = self.paths[:limit]
 
 
213
  valid = []
214
  for p in self.paths:
215
  try:
216
- with Image.open(p) as im:
217
- im.verify()
 
 
 
218
  valid.append(p)
219
  except (OSError, UnidentifiedImageError):
220
  continue
 
221
  self.paths = valid
222
- if len(self.paths) == 0:
223
- raise RuntimeError(f"No valid PNG images found under {root_dir}")
 
224
  random.shuffle(self.paths)
 
225
 
226
  def __len__(self):
227
  return len(self.paths)
@@ -230,21 +245,10 @@ class PngFolderDataset(Dataset):
230
  p = self.paths[idx % len(self.paths)]
231
  with Image.open(p) as img:
232
  img = img.convert("RGB")
233
- if not resize_long_side or resize_long_side <= 0:
234
- return img
235
- w, h = img.size
236
- long = max(w, h)
237
- if long <= resize_long_side:
238
- return img
239
- scale = resize_long_side / float(long)
240
- new_w = int(round(w * scale))
241
- new_h = int(round(h * scale))
242
- return img.resize((new_w, new_h), Image.BICUBIC)
243
 
244
  def random_crop(img, sz):
245
  w, h = img.size
246
- if w < sz or h < sz:
247
- img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
248
  x = random.randint(0, max(1, img.width - sz))
249
  y = random.randint(0, max(1, img.height - sz))
250
  return img.crop((x, y, x + sz, y + sz))
@@ -254,11 +258,6 @@ tfm = transforms.Compose([
254
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
255
  ])
256
 
257
- dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
258
- print("len(dataset)",len(dataset))
259
- if len(dataset) < batch_size:
260
- raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
261
-
262
  def collate_fn(batch):
263
  imgs = []
264
  for img in batch:
@@ -266,15 +265,12 @@ def collate_fn(batch):
266
  imgs.append(tfm(img))
267
  return torch.stack(imgs)
268
 
269
- dataloader = DataLoader(
270
- dataset,
271
- batch_size=batch_size,
272
- shuffle=True,
273
- collate_fn=collate_fn,
274
- num_workers=num_workers,
275
- pin_memory=True,
276
- drop_last=True
277
- )
278
 
279
  # --------------------------- Оптимизатор ---------------------------
280
  def get_param_groups(module, weight_decay=0.001):
@@ -350,15 +346,6 @@ def _get_lpips():
350
  _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
351
  return _lpips_net
352
 
353
- _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
354
- _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
355
- def sobel_edges(x: torch.Tensor) -> torch.Tensor:
356
- C = x.shape[1]
357
- kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
358
- ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
359
- gx = F.conv2d(x, kx, padding=1, groups=C)
360
- gy = F.conv2d(x, ky, padding=1, groups=C)
361
- return torch.sqrt(gx * gx + gy * gy + 1e-12)
362
 
363
  class MedianLossNormalizer:
364
  def __init__(self, desired_ratios: dict, window_steps: int):
@@ -532,7 +519,6 @@ for epoch in range(num_epochs):
532
  "mse": F.mse_loss(rec_f32, imgs_f32),
533
  "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
534
  "fdl": fdl_loss(rec_f32, imgs_f32),
535
- "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
536
  }
537
 
538
  if full_training and not train_decoder_only:
 
26
  import lpips # pip install lpips
27
  from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
  from collections import deque
29
+ import torch.nn.functional as F
30
 
31
  # --------------------------- Параметры ---------------------------
32
  ds_path = "/workspace/d23"
33
+ project = "vae2"
34
  batch_size = 1
35
  base_learning_rate = 6e-6
36
  min_learning_rate = 7e-7
 
53
  mixed_precision = "no"
54
  gradient_accumulation_steps = 1
55
  generated_folder = "samples"
56
+ save_as = "vae3"
57
  num_workers = 0
58
  device = None
59
  torch.backends.cuda.matmul.allow_tf32 = True
 
74
  loss_ratios = {
75
  "lpips": 0.70,#0.50,
76
  "fdl" : 0.10,#0.25,
 
77
  "mse": 0.10,
78
+ "mae": 0.10,
79
  "kl": 0.00,
80
  }
81
  median_coeff_steps = 250
 
195
 
196
 
197
  print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
198
+ for nm in unfrozen_param_names[:10]:
199
  print(" ", nm)
200
 
201
  # --------------------------- Датасет ---------------------------
202
+ from torch.utils.data import Dataset
203
+ from PIL import Image, UnidentifiedImageError
204
+ import random
205
+ import torchvision.transforms as transforms
206
+ import os
207
+
208
  class PngFolderDataset(Dataset):
209
+ def __init__(self, root_dir, resolution=1024, min_exts=('.png',), limit=0):
 
210
  self.resolution = resolution
211
  self.paths = []
212
+
213
  for root, _, files in os.walk(root_dir):
214
+ for f in files:
215
+ if f.lower().endswith(tuple(ext.lower() for ext in min_exts)):
216
+ self.paths.append(os.path.join(root, f))
217
+
218
  if limit:
219
  self.paths = self.paths[:limit]
220
+
221
+ # фильтруем недопустимые картинки
222
  valid = []
223
  for p in self.paths:
224
  try:
225
+ with Image.open(p) as img:
226
+ img.verify() # только метаданные
227
+ w, h = img.size
228
+ if w < resolution or h < resolution:
229
+ continue
230
  valid.append(p)
231
  except (OSError, UnidentifiedImageError):
232
  continue
233
+
234
  self.paths = valid
235
+ if not self.paths:
236
+ raise RuntimeError("No valid images found")
237
+
238
  random.shuffle(self.paths)
239
+ self.transform = transforms.ToTensor() # конвертирует сразу [0,1] float32
240
 
241
  def __len__(self):
242
  return len(self.paths)
 
245
  p = self.paths[idx % len(self.paths)]
246
  with Image.open(p) as img:
247
  img = img.convert("RGB")
248
+ return img
 
 
 
 
 
 
 
 
 
249
 
250
  def random_crop(img, sz):
251
  w, h = img.size
 
 
252
  x = random.randint(0, max(1, img.width - sz))
253
  y = random.randint(0, max(1, img.height - sz))
254
  return img.crop((x, y, x + sz, y + sz))
 
258
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
259
  ])
260
 
 
 
 
 
 
261
  def collate_fn(batch):
262
  imgs = []
263
  for img in batch:
 
265
  imgs.append(tfm(img))
266
  return torch.stack(imgs)
267
 
268
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
269
+ print("len(dataset)",len(dataset))
270
+ if len(dataset) < batch_size:
271
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
272
+
273
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, drop_last=True)
 
 
 
274
 
275
  # --------------------------- Оптимизатор ---------------------------
276
  def get_param_groups(module, weight_decay=0.001):
 
346
  _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
347
  return _lpips_net
348
 
 
 
 
 
 
 
 
 
 
349
 
350
  class MedianLossNormalizer:
351
  def __init__(self, desired_ratios: dict, window_steps: int):
 
519
  "mse": F.mse_loss(rec_f32, imgs_f32),
520
  "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
521
  "fdl": fdl_loss(rec_f32, imgs_f32),
 
522
  }
523
 
524
  if full_training and not train_decoder_only: