gihakkk commited on
Commit
33ac553
ยท
verified ยท
1 Parent(s): 8217c1c

Upload 2 files

Browse files
Files changed (2) hide show
  1. Autoencoder.py +63 -0
  2. CNN_AutoEncoder.py +66 -0
Autoencoder.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class DenseAutoencoder(nn.Module):
5
+ def __init__(self, input_dim=784, hidden_dims=[256, 64], latent_dim=32):
6
+ super().__init__()
7
+
8
+ # Encoder
9
+ encoder_layers = []
10
+ prev_dim = input_dim
11
+ for h in hidden_dims: # [256, 64]
12
+ encoder_layers += [nn.Linear(prev_dim, h), nn.ReLU()]
13
+ prev_dim = h
14
+ encoder_layers.append(nn.Linear(prev_dim, latent_dim))
15
+ self.encoder = nn.Sequential(*encoder_layers)
16
+
17
+ # Decoder
18
+ decoder_layers = []
19
+ prev_dim = latent_dim
20
+ for h in reversed(hidden_dims): # [64, 256]
21
+ decoder_layers += [nn.Linear(prev_dim, h), nn.ReLU()]
22
+ prev_dim = h
23
+
24
+ # ๋งˆ์ง€๋ง‰ ์ถœ๋ ฅ ๋ ˆ์ด์–ด (Sigmoid ์ œ๊ฑฐ)
25
+ decoder_layers.append(nn.Linear(prev_dim, input_dim))
26
+ self.decoder = nn.Sequential(*decoder_layers)
27
+
28
+ def forward(self, x):
29
+ z = self.encoder(x)
30
+ x_hat = self.decoder(z)
31
+ return x_hat, z
32
+
33
+ # --- ์ด ์Šคํฌ๋ฆฝํŠธ๊ฐ€ ์ง์ ‘ ์‹คํ–‰๋  ๋•Œ๋งŒ ์•„๋ž˜ ์ฝ”๋“œ๊ฐ€ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค ---
34
+ if __name__ == "__main__":
35
+
36
+ # 1. ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
37
+ model = DenseAutoencoder(latent_dim=32)
38
+
39
+ # 2. ํ…Œ์ŠคํŠธ์šฉ ๋”๋ฏธ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
40
+ # (Batch_Size, Input_Dim)
41
+ # MSELoss๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ randn์œผ๋กœ ์ƒ์„ฑ
42
+ batch_size = 64
43
+ x = torch.randn(batch_size, 784)
44
+
45
+ print(f"์›๋ณธ ๋ฐ์ดํ„ฐ(x) shape: {x.shape}")
46
+
47
+ # 3. ๋ชจ๋ธ ์‹คํ–‰ (Forward pass)
48
+ x_hat, z = model(x)
49
+
50
+ print(f"์••์ถ•๋œ ์ž ์žฌ ๋ฒกํ„ฐ(z) shape: {z.shape}")
51
+ print(f"๋ณต์›๋œ ๋ฐ์ดํ„ฐ(x_hat) shape: {x_hat.shape}")
52
+
53
+ # 4. ์†์‹ค(Loss) ๊ณ„์‚ฐ
54
+ # ๋ชจ๋ธ ์ถœ๋ ฅ์ด Linear์ด๋ฏ€๋กœ MSELoss ์‚ฌ์šฉ
55
+ loss_fn = nn.MSELoss()
56
+ loss = loss_fn(x_hat, x)
57
+
58
+ print(f"\n๊ณ„์‚ฐ๋œ ์†์‹ค(MSELoss): {loss.item()}")
59
+
60
+ # 5. ์—ญ์ „ํŒŒ (Backpropagation)
61
+ loss.backward()
62
+
63
+ print("์—ญ์ „ํŒŒ ์™„๋ฃŒ")
CNN_AutoEncoder.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class CNNAutoencoder(nn.Module):
5
+ def __init__(self, latent_dim=32):
6
+ super().__init__()
7
+ # Encoder
8
+ self.encoder = nn.Sequential(
9
+ # (B, 1, 28, 28)
10
+ nn.Conv2d(1, 16, 3, padding=1), nn.ReLU(), # (B, 16, 28, 28)
11
+ nn.MaxPool2d(2, 2), # (B, 16, 14, 14)
12
+ nn.Conv2d(16, 32, 3, padding=1), nn.ReLU(), # (B, 32, 14, 14)
13
+ nn.MaxPool2d(2, 2), # (B, 32, 7, 7)
14
+ nn.Flatten(),
15
+ nn.Linear(32*7*7, latent_dim) # (B, 32)
16
+ )
17
+
18
+ # Decoder
19
+ self.decoder_input = nn.Linear(latent_dim, 32*7*7)
20
+ self.decoder = nn.Sequential(
21
+ nn.Unflatten(1, (32, 7, 7)),
22
+ # (B, 32, 7, 7)
23
+ nn.ConvTranspose2d(32, 16, 2, stride=2), nn.ReLU(),
24
+ # (B, 16, 14, 14)
25
+ nn.ConvTranspose2d(16, 1, 2, stride=2),
26
+ # (B, 1, 28, 28)
27
+ nn.Sigmoid()
28
+ )
29
+
30
+ def forward(self, x):
31
+ z = self.encoder(x)
32
+ z_dec = self.decoder_input(z)
33
+ x_hat = self.decoder(z_dec)
34
+ return x_hat, z
35
+
36
+ # --- ์ด ์Šคํฌ๋ฆฝํŠธ๊ฐ€ ์ง์ ‘ ์‹คํ–‰๋  ๋•Œ๋งŒ ์•„๋ž˜ ์ฝ”๋“œ๊ฐ€ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค ---
37
+ if __name__ == "__main__":
38
+
39
+ # 1. ๋ชจ๋ธ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
40
+ model = CNNAutoencoder(latent_dim=32)
41
+
42
+ # 2. ํ…Œ์ŠคํŠธ์šฉ ๋”๋ฏธ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
43
+ # (Batch_Size, Channels, Height, Width)
44
+ # BCELoss๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด 0๊ณผ 1 ์‚ฌ์ด์˜ ๊ฐ’์œผ๋กœ ์ƒ์„ฑ (torch.rand)
45
+ batch_size = 64
46
+ x = torch.rand(batch_size, 1, 28, 28)
47
+
48
+ print(f"์›๋ณธ ๋ฐ์ดํ„ฐ(x) shape: {x.shape}")
49
+
50
+ # 3. ๋ชจ๋ธ ์‹คํ–‰ (Forward pass)
51
+ x_hat, z = model(x)
52
+
53
+ print(f"์••์ถ•๋œ ์ž ์žฌ ๋ฒกํ„ฐ(z) shape: {z.shape}")
54
+ print(f"๋ณต์›๋œ ๋ฐ์ดํ„ฐ(x_hat) shape: {x_hat.shape}")
55
+
56
+ # 4. ์†์‹ค(Loss) ๊ณ„์‚ฐ
57
+ # ๋ชจ๋ธ ์ถœ๋ ฅ์ด Sigmoid์ด๋ฏ€๋กœ BCELoss ์‚ฌ์šฉ
58
+ loss_fn = nn.BCELoss()
59
+ loss = loss_fn(x_hat, x)
60
+
61
+ print(f"\n๊ณ„์‚ฐ๋œ ์†์‹ค(BCELoss): {loss.item()}")
62
+
63
+ # 5. ์—ญ์ „ํŒŒ (Backpropagation)
64
+ loss.backward()
65
+
66
+ print("์—ญ์ „ํŒŒ ์™„๋ฃŒ")