Spaces:
Paused
Paused
lanny xu
commited on
Commit
ยท
6b60426
1
Parent(s):
447a3ac
delete VAE all files
Browse files- Channel่ฎฒ่งฃ.txt +0 -337
- vae_model_structure.py +0 -162
- vae_structrue.txt +0 -506
- vae_training_example.py +0 -239
Channel่ฎฒ่งฃ.txt
DELETED
|
@@ -1,337 +0,0 @@
|
|
| 1 |
-
ไปไนๆฏ้้๏ผChannel๏ผ๏ผ
|
| 2 |
-
1. ๅพๅ็้้
|
| 3 |
-
้้็ๅบๆฌๆฆๅฟต๏ผ
|
| 4 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 5 |
-
|
| 6 |
-
้้ = ๅพๅ็"็ปดๅบฆ"ๆ"ๅฑ"
|
| 7 |
-
ๆฏไธช้้ๆฏไธไธช 2D ็ฉ้ต๏ผๅ
ๅซ็นๅฎ็ไฟกๆฏ
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
็คบไพ 1: ็ฐๅบฆๅพๅ๏ผ1 ้้๏ผ
|
| 11 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 12 |
-
|
| 13 |
-
ๅฝข็ถ: (1, 28, 28)
|
| 14 |
-
โ โ โ
|
| 15 |
-
้้ๆฐ ้ซ ๅฎฝ
|
| 16 |
-
|
| 17 |
-
ๅชๆไธไธช้้๏ผๅญๅจ็ฐๅบฆๅผ:
|
| 18 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 19 |
-
โ ้้ 0 (็ฐๅบฆๅผ) โ
|
| 20 |
-
โ โโโโโโโโโโโโโโโโโโโโ
|
| 21 |
-
โ [0.2, 0.5, 0.8, ...]โ
|
| 22 |
-
โ [0.1, 0.9, 0.3, ...]โ
|
| 23 |
-
โ [0.7, 0.4, 0.6, ...]โ
|
| 24 |
-
โ ... โ
|
| 25 |
-
โ โ
|
| 26 |
-
โ 28ร28 ็็ฉ้ต โ
|
| 27 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 28 |
-
|
| 29 |
-
ๆฏไธชๅ็ด ๅผ: 0.0 (้ป) ~ 1.0 (็ฝ)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
็คบไพ 2: RGB ๅฝฉ่ฒๅพๅ๏ผ3 ้้๏ผ
|
| 33 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 34 |
-
|
| 35 |
-
ๅฝข็ถ: (3, 224, 224)
|
| 36 |
-
โ โ โ
|
| 37 |
-
3้้ ้ซ ๅฎฝ
|
| 38 |
-
|
| 39 |
-
ๆไธไธช้้๏ผๅๅซๅญๅจ RGB ไฟกๆฏ:
|
| 40 |
-
|
| 41 |
-
โโโโโโโโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโโโโโโโ
|
| 42 |
-
โ ้้ 0 (Red) โ โ ้้ 1 (Green) โ โ ้้ 2 (Blue) โ
|
| 43 |
-
โ โโโโโโโโโโโโโโโโโโโโ โ โโโโโโโโโโโโโโโโโโโโ โ โโโโโโโโโโโโโโโโโโโโ
|
| 44 |
-
โ [0.9, 0.2, 0.1, ...]โ โ [0.1, 0.8, 0.3, ...]โ โ [0.2, 0.3, 0.9, ...]โ
|
| 45 |
-
โ [0.8, 0.3, 0.2, ...]โ โ [0.2, 0.7, 0.4, ...]โ โ [0.1, 0.4, 0.8, ...]โ
|
| 46 |
-
โ [0.7, 0.4, 0.3, ...]โ โ [0.3, 0.6, 0.5, ...]โ โ [0.3, 0.5, 0.7, ...]โ
|
| 47 |
-
โ ... โ โ ... โ โ ... โ
|
| 48 |
-
โ 224ร224 โ โ 224ร224 โ โ 224ร224 โ
|
| 49 |
-
โโโโโโโโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโโโโโโโ
|
| 50 |
-
็บข่ฒๅผบๅบฆ ็ปฟ่ฒๅผบๅบฆ ่่ฒๅผบๅบฆ
|
| 51 |
-
|
| 52 |
-
ไธไธชๅ็ด ็น็ๅฎๆด้ข่ฒ = (R, G, B)
|
| 53 |
-
ไพๅฆ: ไฝ็ฝฎ (10, 15) ็้ข่ฒ = (0.9, 0.1, 0.2) โ ็บข่ฒๅๅค
|
| 54 |
-
|
| 55 |
-
2. ๅท็งฏๅ็้้๏ผ็นๅพๅพ Feature Map๏ผ
|
| 56 |
-
1 ้้ โ 32 ้้็ๅซไน๏ผ
|
| 57 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 58 |
-
|
| 59 |
-
่พๅ
ฅ: (1, 28, 28) # 1 ไธช็ฐๅบฆ้้
|
| 60 |
-
โ ๅท็งฏๆไฝ
|
| 61 |
-
่พๅบ: (32, 14, 14) # 32 ไธช็นๅพ้้
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
ไปไนๆฏ 32 ไธช้้๏ผ
|
| 65 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 66 |
-
|
| 67 |
-
ๆฏไธช้้ = ไธไธชๅท็งฏๆ ธ๏ผๆปคๆณขๅจ๏ผๆฃๆตๅฐ็็นๅพ
|
| 68 |
-
32 ไธช้้ = 32 ไธชไธๅ็็นๅพๆฃๆตๅจ
|
| 69 |
-
|
| 70 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 71 |
-
โ ้้ 0: ๆฃๆต่พน็ผ โ 14ร14 ็ฉ้ต
|
| 72 |
-
โ โโโโโโโโโโโโโโโโโโโโ
|
| 73 |
-
โ [0.8, 0.1, 0.0, ...]โ ้ซๆฟๆดปๅผ โ ๆฃๆตๅฐ่พน็ผ
|
| 74 |
-
โ [0.9, 0.2, 0.0, ...]โ
|
| 75 |
-
โ ... โ
|
| 76 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 77 |
-
|
| 78 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 79 |
-
โ ้้ 1: ๆฃๆตๅๅฝข โ 14ร14 ็ฉ้ต
|
| 80 |
-
โ โโโโโโโโโโโโโโโโโโโโ
|
| 81 |
-
โ [0.0, 0.7, 0.0, ...]โ ้ซๆฟๆดปๅผ โ ๆฃๆตๅฐๅๅฝข
|
| 82 |
-
โ [0.0, 0.9, 0.1, ...]โ
|
| 83 |
-
โ ... โ
|
| 84 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 85 |
-
|
| 86 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 87 |
-
โ ้้ 2: ๆฃๆต็บน็ โ 14ร14 ็ฉ้ต
|
| 88 |
-
โ โโโโโโโโโโโโโโโโโโโโ
|
| 89 |
-
โ [0.3, 0.4, 0.8, ...]โ
|
| 90 |
-
โ ... โ
|
| 91 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 92 |
-
|
| 93 |
-
...
|
| 94 |
-
|
| 95 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 96 |
-
โ ้้ 31: ๆฃๆตๅคๆๆจกๅผโ 14ร14 ็ฉ้ต
|
| 97 |
-
โ โโโโโโโโโโโโโโโโโโโโ
|
| 98 |
-
โ [0.2, 0.6, 0.3, ...]โ
|
| 99 |
-
โ ... โ
|
| 100 |
-
โโโโโโโโโโโโโโโโโโโโโโโ
|
| 101 |
-
|
| 102 |
-
ๆปๅ
ฑ 32 ไธช็นๅพๅพ๏ผๆฏไธช้ฝๆฏ 14ร14 ็็ฉ้ต
|
| 103 |
-
|
| 104 |
-
3. ๅท็งฏๅฆไฝไบง็ๅค้้๏ผ
|
| 105 |
-
|
| 106 |
-
Conv2d(in_channels=1, out_channels=32, kernel_size=4)
|
| 107 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ๏ฟฝ๏ฟฝ๏ฟฝโโโโโโโโโโโโโโโโโโโ
|
| 108 |
-
|
| 109 |
-
ๅๆฐ่งฃ้๏ผ
|
| 110 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 111 |
-
in_channels=1: ่พๅ
ฅๆ 1 ไธช้้๏ผ็ฐๅบฆๅพ๏ผ
|
| 112 |
-
out_channels=32: ่พๅบๆ 32 ไธช้้๏ผ32 ไธช็นๅพๅพ๏ผ
|
| 113 |
-
kernel_size=4: ๆฏไธชๅท็งฏๆ ธๆฏ 4ร4 ็็ฉ้ต
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
ๅ
้จๆ 32 ไธชๅท็งฏๆ ธ๏ผๆปคๆณขๅจ๏ผ๏ผ
|
| 117 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 118 |
-
|
| 119 |
-
ๅท็งฏๆ ธ 1 (4ร4): ๅท็งฏๆ ธ 2 (4ร4): ๅท็งฏๆ ธ 32 (4ร4):
|
| 120 |
-
โโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโ
|
| 121 |
-
โ [ 1 1 1 1] โ โ [ 1 0 -1 0] โ โ [ 0 1 0 1] โ
|
| 122 |
-
โ [ 1 1 1 1] โ โ [ 0 1 0 -1] โ โ [ 1 0 1 0] โ
|
| 123 |
-
โ [-1 -1 -1 -1] โ โ [-1 0 1 0] โ โ [ 0 1 0 1] โ
|
| 124 |
-
โ [-1 -1 -1 -1] โ โ [ 0 -1 0 1] โ โ [ 1 0 1 0] โ
|
| 125 |
-
โโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโ
|
| 126 |
-
ๆฃๆตๆฐดๅนณ่พน็ผ ๆฃๆตๅฏน่ง็บฟ ๆฃๆตๆฃ็ๆ ผ
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
ๅท็งฏ่ฟ็จ๏ผ
|
| 130 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 131 |
-
|
| 132 |
-
่พๅ
ฅๅพๅ (1, 28, 28):
|
| 133 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 134 |
-
โ ่พๅ
ฅ้้ 0 โ
|
| 135 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 136 |
-
โ [0.2, 0.5, 0.8, 0.1, ...] โ
|
| 137 |
-
โ [0.1, 0.9, 0.3, 0.4, ...] โ
|
| 138 |
-
โ [0.7, 0.4, 0.6, 0.2, ...] โ
|
| 139 |
-
โ ... โ
|
| 140 |
-
โ 28ร28 โ
|
| 141 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 142 |
-
โ
|
| 143 |
-
โโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโ ... โโโ
|
| 144 |
-
โ โ โ โ
|
| 145 |
-
ๅท็งฏๆ ธ 1 ๆซๆ ๅท็งฏๆ ธ 2 ๆซๆ ๅท็งฏๆ ธ 3 ๆซๆ ... ๅท็งฏๆ ธ 32 ๆซๆ
|
| 146 |
-
โ โ โ โ
|
| 147 |
-
่พๅบ้้ 0 ่พๅบ้้ 1 ่พๅบ้้ 2 ... ่พๅบ้้ 31
|
| 148 |
-
(14ร14) (14ร14) (14ร14) (14ร14)
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
ๆ็ป่พๅบ (32, 14, 14):
|
| 152 |
-
32 ไธช็นๅพๅพ๏ผๆฏไธช 14ร14
|
| 153 |
-
|
| 154 |
-
4. ไธบไปไน้่ฆๅค้้๏ผ
|
| 155 |
-
ๅค้้็ไฝ็จ๏ผ
|
| 156 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 157 |
-
|
| 158 |
-
1. ๆๅไธๅ็็นๅพ
|
| 159 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 160 |
-
|
| 161 |
-
่พๅ
ฅ: ๆๅๆฐๅญ "8" ็ๅพๅ
|
| 162 |
-
|
| 163 |
-
้้ 0: ๆฃๆตๅ็ด่พน็ผ โ ๆฟๆดปๅผ้ซๅจ 8 ็ไธคไธช็ซ็บฟ
|
| 164 |
-
้้ 1: ๆฃๆตๆฐดๅนณ่พน็ผ โ ๆฟๆดปๅผ้ซๅจ 8 ็ไธไธญไธๆจช็บฟ
|
| 165 |
-
้้ 2: ๆฃๆตๅๅฝข โ ๆฟๆดปๅผ้ซๅจ 8 ็ไธไธไธคไธชๅ
|
| 166 |
-
้้ 3: ๆฃๆตไบคๅ็น โ ๆฟๆดปๅผ้ซๅจ 8 ็ไธญ้ดไบคๅๅค
|
| 167 |
-
...
|
| 168 |
-
้้ 31: ๆฃๆตๅคๆ็บน็ โ ๆฟๆดปๅผ้ซๅจ็นๅฎไฝ็ฝฎ
|
| 169 |
-
|
| 170 |
-
้่ฟ 32 ไธชไธๅ็็นๅพ๏ผๆจกๅๅฏไปฅไปไธๅ่งๅบฆ็่งฃๅพๅ๏ผ
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
2. ้ๅฑๆๅๆฝ่ฑก็นๅพ
|
| 174 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 175 |
-
|
| 176 |
-
Layer 1 (1โ32 ้้):
|
| 177 |
-
ๆฃๆตไฝ็บง็นๅพ: ่พน็ผใ็บฟๆกใ่ง็น
|
| 178 |
-
|
| 179 |
-
Layer 2 (32โ64 ้้):
|
| 180 |
-
็ปๅไฝ็บง็นๅพ๏ผๆฃๆตไธญ็บง็นๅพ: ๆฒ็บฟใ็ฎๅๅฝข็ถ
|
| 181 |
-
|
| 182 |
-
Layer 3 (64โ128 ้้):
|
| 183 |
-
็ปๅไธญ็บง็นๅพ๏ผๆฃๆต้ซ็บง็นๅพ: ๅฎๆด็ๆฐๅญๅฝข็ถ
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
3. ๅขๅผบ่กจ่พพ่ฝๅ
|
| 187 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 188 |
-
|
| 189 |
-
ๅฆๆๅชๆ 1 ไธช้้๏ผ1 ไธชๅท็งฏๆ ธ๏ผ:
|
| 190 |
-
โ ๅช่ฝๆฃๆตไธ็งๆจกๅผ
|
| 191 |
-
โ ่กจ่พพ่ฝๅๆ้
|
| 192 |
-
โ ๅ็กฎ็ไฝ
|
| 193 |
-
|
| 194 |
-
ๆ 32 ไธช้้๏ผ32 ไธชๅท็งฏๆ ธ๏ผ:
|
| 195 |
-
โ
ๅฏไปฅๅๆถๆฃๆต 32 ็งไธๅ็ๆจกๅผ
|
| 196 |
-
โ
่กจ่พพ่ฝๅๅผบ
|
| 197 |
-
โ
ๅ็กฎ็้ซ
|
| 198 |
-
|
| 199 |
-
5. ้้ๆฐ็้ๆฉ
|
| 200 |
-
ไธบไปไน้ๆฉ 32ใ64ใ128 ็ญ๏ผ
|
| 201 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 202 |
-
|
| 203 |
-
ๅ
ธๅ็ CNN ้้ๆฐ่ฎพ่ฎก๏ผ
|
| 204 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 205 |
-
|
| 206 |
-
่พๅ
ฅ: 1 ้้๏ผ็ฐๅบฆๅพ๏ผๆ 3 ้้๏ผRGB ๅพ๏ผ
|
| 207 |
-
โ
|
| 208 |
-
Conv1: 32 ๏ฟฝ๏ฟฝ้ # ๆๅๅบ็ก็นๅพ
|
| 209 |
-
โ
|
| 210 |
-
Conv2: 64 ้้ # ็นๅพๆฐ้็ฟปๅ
|
| 211 |
-
โ
|
| 212 |
-
Conv3: 128 ้้ # ่ฟไธๆญฅๅขๅ
|
| 213 |
-
โ
|
| 214 |
-
Conv4: 256 ้้ # ้ซๅฑๆฝ่ฑก็นๅพ
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
่งๅพ๏ผ
|
| 218 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 219 |
-
|
| 220 |
-
โ
้้ๆฐ้ๅฑๅขๅ ๏ผ32 โ 64 โ 128 โ 256๏ผ
|
| 221 |
-
- ็ฉบ้ดๅฐบๅฏธ้ๅฑๅๅฐ๏ผ28ร28 โ 14ร14 โ 7ร7๏ผ
|
| 222 |
-
- ็จๆดๅค้้ๅผฅ่กฅ็ฉบ้ดไฟกๆฏ็ๆๅคฑ
|
| 223 |
-
|
| 224 |
-
โ
้ๅธธๆฏ 2 ็ๅนๆฌก๏ผ32, 64, 128, 256๏ผ
|
| 225 |
-
- ไพฟไบ GPU ่ฎก็ฎไผๅ
|
| 226 |
-
- ๆนไพฟๅ
ๅญๅฏน้ฝ
|
| 227 |
-
|
| 228 |
-
โ
ๆ นๆฎไปปๅกๅคๆๅบฆ่ฐๆด
|
| 229 |
-
- ็ฎๅไปปๅก: 16, 32, 64
|
| 230 |
-
- ๅคๆไปปๅก: 64, 128, 256, 512
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
ๅๆฐ้ๅฏนๆฏ๏ผ
|
| 234 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 235 |
-
|
| 236 |
-
Conv2d(1, 32, kernel_size=4):
|
| 237 |
-
ๅๆฐ้ = 1 ร 32 ร 4 ร 4 = 512
|
| 238 |
-
|
| 239 |
-
Conv2d(1, 64, kernel_size=4):
|
| 240 |
-
ๅๆฐ้ = 1 ร 64 ร 4 ร 4 = 1,024 # ็ฟปๅ
|
| 241 |
-
|
| 242 |
-
Conv2d(32, 64, kernel_size=4):
|
| 243 |
-
ๅๆฐ้ = 32 ร 64 ร 4 ร 4 = 32,768 # ๆดๅค๏ผ
|
| 244 |
-
|
| 245 |
-
้้ๆฐ่ถๅค โ ๅๆฐ้่ถๅคง โ ่กจ่พพ่ฝๅ่ถๅผบ๏ผไฝ่ฎก็ฎๆๆฌไน่ถ้ซ
|
| 246 |
-
|
| 247 |
-
6. ็ด่ง็ฑปๆฏ
|
| 248 |
-
้้็็ฑปๆฏ็่งฃ๏ผ
|
| 249 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 250 |
-
|
| 251 |
-
็ฑปๆฏ 1: ๅคไธชไธๅฎถ
|
| 252 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 253 |
-
|
| 254 |
-
่พๅ
ฅๅพๅ = ไธไปฝ็
ๅ
|
| 255 |
-
|
| 256 |
-
1 ไธช้้ = 1 ไธชๅป็็็
ๅ
|
| 257 |
-
โ ๅชๆไธไธช่ง่ง๏ผๅฏ่ฝๆผ่ฏ
|
| 258 |
-
|
| 259 |
-
32 ไธช้้ = 32 ไธชไธๅฎถๅๆถ็็
ๅ
|
| 260 |
-
โ ไธๅฎถ 1: ็่พน็ผ๏ผๆฏๅฆๆๅคไผค๏ผ
|
| 261 |
-
โ ไธๅฎถ 2: ็้ข่ฒ๏ผๆฏๅฆๆ็็๏ผ
|
| 262 |
-
โ ไธๅฎถ 3: ็ๅฝข็ถ๏ผๆฏๅฆๆ่ฟ็ค๏ผ
|
| 263 |
-
โ ...
|
| 264 |
-
โ ไธๅฎถ 32: ็ๅคๆๆจกๅผ
|
| 265 |
-
|
| 266 |
-
็ปผๅ 32 ไธชไธๅฎถ็ๆ่ง โ ๆดๅ็กฎ็่ฏๆญ๏ผ
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
็ฑปๆฏ 2: ๅคไธชๆปค้
|
| 270 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 271 |
-
|
| 272 |
-
่พๅ
ฅๅพๅ = ๅๅง็
ง็
|
| 273 |
-
|
| 274 |
-
1 ไธช้้ = ๅบ็จ 1 ไธชๆปค้
|
| 275 |
-
โ ๅชๆไธ็งๆๆ
|
| 276 |
-
|
| 277 |
-
32 ไธช้้ = ๅๆถๅบ็จ 32 ไธชๆปค้
|
| 278 |
-
โ ๆปค้ 1: ่พน็ผๅขๅผบ
|
| 279 |
-
โ ๆปค้ 2: ๅฏนๆฏๅบฆๅขๅผบ
|
| 280 |
-
โ ๆปค้ 3: ้ซๆฏๆจก็ณ
|
| 281 |
-
โ ...
|
| 282 |
-
โ ๆปค้ 32: ๅคๆๅๆข
|
| 283 |
-
|
| 284 |
-
ๆฏไธชๆปค้ๆๅไธๅ็่ง่งไฟกๆฏ๏ผ
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
็ฑปๆฏ 3: ๅคไธชไพฆๆข
|
| 288 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 289 |
-
|
| 290 |
-
่พๅ
ฅๅพๅ = ๆกๅ็ฐๅบ
|
| 291 |
-
|
| 292 |
-
1 ไธช้้ = 1 ไธชไพฆๆข่ฐๆฅ
|
| 293 |
-
โ ๅฏ่ฝ้ๆผ็บฟ็ดข
|
| 294 |
-
|
| 295 |
-
32 ไธช้้ = 32 ไธชไพฆๆขๅๆถ่ฐๆฅ
|
| 296 |
-
โ ไพฆๆข 1: ๆฃๆฅๆ็บน
|
| 297 |
-
โ ไพฆๆข 2: ๅๆ่ถณ่ฟน
|
| 298 |
-
โ ไพฆๆข 3: ๆฅ็็ๆง
|
| 299 |
-
โ ...
|
| 300 |
-
โ ไพฆๆข 32: ็ปผๅๅๆ
|
| 301 |
-
|
| 302 |
-
ๆฏไธชไพฆๆขๅ
ณๆณจไธๅ็็บฟ็ดข๏ผๆๅ็ปผๅๅพๅบ็ป่ฎบ๏ผ
|
| 303 |
-
|
| 304 |
-
ๆป็ป
|
| 305 |
-
้้๏ผChannel๏ผๆป็ป๏ผ
|
| 306 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 307 |
-
|
| 308 |
-
ๅฎไน๏ผ
|
| 309 |
-
้้ = ็นๅพๅพ็"ๅฑ"ๆ"็ปดๅบฆ"
|
| 310 |
-
ๆฏไธช้้ๆฏไธไธช 2D ็ฉ้ต๏ผๅญๅจ็นๅฎ็็นๅพไฟกๆฏ
|
| 311 |
-
|
| 312 |
-
ไฝ็จ๏ผ
|
| 313 |
-
โ
1 ไธช้้ โ 1 ไธชๅท็งฏๆ ธ โ ๆฃๆต 1 ็ง็นๅพ
|
| 314 |
-
โ
32 ไธช้้ โ 32 ไธชๅท็งฏๆ ธ โ ๆฃๆต 32 ็งไธๅ็นๅพ
|
| 315 |
-
โ
้้ๆฐ่ถๅค โ ็นๅพ่ถไธฐๅฏ โ ่กจ่พพ่ฝๅ่ถๅผบ
|
| 316 |
-
|
| 317 |
-
ๆผๅ๏ผ
|
| 318 |
-
่พๅ
ฅ: 1 ้้๏ผ็ฐๅบฆๅพ๏ผๆ 3 ้้๏ผRGB๏ผ
|
| 319 |
-
โ
|
| 320 |
-
Conv1: 32 ้้๏ผๆฃๆตๅบ็ก็นๅพ๏ผ
|
| 321 |
-
โ
|
| 322 |
-
Conv2: 64 ้้๏ผๆฃๆตไธญ็บง็นๅพ๏ผ
|
| 323 |
-
โ
|
| 324 |
-
ๆดๅคๅฑ: 128, 256 ้้๏ผๆฃๆต้ซ็บงๆฝ่ฑก็นๅพ๏ผ
|
| 325 |
-
|
| 326 |
-
ๅ
ณ้ฎ็น๏ผ
|
| 327 |
-
๐ ๆฏไธช้้ = ไธไธช็ฌ็ซ็็นๅพๆฃๆตๅจ
|
| 328 |
-
๐ ๅค้้ๅนถ่กๅทฅไฝ๏ผไปไธๅ่งๅบฆ็่งฃๅพๅ
|
| 329 |
-
๐ ้้ๆฐๆฏๅฏไปฅ่ฎพ่ฎก็่ถ
ๅๆฐ
|
| 330 |
-
๐ ้ๅธธ้ๅฑๅขๅ ๏ผๅผฅ่กฅ็ฉบ้ดๅฐบๅฏธ็ๅๅฐ
|
| 331 |
-
|
| 332 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 333 |
-
|
| 334 |
-
1 ้้ โ 32 ้้ = ไปๅไธ่ง่ง โ ๅคไธชไธๅฎถ็็ปผๅ่ง่ง๏ผ
|
| 335 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 336 |
-
|
| 337 |
-
็ฎๅๆฅ่ฏด๏ผ้้ๅฐฑๅๅคไธชไธๅฎถๅๆถ็ๅไธๅผ ๅพๅ๏ผๆฏไธชไธๅฎถ๏ผ้้๏ผๅ
ณๆณจไธๅ็็นๅพ๏ผๆๅ็ปผๅๆๆไธๅฎถ็ๆ่ง๏ผๅพๅฐๅฏนๅพๅๆดๅ
จ้ข็็่งฃ๏ผ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vae_model_structure.py
DELETED
|
@@ -1,162 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class VAE(nn.Module):
|
| 7 |
-
"""ๅๅ่ช็ผ็ ๅจ"""
|
| 8 |
-
|
| 9 |
-
def __init__(self, latent_dim=20):
|
| 10 |
-
super(VAE, self).__init__()
|
| 11 |
-
|
| 12 |
-
# ============================================
|
| 13 |
-
# Encoder (็ผ็ ๅจ)
|
| 14 |
-
# ============================================
|
| 15 |
-
|
| 16 |
-
# ๅท็งฏๅฑ 1: 1โ32 channels, 28ร28โ14ร14
|
| 17 |
-
self.conv1 = nn.Conv2d(
|
| 18 |
-
in_channels=1,
|
| 19 |
-
out_channels=32,
|
| 20 |
-
kernel_size=4,
|
| 21 |
-
stride=2,
|
| 22 |
-
padding=1
|
| 23 |
-
)
|
| 24 |
-
|
| 25 |
-
# ๅท็งฏๅฑ 2: 32โ64 channels, 14ร14โ7ร7
|
| 26 |
-
self.conv2 = nn.Conv2d(
|
| 27 |
-
in_channels=32,
|
| 28 |
-
out_channels=64,
|
| 29 |
-
kernel_size=4,
|
| 30 |
-
stride=2,
|
| 31 |
-
padding=1
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
# ๅ
จ่ฟๆฅๅฑ: 3136โ256
|
| 35 |
-
self.fc1 = nn.Linear(64 * 7 * 7, 256)
|
| 36 |
-
|
| 37 |
-
# ๆฝๅจ็ฉบ้ดๅๆฏ
|
| 38 |
-
self.fc_mu = nn.Linear(256, latent_dim) # ๅๅผ
|
| 39 |
-
self.fc_logvar = nn.Linear(256, latent_dim) # ๅฏนๆฐๆนๅทฎ
|
| 40 |
-
|
| 41 |
-
# ============================================
|
| 42 |
-
# Decoder (่งฃ็ ๅจ)
|
| 43 |
-
# ============================================
|
| 44 |
-
|
| 45 |
-
# ๅ
จ่ฟๆฅๅฑ: 20โ256โ3136
|
| 46 |
-
self.fc2 = nn.Linear(latent_dim, 256)
|
| 47 |
-
self.fc3 = nn.Linear(256, 64 * 7 * 7)
|
| 48 |
-
|
| 49 |
-
# ่ฝฌ็ฝฎๅท็งฏ 1: 64โ32 channels, 7ร7โ14ร14
|
| 50 |
-
self.deconv1 = nn.ConvTranspose2d(
|
| 51 |
-
in_channels=64,
|
| 52 |
-
out_channels=32,
|
| 53 |
-
kernel_size=4,
|
| 54 |
-
stride=2,
|
| 55 |
-
padding=1
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
# ่ฝฌ็ฝฎๅท็งฏ 2: 32โ1 channels, 14ร14โ28ร28
|
| 59 |
-
self.deconv2 = nn.ConvTranspose2d(
|
| 60 |
-
in_channels=32,
|
| 61 |
-
out_channels=1,
|
| 62 |
-
kernel_size=4,
|
| 63 |
-
stride=2,
|
| 64 |
-
padding=1
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
def encode(self, x):
|
| 68 |
-
"""็ผ็ ๅจ: x โ ฮผ, log(ฯยฒ)"""
|
| 69 |
-
# x: (batch, 1, 28, 28)
|
| 70 |
-
|
| 71 |
-
h = F.relu(self.conv1(x)) # โ (batch, 32, 14, 14)
|
| 72 |
-
h = F.relu(self.conv2(h)) # โ (batch, 64, 7, 7)
|
| 73 |
-
h = h.view(-1, 64 * 7 * 7) # โ (batch, 3136)
|
| 74 |
-
h = F.relu(self.fc1(h)) # โ (batch, 256)
|
| 75 |
-
|
| 76 |
-
mu = self.fc_mu(h) # โ (batch, 20)
|
| 77 |
-
logvar = self.fc_logvar(h) # โ (batch, 20)
|
| 78 |
-
|
| 79 |
-
return mu, logvar
|
| 80 |
-
|
| 81 |
-
def reparameterize(self, mu, logvar):
|
| 82 |
-
"""้ๅๆฐๅ: z = ฮผ + ฯฮต"""
|
| 83 |
-
std = torch.exp(0.5 * logvar) # ฯ = exp(log(ฯยฒ)/2)
|
| 84 |
-
eps = torch.randn_like(std) # ฮต ~ N(0,1)
|
| 85 |
-
z = mu + eps * std # z = ฮผ + ฯฮต
|
| 86 |
-
return z
|
| 87 |
-
|
| 88 |
-
def decode(self, z):
|
| 89 |
-
"""่งฃ็ ๅจ: z โ x'"""
|
| 90 |
-
# z: (batch, 20)
|
| 91 |
-
|
| 92 |
-
h = F.relu(self.fc2(z)) # โ (batch, 256)
|
| 93 |
-
h = F.relu(self.fc3(h)) # โ (batch, 3136)
|
| 94 |
-
h = h.view(-1, 64, 7, 7) # โ (batch, 64, 7, 7)
|
| 95 |
-
h = F.relu(self.deconv1(h)) # โ (batch, 32, 14, 14)
|
| 96 |
-
x_recon = torch.sigmoid(self.deconv2(h)) # โ (batch, 1, 28, 28)
|
| 97 |
-
|
| 98 |
-
return x_recon
|
| 99 |
-
|
| 100 |
-
def forward(self, x):
|
| 101 |
-
"""ๅๅไผ ๆญ"""
|
| 102 |
-
mu, logvar = self.encode(x) # ็ผ็
|
| 103 |
-
z = self.reparameterize(mu, logvar) # ้ๆ ท
|
| 104 |
-
x_recon = self.decode(z) # ่งฃ็
|
| 105 |
-
return x_recon, mu, logvar
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
# ============================================
|
| 109 |
-
# ๆๅคฑๅฝๆฐ
|
| 110 |
-
# ============================================
|
| 111 |
-
|
| 112 |
-
def vae_loss(x_recon, x, mu, logvar):
|
| 113 |
-
"""
|
| 114 |
-
VAE ๆๅคฑ = ้ๅปบๆๅคฑ + KL ๆฃๅบฆ
|
| 115 |
-
|
| 116 |
-
Args:
|
| 117 |
-
x_recon: ้ๅปบๅพๅ (batch, 1, 28, 28)
|
| 118 |
-
x: ๅๅงๅพๅ (batch, 1, 28, 28)
|
| 119 |
-
mu: ๅๅผ (batch, latent_dim)
|
| 120 |
-
logvar: ๅฏนๆฐๆนๅทฎ (batch, latent_dim)
|
| 121 |
-
"""
|
| 122 |
-
# 1. ้ๅปบๆๅคฑ (Binary Cross Entropy)
|
| 123 |
-
# ่กก้้ๅปบๅพๅไธๅๅพ็ๅทฎๅผ
|
| 124 |
-
recon_loss = F.binary_cross_entropy(
|
| 125 |
-
x_recon, x, reduction='sum'
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
# 2. KL ๆฃๅบฆ (Kullback-Leibler Divergence)
|
| 129 |
-
# ่กก้ q(z|x) ไธๅ
้ช p(z)=N(0,1) ็ๅทฎๅผ
|
| 130 |
-
# KL(q||p) = -0.5 * ฮฃ(1 + log(ฯยฒ) - ฮผยฒ - ฯยฒ)
|
| 131 |
-
kl_loss = -0.5 * torch.sum(
|
| 132 |
-
1 + logvar - mu.pow(2) - logvar.exp()
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
# ๆปๆๅคฑ
|
| 136 |
-
total_loss = recon_loss + kl_loss
|
| 137 |
-
|
| 138 |
-
return total_loss, recon_loss, kl_loss
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
# ============================================
|
| 142 |
-
# ไฝฟ็จ็คบไพ
|
| 143 |
-
# ============================================
|
| 144 |
-
|
| 145 |
-
# ๅๅปบๆจกๅ
|
| 146 |
-
model = VAE(latent_dim=20)
|
| 147 |
-
|
| 148 |
-
# ่พๅ
ฅๅพๅ (batch_size=32, channels=1, height=28, width=28)
|
| 149 |
-
x = torch.randn(32, 1, 28, 28)
|
| 150 |
-
|
| 151 |
-
# ๅๅไผ ๆญ
|
| 152 |
-
x_recon, mu, logvar = model(x)
|
| 153 |
-
|
| 154 |
-
# ่ฎก็ฎๆๅคฑ
|
| 155 |
-
loss, recon_loss, kl_loss = vae_loss(x_recon, x, mu, logvar)
|
| 156 |
-
|
| 157 |
-
print(f"้ๅปบๅฝข็ถ: {x_recon.shape}") # (32, 1, 28, 28)
|
| 158 |
-
print(f"ฮผ ๅฝข็ถ: {mu.shape}") # (32, 20)
|
| 159 |
-
print(f"log(ฯยฒ) ๅฝข็ถ: {logvar.shape}") # (32, 20)
|
| 160 |
-
print(f"ๆปๆๅคฑ: {loss.item():.2f}")
|
| 161 |
-
print(f"้ๅปบๆๅคฑ: {recon_loss.item():.2f}")
|
| 162 |
-
print(f"KLๆฃๅบฆ: {kl_loss.item():.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vae_structrue.txt
DELETED
|
@@ -1,506 +0,0 @@
|
|
| 1 |
-
### VAE ๆจกๅๅฎๆดๆถๆ
|
| 2 |
-
VAE ๆดไฝ็ปๆ๏ผ
|
| 3 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 4 |
-
|
| 5 |
-
่พๅ
ฅๅพๅ (x)
|
| 6 |
-
โ
|
| 7 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 8 |
-
โ Encoder (็ผ็ ๅจ) โ
|
| 9 |
-
โ ๅฐ่พๅ
ฅๅ็ผฉๅฐๆฝๅจ็ฉบ้ด โ
|
| 10 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 11 |
-
โ
|
| 12 |
-
ๆฝๅจ่กจ็คบ (z) = ๅๅผ (ฮผ) + ๆ ๅๅทฎ (ฯ) ร ฮต
|
| 13 |
-
โ
|
| 14 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 15 |
-
โ Decoder (่งฃ็ ๅจ) โ
|
| 16 |
-
โ ไปๆฝๅจ็ฉบ้ด้ๅปบ่พๅ
ฅ โ
|
| 17 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 18 |
-
โ
|
| 19 |
-
้ๅปบๅพๅ (x')
|
| 20 |
-
|
| 21 |
-
ๆๅคฑ = ้ๅปบๆๅคฑ + KL ๆฃๅบฆ
|
| 22 |
-
|
| 23 |
-
่ฏฆ็ป Layer ็ปๆ่งฃๆ
|
| 24 |
-
|
| 25 |
-
1. Encoder (็ผ็ ๅจ)
|
| 26 |
-
|
| 27 |
-
Encoder ็ปๆ๏ผไปฅ MNIST 28ร28 ๅพๅไธบไพ๏ผ
|
| 28 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 29 |
-
|
| 30 |
-
่พๅ
ฅ: (batch_size, 1, 28, 28) # ็ฐๅบฆๅพๅ
|
| 31 |
-
|
| 32 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 33 |
-
โ Layer 1: Conv2d โ
|
| 34 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 35 |
-
โ ่พๅ
ฅ้้: 1 โ
|
| 36 |
-
โ ่พๅบ้้: 32 โ
|
| 37 |
-
โ ๅท็งฏๆ ธ: 4ร4 โ
|
| 38 |
-
โ ๆญฅ้ฟ: 2 โ
|
| 39 |
-
โ ๅกซๅ
: 1 โ
|
| 40 |
-
โ ่พๅบๅฝข็ถ: (batch, 32, 14, 14) โ
|
| 41 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 42 |
-
โ
|
| 43 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 44 |
-
โ Activation: ReLU โ
|
| 45 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 46 |
-
โ ่พๅบๅฝข็ถ: (batch, 32, 14, 14) โ
|
| 47 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 48 |
-
โ
|
| 49 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 50 |
-
โ Layer 2: Conv2d โ
|
| 51 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 52 |
-
โ ่พๅ
ฅ้้: 32 โ
|
| 53 |
-
โ ่พๅบ้้: 64 โ
|
| 54 |
-
โ ๅท็งฏๆ ธ: 4ร4 โ
|
| 55 |
-
โ ๆญฅ้ฟ: 2 โ
|
| 56 |
-
โ ๅกซๅ
: 1 โ
|
| 57 |
-
โ ่พๅบๅฝข็ถ: (batch, 64, 7, 7) โ
|
| 58 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 59 |
-
โ
|
| 60 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 61 |
-
โ Activation: ReLU โ
|
| 62 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 63 |
-
โ ่พๅบๅฝข็ถ: (batch, 64, 7, 7) โ
|
| 64 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 65 |
-
โ
|
| 66 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 67 |
-
โ Flatten โ
|
| 68 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 69 |
-
โ ่พๅ
ฅ: (batch, 64, 7, 7) โ
|
| 70 |
-
โ ่พๅบ: (batch, 3136) # 64ร7ร7=3136 โ
|
| 71 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 72 |
-
โ
|
| 73 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 74 |
-
โ Layer 3: Linear (ๅ
จ่ฟๆฅๅฑ) โ
|
| 75 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 76 |
-
โ ่พๅ
ฅ็ปดๅบฆ: 3136 โ
|
| 77 |
-
โ ่พๅบ็ปดๅบฆ: 256 โ
|
| 78 |
-
โ ่พๅบๅฝข็ถ: (batch, 256) โ
|
| 79 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 80 |
-
โ
|
| 81 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 82 |
-
โ Activation: ReLU โ
|
| 83 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 84 |
-
โ ่พๅบๅฝข็ถ: (batch, 256) โ
|
| 85 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 86 |
-
โ
|
| 87 |
-
โโโโโโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโ
|
| 88 |
-
โ โ โ
|
| 89 |
-
โโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโ
|
| 90 |
-
โ fc_mu โ โ fc_logvar โ โ (ไธคไธชๅๆฏ) โ
|
| 91 |
-
โ Linear โ โ Linear โ โ โ
|
| 92 |
-
โ 256 โ 20 โ โ 256 โ 20 โ โ ๆฝๅจ็ปดๅบฆ=20 โ
|
| 93 |
-
โ โ โ โ โ โ
|
| 94 |
-
โ ฮผ (ๅๅผ) โ โ log(ฯยฒ) โ โ โ
|
| 95 |
-
โ (batch, 20) โ โ (batch, 20) โ โ โ
|
| 96 |
-
โโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโ
|
| 97 |
-
โ โ
|
| 98 |
-
โโโโโโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 99 |
-
้ๅๆฐๅๆๅทง
|
| 100 |
-
z = ฮผ + ฯ ร ฮต
|
| 101 |
-
(batch, 20)
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
ๅๆฐ็ป่ฎก๏ผ
|
| 105 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 106 |
-
Conv1: 1ร32ร4ร4 + 32 bias = 544
|
| 107 |
-
Conv2: 32ร64ร4ร4 + 64 bias = 32,832
|
| 108 |
-
Linear: 3136ร256 + 256 bias = 803,072
|
| 109 |
-
fc_mu: 256ร20 + 20 bias = 5,140
|
| 110 |
-
fc_logvar: 256ร20 + 20 bias = 5,140
|
| 111 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 112 |
-
ๆป่ฎก: ~846,728 ๅๆฐ
|
| 113 |
-
|
| 114 |
-
2. ้ๅๆฐๅๆๅทง (Reparameterization Trick)
|
| 115 |
-
|
| 116 |
-
้ๅๆฐๅๅฑ๏ผ
|
| 117 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 118 |
-
|
| 119 |
-
่พๅ
ฅ: ฮผ (ๅๅผ), log(ฯยฒ) (ๅฏนๆฐๆนๅทฎ)
|
| 120 |
-
(batch, 20), (batch, 20)
|
| 121 |
-
|
| 122 |
-
ๆญฅ้ชค:
|
| 123 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 124 |
-
|
| 125 |
-
1. ่ฎก็ฎๆ ๅๅทฎ:
|
| 126 |
-
ฯ = exp(0.5 ร log(ฯยฒ))
|
| 127 |
-
= exp(log(ฯ))
|
| 128 |
-
= ฯ
|
| 129 |
-
|
| 130 |
-
2. ้ๆ ท้ๆบๅชๅฃฐ:
|
| 131 |
-
ฮต ~ N(0, 1) # ๆ ๅๆญฃๆๅๅธ
|
| 132 |
-
ๅฝข็ถ: (batch, 20)
|
| 133 |
-
|
| 134 |
-
3. ้ๅๆฐๅ:
|
| 135 |
-
z = ฮผ + ฯ ร ฮต
|
| 136 |
-
|
| 137 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 138 |
-
โ ไธบไปไน่ฟๆ ทๅ๏ผ โ
|
| 139 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 140 |
-
โ ็ดๆฅไป N(ฮผ, ฯยฒ) ้ๆ ทไธๅฏๅพฎๅ โ
|
| 141 |
-
โ ้่ฟ ฮต ~ N(0,1) ไฝฟๆขฏๅบฆๅฏไปฅๅไผ โ
|
| 142 |
-
โ ฮผ ๅ ฯ ้ฝๅฏไปฅ่ขซไผๅ โ
|
| 143 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 144 |
-
|
| 145 |
-
่พๅบ: z (ๆฝๅจๅ้)
|
| 146 |
-
(batch, 20)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
ไปฃ็ ๅฎ็ฐ:
|
| 150 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 151 |
-
|
| 152 |
-
def reparameterize(mu, logvar):
|
| 153 |
-
std = torch.exp(0.5 * logvar) # ฯ = exp(log(ฯยฒ)/2)
|
| 154 |
-
eps = torch.randn_like(std) # ฮต ~ N(0,1)
|
| 155 |
-
z = mu + eps * std # z = ฮผ + ฯฮต
|
| 156 |
-
return z
|
| 157 |
-
|
| 158 |
-
3. Decoder (่งฃ็ ๅจ)
|
| 159 |
-
|
| 160 |
-
Decoder ็ปๆ๏ผ
|
| 161 |
-
โโโโโโโโโโโโโโโโโ๏ฟฝ๏ฟฝ๏ฟฝโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 162 |
-
|
| 163 |
-
่พๅ
ฅ: z (ๆฝๅจๅ้)
|
| 164 |
-
(batch, 20)
|
| 165 |
-
|
| 166 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 167 |
-
โ Layer 1: Linear (ๅ
จ่ฟๆฅๅฑ) โ
|
| 168 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 169 |
-
โ ่พๅ
ฅ็ปดๅบฆ: 20 โ
|
| 170 |
-
โ ่พๅบ็ปดๅบฆ: 256 โ
|
| 171 |
-
โ ่พๅบๅฝข็ถ: (batch, 256) โ
|
| 172 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 173 |
-
โ
|
| 174 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 175 |
-
โ Activation: ReLU โ
|
| 176 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 177 |
-
โ ่พๅบๅฝข็ถ: (batch, 256) โ
|
| 178 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 179 |
-
โ
|
| 180 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 181 |
-
โ Layer 2: Linear โ
|
| 182 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 183 |
-
โ ่พๅ
ฅ็ปดๅบฆ: 256 โ
|
| 184 |
-
โ ่พๅบ็ปดๅบฆ: 3136 # 64ร7ร7 โ
|
| 185 |
-
โ ่พๅบๅฝข็ถ: (batch, 3136) โ
|
| 186 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 187 |
-
โ
|
| 188 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 189 |
-
โ Activation: ReLU โ
|
| 190 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 191 |
-
โ ่พๅบๅฝข็ถ: (batch, 3136) โ
|
| 192 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 193 |
-
โ
|
| 194 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 195 |
-
โ Reshape (Unflatten) โ
|
| 196 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 197 |
-
โ ่พๅ
ฅ: (batch, 3136) โ
|
| 198 |
-
โ ่พๅบ: (batch, 64, 7, 7) โ
|
| 199 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 200 |
-
โ
|
| 201 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 202 |
-
โ Layer 3: ConvTranspose2d (่ฝฌ็ฝฎๅท็งฏ/ไธ้ๆ ท) โ
|
| 203 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 204 |
-
โ ่พๅ
ฅ้้: 64 โ
|
| 205 |
-
โ ่พๅบ้้: 32 โ
|
| 206 |
-
โ ๅท็งฏๆ ธ: 4ร4 โ
|
| 207 |
-
โ ๆญฅ้ฟ: 2 โ
|
| 208 |
-
โ ๅกซๅ
: 1 โ
|
| 209 |
-
โ ่พๅบๅฝข็ถ: (batch, 32, 14, 14) โ
|
| 210 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 211 |
-
โ
|
| 212 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 213 |
-
โ Activation: ReLU โ
|
| 214 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 215 |
-
โ ่พๅบๅฝข็ถ: (batch, 32, 14, 14) โ
|
| 216 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 217 |
-
โ
|
| 218 |
-
โโโ๏ฟฝ๏ฟฝ๏ฟฝโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 219 |
-
โ Layer 4: ConvTranspose2d โ
|
| 220 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 221 |
-
โ ่พๅ
ฅ้้: 32 โ
|
| 222 |
-
โ ่พๅบ้้: 1 โ
|
| 223 |
-
โ ๅท็งฏๆ ธ: 4ร4 โ
|
| 224 |
-
โ ๆญฅ้ฟ: 2 โ
|
| 225 |
-
โ ๅกซๅ
: 1 โ
|
| 226 |
-
โ ่พๅบๅฝข็ถ: (batch, 1, 28, 28) โ
|
| 227 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 228 |
-
โ
|
| 229 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 230 |
-
โ Activation: Sigmoid โ
|
| 231 |
-
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 232 |
-
โ ไฝ็จ: ๅฐ่พๅบๅ็ผฉๅฐ [0, 1] ่ๅด โ
|
| 233 |
-
โ ่พๅบๅฝข็ถ: (batch, 1, 28, 28) โ
|
| 234 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 235 |
-
โ
|
| 236 |
-
่พๅบ: ้ๅปบๅพๅ x'
|
| 237 |
-
(batch, 1, 28, 28)
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
ๅๆฐ็ป่ฎก๏ผ
|
| 241 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 242 |
-
Linear1: 20ร256 + 256 bias = 5,376
|
| 243 |
-
Linear2: 256ร3136 + 3136 bias = 806,016
|
| 244 |
-
ConvTranspose1: 64ร32ร4ร4 + 32 bias = 32,800
|
| 245 |
-
ConvTranspose2: 32ร1ร4ร4 + 1 bias = 513
|
| 246 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 247 |
-
ๆป่ฎก: ~844,705 ๅๆฐ
|
| 248 |
-
|
| 249 |
-
ๅฎๆด PyTorch ๅฎ็ฐ
|
| 250 |
-
|
| 251 |
-
import torch
|
| 252 |
-
import torch.nn as nn
|
| 253 |
-
import torch.nn.functional as F
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
class VAE(nn.Module):
|
| 257 |
-
"""ๅๅ่ช็ผ็ ๅจ"""
|
| 258 |
-
|
| 259 |
-
def __init__(self, latent_dim=20):
|
| 260 |
-
super(VAE, self).__init__()
|
| 261 |
-
|
| 262 |
-
# ============================================
|
| 263 |
-
# Encoder (็ผ็ ๅจ)
|
| 264 |
-
# ============================================
|
| 265 |
-
|
| 266 |
-
# ๅท็งฏๅฑ 1: 1โ32 channels, 28ร28โ14ร14
|
| 267 |
-
self.conv1 = nn.Conv2d(
|
| 268 |
-
in_channels=1,
|
| 269 |
-
out_channels=32,
|
| 270 |
-
kernel_size=4,
|
| 271 |
-
stride=2,
|
| 272 |
-
padding=1
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
# ๅท็งฏๅฑ 2: 32โ64 channels, 14ร14โ7ร7
|
| 276 |
-
self.conv2 = nn.Conv2d(
|
| 277 |
-
in_channels=32,
|
| 278 |
-
out_channels=64,
|
| 279 |
-
kernel_size=4,
|
| 280 |
-
stride=2,
|
| 281 |
-
padding=1
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
# ๅ
จ่ฟๆฅๅฑ: 3136โ256
|
| 285 |
-
self.fc1 = nn.Linear(64 * 7 * 7, 256)
|
| 286 |
-
|
| 287 |
-
# ๆฝๅจ็ฉบ้ดๅๆฏ
|
| 288 |
-
self.fc_mu = nn.Linear(256, latent_dim) # ๅๅผ
|
| 289 |
-
self.fc_logvar = nn.Linear(256, latent_dim) # ๅฏนๆฐๆนๅทฎ
|
| 290 |
-
|
| 291 |
-
# ============================================
|
| 292 |
-
# Decoder (่งฃ็ ๅจ)
|
| 293 |
-
# ============================================
|
| 294 |
-
|
| 295 |
-
# ๅ
จ่ฟๆฅๅฑ: 20โ256โ3136
|
| 296 |
-
self.fc2 = nn.Linear(latent_dim, 256)
|
| 297 |
-
self.fc3 = nn.Linear(256, 64 * 7 * 7)
|
| 298 |
-
|
| 299 |
-
# ่ฝฌ็ฝฎๅท็งฏ 1: 64โ32 channels, 7ร7โ14ร14
|
| 300 |
-
self.deconv1 = nn.ConvTranspose2d(
|
| 301 |
-
in_channels=64,
|
| 302 |
-
out_channels=32,
|
| 303 |
-
kernel_size=4,
|
| 304 |
-
stride=2,
|
| 305 |
-
padding=1
|
| 306 |
-
)
|
| 307 |
-
|
| 308 |
-
# ่ฝฌ็ฝฎๅท็งฏ 2: 32โ1 channels, 14ร14โ28ร28
|
| 309 |
-
self.deconv2 = nn.ConvTranspose2d(
|
| 310 |
-
in_channels=32,
|
| 311 |
-
out_channels=1,
|
| 312 |
-
kernel_size=4,
|
| 313 |
-
stride=2,
|
| 314 |
-
padding=1
|
| 315 |
-
)
|
| 316 |
-
|
| 317 |
-
def encode(self, x):
|
| 318 |
-
"""็ผ็ ๅจ: x โ ฮผ, log(ฯยฒ)"""
|
| 319 |
-
# x: (batch, 1, 28, 28)
|
| 320 |
-
|
| 321 |
-
h = F.relu(self.conv1(x)) # โ (batch, 32, 14, 14)
|
| 322 |
-
h = F.relu(self.conv2(h)) # โ (batch, 64, 7, 7)
|
| 323 |
-
h = h.view(-1, 64 * 7 * 7) # โ (batch, 3136)
|
| 324 |
-
h = F.relu(self.fc1(h)) # โ (batch, 256)
|
| 325 |
-
|
| 326 |
-
mu = self.fc_mu(h) # โ (batch, 20)
|
| 327 |
-
logvar = self.fc_logvar(h) # โ (batch, 20)
|
| 328 |
-
|
| 329 |
-
return mu, logvar
|
| 330 |
-
|
| 331 |
-
def reparameterize(self, mu, logvar):
|
| 332 |
-
"""้ๅๆฐๅ: z = ฮผ + ฯฮต"""
|
| 333 |
-
std = torch.exp(0.5 * logvar) # ฯ = exp(log(ฯยฒ)/2)
|
| 334 |
-
eps = torch.randn_like(std) # ฮต ~ N(0,1)
|
| 335 |
-
z = mu + eps * std # z = ฮผ + ฯฮต
|
| 336 |
-
return z
|
| 337 |
-
|
| 338 |
-
def decode(self, z):
|
| 339 |
-
"""่งฃ็ ๅจ: z โ x'"""
|
| 340 |
-
# z: (batch, 20)
|
| 341 |
-
|
| 342 |
-
h = F.relu(self.fc2(z)) # โ (batch, 256)
|
| 343 |
-
h = F.relu(self.fc3(h)) # โ (batch, 3136)
|
| 344 |
-
h = h.view(-1, 64, 7, 7) # โ (batch, 64, 7, 7)
|
| 345 |
-
h = F.relu(self.deconv1(h)) # โ (batch, 32, 14, 14)
|
| 346 |
-
x_recon = torch.sigmoid(self.deconv2(h)) # โ (batch, 1, 28, 28)
|
| 347 |
-
|
| 348 |
-
return x_recon
|
| 349 |
-
|
| 350 |
-
def forward(self, x):
|
| 351 |
-
"""ๅๅไผ ๆญ"""
|
| 352 |
-
mu, logvar = self.encode(x) # ็ผ็
|
| 353 |
-
z = self.reparameterize(mu, logvar) # ้ๆ ท
|
| 354 |
-
x_recon = self.decode(z) # ่งฃ็
|
| 355 |
-
return x_recon, mu, logvar
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
# ============================================
|
| 359 |
-
# ๆๅคฑๅฝๆฐ
|
| 360 |
-
# ============================================
|
| 361 |
-
|
| 362 |
-
def vae_loss(x_recon, x, mu, logvar):
|
| 363 |
-
"""
|
| 364 |
-
VAE ๆๅคฑ = ้ๅปบๆๅคฑ + KL ๆฃๅบฆ
|
| 365 |
-
|
| 366 |
-
Args:
|
| 367 |
-
x_recon: ้ๅปบๅพๅ (batch, 1, 28, 28)
|
| 368 |
-
x: ๅๅงๅพๅ (batch, 1, 28, 28)
|
| 369 |
-
mu: ๅๅผ (batch, latent_dim)
|
| 370 |
-
logvar: ๅฏนๆฐๆนๅทฎ (batch, latent_dim)
|
| 371 |
-
"""
|
| 372 |
-
# 1. ้ๅปบๆๅคฑ (Binary Cross Entropy)
|
| 373 |
-
# ่กก้้ๅปบๅพๅไธๅๅพ็ๅทฎๅผ
|
| 374 |
-
recon_loss = F.binary_cross_entropy(
|
| 375 |
-
x_recon, x, reduction='sum'
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
# 2. KL ๆฃๅบฆ (Kullback-Leibler Divergence)
|
| 379 |
-
# ่กก้ q(z|x) ไธๅ
้ช p(z)=N(0,1) ็ๅทฎๅผ
|
| 380 |
-
# KL(q||p) = -0.5 * ฮฃ(1 + log(ฯยฒ) - ฮผยฒ - ฯยฒ)
|
| 381 |
-
kl_loss = -0.5 * torch.sum(
|
| 382 |
-
1 + logvar - mu.pow(2) - logvar.exp()
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
# ๆปๆๅคฑ
|
| 386 |
-
total_loss = recon_loss + kl_loss
|
| 387 |
-
|
| 388 |
-
return total_loss, recon_loss, kl_loss
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
# ============================================
|
| 392 |
-
# ไฝฟ็จ็คบไพ
|
| 393 |
-
# ============================================
|
| 394 |
-
|
| 395 |
-
# ๅๅปบๆจกๅ
|
| 396 |
-
model = VAE(latent_dim=20)
|
| 397 |
-
|
| 398 |
-
# ่พๅ
ฅๅพๅ (batch_size=32, channels=1, height=28, width=28)
|
| 399 |
-
x = torch.randn(32, 1, 28, 28)
|
| 400 |
-
|
| 401 |
-
# ๅๅไผ ๆญ
|
| 402 |
-
x_recon, mu, logvar = model(x)
|
| 403 |
-
|
| 404 |
-
# ่ฎก็ฎๆๅคฑ
|
| 405 |
-
loss, recon_loss, kl_loss = vae_loss(x_recon, x, mu, logvar)
|
| 406 |
-
|
| 407 |
-
print(f"้ๅปบๅฝข็ถ: {x_recon.shape}") # (32, 1, 28, 28)
|
| 408 |
-
print(f"ฮผ ๅฝข็ถ: {mu.shape}") # (32, 20)
|
| 409 |
-
print(f"log(ฯยฒ) ๅฝข็ถ: {logvar.shape}") # (32, 20)
|
| 410 |
-
print(f"ๆปๆๅคฑ: {loss.item():.2f}")
|
| 411 |
-
print(f"้ๅปบๆๅคฑ: {recon_loss.item():.2f}")
|
| 412 |
-
print(f"KLๆฃๅบฆ: {kl_loss.item():.2f}")
|
| 413 |
-
|
| 414 |
-
#### ๅฑ็บงๆฐๆฎๆตๅจ่ฏฆ่งฃ
|
| 415 |
-
|
| 416 |
-
ๅฎๆดๆฐๆฎๆตๅจ๏ผ
|
| 417 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 418 |
-
|
| 419 |
-
่พๅ
ฅๅพๅ x:
|
| 420 |
-
(32, 1, 28, 28) # batch=32, ็ฐๅบฆๅพ, 28ร28ๅ็ด
|
| 421 |
-
โ
|
| 422 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 423 |
-
Encoder
|
| 424 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 425 |
-
Conv1 + ReLU:
|
| 426 |
-
(32, 1, 28, 28) โ (32, 32, 14, 14)
|
| 427 |
-
โ
|
| 428 |
-
Conv2 + ReLU:
|
| 429 |
-
(32, 32, 14, 14) โ (32, 64, 7, 7)
|
| 430 |
-
โ
|
| 431 |
-
Flatten:
|
| 432 |
-
(32, 64, 7, 7) โ (32, 3136)
|
| 433 |
-
โ
|
| 434 |
-
FC1 + ReLU:
|
| 435 |
-
(32, 3136) โ (32, 256)
|
| 436 |
-
โ
|
| 437 |
-
โโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโ
|
| 438 |
-
fc_mu fc_logvar
|
| 439 |
-
(32, 256)โ(32,20) (32, 256)โ(32,20)
|
| 440 |
-
โ โ
|
| 441 |
-
ฮผ log(ฯยฒ)
|
| 442 |
-
|
| 443 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 444 |
-
Reparameterization
|
| 445 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 446 |
-
z = ฮผ + exp(0.5รlog(ฯยฒ)) ร ฮต
|
| 447 |
-
โ
|
| 448 |
-
(32, 20) # ๆฝๅจๅ้
|
| 449 |
-
|
| 450 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 451 |
-
Decoder
|
| 452 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 453 |
-
FC2 + ReLU:
|
| 454 |
-
(32, 20) โ (32, 256)
|
| 455 |
-
โ
|
| 456 |
-
FC3 + ReLU:
|
| 457 |
-
(32, 256) โ (32, 3136)
|
| 458 |
-
โ
|
| 459 |
-
Reshape:
|
| 460 |
-
(32, 3136) โ (32, 64, 7, 7)
|
| 461 |
-
โ
|
| 462 |
-
ConvTranspose1 + ReLU:
|
| 463 |
-
(32, 64, 7, 7) โ (32, 32, 14, 14)
|
| 464 |
-
โ
|
| 465 |
-
ConvTranspose2 + Sigmoid:
|
| 466 |
-
(32, 32, 14, 14) โ (32, 1, 28, 28)
|
| 467 |
-
โ
|
| 468 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 469 |
-
|
| 470 |
-
่พๅบ้ๅปบๅพๅ x':
|
| 471 |
-
(32, 1, 28, 28) # ไธ่พๅ
ฅ็ธๅๅฝข็ถ
|
| 472 |
-
|
| 473 |
-
ๅ
ณ้ฎ่ฎพ่ฎก่ฆ็น
|
| 474 |
-
|
| 475 |
-
VAE ็ๆ ธๅฟ่ฎพ่ฎก๏ผ
|
| 476 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 477 |
-
|
| 478 |
-
1. ๅฏน็งฐ็ Encoder-Decoder ็ปๆ
|
| 479 |
-
โ
Encoder ๅ็ผฉ: 28ร28 โ 20
|
| 480 |
-
โ
Decoder ่งฃๅ: 20 โ 28ร28
|
| 481 |
-
|
| 482 |
-
2. ๆฝๅจ็ฉบ้ด็ๆฆ็ๆง
|
| 483 |
-
โ
ไธๆฏ็กฎๅฎๆง็ๅ้๏ผ่ๆฏๅๅธ N(ฮผ, ฯยฒ)
|
| 484 |
-
โ
ๅ
่ฎธๅนณๆปๆๅผๅ้ๆ ทๆฐๆ ทๆฌ
|
| 485 |
-
|
| 486 |
-
3. ้ๅๆฐๅๆๅทง
|
| 487 |
-
โ
ไฝฟ้ๆบ้ๆ ท่ฟ็จๅฏๅพฎๅ
|
| 488 |
-
โ
ๅ
่ฎธ๏ฟฝ๏ฟฝ๏ฟฝๅบฆๅไผ
|
| 489 |
-
|
| 490 |
-
4. ๅ้ๆๅคฑ
|
| 491 |
-
โ
้ๅปบๆๅคฑ: ็กฎไฟ้ๅปบ่ดจ้
|
| 492 |
-
โ
KL ๆฃๅบฆ: ๆญฃๅๅๆฝๅจ็ฉบ้ด
|
| 493 |
-
|
| 494 |
-
5. ไธ BERT ็ๅฏนๆฏ
|
| 495 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 496 |
-
BERT: ๅๅ็ผ็ ๅจ๏ผ็จไบ็่งฃ
|
| 497 |
-
VAE: ็ผ็ -่งฃ็ ๅจ๏ผ็จไบ็ๆ
|
| 498 |
-
|
| 499 |
-
BERT: ็กฎๅฎๆง่พๅบ (ๅบๅฎ็ embedding)
|
| 500 |
-
VAE: ๆฆ็ๆง่พๅบ (ไปๅๅธไธญ้ๆ ท)
|
| 501 |
-
|
| 502 |
-
BERT: [CLS] ่ๅๅ
จๅฑ่ฏญไน
|
| 503 |
-
VAE: z ็ผ็ ๆฝๅจ็นๅพ๏ผๅฏ็ๆๆฐๆ ทๆฌ
|
| 504 |
-
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vae_training_example.py
DELETED
|
@@ -1,239 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
VAE่ฎญ็ป่ฟ็จ่ฏฆ็ป็คบไพ
|
| 3 |
-
ๅฑ็คบVAEๅจMNISTๆฐๆฎ้ไธ็ๅฎๆด่ฎญ็ปๆต็จ
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.optim as optim
|
| 9 |
-
import torchvision
|
| 10 |
-
import torchvision.transforms as transforms
|
| 11 |
-
from torch.utils.data import DataLoader
|
| 12 |
-
import matplotlib.pyplot as plt
|
| 13 |
-
import numpy as np
|
| 14 |
-
|
| 15 |
-
from vae_model_structure import VAE, VAELoss
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class VAEVisualizer:
|
| 19 |
-
"""VAEๅฏ่งๅๅทฅๅ
ท็ฑป"""
|
| 20 |
-
|
| 21 |
-
@staticmethod
|
| 22 |
-
def plot_reconstruction(original, reconstructed, epoch):
|
| 23 |
-
"""็ปๅถๅๅงๅพๅๅ้ๅปบๅพๅ็ๅฏนๆฏ"""
|
| 24 |
-
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
|
| 25 |
-
|
| 26 |
-
for i in range(10):
|
| 27 |
-
# ๅๅงๅพๅ
|
| 28 |
-
axes[0, i].imshow(original[i].cpu().detach().numpy().reshape(28, 28), cmap='gray')
|
| 29 |
-
axes[0, i].set_title('Original')
|
| 30 |
-
axes[0, i].axis('off')
|
| 31 |
-
|
| 32 |
-
# ้ๅปบๅพๅ
|
| 33 |
-
axes[1, i].imshow(reconstructed[i].cpu().detach().numpy().reshape(28, 28), cmap='gray')
|
| 34 |
-
axes[1, i].set_title('Reconstructed')
|
| 35 |
-
axes[1, i].axis('off')
|
| 36 |
-
|
| 37 |
-
plt.suptitle(f'Epoch {epoch} - Reconstruction Comparison')
|
| 38 |
-
plt.tight_layout()
|
| 39 |
-
plt.show()
|
| 40 |
-
|
| 41 |
-
@staticmethod
|
| 42 |
-
def plot_latent_space(model, test_loader, device):
|
| 43 |
-
"""ๅฏ่งๅๆฝๅจ็ฉบ้ด"""
|
| 44 |
-
model.eval()
|
| 45 |
-
|
| 46 |
-
latent_vectors = []
|
| 47 |
-
labels = []
|
| 48 |
-
|
| 49 |
-
with torch.no_grad():
|
| 50 |
-
for data, target in test_loader:
|
| 51 |
-
data = data.to(device)
|
| 52 |
-
mu, _ = model.encoder(data.view(-1, 784))
|
| 53 |
-
latent_vectors.append(mu.cpu().numpy())
|
| 54 |
-
labels.append(target.numpy())
|
| 55 |
-
|
| 56 |
-
latent_vectors = np.concatenate(latent_vectors)
|
| 57 |
-
labels = np.concatenate(labels)
|
| 58 |
-
|
| 59 |
-
plt.figure(figsize=(10, 8))
|
| 60 |
-
scatter = plt.scatter(latent_vectors[:, 0], latent_vectors[:, 1],
|
| 61 |
-
c=labels, cmap='tab10', alpha=0.6)
|
| 62 |
-
plt.colorbar(scatter)
|
| 63 |
-
plt.title('2D Latent Space Visualization')
|
| 64 |
-
plt.xlabel('Latent Dimension 1')
|
| 65 |
-
plt.ylabel('Latent Dimension 2')
|
| 66 |
-
plt.show()
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class VAETrainer:
|
| 70 |
-
"""VAE่ฎญ็ปๅจ็ฑป"""
|
| 71 |
-
|
| 72 |
-
def __init__(self, model, train_loader, test_loader, device, learning_rate=1e-3):
|
| 73 |
-
self.model = model.to(device)
|
| 74 |
-
self.train_loader = train_loader
|
| 75 |
-
self.test_loader = test_loader
|
| 76 |
-
self.device = device
|
| 77 |
-
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
|
| 78 |
-
|
| 79 |
-
# ่ฎฐๅฝ่ฎญ็ปๅๅฒ
|
| 80 |
-
self.train_losses = []
|
| 81 |
-
self.test_losses = []
|
| 82 |
-
|
| 83 |
-
def train_epoch(self, epoch):
|
| 84 |
-
"""่ฎญ็ปไธไธชepoch"""
|
| 85 |
-
self.model.train()
|
| 86 |
-
train_loss = 0
|
| 87 |
-
|
| 88 |
-
for batch_idx, (data, _) in enumerate(self.train_loader):
|
| 89 |
-
data = data.to(self.device)
|
| 90 |
-
self.optimizer.zero_grad()
|
| 91 |
-
|
| 92 |
-
# ๅๅไผ ๆญ
|
| 93 |
-
recon_batch, mu, logvar = self.model(data.view(-1, 784))
|
| 94 |
-
|
| 95 |
-
# ่ฎก็ฎๆๅคฑ
|
| 96 |
-
loss = VAELoss.loss_function(recon_batch, data.view(-1, 784), mu, logvar)
|
| 97 |
-
|
| 98 |
-
# ๅๅไผ ๆญ
|
| 99 |
-
loss.backward()
|
| 100 |
-
train_loss += loss.item()
|
| 101 |
-
self.optimizer.step()
|
| 102 |
-
|
| 103 |
-
if batch_idx % 100 == 0:
|
| 104 |
-
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(self.train_loader.dataset)} '
|
| 105 |
-
f'({100. * batch_idx / len(self.train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')
|
| 106 |
-
|
| 107 |
-
avg_loss = train_loss / len(self.train_loader.dataset)
|
| 108 |
-
self.train_losses.append(avg_loss)
|
| 109 |
-
return avg_loss
|
| 110 |
-
|
| 111 |
-
def test_epoch(self, epoch):
|
| 112 |
-
"""ๆต่ฏไธไธชepoch"""
|
| 113 |
-
self.model.eval()
|
| 114 |
-
test_loss = 0
|
| 115 |
-
|
| 116 |
-
with torch.no_grad():
|
| 117 |
-
for data, _ in self.test_loader:
|
| 118 |
-
data = data.to(self.device)
|
| 119 |
-
recon_batch, mu, logvar = self.model(data.view(-1, 784))
|
| 120 |
-
test_loss += VAELoss.loss_function(recon_batch, data.view(-1, 784), mu, logvar).item()
|
| 121 |
-
|
| 122 |
-
avg_loss = test_loss / len(self.test_loader.dataset)
|
| 123 |
-
self.test_losses.append(avg_loss)
|
| 124 |
-
|
| 125 |
-
print(f'====> Test set loss: {avg_loss:.4f}')
|
| 126 |
-
return avg_loss
|
| 127 |
-
|
| 128 |
-
def train(self, epochs=10):
|
| 129 |
-
"""ๅฎๆด่ฎญ็ป่ฟ็จ"""
|
| 130 |
-
print("๐ ๅผๅงVAE่ฎญ็ป...")
|
| 131 |
-
|
| 132 |
-
for epoch in range(1, epochs + 1):
|
| 133 |
-
train_loss = self.train_epoch(epoch)
|
| 134 |
-
test_loss = self.test_epoch(epoch)
|
| 135 |
-
|
| 136 |
-
# ๆฏ5ไธชepochๅฏ่งๅไธๆฌก
|
| 137 |
-
if epoch % 5 == 0:
|
| 138 |
-
self.visualize_reconstruction(epoch)
|
| 139 |
-
|
| 140 |
-
print("โ
่ฎญ็ปๅฎๆ๏ผ")
|
| 141 |
-
self.plot_training_history()
|
| 142 |
-
|
| 143 |
-
def visualize_reconstruction(self, epoch):
|
| 144 |
-
"""ๅฏ่งๅ้ๅปบ็ปๆ"""
|
| 145 |
-
self.model.eval()
|
| 146 |
-
|
| 147 |
-
with torch.no_grad():
|
| 148 |
-
# ่ทๅไธๆนๆต่ฏๆฐๆฎ
|
| 149 |
-
data_iter = iter(self.test_loader)
|
| 150 |
-
test_data, _ = next(data_iter)
|
| 151 |
-
test_data = test_data.to(self.device)
|
| 152 |
-
|
| 153 |
-
# ้ๅปบ
|
| 154 |
-
recon_batch, _, _ = self.model(test_data.view(-1, 784))
|
| 155 |
-
|
| 156 |
-
# ๅฏ่งๅ
|
| 157 |
-
VAEVisualizer.plot_reconstruction(test_data.view(-1, 784)[:10],
|
| 158 |
-
recon_batch[:10], epoch)
|
| 159 |
-
|
| 160 |
-
def plot_training_history(self):
|
| 161 |
-
"""็ปๅถ่ฎญ็ปๅๅฒ"""
|
| 162 |
-
plt.figure(figsize=(10, 6))
|
| 163 |
-
plt.plot(self.train_losses, label='Train Loss')
|
| 164 |
-
plt.plot(self.test_losses, label='Test Loss')
|
| 165 |
-
plt.xlabel('Epoch')
|
| 166 |
-
plt.ylabel('Loss')
|
| 167 |
-
plt.title('VAE Training History')
|
| 168 |
-
plt.legend()
|
| 169 |
-
plt.grid(True)
|
| 170 |
-
plt.show()
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
# ============================================================================
|
| 174 |
-
# ๆฐๆฎๅๅค
|
| 175 |
-
# ============================================================================
|
| 176 |
-
|
| 177 |
-
def load_mnist_data(batch_size=128):
|
| 178 |
-
"""ๅ ่ฝฝMNISTๆฐๆฎ้"""
|
| 179 |
-
|
| 180 |
-
transform = transforms.Compose([
|
| 181 |
-
transforms.ToTensor(),
|
| 182 |
-
])
|
| 183 |
-
|
| 184 |
-
# ่ฎญ็ป้
|
| 185 |
-
train_dataset = torchvision.datasets.MNIST(
|
| 186 |
-
root='./data',
|
| 187 |
-
train=True,
|
| 188 |
-
download=True,
|
| 189 |
-
transform=transform
|
| 190 |
-
)
|
| 191 |
-
|
| 192 |
-
# ๆต่ฏ้
|
| 193 |
-
test_dataset = torchvision.datasets.MNIST(
|
| 194 |
-
root='./data',
|
| 195 |
-
train=False,
|
| 196 |
-
download=True,
|
| 197 |
-
transform=transform
|
| 198 |
-
)
|
| 199 |
-
|
| 200 |
-
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 201 |
-
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 202 |
-
|
| 203 |
-
return train_loader, test_loader
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
# ============================================================================
|
| 207 |
-
# ไธปๅฝๆฐ
|
| 208 |
-
# ============================================================================
|
| 209 |
-
|
| 210 |
-
def main():
|
| 211 |
-
"""ไธป่ฎญ็ปๅฝๆฐ"""
|
| 212 |
-
|
| 213 |
-
# ่ฎพ็ฝฎ่ฎพๅค
|
| 214 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 215 |
-
print(f"ไฝฟ็จ่ฎพๅค: {device}")
|
| 216 |
-
|
| 217 |
-
# ๅ ่ฝฝๆฐๆฎ
|
| 218 |
-
print("๐ฆ ๅ ่ฝฝMNISTๆฐๆฎ้...")
|
| 219 |
-
train_loader, test_loader = load_mnist_data()
|
| 220 |
-
|
| 221 |
-
# ๅๅปบVAEๆจกๅ
|
| 222 |
-
print("๐๏ธ ๅๅปบVAEๆจกๅ...")
|
| 223 |
-
model = VAE(input_dim=784, hidden_dims=[512, 256], latent_dim=20)
|
| 224 |
-
|
| 225 |
-
# ๅๅปบ่ฎญ็ปๅจ
|
| 226 |
-
trainer = VAETrainer(model, train_loader, test_loader, device)
|
| 227 |
-
|
| 228 |
-
# ๅผๅง่ฎญ็ป
|
| 229 |
-
trainer.train(epochs=10)
|
| 230 |
-
|
| 231 |
-
# ๅฏ่งๅๆฝๅจ็ฉบ้ด
|
| 232 |
-
print("\n๐ ๅฏ่งๅๆฝๅจ็ฉบ้ด...")
|
| 233 |
-
VAEVisualizer.plot_latent_space(model, test_loader, device)
|
| 234 |
-
|
| 235 |
-
return model, trainer
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
if __name__ == "__main__":
|
| 239 |
-
model, trainer = main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|