gihakkk commited on
Commit
e1b5e8e
ยท
verified ยท
1 Parent(s): fa9e235

Upload UNet.py

Browse files
Files changed (1) hide show
  1. UNet.py +190 -0
UNet.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # ๋‚œ์ˆ˜ ์ƒ์„ฑ์„ ์œ„ํ•œ ํ—ฌํผ (๊ฐ€์ค‘์น˜๋ฅผ ์˜๋ฏธ)
4
+ def randn(*shape):
5
+ # Xavier/Glorot ์ดˆ๊ธฐํ™”์™€ ์œ ์‚ฌํ•˜๊ฒŒ ์Šค์ผ€์ผ๋ง (์ดํ•ด๋ฅผ ๋•๊ธฐ ์œ„ํ•จ)
6
+ return np.random.randn(*shape) * np.sqrt(2.0 / (shape[0] * np.prod(shape[2:])))
7
+
8
+ def randn_bias(*shape):
9
+ return np.zeros(shape)
10
+
11
+ class NumpyUNet:
12
+ def __init__(self, in_channels=1, out_classes=2):
13
+ """
14
+ NumPy๋กœ U-Net ๊ฐ€์ค‘์น˜๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.
15
+ ์—ฌ๊ธฐ์„œ๋Š” 2-Level U-Net์„ ํ•˜๋“œ์ฝ”๋”ฉํ•ฉ๋‹ˆ๋‹ค. (์˜ˆ: 64 -> 128 -> 256(๋ฐ”๋‹ฅ) -> 128 -> 64)
16
+ """
17
+ self.weights = {}
18
+
19
+ # --- ์ธ์ฝ”๋” (Encoder) ๊ฐ€์ค‘์น˜ ---
20
+ # Level 1 (Input -> 64 filters)
21
+ self.weights['enc1_w1'] = randn(64, in_channels, 3, 3)
22
+ self.weights['enc1_b1'] = randn_bias(64)
23
+ self.weights['enc1_w2'] = randn(64, 64, 3, 3)
24
+ self.weights['enc1_b2'] = randn_bias(64)
25
+
26
+ # Level 2 (64 -> 128 filters)
27
+ self.weights['enc2_w1'] = randn(128, 64, 3, 3)
28
+ self.weights['enc2_b1'] = randn_bias(128)
29
+ self.weights['enc2_w2'] = randn(128, 128, 3, 3)
30
+ self.weights['enc2_b2'] = randn_bias(128)
31
+
32
+ # --- ๋ฐ”๋‹ฅ (Bottleneck) ๊ฐ€์ค‘์น˜ ---
33
+ # (128 -> 256 filters)
34
+ self.weights['bottle_w1'] = randn(256, 128, 3, 3)
35
+ self.weights['bottle_b1'] = randn_bias(256)
36
+ self.weights['bottle_w2'] = randn(256, 256, 3, 3)
37
+ self.weights['bottle_b2'] = randn_bias(256)
38
+
39
+ # --- ๋””์ฝ”๋” (Decoder) ๊ฐ€์ค‘์น˜ ---
40
+ # Level 1 (Up-Conv 256 + Skip 128 = 384 -> 128 filters)
41
+ self.weights['dec1_w1'] = randn(128, 384, 3, 3)
42
+ self.weights['dec1_b1'] = randn_bias(128)
43
+ self.weights['dec1_w2'] = randn(128, 128, 3, 3)
44
+ self.weights['dec1_b2'] = randn_bias(128)
45
+
46
+ # Level 2 (Up-Conv 128 + Skip 64 = 192 -> 64 filters)
47
+ self.weights['dec2_w1'] = randn(64, 192, 3, 3)
48
+ self.weights['dec2_b1'] = randn_bias(64)
49
+ self.weights['dec2_w2'] = randn(64, 64, 3, 3)
50
+ self.weights['dec2_b2'] = randn_bias(64)
51
+
52
+ # --- ์ตœ์ข… 1x1 Conv ---
53
+ self.weights['final_w'] = randn(out_classes, 64, 1, 1)
54
+ self.weights['final_b'] = randn_bias(out_classes)
55
+
56
+ # --- U-Net์˜ ํ•ต์‹ฌ ์—ฐ์‚ฐ๋“ค ---
57
+
58
+ def _relu(self, x):
59
+ return np.maximum(0, x)
60
+
61
+ def _conv2d(self, x, kernel, bias, padding=1):
62
+ """
63
+ NumPy๋ฅผ ์‚ฌ์šฉํ•œ 'same' 2D ์ปจ๋ณผ๋ฃจ์…˜ (stride=1)
64
+ x: (In_C, H, W)
65
+ kernel: (Out_C, In_C, K, K)
66
+ bias: (Out_C,)
67
+ """
68
+ in_C, in_H, in_W = x.shape
69
+ out_C, _, K, _ = kernel.shape
70
+
71
+ # ํŒจ๋”ฉ ์ ์šฉ ('same'์„ ์œ„ํ•ด)
72
+ padded_x = np.pad(x, ((0, 0), (padding, padding), (padding, padding)), 'constant')
73
+
74
+ # ์ถœ๋ ฅ ๋งต ์ดˆ๊ธฐํ™”
75
+ out_H, out_W = in_H, in_W # 'same' ํŒจ๋”ฉ์ด๋ฏ€๋กœ ํฌ๊ธฐ ๋™์ผ
76
+ output = np.zeros((out_C, out_H, out_W))
77
+
78
+ # ์ปจ๋ณผ๋ฃจ์…˜ ์—ฐ์‚ฐ (๋งค์šฐ ๋А๋ฆฐ ์ด์ค‘ ๋ฃจํ”„)
79
+ for k in range(out_C): # ์ถœ๋ ฅ ์ฑ„๋„
80
+ for i in range(out_H): # ๋†’์ด
81
+ for j in range(out_W): # ๋„ˆ๋น„
82
+ # (In_C, K, K) ํฌ๊ธฐ์˜ ํŒจ์น˜๋ฅผ ์ž˜๋ผ๋ƒ„
83
+ patch = padded_x[:, i:i+K, j:j+K]
84
+ # (Out_C[k], In_C, K, K) ์ปค๋„๊ณผ ์š”์†Œ๋ณ„ ๊ณฑ์…ˆ ํ›„ ํ•ฉ์‚ฐ
85
+ output[k, i, j] = np.sum(patch * kernel[k]) + bias[k]
86
+ return output
87
+
88
+ def _max_pool2d(self, x, pool_size=2):
89
+ """ 2x2 Max Pooling """
90
+ in_C, in_H, in_W = x.shape
91
+ out_H = in_H // pool_size
92
+ out_W = in_W // pool_size
93
+ output = np.zeros((in_C, out_H, out_W))
94
+
95
+ for c in range(in_C):
96
+ for i in range(out_H):
97
+ for j in range(out_W):
98
+ patch = x[c, i*pool_size:(i+1)*pool_size, j*pool_size:(j+1)*pool_size]
99
+ output[c, i, j] = np.max(patch)
100
+ return output
101
+
102
+ def _upsample2d(self, x, scale=2):
103
+ """
104
+ Transposed Conv ๋Œ€์‹  ๊ฐ„๋‹จํ•œ Nearest-neighbor ์—…์ƒ˜ํ”Œ๋ง ๊ตฌํ˜„
105
+ """
106
+ # np.repeat๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ํ–‰๊ณผ ์—ด์„ 'scale'๋งŒํผ ๋ฐ˜๋ณต
107
+ return x.repeat(scale, axis=1).repeat(scale, axis=2)
108
+
109
+ def _conv_block(self, x, w1, b1, w2, b2):
110
+ """ (3x3 Conv + ReLU) * 2ํšŒ ๋ฐ˜๋ณต ๋ธ”๋ก """
111
+ x = self._conv2d(x, w1, b1, padding=1)
112
+ x = self._relu(x)
113
+ x = self._conv2d(x, w2, b2, padding=1)
114
+ x = self._relu(x)
115
+ return x
116
+
117
+ # --- U-Net ์ˆœ์ „ํŒŒ (Forward Pass) ---
118
+
119
+ def forward(self, x):
120
+ """
121
+ U-Net ์•„ํ‚คํ…์ฒ˜๋ฅผ ๋”ฐ๋ผ ์ˆœ์ „ํŒŒ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
122
+ x: (In_C, H, W)
123
+ """
124
+ w = self.weights
125
+ skip_connections = []
126
+
127
+ print(f"Input: \t\t{x.shape}")
128
+
129
+ # === 1. ์ธ์ฝ”๋” (์ˆ˜์ถ• ๊ฒฝ๋กœ) ===
130
+ # Level 1
131
+ e1 = self._conv_block(x, w['enc1_w1'], w['enc1_b1'], w['enc1_w2'], w['enc1_b2'])
132
+ p1 = self._max_pool2d(e1)
133
+ skip_connections.append(e1) # ์Šคํ‚ต ์—ฐ๊ฒฐ์„ ์œ„ํ•ด ์ €์žฅ
134
+ print(f"Encoder 1: \t{e1.shape} -> Pool: {p1.shape}")
135
+
136
+ # Level 2
137
+ e2 = self._conv_block(p1, w['enc2_w1'], w['enc2_b1'], w['enc2_w2'], w['enc2_b2'])
138
+ p2 = self._max_pool2d(e2)
139
+ skip_connections.append(e2) # ์Šคํ‚ต ์—ฐ๊ฒฐ์„ ์œ„ํ•ด ์ €์žฅ
140
+ print(f"Encoder 2: \t{e2.shape} -> Pool: {p2.shape}")
141
+
142
+ # === 2. ๋ฐ”๋‹ฅ (Bottleneck) ===
143
+ b = self._conv_block(p2, w['bottle_w1'], w['bottle_b1'], w['bottle_w2'], w['bottle_b2'])
144
+ print(f"Bottleneck: \t{b.shape}")
145
+
146
+ # === 3. ๋””์ฝ”๋” (ํ™•์žฅ ๊ฒฝ๋กœ) ===
147
+ skip_connections = skip_connections[::-1] # ์ˆœ์„œ ๋’ค์ง‘๊ธฐ (LIFO)
148
+
149
+ # Level 1
150
+ u1 = self._upsample2d(b)
151
+ s1 = skip_connections[0] # Encoder 2์˜ ์ถœ๋ ฅ (e2)
152
+ c1 = np.concatenate((u1, s1), axis=0) # ์ฑ„๋„ ์ถ•(axis=0)์œผ๋กœ ๊ฒฐํ•ฉ
153
+ d1 = self._conv_block(c1, w['dec1_w1'], w['dec1_b1'], w['dec1_w2'], w['dec1_b2'])
154
+ print(f"Decoder 1: \tUp: {u1.shape} + Skip: {s1.shape} = Concat: {c1.shape} -> Block: {d1.shape}")
155
+
156
+ # Level 2
157
+ u2 = self._upsample2d(d1)
158
+ s2 = skip_connections[1] # Encoder 1์˜ ์ถœ๋ ฅ (e1)
159
+ c2 = np.concatenate((u2, s2), axis=0) # ๊ฒฐํ•ฉ
160
+ d2 = self._conv_block(c2, w['dec2_w1'], w['dec2_b1'], w['dec2_w2'], w['dec2_b2'])
161
+ print(f"Decoder 2: \tUp: {u2.shape} + Skip: {s2.shape} = Concat: {c2.shape} -> Block: {d2.shape}")
162
+
163
+ # === 4. ์ตœ์ข… 1x1 Conv ===
164
+ # 1x1 Conv๋Š” 3x3 Conv์™€ ๋™์ผํ•˜์ง€๋งŒ K=1, padding=0์„ ์‚ฌ์šฉ
165
+ output = self._conv2d(d2, w['final_w'], w['final_b'], padding=0)
166
+ print(f"Final 1x1 Conv: {output.shape}")
167
+
168
+ return output
169
+
170
+ # --- ์‹คํ–‰ ์˜ˆ์‹œ ---
171
+ if __name__ == "__main__":
172
+ # (์ฑ„๋„, ๋†’์ด, ๋„ˆ๋น„) - ๋†’์ด/๋„ˆ๋น„๋Š” 2์˜ ๋ฐฐ์ˆ˜์—ฌ์•ผ ํ•จ
173
+ # (๋งค์šฐ ๋А๋ฆฌ๋ฏ€๋กœ ์ž‘์€ ์ด๋ฏธ์ง€ ์‚ฌ์šฉ)
174
+ dummy_image = np.random.randn(1, 32, 32)
175
+
176
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” (์ž…๋ ฅ ์ฑ„๋„ 1, ์ถœ๋ ฅ ํด๋ž˜์Šค 2)
177
+ model = NumpyUNet(in_channels=1, out_classes=2)
178
+
179
+ print("--- U-Net Forward Pass Start ---")
180
+
181
+ # ์ˆœ์ „ํŒŒ ์‹คํ–‰
182
+ output_map = model.forward(dummy_image)
183
+
184
+ print("--- U-Net Forward Pass End ---")
185
+ print(f"\n์ตœ์ข… ์ž…๋ ฅ ์ด๋ฏธ์ง€ Shape: {dummy_image.shape}")
186
+ print(f"์ตœ์ข… ์ถœ๋ ฅ ๋งต Shape: {output_map.shape}")
187
+
188
+ # ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ์˜ ๋†’์ด/๋„ˆ๋น„๊ฐ€ ๋™์ผํ•˜๊ณ  ์ฑ„๋„ ์ˆ˜๋งŒ ๋ฐ”๋€ ๊ฒƒ์„ ํ™•์ธ
189
+ assert dummy_image.shape[1:] == output_map.shape[1:]
190
+ assert output_map.shape[0] == 2