Upload folder using huggingface_hub
Browse files- .ipynb_checkpoints/config-checkpoint.json +103 -0
- Untitled.ipynb +326 -0
- config.json +103 -0
- diffusion_pytorch_model.safetensors +3 -0
- scale.py +107 -0
.ipynb_checkpoints/config-checkpoint.json
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"block_out_channels": [
|
| 6 |
+
128,
|
| 7 |
+
256,
|
| 8 |
+
512,
|
| 9 |
+
512
|
| 10 |
+
],
|
| 11 |
+
"down_block_types": [
|
| 12 |
+
"DownEncoderBlock2D",
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D"
|
| 16 |
+
],
|
| 17 |
+
"force_upcast": true,
|
| 18 |
+
"in_channels": 3,
|
| 19 |
+
"latent_channels": 32,
|
| 20 |
+
"latents_mean": [
|
| 21 |
+
-0.03542253375053406,
|
| 22 |
+
0.20086465775966644,
|
| 23 |
+
-0.016413161531090736,
|
| 24 |
+
-0.0956302210688591,
|
| 25 |
+
-0.2672063112258911,
|
| 26 |
+
0.2609933018684387,
|
| 27 |
+
-0.07806991040706635,
|
| 28 |
+
-0.48407721519470215,
|
| 29 |
+
0.21844269335269928,
|
| 30 |
+
-0.1122383326292038,
|
| 31 |
+
0.27197545766830444,
|
| 32 |
+
-0.18958772718906403,
|
| 33 |
+
0.18776826560497284,
|
| 34 |
+
0.0987580344080925,
|
| 35 |
+
0.2837068736553192,
|
| 36 |
+
-0.4486690163612366,
|
| 37 |
+
0.4816776514053345,
|
| 38 |
+
0.02947971224784851,
|
| 39 |
+
-0.1337375044822693,
|
| 40 |
+
-0.39750921726226807,
|
| 41 |
+
-0.08513020724058151,
|
| 42 |
+
-0.054023586213588715,
|
| 43 |
+
-0.3943594992160797,
|
| 44 |
+
0.23918119072914124,
|
| 45 |
+
-0.12466679513454437,
|
| 46 |
+
0.09935147315263748,
|
| 47 |
+
0.31858691573143005,
|
| 48 |
+
0.48585832118988037,
|
| 49 |
+
-0.6416525840759277,
|
| 50 |
+
-0.15164820849895477,
|
| 51 |
+
-0.4693508744239807,
|
| 52 |
+
-0.13071806728839874
|
| 53 |
+
],
|
| 54 |
+
"latents_std": [
|
| 55 |
+
1.5792087316513062,
|
| 56 |
+
1.5769503116607666,
|
| 57 |
+
1.5864241123199463,
|
| 58 |
+
1.6454921960830688,
|
| 59 |
+
1.5336694717407227,
|
| 60 |
+
1.5587652921676636,
|
| 61 |
+
1.5838669538497925,
|
| 62 |
+
1.5659377574920654,
|
| 63 |
+
1.6860467195510864,
|
| 64 |
+
1.5192310810089111,
|
| 65 |
+
1.573639988899231,
|
| 66 |
+
1.5953549146652222,
|
| 67 |
+
1.5271092653274536,
|
| 68 |
+
1.6246271133422852,
|
| 69 |
+
1.7054023742675781,
|
| 70 |
+
1.607722282409668,
|
| 71 |
+
1.558642864227295,
|
| 72 |
+
1.5824549198150635,
|
| 73 |
+
1.6202995777130127,
|
| 74 |
+
1.6206320524215698,
|
| 75 |
+
1.6379750967025757,
|
| 76 |
+
1.6527063846588135,
|
| 77 |
+
1.498811960220337,
|
| 78 |
+
1.5706247091293335,
|
| 79 |
+
1.5854856967926025,
|
| 80 |
+
1.4828169345855713,
|
| 81 |
+
1.5693111419677734,
|
| 82 |
+
1.692481517791748,
|
| 83 |
+
1.6409776210784912,
|
| 84 |
+
1.6216280460357666,
|
| 85 |
+
1.6087706089019775,
|
| 86 |
+
1.5776633024215698
|
| 87 |
+
],
|
| 88 |
+
"layers_per_block": 2,
|
| 89 |
+
"mid_block_add_attention": true,
|
| 90 |
+
"norm_num_groups": 32,
|
| 91 |
+
"out_channels": 3,
|
| 92 |
+
"sample_size": 32,
|
| 93 |
+
"scaling_factor": 1.0,
|
| 94 |
+
"shift_factor": 0.0,
|
| 95 |
+
"up_block_types": [
|
| 96 |
+
"UpDecoderBlock2D",
|
| 97 |
+
"UpDecoderBlock2D",
|
| 98 |
+
"UpDecoderBlock2D",
|
| 99 |
+
"UpDecoderBlock2D"
|
| 100 |
+
],
|
| 101 |
+
"use_post_quant_conv": true,
|
| 102 |
+
"use_quant_conv": true
|
| 103 |
+
}
|
Untitled.ipynb
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 4,
|
| 6 |
+
"id": "dccce86b-90a0-47c7-aaad-2ebb16d90756",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"Картинка загружена: torch.Size([1, 3, 1280, 1280])\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"=======================================================\n",
|
| 16 |
+
"VAE : FLUX.2\n",
|
| 17 |
+
"repo: AiArtLab/sdxs-1b\n",
|
| 18 |
+
"latent_channels : 32\n",
|
| 19 |
+
"scaling_factor : 1.00000\n",
|
| 20 |
+
"shift_factor : 0.00000\n",
|
| 21 |
+
"latents_mean : нет\n",
|
| 22 |
+
"latents_std : нет\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"[encode] raw latents: torch.Size([1, 32, 160, 160])\n",
|
| 25 |
+
"[flux2] patchify : torch.Size([1, 32, 160, 160]) → torch.Size([1, 128, 80, 80])\n",
|
| 26 |
+
"[flux2] BN norm : mean=-0.0096 std=1.7674\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"[STATS] после BN нормализации (128ch):\n",
|
| 29 |
+
" log-variance : -0.0767 (идеал ≈ 0.0)\n",
|
| 30 |
+
" mean : -0.0134\n",
|
| 31 |
+
" std : 0.9624\n",
|
| 32 |
+
" shape : torch.Size([1, 128, 80, 80])\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"[flux2] BN denorm + unpatchify: torch.Size([1, 32, 160, 160])\n",
|
| 35 |
+
"Сохранено: vaetest/decoded_FLUX.2.png\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"=======================================================\n",
|
| 38 |
+
"VAE : vae32ch2\n",
|
| 39 |
+
"repo: vae32ch2\n",
|
| 40 |
+
"latent_channels : 32\n",
|
| 41 |
+
"scaling_factor : 1.00000\n",
|
| 42 |
+
"shift_factor : 0.00000\n",
|
| 43 |
+
"latents_mean : да (32ch)\n",
|
| 44 |
+
"latents_std : да (32ch)\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"[encode] raw latents: torch.Size([1, 32, 160, 160])\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"[STATS] после per-channel нормализации (32ch):\n",
|
| 49 |
+
" log-variance : 0.1192 (идеал ≈ 0.0)\n",
|
| 50 |
+
" mean : -0.0016\n",
|
| 51 |
+
" std : 1.0614\n",
|
| 52 |
+
" shape : torch.Size([1, 32, 160, 160])\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"[vae32ch2] denorm: torch.Size([1, 32, 160, 160])\n",
|
| 55 |
+
"Сохранено: vaetest/decoded_vae32ch2.png\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"=======================================================\n",
|
| 58 |
+
"Готово\n"
|
| 59 |
+
]
|
| 60 |
+
}
|
| 61 |
+
],
|
| 62 |
+
"source": [
|
| 63 |
+
"\n",
|
| 64 |
+
"import torch\n",
|
| 65 |
+
"from PIL import Image\n",
|
| 66 |
+
"from diffusers import AutoencoderKL, AutoencoderKLFlux2\n",
|
| 67 |
+
"from torchvision.transforms.functional import to_pil_image\n",
|
| 68 |
+
"import matplotlib.pyplot as plt\n",
|
| 69 |
+
"import os\n",
|
| 70 |
+
"from torchvision.transforms import ToTensor, Normalize, CenterCrop\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"# ── Настройки ─────────────────────────────────────────────────────────────────\n",
|
| 73 |
+
"IMG_PATH = \"1234.png\"\n",
|
| 74 |
+
"OUT_DIR = \"vaetest\"\n",
|
| 75 |
+
"device = \"cuda\"\n",
|
| 76 |
+
"dtype = torch.float32\n",
|
| 77 |
+
"os.makedirs(OUT_DIR, exist_ok=True)\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"VAES = {\n",
|
| 80 |
+
" \"FLUX.2\": (\"flux2\", \"AiArtLab/sdxs-1b\"),\n",
|
| 81 |
+
" \"vae32ch2\": (\"vae32ch\", \"vae32ch2\"),\n",
|
| 82 |
+
"}\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"# ── Patchify / Unpatchify ─────────────────────────────────────────────────────\n",
|
| 85 |
+
"def _patchify_latents(latents):\n",
|
| 86 |
+
" B, C, H, W = latents.shape\n",
|
| 87 |
+
" latents = latents.view(B, C, H // 2, 2, W // 2, 2)\n",
|
| 88 |
+
" latents = latents.permute(0, 1, 3, 5, 2, 4)\n",
|
| 89 |
+
" latents = latents.reshape(B, C * 4, H // 2, W // 2)\n",
|
| 90 |
+
" return latents\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"def _unpatchify_latents(latents):\n",
|
| 93 |
+
" B, C, H, W = latents.shape\n",
|
| 94 |
+
" latents = latents.reshape(B, C // 4, 2, 2, H, W)\n",
|
| 95 |
+
" latents = latents.permute(0, 1, 4, 2, 5, 3)\n",
|
| 96 |
+
" latents = latents.reshape(B, C // 4, H * 2, W * 2)\n",
|
| 97 |
+
" return latents\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"# ── Загрузка картинки ─────────────────────────────────────────────────────────\n",
|
| 100 |
+
"def load_image(path):\n",
|
| 101 |
+
" img = Image.open(path).convert(\"RGB\")\n",
|
| 102 |
+
" w, h = img.size\n",
|
| 103 |
+
" img = CenterCrop((h // 8 * 8, w // 8 * 8))(img)\n",
|
| 104 |
+
" tensor = ToTensor()(img).unsqueeze(0)\n",
|
| 105 |
+
" tensor = Normalize(mean=[0.5]*3, std=[0.5]*3)(tensor)\n",
|
| 106 |
+
" return img, tensor.to(device, dtype=dtype)\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"def tensor_to_img(t):\n",
|
| 109 |
+
" t = (t * 0.5 + 0.5).clamp(0, 1)\n",
|
| 110 |
+
" return to_pil_image(t[0])\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"# ── Статистика ────────────────────────────────────────────────────────────────\n",
|
| 113 |
+
"def logvariance(latents):\n",
|
| 114 |
+
" return torch.log(latents.var() + 1e-8).item()\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"def print_stats(name, latents):\n",
|
| 117 |
+
" lv = logvariance(latents)\n",
|
| 118 |
+
" print(f\" log-variance : {lv:.4f} (идеал ≈ 0.0)\")\n",
|
| 119 |
+
" print(f\" mean : {latents.mean():.4f}\")\n",
|
| 120 |
+
" print(f\" std : {latents.std():.4f}\")\n",
|
| 121 |
+
" print(f\" shape : {latents.shape}\")\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"def plot_latent_distribution(latents, title, save_path):\n",
|
| 124 |
+
" from scipy.stats import probplot\n",
|
| 125 |
+
" lat = latents.detach().cpu().float().numpy().flatten()\n",
|
| 126 |
+
"\n",
|
| 127 |
+
" plt.figure(figsize=(10, 4))\n",
|
| 128 |
+
"\n",
|
| 129 |
+
" plt.subplot(1, 2, 1)\n",
|
| 130 |
+
" plt.hist(lat, bins=100, density=True, alpha=0.7, color=\"steelblue\")\n",
|
| 131 |
+
" plt.title(f\"{title} histogram\")\n",
|
| 132 |
+
" plt.xlabel(\"latent value\")\n",
|
| 133 |
+
" plt.ylabel(\"density\")\n",
|
| 134 |
+
"\n",
|
| 135 |
+
" plt.subplot(1, 2, 2)\n",
|
| 136 |
+
" probplot(lat, dist=\"norm\", plot=plt)\n",
|
| 137 |
+
" plt.title(f\"{title} QQ-plot\")\n",
|
| 138 |
+
"\n",
|
| 139 |
+
" plt.tight_layout()\n",
|
| 140 |
+
" plt.savefig(save_path)\n",
|
| 141 |
+
" plt.close()\n",
|
| 142 |
+
" print(f\" график сохранён: {save_path}\")\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"# ── Нормализация из конфига (per-channel для vae32ch) ────────────────────────\n",
|
| 145 |
+
"def make_norm_tensors(cfg, latent_channels, device, dtype):\n",
|
| 146 |
+
" mean = getattr(cfg, \"latents_mean\", None)\n",
|
| 147 |
+
" std = getattr(cfg, \"latents_std\", None)\n",
|
| 148 |
+
" shift = getattr(cfg, \"shift_factor\", 0.0)\n",
|
| 149 |
+
" scale = getattr(cfg, \"scaling_factor\", 1.0)\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" if mean is not None:\n",
|
| 152 |
+
" mean = torch.tensor(mean, device=device, dtype=dtype).view(1, latent_channels, 1, 1)\n",
|
| 153 |
+
" if std is not None:\n",
|
| 154 |
+
" std = torch.tensor(std, device=device, dtype=dtype).view(1, latent_channels, 1, 1)\n",
|
| 155 |
+
"\n",
|
| 156 |
+
" shift = torch.tensor(shift if shift else 0., device=device, dtype=dtype)\n",
|
| 157 |
+
" scale = torch.tensor(scale, device=device, dtype=dtype)\n",
|
| 158 |
+
" return mean, std, shift, scale\n",
|
| 159 |
+
"\n",
|
| 160 |
+
"# ── Основной цикл ─────────────────────────────────────────────────────────────\n",
|
| 161 |
+
"img, x = load_image(IMG_PATH)\n",
|
| 162 |
+
"img.save(os.path.join(OUT_DIR, \"original.png\"))\n",
|
| 163 |
+
"print(f\"Картинка загружена: {x.shape}\")\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"for name, (kind, repo) in VAES.items():\n",
|
| 166 |
+
" print(f\"\\n{'='*55}\")\n",
|
| 167 |
+
" print(f\"VAE : {name}\")\n",
|
| 168 |
+
" print(f\"repo: {repo}\")\n",
|
| 169 |
+
"\n",
|
| 170 |
+
" # --- загружаем нужный класс ---\n",
|
| 171 |
+
" if kind == \"flux2\":\n",
|
| 172 |
+
" vae = AutoencoderKLFlux2.from_pretrained(\n",
|
| 173 |
+
" repo, subfolder=\"vae\", torch_dtype=dtype\n",
|
| 174 |
+
" ).to(device)\n",
|
| 175 |
+
" else:\n",
|
| 176 |
+
" vae = AutoencoderKL.from_pretrained(\n",
|
| 177 |
+
" repo, torch_dtype=dtype\n",
|
| 178 |
+
" ).to(device)\n",
|
| 179 |
+
" vae.eval()\n",
|
| 180 |
+
"\n",
|
| 181 |
+
" latent_channels = vae.config.latent_channels\n",
|
| 182 |
+
" mean_t, std_t, shift_t, scale_t = make_norm_tensors(\n",
|
| 183 |
+
" vae.config, latent_channels, device, dtype\n",
|
| 184 |
+
" )\n",
|
| 185 |
+
"\n",
|
| 186 |
+
" print(f\"latent_channels : {latent_channels}\")\n",
|
| 187 |
+
" print(f\"scaling_factor : {scale_t.item():.5f}\")\n",
|
| 188 |
+
" print(f\"shift_factor : {shift_t.item():.5f}\")\n",
|
| 189 |
+
" print(f\"latents_mean : {'да (' + str(latent_channels) + 'ch)' if mean_t is not None else 'нет'}\")\n",
|
| 190 |
+
" print(f\"latents_std : {'да (' + str(latent_channels) + 'ch)' if std_t is not None else 'нет'}\")\n",
|
| 191 |
+
"\n",
|
| 192 |
+
" with torch.no_grad():\n",
|
| 193 |
+
"\n",
|
| 194 |
+
" # ── ENCODE ────────────────────────────────────────────────────────────\n",
|
| 195 |
+
" latents = vae.encode(x).latent_dist.sample().to(dtype)\n",
|
| 196 |
+
" print(f\"\\n[encode] raw latents: {latents.shape}\")\n",
|
| 197 |
+
"\n",
|
| 198 |
+
" if kind == \"flux2\":\n",
|
| 199 |
+
" # 32ch → patchify → 128ch\n",
|
| 200 |
+
" latents_patched = _patchify_latents(latents)\n",
|
| 201 |
+
" print(f\"[flux2] patchify : {latents.shape} → {latents_patched.shape}\")\n",
|
| 202 |
+
"\n",
|
| 203 |
+
" # BN нормализация в 128-канальном пространстве\n",
|
| 204 |
+
" bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(device, dtype)\n",
|
| 205 |
+
" bn_std = torch.sqrt(\n",
|
| 206 |
+
" vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps\n",
|
| 207 |
+
" ).to(device, dtype)\n",
|
| 208 |
+
" latents_normed = (latents_patched - bn_mean) / bn_std\n",
|
| 209 |
+
" print(f\"[flux2] BN norm : mean={bn_mean.mean():.4f} std={bn_std.mean():.4f}\")\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" # считаем статистику в 128ch нормализованном пространстве\n",
|
| 212 |
+
" print(\"\\n[STATS] после BN нормализации (128ch):\")\n",
|
| 213 |
+
" print_stats(name, latents_normed)\n",
|
| 214 |
+
" #plot_latent_distribution(\n",
|
| 215 |
+
" # latents_normed,\n",
|
| 216 |
+
" # f\"{name}_latents\",\n",
|
| 217 |
+
" # os.path.join(OUT_DIR, f\"dist_{name}.png\")\n",
|
| 218 |
+
" #)\n",
|
| 219 |
+
"\n",
|
| 220 |
+
" # unpatchify → 32ch (для decode)\n",
|
| 221 |
+
" latents = _unpatchify_latents(latents_normed)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
" else: # vae32ch2\n",
|
| 224 |
+
" # per-channel нормализация из конфига\n",
|
| 225 |
+
" if mean_t is not None and std_t is not None:\n",
|
| 226 |
+
" latents = (latents - mean_t) / std_t\n",
|
| 227 |
+
" latents = (latents - shift_t) / scale_t\n",
|
| 228 |
+
"\n",
|
| 229 |
+
" print(f\"\\n[STATS] после per-channel нормализации ({latent_channels}ch):\")\n",
|
| 230 |
+
" print_stats(name, latents)\n",
|
| 231 |
+
" #plot_latent_distribution(\n",
|
| 232 |
+
" # latents,\n",
|
| 233 |
+
" # f\"{name}_latents\",\n",
|
| 234 |
+
" # os.path.join(OUT_DIR, f\"dist_{name}.png\")\n",
|
| 235 |
+
" #)\n",
|
| 236 |
+
"\n",
|
| 237 |
+
" # ── DECODE ────────────────────────────────────────────────────────────\n",
|
| 238 |
+
" if kind == \"flux2\":\n",
|
| 239 |
+
" # patchify → denorm → unpatchify\n",
|
| 240 |
+
" latents_patched = _patchify_latents(latents)\n",
|
| 241 |
+
" latents_denormed = latents_patched * bn_std + bn_mean\n",
|
| 242 |
+
" latents = _unpatchify_latents(latents_denormed)\n",
|
| 243 |
+
" print(f\"\\n[flux2] BN denorm + unpatchify: {latents.shape}\")\n",
|
| 244 |
+
"\n",
|
| 245 |
+
" else: # vae32ch2\n",
|
| 246 |
+
" latents = latents * scale_t + shift_t\n",
|
| 247 |
+
" if mean_t is not None and std_t is not None:\n",
|
| 248 |
+
" latents = latents * std_t + mean_t\n",
|
| 249 |
+
" print(f\"\\n[vae32ch2] denorm: {latents.shape}\")\n",
|
| 250 |
+
"\n",
|
| 251 |
+
" rec = vae.decode(latents).sample\n",
|
| 252 |
+
"\n",
|
| 253 |
+
" out_path = os.path.join(OUT_DIR, f\"decoded_{name}.png\")\n",
|
| 254 |
+
" tensor_to_img(rec).save(out_path)\n",
|
| 255 |
+
" print(f\"Сохранено: {out_path}\")\n",
|
| 256 |
+
"\n",
|
| 257 |
+
"print(f\"\\n{'='*55}\")\n",
|
| 258 |
+
"print(\"Готово\")"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "code",
|
| 263 |
+
"execution_count": 3,
|
| 264 |
+
"id": "c219c07b-8da2-4182-ace6-8c3cc63ae3b1",
|
| 265 |
+
"metadata": {},
|
| 266 |
+
"outputs": [
|
| 267 |
+
{
|
| 268 |
+
"name": "stdout",
|
| 269 |
+
"output_type": "stream",
|
| 270 |
+
"text": [
|
| 271 |
+
"Requirement already satisfied: scipy in /usr/local/lib/python3.12/dist-packages (1.17.1)\n",
|
| 272 |
+
"Requirement already satisfied: numpy<2.7,>=1.26.4 in /usr/local/lib/python3.12/dist-packages (from scipy) (2.4.0)\n",
|
| 273 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
| 274 |
+
"\u001b[0m"
|
| 275 |
+
]
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"ename": "ModuleNotFoundError",
|
| 279 |
+
"evalue": "No module named 'scipy'",
|
| 280 |
+
"output_type": "error",
|
| 281 |
+
"traceback": [
|
| 282 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 283 |
+
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
|
| 284 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m get_ipython().system(\u001b[33m'\u001b[39m\u001b[33mpip install --user scipy\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mscipy\u001b[39;00m\n\u001b[32m 4\u001b[39m \u001b[38;5;28mprint\u001b[39m(scipy.__version__)\n",
|
| 285 |
+
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'scipy'"
|
| 286 |
+
]
|
| 287 |
+
}
|
| 288 |
+
],
|
| 289 |
+
"source": [
|
| 290 |
+
"!pip install --user scipy\n",
|
| 291 |
+
"\n",
|
| 292 |
+
"import scipy\n",
|
| 293 |
+
"print(scipy.__version__)\n"
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"cell_type": "code",
|
| 298 |
+
"execution_count": null,
|
| 299 |
+
"id": "43a4e1bc-2b02-4604-b69e-1a5aa276b6ac",
|
| 300 |
+
"metadata": {},
|
| 301 |
+
"outputs": [],
|
| 302 |
+
"source": []
|
| 303 |
+
}
|
| 304 |
+
],
|
| 305 |
+
"metadata": {
|
| 306 |
+
"kernelspec": {
|
| 307 |
+
"display_name": "Python3 (ipykernel)",
|
| 308 |
+
"language": "python",
|
| 309 |
+
"name": "python3"
|
| 310 |
+
},
|
| 311 |
+
"language_info": {
|
| 312 |
+
"codemirror_mode": {
|
| 313 |
+
"name": "ipython",
|
| 314 |
+
"version": 3
|
| 315 |
+
},
|
| 316 |
+
"file_extension": ".py",
|
| 317 |
+
"mimetype": "text/x-python",
|
| 318 |
+
"name": "python",
|
| 319 |
+
"nbconvert_exporter": "python",
|
| 320 |
+
"pygments_lexer": "ipython3",
|
| 321 |
+
"version": "3.12.12"
|
| 322 |
+
}
|
| 323 |
+
},
|
| 324 |
+
"nbformat": 4,
|
| 325 |
+
"nbformat_minor": 5
|
| 326 |
+
}
|
config.json
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "AutoencoderKL",
|
| 3 |
+
"_diffusers_version": "0.36.0",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"block_out_channels": [
|
| 6 |
+
128,
|
| 7 |
+
256,
|
| 8 |
+
512,
|
| 9 |
+
512
|
| 10 |
+
],
|
| 11 |
+
"down_block_types": [
|
| 12 |
+
"DownEncoderBlock2D",
|
| 13 |
+
"DownEncoderBlock2D",
|
| 14 |
+
"DownEncoderBlock2D",
|
| 15 |
+
"DownEncoderBlock2D"
|
| 16 |
+
],
|
| 17 |
+
"force_upcast": true,
|
| 18 |
+
"in_channels": 3,
|
| 19 |
+
"latent_channels": 32,
|
| 20 |
+
"latents_mean": [
|
| 21 |
+
-0.03542253375053406,
|
| 22 |
+
0.20086465775966644,
|
| 23 |
+
-0.016413161531090736,
|
| 24 |
+
-0.0956302210688591,
|
| 25 |
+
-0.2672063112258911,
|
| 26 |
+
0.2609933018684387,
|
| 27 |
+
-0.07806991040706635,
|
| 28 |
+
-0.48407721519470215,
|
| 29 |
+
0.21844269335269928,
|
| 30 |
+
-0.1122383326292038,
|
| 31 |
+
0.27197545766830444,
|
| 32 |
+
-0.18958772718906403,
|
| 33 |
+
0.18776826560497284,
|
| 34 |
+
0.0987580344080925,
|
| 35 |
+
0.2837068736553192,
|
| 36 |
+
-0.4486690163612366,
|
| 37 |
+
0.4816776514053345,
|
| 38 |
+
0.02947971224784851,
|
| 39 |
+
-0.1337375044822693,
|
| 40 |
+
-0.39750921726226807,
|
| 41 |
+
-0.08513020724058151,
|
| 42 |
+
-0.054023586213588715,
|
| 43 |
+
-0.3943594992160797,
|
| 44 |
+
0.23918119072914124,
|
| 45 |
+
-0.12466679513454437,
|
| 46 |
+
0.09935147315263748,
|
| 47 |
+
0.31858691573143005,
|
| 48 |
+
0.48585832118988037,
|
| 49 |
+
-0.6416525840759277,
|
| 50 |
+
-0.15164820849895477,
|
| 51 |
+
-0.4693508744239807,
|
| 52 |
+
-0.13071806728839874
|
| 53 |
+
],
|
| 54 |
+
"latents_std": [
|
| 55 |
+
1.5792087316513062,
|
| 56 |
+
1.5769503116607666,
|
| 57 |
+
1.5864241123199463,
|
| 58 |
+
1.6454921960830688,
|
| 59 |
+
1.5336694717407227,
|
| 60 |
+
1.5587652921676636,
|
| 61 |
+
1.5838669538497925,
|
| 62 |
+
1.5659377574920654,
|
| 63 |
+
1.6860467195510864,
|
| 64 |
+
1.5192310810089111,
|
| 65 |
+
1.573639988899231,
|
| 66 |
+
1.5953549146652222,
|
| 67 |
+
1.5271092653274536,
|
| 68 |
+
1.6246271133422852,
|
| 69 |
+
1.7054023742675781,
|
| 70 |
+
1.607722282409668,
|
| 71 |
+
1.558642864227295,
|
| 72 |
+
1.5824549198150635,
|
| 73 |
+
1.6202995777130127,
|
| 74 |
+
1.6206320524215698,
|
| 75 |
+
1.6379750967025757,
|
| 76 |
+
1.6527063846588135,
|
| 77 |
+
1.498811960220337,
|
| 78 |
+
1.5706247091293335,
|
| 79 |
+
1.5854856967926025,
|
| 80 |
+
1.4828169345855713,
|
| 81 |
+
1.5693111419677734,
|
| 82 |
+
1.692481517791748,
|
| 83 |
+
1.6409776210784912,
|
| 84 |
+
1.6216280460357666,
|
| 85 |
+
1.6087706089019775,
|
| 86 |
+
1.5776633024215698
|
| 87 |
+
],
|
| 88 |
+
"layers_per_block": 2,
|
| 89 |
+
"mid_block_add_attention": true,
|
| 90 |
+
"norm_num_groups": 32,
|
| 91 |
+
"out_channels": 3,
|
| 92 |
+
"sample_size": 32,
|
| 93 |
+
"scaling_factor": 1.0,
|
| 94 |
+
"shift_factor": 0.0,
|
| 95 |
+
"up_block_types": [
|
| 96 |
+
"UpDecoderBlock2D",
|
| 97 |
+
"UpDecoderBlock2D",
|
| 98 |
+
"UpDecoderBlock2D",
|
| 99 |
+
"UpDecoderBlock2D"
|
| 100 |
+
],
|
| 101 |
+
"use_post_quant_conv": true,
|
| 102 |
+
"use_quant_conv": true
|
| 103 |
+
}
|
diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6369e370ff02168a240a9ebfd47810dd7babb36f76b7d9999e5d78cb4a1976c2
|
| 3 |
+
size 336212308
|
scale.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from diffusers import AutoencoderKL
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import pathlib
|
| 7 |
+
|
| 8 |
+
# ── 1. Загружаем VAE ──────────────────────────────────────────────────────────
|
| 9 |
+
vae = AutoencoderKL.from_pretrained("vae32ch", torch_dtype=torch.float32)
|
| 10 |
+
vae.eval().cuda()
|
| 11 |
+
|
| 12 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) # = 8
|
| 13 |
+
|
| 14 |
+
# ── 2. Собираем все PNG рекурсивно ───────────────────────────────────────────
|
| 15 |
+
dataset_path = pathlib.Path("/workspace/ds")
|
| 16 |
+
image_paths = sorted(dataset_path.rglob("*.png"))
|
| 17 |
+
print(f"Найдено картинок: {len(image_paths)}")
|
| 18 |
+
|
| 19 |
+
# Берём первые 3000
|
| 20 |
+
image_paths = image_paths[:30000]
|
| 21 |
+
|
| 22 |
+
# ── 3. Препроцессинг — кроп до кратного 8 без ресайза ────────────────────────
|
| 23 |
+
def preprocess(path):
|
| 24 |
+
img = Image.open(path).convert("RGB")
|
| 25 |
+
w, h = img.size
|
| 26 |
+
|
| 27 |
+
new_w = (w // vae_scale_factor) * vae_scale_factor
|
| 28 |
+
new_h = (h // vae_scale_factor) * vae_scale_factor
|
| 29 |
+
|
| 30 |
+
if new_w != w or new_h != h:
|
| 31 |
+
left = (w - new_w) // 2
|
| 32 |
+
top = (h - new_h) // 2
|
| 33 |
+
img = img.crop((left, top, left + new_w, top + new_h))
|
| 34 |
+
|
| 35 |
+
x = torch.from_numpy(np.array(img).astype(np.float32) / 255.0)
|
| 36 |
+
x = x.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
|
| 37 |
+
x = x * 2.0 - 1.0 # [-1, 1]
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
# ── 4. Считаем статистику по каналам ─────────────────────────────────────────
|
| 41 |
+
latent_channels = vae.config.latent_channels # 32
|
| 42 |
+
|
| 43 |
+
all_means = [] # [N, C]
|
| 44 |
+
all_stds = [] # [N, C]
|
| 45 |
+
errors = []
|
| 46 |
+
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
for path in tqdm(image_paths, desc="Encoding"):
|
| 49 |
+
try:
|
| 50 |
+
x = preprocess(path).cuda()
|
| 51 |
+
lat = vae.encode(x).latent_dist.sample() # [1, C, H, W]
|
| 52 |
+
flat = lat.squeeze(0).float().reshape(latent_channels, -1) # [C, H*W]
|
| 53 |
+
|
| 54 |
+
all_means.append(flat.mean(dim=1).cpu()) # [C]
|
| 55 |
+
all_stds.append(flat.std(dim=1).cpu()) # [C]
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
errors.append((path, str(e)))
|
| 59 |
+
|
| 60 |
+
if errors:
|
| 61 |
+
print(f"\nОшибки ({len(errors)}):")
|
| 62 |
+
for p, e in errors:
|
| 63 |
+
print(f" {p}: {e}")
|
| 64 |
+
|
| 65 |
+
mean = torch.stack(all_means).mean(dim=0) # [C]
|
| 66 |
+
std = torch.stack(all_stds).mean(dim=0) # [C]
|
| 67 |
+
|
| 68 |
+
print(f"\nОбработано картинок: {len(all_means)}")
|
| 69 |
+
print(f"\nlatents_mean ({latent_channels} каналов):")
|
| 70 |
+
print(mean.tolist())
|
| 71 |
+
print(f"\nlatents_std ({latent_channels} каналов):")
|
| 72 |
+
print(std.tolist())
|
| 73 |
+
|
| 74 |
+
# ── 5. Создаём новый VAE с той же архитектурой + scaling векторы ──────────────
|
| 75 |
+
cfg = vae.config
|
| 76 |
+
|
| 77 |
+
new_vae = AutoencoderKL(
|
| 78 |
+
in_channels = cfg.in_channels,
|
| 79 |
+
out_channels = cfg.out_channels,
|
| 80 |
+
latent_channels = cfg.latent_channels,
|
| 81 |
+
block_out_channels = cfg.block_out_channels,
|
| 82 |
+
layers_per_block = cfg.layers_per_block,
|
| 83 |
+
norm_num_groups = cfg.norm_num_groups,
|
| 84 |
+
act_fn = cfg.act_fn,
|
| 85 |
+
down_block_types = cfg.down_block_types,
|
| 86 |
+
up_block_types = cfg.up_block_types,
|
| 87 |
+
)
|
| 88 |
+
new_vae.eval()
|
| 89 |
+
|
| 90 |
+
# Переносим веса
|
| 91 |
+
result = new_vae.load_state_dict(vae.state_dict(), strict=False)
|
| 92 |
+
print(f"\nВеса перенесены: {result}")
|
| 93 |
+
|
| 94 |
+
# Прописываем scaling векторы в конфиг
|
| 95 |
+
new_vae.register_to_config(
|
| 96 |
+
latents_mean = mean.tolist(),
|
| 97 |
+
latents_std = std.tolist(),
|
| 98 |
+
scaling_factor = 1.0,
|
| 99 |
+
shift_factor = 0.0,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
print(f"\nlatents_mean в конфиге: {new_vae.config.latents_mean[:4]}...")
|
| 103 |
+
print(f"latents_std в конфиге: {new_vae.config.latents_std[:4]}...")
|
| 104 |
+
|
| 105 |
+
# ── 6. Сохраняем ──────────────────────────────────────────────────────────────
|
| 106 |
+
new_vae.save_pretrained("vae32ch2")
|
| 107 |
+
print("\nСохранено в vae32ch2/")
|