lanny xu commited on
Commit
447a3ac
·
1 Parent(s): d3957d7

add VAE describe

Browse files
Files changed (1) hide show
  1. vae_model_structure.py +135 -163
vae_model_structure.py CHANGED
@@ -1,190 +1,162 @@
1
- """
2
- VAE(变分自编码器)模型完整结构解析
3
- 包含编码器、解码器、重参数化技巧和损失函数
4
- """
5
-
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
- import numpy as np
10
 
11
 
12
- class Encoder(nn.Module):
13
- """编码器:将输入数据映射到潜在空间的均值和方差"""
14
 
15
- def __init__(self, input_dim=784, hidden_dims=[512, 256], latent_dim=20):
16
- super(Encoder, self).__init__()
17
-
18
- # 第1层:输入层 → 第一个隐藏层
19
- self.fc1 = nn.Linear(input_dim, hidden_dims[0])
20
-
21
- # 第2层:第一个隐藏层 → 第二个隐藏层
22
- self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
23
-
24
- # 第3层:第二个隐藏层 → 潜在空间均值
25
- self.fc_mu = nn.Linear(hidden_dims[1], latent_dim)
26
-
27
- # 第4层:第二个隐藏层 → 潜在空间对数方差
28
- self.fc_logvar = nn.Linear(hidden_dims[1], latent_dim)
29
-
30
- def forward(self, x):
31
- print("\n🔍 编码器前向传播过程:")
32
- print(f"输入形状: {x.shape}")
33
-
34
- # Layer 1: 输入 → 隐藏层1
35
- h1 = F.relu(self.fc1(x))
36
- print(f"Layer 1 后: {h1.shape}")
37
-
38
- # Layer 2: 隐藏层1 → 隐藏层2
39
- h2 = F.relu(self.fc2(h1))
40
- print(f"Layer 2 后: {h2.shape}")
41
-
42
- # Layer 3: 计算均值 μ
43
- mu = self.fc_mu(h2)
44
- print(f"均值 μ 形状: {mu.shape}")
45
-
46
- # Layer 4: 计算对数方差 log(σ²)
47
- logvar = self.fc_logvar(h2)
48
- print(f"对数方差 logvar 形状: {logvar.shape}")
49
 
