| import torch |
| import numpy as np |
| from PIL import Image |
| from diffusers import AutoencoderKL |
| from tqdm import tqdm |
| import pathlib |
|
|
| |
| vae = AutoencoderKL.from_pretrained("vae32ch", torch_dtype=torch.float32) |
| vae.eval().cuda() |
|
|
| vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) |
|
|
| |
| dataset_path = pathlib.Path("/workspace/ds") |
| image_paths = sorted(dataset_path.rglob("*.png")) |
| print(f"Найдено картинок: {len(image_paths)}") |
|
|
| |
| image_paths = image_paths[:30000] |
|
|
| |
| def preprocess(path): |
| img = Image.open(path).convert("RGB") |
| w, h = img.size |
|
|
| new_w = (w // vae_scale_factor) * vae_scale_factor |
| new_h = (h // vae_scale_factor) * vae_scale_factor |
|
|
| if new_w != w or new_h != h: |
| left = (w - new_w) // 2 |
| top = (h - new_h) // 2 |
| img = img.crop((left, top, left + new_w, top + new_h)) |
|
|
| x = torch.from_numpy(np.array(img).astype(np.float32) / 255.0) |
| x = x.permute(2, 0, 1).unsqueeze(0) |
| x = x * 2.0 - 1.0 |
| return x |
|
|
| |
| latent_channels = vae.config.latent_channels |
|
|
| all_means = [] |
| all_stds = [] |
| errors = [] |
|
|
| with torch.no_grad(): |
| for path in tqdm(image_paths, desc="Encoding"): |
| try: |
| x = preprocess(path).cuda() |
| lat = vae.encode(x).latent_dist.sample() |
| flat = lat.squeeze(0).float().reshape(latent_channels, -1) |
|
|
| all_means.append(flat.mean(dim=1).cpu()) |
| all_stds.append(flat.std(dim=1).cpu()) |
|
|
| except Exception as e: |
| errors.append((path, str(e))) |
|
|
| if errors: |
| print(f"\nОшибки ({len(errors)}):") |
| for p, e in errors: |
| print(f" {p}: {e}") |
|
|
| mean = torch.stack(all_means).mean(dim=0) |
| std = torch.stack(all_stds).mean(dim=0) |
|
|
| print(f"\nОбработано картинок: {len(all_means)}") |
| print(f"\nlatents_mean ({latent_channels} каналов):") |
| print(mean.tolist()) |
| print(f"\nlatents_std ({latent_channels} каналов):") |
| print(std.tolist()) |
|
|
| |
| cfg = vae.config |
|
|
| new_vae = AutoencoderKL( |
| in_channels = cfg.in_channels, |
| out_channels = cfg.out_channels, |
| latent_channels = cfg.latent_channels, |
| block_out_channels = cfg.block_out_channels, |
| layers_per_block = cfg.layers_per_block, |
| norm_num_groups = cfg.norm_num_groups, |
| act_fn = cfg.act_fn, |
| down_block_types = cfg.down_block_types, |
| up_block_types = cfg.up_block_types, |
| ) |
| new_vae.eval() |
|
|
| |
| result = new_vae.load_state_dict(vae.state_dict(), strict=False) |
| print(f"\nВеса перенесены: {result}") |
|
|
| |
| new_vae.register_to_config( |
| latents_mean = mean.tolist(), |
| latents_std = std.tolist(), |
| scaling_factor = 1.0, |
| shift_factor = 0.0, |
| ) |
|
|
| print(f"\nlatents_mean в конфиге: {new_vae.config.latents_mean[:4]}...") |
| print(f"latents_std в конфиге: {new_vae.config.latents_std[:4]}...") |
|
|
| |
| new_vae.save_pretrained("vae32ch2") |
| print("\nСохранено в vae32ch2/") |
|
|