File size: 4,311 Bytes
14c7142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import numpy as np
from PIL import Image
from diffusers import AutoencoderKL
from tqdm import tqdm
import pathlib

# ── 1. Загружаем VAE ──────────────────────────────────────────────────────────
vae = AutoencoderKL.from_pretrained("vae32ch", torch_dtype=torch.float32)
vae.eval().cuda()

vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)  # = 8

# ── 2. Собираем все PNG рекурсивно ───────────────────────────────────────────
dataset_path = pathlib.Path("/workspace/ds")
image_paths  = sorted(dataset_path.rglob("*.png"))
print(f"Найдено картинок: {len(image_paths)}")

# Берём первые 3000
image_paths = image_paths[:30000]

# ── 3. Препроцессинг — кроп до кратного 8 без ресайза ────────────────────────
def preprocess(path):
    img = Image.open(path).convert("RGB")
    w, h = img.size

    new_w = (w // vae_scale_factor) * vae_scale_factor
    new_h = (h // vae_scale_factor) * vae_scale_factor

    if new_w != w or new_h != h:
        left = (w - new_w) // 2
        top  = (h - new_h) // 2
        img  = img.crop((left, top, left + new_w, top + new_h))

    x = torch.from_numpy(np.array(img).astype(np.float32) / 255.0)
    x = x.permute(2, 0, 1).unsqueeze(0)  # [1, 3, H, W]
    x = x * 2.0 - 1.0                    # [-1, 1]
    return x

# ── 4. Считаем статистику по каналам ─────────────────────────────────────────
latent_channels = vae.config.latent_channels  # 32

all_means = []  # [N, C]
all_stds  = []  # [N, C]
errors    = []

with torch.no_grad():
    for path in tqdm(image_paths, desc="Encoding"):
        try:
            x    = preprocess(path).cuda()
            lat  = vae.encode(x).latent_dist.sample()          # [1, C, H, W]
            flat = lat.squeeze(0).float().reshape(latent_channels, -1)  # [C, H*W]

            all_means.append(flat.mean(dim=1).cpu())  # [C]
            all_stds.append(flat.std(dim=1).cpu())    # [C]

        except Exception as e:
            errors.append((path, str(e)))

if errors:
    print(f"\nОшибки ({len(errors)}):")
    for p, e in errors:
        print(f"  {p}: {e}")

mean = torch.stack(all_means).mean(dim=0)  # [C]
std  = torch.stack(all_stds).mean(dim=0)   # [C]

print(f"\nОбработано картинок: {len(all_means)}")
print(f"\nlatents_mean ({latent_channels} каналов):")
print(mean.tolist())
print(f"\nlatents_std ({latent_channels} каналов):")
print(std.tolist())

# ── 5. Создаём новый VAE с той же архитектурой + scaling векторы ──────────────
cfg = vae.config

new_vae = AutoencoderKL(
    in_channels        = cfg.in_channels,
    out_channels       = cfg.out_channels,
    latent_channels    = cfg.latent_channels,
    block_out_channels = cfg.block_out_channels,
    layers_per_block   = cfg.layers_per_block,
    norm_num_groups    = cfg.norm_num_groups,
    act_fn             = cfg.act_fn,
    down_block_types   = cfg.down_block_types,
    up_block_types     = cfg.up_block_types,
)
new_vae.eval()

# Переносим веса
result = new_vae.load_state_dict(vae.state_dict(), strict=False)
print(f"\nВеса перенесены: {result}")

# Прописываем scaling векторы в конфиг
new_vae.register_to_config(
    latents_mean   = mean.tolist(),
    latents_std    = std.tolist(),
    scaling_factor = 1.0,
    shift_factor   = 0.0,
)

print(f"\nlatents_mean в конфиге: {new_vae.config.latents_mean[:4]}...")
print(f"latents_std  в конфиге: {new_vae.config.latents_std[:4]}...")

# ── 6. Сохраняем ──────────────────────────────────────────────────────────────
new_vae.save_pretrained("vae32ch2")
print("\nСохранено в vae32ch2/")