50
- return mu, logvar
51
-
52
-
53
- class Decoder(nn.Module):
54
- """解码器:从潜在空间重建原始数据"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- def __init__(self, latent_dim=20, hidden_dims=[256, 512], output_dim=784):
57
- super(Decoder, self).__init__()
58
-
59
- # 第1层:潜在空间 → 第一个隐藏层
60
- self.fc1 = nn.Linear(latent_dim, hidden_dims[0])
61
-
62
- # 第2层:第一个隐藏层 → 第二个隐藏层
63
- self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
64
-
65
- # 第3层:第二个隐藏层 → 输出层
66
- self.fc3 = nn.Linear(hidden_dims[1], output_dim)
67
-
68
- def forward(self, z):
69
- print("\n🔧 解码器前向传播过程:")
70
- print(f"潜在变量 z 形状: {z.shape}")
71
 
72
- # Layer 1: 潜在空间 隐藏层1
73
- h1 = F.relu(self.fc1(z))
74
- print(f"Layer 1 后: {h1.shape}")
 
75
 
76
- # Layer 2: 隐藏层1 隐藏层2
77
- h2 = F.relu(self.fc2(h1))
78
- print(f"Layer 2 后: {h2.shape}")
79
 
80
- # Layer 3: 隐藏层2 → 输出(使用sigmoid确保值在[0,1])
81
- recon_x = torch.sigmoid(self.fc3(h2))
82
- print(f"重建输出形状: {recon_x.shape}")
83
-
84
- return recon_x
85
-
86
-
87
- class Reparameterization:
88
- """重参数化技巧:从N(μ, σ²)采样,同时保持梯度可传播"""
89
 
90
- @staticmethod
91
- def reparameterize(mu, logvar):
92
- print("\n🔄 重参数化过程:")
93
- print(f"输入均值 μ: {mu.shape}, 对数方差 logvar: {logvar.shape}")
94
-
95
- # 计算标准差 σ = exp(0.5 * log(σ²))
96
- std = torch.exp(0.5 * logvar)
97
- print(f"标准差 σ 形状: {std.shape}")
98
-
99
- # 从标准正态分布采样 ε ~ N(0, I)
100
- eps = torch.randn_like(std)
101
- print(f"噪声 ε 形状: {eps.shape}")
102
-
103
- # 重参数化:z = μ + σ ⊙ ε
104
- z = mu + eps * std
105
- print(f"采样结果 z 形状: {z.shape}")
106
-
107
  return z
108
-
109
-
110
- class VAELoss:
111
- """VAE损失函数:重建损失 + KL散度"""
112
 
113
- @staticmethod
114
- def loss_function(recon_x, x, mu, logvar):
115
- print("\n📊 损失计算过程:")
116
 
117
- # 1. 重建损失(二元交叉熵)
118
- BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
119
- print(f"重建损失 BCE: {BCE.item():.2f}")
 
 
120
 
121
- # 2. KL散度(潜在分布与标准正态分布的差异)
122
- # KL = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
123
- KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
124
- print(f"KL散度 KLD: {KLD.item():.2f}")
125
-
126
- # 3. 总损失
127
- total_loss = BCE + KLD
128
- print(f"总损失: {total_loss.item():.2f} (BCE + KLD)")
129
-
130
- return total_loss
131
-
132
-
133
- class VAE(nn.Module):
134
- """完整的VAE模型"""
135
 
136
- def __init__(self, input_dim=784, hidden_dims=[512, 256], latent_dim=20):
137
- super(VAE, self).__init__()
138
-
139
- self.encoder = Encoder(input_dim, hidden_dims, latent_dim)
140
- self.decoder = Decoder(latent_dim, hidden_dims[::-1], input_dim)
141
-
142
  def forward(self, x):
143
- print("=" * 60)
144
- print("🚀 VAE 完整前向传播")
145
- print("=" * 60)
146
-
147
- # 编码器:x (μ, logvar)
148
- mu, logvar = self.encoder(x)
149
-
150
- # 重参数化:从N(μ, σ²)采样z
151
- z = Reparameterization.reparameterize(mu, logvar)
152
-
153
- # 解码器:z → 重建数据
154
- recon_x = self.decoder(z)
155
-
156
- return recon_x, mu, logvar
157
 
158
 
159
- # ============================================================================
160
- # 测试代码
161
- # ============================================================================
162
 
163
- def test_vae():
164
- """测试VAE模型结构"""
165
-
166
- print("🧪 开始测试VAE模型...")
167
-
168
- # 创建模拟数据(batch_size=4, 输入维度784)
169
- batch_size = 4
170
- input_dim = 784
171
- x = torch.randn(batch_size, input_dim)
172
-
173
- print(f"\n📦 输入数据形状: {x.shape}")
174
 
175
- # 初始化VAE模型
176
- model = VAE(input_dim=input_dim)
 
 
 
 
 
 
 
 
 
177
 
178
- # 前向传播
179
- recon_x, mu, logvar = model(x)
 
 
 
 
180
 
181
- # 计算损失
182
- loss = VAELoss.loss_function(recon_x, x, mu, logvar)
183
 
184
- print("\n✅ VAE模型测试完成!")
185
-
186
- return model, loss
 
 
 
 
 
 
 
 
 
 
 
 
187
 
 
 
188
 
189
- if __name__ == "__main__":
190
- test_vae()
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
 
5
 
6
+ class VAE(nn.Module):
7
+ """变分自编码器"""
8
 
9
+ def __init__(self, latent_dim=20):
10
+ super(VAE, self).__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ # ============================================
13
+ # Encoder (编码器)
14
+ # ============================================
15
+
16
+ # 卷积层 1: 1→32 channels, 28×28→14×14
17
+ self.conv1 = nn.Conv2d(
18
+ in_channels=1,
19
+ out_channels=32,
20
+ kernel_size=4,
21
+ stride=2,
22
+ padding=1
23
+ )
24
+
25
+ # 卷积层 2: 32→64 channels, 14×14→7×7
26
+ self.conv2 = nn.Conv2d(
27
+ in_channels=32,
28
+ out_channels=64,
29
+ kernel_size=4,
30
+ stride=2,
31
+ padding=1
32
+ )
33
+
34
+ # 全连接层: 3136→256
35
+ self.fc1 = nn.Linear(64 * 7 * 7, 256)
36
+
37
+ # 潜在空间分支
38
+ self.fc_mu = nn.Linear(256, latent_dim) # 均值
39
+ self.fc_logvar = nn.Linear(256, latent_dim) # 对数方差
40
+
41
+ # ============================================
42
+ # Decoder (解码器)
43
+ # ============================================
44
+
45
+ # 全连接层: 20→256→3136
46
+ self.fc2 = nn.Linear(latent_dim, 256)
47
+ self.fc3 = nn.Linear(256, 64 * 7 * 7)
48
+
49
+ # 转置卷积 1: 64→32 channels, 7×7→14×14
50
+ self.deconv1 = nn.ConvTranspose2d(
51
+ in_channels=64,
52
+ out_channels=32,
53
+ kernel_size=4,
54
+ stride=2,
55
+ padding=1
56
+ )
57
+
58
+ # 转置卷积 2: 32→1 channels, 14×14→28×28
59
+ self.deconv2 = nn.ConvTranspose2d(
60
+ in_channels=32,
61
+ out_channels=1,
62
+ kernel_size=4,
63
+ stride=2,
64
+ padding=1
65
+ )
66
 
67
+ def encode(self, x):
68
+ """编码器: x → μ, log(σ²)"""
69
+ # x: (batch, 1, 28, 28)
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ h = F.relu(self.conv1(x)) # (batch, 32, 14, 14)
72
+ h = F.relu(self.conv2(h)) # → (batch, 64, 7, 7)
73
+ h = h.view(-1, 64 * 7 * 7) # → (batch, 3136)
74
+ h = F.relu(self.fc1(h)) # → (batch, 256)
75
 
76
+ mu = self.fc_mu(h) #(batch, 20)
77
+ logvar = self.fc_logvar(h) # → (batch, 20)
 
78
 
79
+ return mu, logvar
 
 
 
 
 
 
 
 
80
 
81
+ def reparameterize(self, mu, logvar):
82
+ """重参数化: z = μ + σε"""
83
+ std = torch.exp(0.5 * logvar) # σ = exp(log(σ²)/2)
84
+ eps = torch.randn_like(std) # ε ~ N(0,1)
85
+ z = mu + eps * std # z = μ + σε
 
 
 
 
 
 
 
 
 
 
 
 
86
  return z
 
 
 
 
87
 
88
+ def decode(self, z):
89
+ """解码器: z x'"""
90
+ # z: (batch, 20)
91
 
92
+ h = F.relu(self.fc2(z)) # → (batch, 256)
93
+ h = F.relu(self.fc3(h)) # → (batch, 3136)
94
+ h = h.view(-1, 64, 7, 7) # → (batch, 64, 7, 7)
95
+ h = F.relu(self.deconv1(h)) # → (batch, 32, 14, 14)
96
+ x_recon = torch.sigmoid(self.deconv2(h)) # → (batch, 1, 28, 28)
97
 
98
+ return x_recon
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
 
 
 
 
 
100
  def forward(self, x):
101
+ """前向传播"""
102
+ mu, logvar = self.encode(x) # 编码
103
+ z = self.reparameterize(mu, logvar) # 采样
104
+ x_recon = self.decode(z) # 解码
105
+ return x_recon, mu, logvar
 
 
 
 
 
 
 
 
 
106
 
107
 
108
+ # ============================================
109
+ # 损失函数
110
+ # ============================================
111
 
112
+ def vae_loss(x_recon, x, mu, logvar):
113
+ """
114
+ VAE 损失 = 重建损失 + KL 散度
 
 
 
 
 
 
 
 
115
 
116
+ Args:
117
+ x_recon: 重建图像 (batch, 1, 28, 28)
118
+ x: 原始图像 (batch, 1, 28, 28)
119
+ mu: 均值 (batch, latent_dim)
120
+ logvar: 对数方差 (batch, latent_dim)
121
+ """
122
+ # 1. 重建损失 (Binary Cross Entropy)
123
+ # 衡量重建图像与原图的差异
124
+ recon_loss = F.binary_cross_entropy(
125
+ x_recon, x, reduction='sum'
126
+ )
127
 
128
+ # 2. KL 散度 (Kullback-Leibler Divergence)
129
+ # 衡量 q(z|x) 与先验 p(z)=N(0,1) 的差异
130
+ # KL(q||p) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
131
+ kl_loss = -0.5 * torch.sum(
132
+ 1 + logvar - mu.pow(2) - logvar.exp()
133
+ )
134
 
135
+ # 总损失
136
+ total_loss = recon_loss + kl_loss
137
 
138
+ return total_loss, recon_loss, kl_loss
139
+
140
+
141
+ # ============================================
142
+ # 使用示例
143
+ # ============================================
144
+
145
+ # 创建模型
146
+ model = VAE(latent_dim=20)
147
+
148
+ # 输入图像 (batch_size=32, channels=1, height=28, width=28)
149
+ x = torch.randn(32, 1, 28, 28)
150
+
151
+ # 前向传播
152
+ x_recon, mu, logvar = model(x)
153
 
154
+ # 计算损失
155
+ loss, recon_loss, kl_loss = vae_loss(x_recon, x, mu, logvar)
156
 
157
+ print(f"重建形状: {x_recon.shape}") # (32, 1, 28, 28)
158
+ print(f"μ 形状: {mu.shape}") # (32, 20)
159
+ print(f"log(σ²) 形状: {logvar.shape}") # (32, 20)
160
+ print(f"总损失: {loss.item():.2f}")
161
+ print(f"重建损失: {recon_loss.item():.2f}")
162
+ print(f"KL散度: {kl_loss.item():.2f}")