ymin98 commited on
Commit
dd5a134
·
verified ·
1 Parent(s): 8f5f7e0

Upload 15 files

Browse files
Files changed (5) hide show
  1. README +14 -0
  2. app.py +375 -0
  3. model.py +565 -0
  4. requirements.txt +8 -0
  5. wacaunet_val_f1_331_0.8757.pth +3 -0
README ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: WACA-UNet
3
+ emoji: ⚡
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.0.0
8
+ app_file: app.py
9
+ ---
10
+
11
+ # WACA-UNet IR-drop Demo
12
+
13
+ Gradio demo for WACA-UNet (Weakness-Aware Channel Attention U-Net)
14
+ for static IR-drop prediction on the ICCAD-2023 benchmark.
app.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import math
5
+ from typing import Tuple
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pyplot as plt
14
+ from PIL import Image
15
+
16
+ # ---- Project modules ----
17
+ from model import WACA_Unet
18
+
19
+
20
+ # ==========================
21
+ # Settings
22
+ # ==========================
23
+
24
+ IN_CHANNELS = 25
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ # Path to the pretrained checkpoint.
28
+ # This file should be available in the Space (or local environment).
29
+ MODEL_CHECKPOINT_PATH = (
30
+ "wacaunet_val_f1_331_0.8757.pth"
31
+ )
32
+
33
+ # WACA-UNet output unit:
34
+ # - The model is assumed to predict IR-drop directly in mV.
35
+ # - If SCALE_TO_V is True, the demo will convert mV -> V (divide by 1000) for display.
36
+ SCALE_TO_V = False
37
+
38
+ # Directory containing example .npy inputs (created offline)
39
+ SAMPLES_DIR = "tools/samples"
40
+
41
+
42
+ # ==========================
43
+ # Utility functions
44
+ # ==========================
45
+
46
+ def load_checkpoint_state(path: str, device: str):
47
+ """
48
+ Load a checkpoint and return a state_dict in a robust way.
49
+ Handles common patterns: {'state_dict': ...}, {'net': ...}, or raw state_dict.
50
+ """
51
+ state = torch.load(path, map_location=device)
52
+ if isinstance(state, dict):
53
+ if "state_dict" in state:
54
+ return state["state_dict"]
55
+ if "net" in state:
56
+ return state["net"]
57
+ # Fallback: assume the object itself is a state_dict
58
+ return state
59
+
60
+
61
+ def get_model() -> Tuple[nn.Module, str]:
62
+ """
63
+ Load WACA-UNet once and cache it for reuse in the Gradio session.
64
+ """
65
+ if not hasattr(get_model, "_cache"):
66
+ if not os.path.exists(MODEL_CHECKPOINT_PATH):
67
+ raise FileNotFoundError(
68
+ f"Checkpoint not found at '{MODEL_CHECKPOINT_PATH}'. "
69
+ f"Please upload the checkpoint file and update MODEL_CHECKPOINT_PATH."
70
+ )
71
+
72
+ device = DEVICE
73
+ model = WACA_Unet(in_ch=IN_CHANNELS, depth=4, base_ch=64)
74
+
75
+ state_dict = load_checkpoint_state(MODEL_CHECKPOINT_PATH, device)
76
+ # Strip 'module.' prefix from keys if the checkpoint was saved with DDP
77
+ new_state = {}
78
+ for k, v in state_dict.items():
79
+ if k.startswith("module."):
80
+ new_state[k[len("module."):]] = v
81
+ else:
82
+ new_state[k] = v
83
+
84
+ model.load_state_dict(new_state, strict=False)
85
+ model.to(device)
86
+ model.eval()
87
+
88
+ get_model._cache = (model, device)
89
+ return get_model._cache # type: ignore
90
+
91
+
92
+ def list_sample_files() -> list:
93
+ """
94
+ List sample .npy files under SAMPLES_DIR.
95
+ Each .npy is assumed to store a (25, H, W) or (H, W, 25) array.
96
+ """
97
+ if not os.path.exists(SAMPLES_DIR):
98
+ return []
99
+ files = []
100
+ for fname in sorted(os.listdir(SAMPLES_DIR)):
101
+ if fname.lower().endswith(".npy"):
102
+ files.append(os.path.join(SAMPLES_DIR, fname))
103
+ return files
104
+
105
+
106
+ def fig_to_pil(fig: matplotlib.figure.Figure) -> Image.Image:
107
+ buf = io.BytesIO()
108
+ fig.savefig(buf, format="png", bbox_inches="tight", dpi=150)
109
+ plt.close(fig)
110
+ buf.seek(0)
111
+ img = Image.open(buf).convert("RGB")
112
+ return img
113
+
114
+
115
+ def ensure_chw(arr: np.ndarray) -> np.ndarray:
116
+ """
117
+ Ensure the array has shape (C, H, W).
118
+
119
+ Supported shapes:
120
+ - (C, H, W): returned as-is
121
+ - (H, W, C): transposed to (C, H, W)
122
+ - (H, W): treated as (1, H, W) (not recommended for this demo)
123
+ """
124
+ if arr.ndim == 3:
125
+ c_first = arr.shape[0]
126
+ c_last = arr.shape[-1]
127
+ if c_first == IN_CHANNELS:
128
+ # Already (C, H, W)
129
+ return arr.astype(np.float32)
130
+ elif c_last == IN_CHANNELS:
131
+ # (H, W, C) -> (C, H, W)
132
+ return np.transpose(arr, (2, 0, 1)).astype(np.float32)
133
+ else:
134
+ raise ValueError(
135
+ f"3D array but channel dimension is not {IN_CHANNELS}. "
136
+ f"Got shape={arr.shape}."
137
+ )
138
+ elif arr.ndim == 2:
139
+ # (H, W) -> (1, H, W)
140
+ return arr.astype(np.float32)[None, ...]
141
+ else:
142
+ raise ValueError(
143
+ f"Unsupported input ndim: {arr.ndim}, shape={arr.shape}. "
144
+ f"Expected (C,H,W) or (H,W,C) or (H,W)."
145
+ )
146
+
147
+
148
+ def preprocess_input(arr: np.ndarray) -> torch.Tensor:
149
+ """
150
+ Convert a numpy array to a (1, C, H, W) torch.FloatTensor.
151
+
152
+ Assumptions:
153
+ - The .npy already contains the same normalization used during training,
154
+ e.g., per-channel z-score from the IRDropDataset / CFIRSTNET configuration.
155
+ """
156
+ chw = ensure_chw(arr)
157
+ if chw.shape[0] != IN_CHANNELS:
158
+ raise ValueError(
159
+ f"Unexpected number of channels: expected {IN_CHANNELS}, "
160
+ f"got C={chw.shape[0]}, shape={chw.shape}."
161
+ )
162
+ x = torch.from_numpy(chw).unsqueeze(0) # (1, C, H, W)
163
+ return x
164
+
165
+
166
+ def plot_input_grid(chw: np.ndarray, title_prefix: str = "Input") -> Image.Image:
167
+ """
168
+ Visualize 25 input channels as a 5×5 grid.
169
+ """
170
+ C, H, W = chw.shape
171
+ cols = 5
172
+ rows = math.ceil(C / cols)
173
+
174
+ fig, axes = plt.subplots(rows, cols, figsize=(cols * 2.0, rows * 2.0))
175
+ axes = np.atleast_2d(axes)
176
+
177
+ for idx in range(rows * cols):
178
+ r = idx // cols
179
+ c = idx % cols
180
+ ax = axes[r, c]
181
+ if idx < C:
182
+ ch_img = chw[idx]
183
+ ax.imshow(ch_img, cmap="jet")
184
+ ax.set_title(f"{title_prefix} C{idx}", fontsize=7)
185
+ ax.axis("off")
186
+ else:
187
+ ax.axis("off")
188
+
189
+ fig.suptitle("25-channel Input (5×5)", fontsize=12)
190
+ fig.tight_layout()
191
+ return fig_to_pil(fig)
192
+
193
+
194
+ def plot_prediction(pred_2d: np.ndarray, to_v: bool = False) -> Image.Image:
195
+ """
196
+ Visualize the predicted IR-drop (single channel) with a colormap and colorbar.
197
+
198
+ Model output is assumed to be in mV:
199
+ - to_v == False: render directly in mV.
200
+ - to_v == True : convert mV -> V by dividing by 1000 for display.
201
+ """
202
+ pred_mV = np.nan_to_num(pred_2d.astype(np.float32))
203
+
204
+ if to_v:
205
+ pred = pred_mV * 1e-3 # mV -> V
206
+ unit = "V"
207
+ else:
208
+ pred = pred_mV
209
+ unit = "mV"
210
+
211
+ fig, ax = plt.subplots(figsize=(4, 4))
212
+ im = ax.imshow(pred, cmap="jet")
213
+ ax.set_title(f"Predicted IR-drop [{unit}]")
214
+ ax.axis("off")
215
+ fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
216
+ fig.tight_layout()
217
+ return fig_to_pil(fig)
218
+
219
+
220
+ def load_npy_file(path: str) -> np.ndarray:
221
+ return np.array(np.load(path))
222
+
223
+
224
+ def load_uploaded_npy(file_obj) -> np.ndarray:
225
+ """
226
+ Load a numpy array from a Gradio File object.
227
+ """
228
+ return np.load(file_obj.name)
229
+
230
+
231
+ # ==========================
232
+ # Gradio callback
233
+ # ==========================
234
+
235
+ def infer(
236
+ sample_path: str,
237
+ uploaded_file,
238
+ use_uploaded: bool,
239
+ scale_to_v: bool,
240
+ ) -> Tuple[Image.Image, Image.Image, str]:
241
+ """
242
+ Main inference function for Gradio.
243
+
244
+ Args:
245
+ sample_path: Selected path under tools/samples/.
246
+ uploaded_file: User-uploaded .npy file (Gradio File).
247
+ use_uploaded: If True and a file is provided, uploaded_file takes priority.
248
+ scale_to_v: If True, convert model output from mV to V for visualization.
249
+
250
+ Returns:
251
+ input_img: 5×5 grid of the 25 input channels.
252
+ pred_img: Predicted IR-drop heatmap.
253
+ info: Text summary (shapes, ranges, units).
254
+ """
255
+ # 1) Decide input source
256
+ if use_uploaded and uploaded_file is not None:
257
+ arr = load_uploaded_npy(uploaded_file)
258
+ src_desc = f"Uploaded file: {os.path.basename(uploaded_file.name)}"
259
+ else:
260
+ if not sample_path:
261
+ raise gr.Error("Please select a sample file or upload your own .npy input.")
262
+ arr = load_npy_file(sample_path)
263
+ src_desc = f"Sample: {os.path.basename(sample_path)}"
264
+
265
+ # 2) Preprocess
266
+ try:
267
+ x = preprocess_input(arr) # (1, C, H, W)
268
+ except ValueError as e:
269
+ raise gr.Error(str(e))
270
+
271
+ model, device = get_model()
272
+ x = x.to(device)
273
+
274
+ # 3) Inference
275
+ with torch.no_grad():
276
+ out = model(x)
277
+ if isinstance(out, dict) and "x_recon" in out:
278
+ y = out["x_recon"]
279
+ else:
280
+ y = out
281
+
282
+ # y: (1, 1, H, W) or (1, H, W) -> (H, W)
283
+ pred_np = y.detach().cpu().squeeze().numpy()
284
+
285
+ chw = ensure_chw(arr)
286
+ pred_img = plot_prediction(pred_np, to_v=scale_to_v)
287
+ input_img = plot_input_grid(chw, title_prefix="Input")
288
+
289
+ unit_display = "V" if scale_to_v else "mV"
290
+ info = (
291
+ f"{src_desc}\n"
292
+ f"Input shape: {arr.shape}, "
293
+ f"Pred shape: {pred_np.shape}, "
294
+ f"Pred range (model output in mV): "
295
+ f"[{float(pred_np.min()):.4g}, {float(pred_np.max()):.4g}] "
296
+ f"→ displayed in [{unit_display}].\n"
297
+ f"(Input feature composition follows the 25-channel configuration "
298
+ f"from the official CFIRSTNET ICCAD-2023 public repository.)"
299
+ )
300
+
301
+ return input_img, pred_img, info
302
+
303
+
304
+ # ==========================
305
+ # Gradio UI
306
+ # ==========================
307
+
308
+ def build_demo():
309
+ sample_files = list_sample_files()
310
+ sample_choices = (
311
+ [("", "--- choose sample ---")]
312
+ + [(p, os.path.basename(p)) for p in sample_files]
313
+ )
314
+
315
+ with gr.Blocks(title="WACA-UNet IR-drop Demo") as demo:
316
+ gr.Markdown(
317
+ """
318
+ # WACA-UNet IR-drop Prediction Demo
319
+
320
+ - **Input**: 25-channel physical/power-delivery feature maps
321
+ (e.g., HIRD, WR, effective distance, PDN density, etc.).
322
+ - **Output**: Predicted static IR-drop map (visualized in mV or V).
323
+ - The 25-channel input composition follows the configuration used
324
+ in the official [CFIRSTNET GitHub repository](https://github.com/jason122490/CFIRSTNET)
325
+ for the ICCAD-2023 benchmark.
326
+ - You can either:
327
+ - select a pre-generated sample from `tools/samples/`, or
328
+ - upload your own `.npy` file with shape `(25, H, W)` or `(H, W, 25)`.
329
+ """
330
+ )
331
+
332
+
333
+ with gr.Row():
334
+ with gr.Column(scale=1):
335
+ sample_dropdown = gr.Dropdown(
336
+ choices=[c[0] for c in sample_choices],
337
+ value=sample_choices[1][0] if len(sample_choices) > 1 else "",
338
+ label="Sample .npy from tools/samples/",
339
+ info="Pre-generated example inputs (25 × H × W) stored as .npy files.",
340
+ )
341
+ use_uploaded = gr.Checkbox(
342
+ label="Use uploaded file if available",
343
+ value=True,
344
+ )
345
+ uploaded_file = gr.File(
346
+ label="Custom input (.npy, shape = (25, H, W) or (H, W, 25))",
347
+ file_types=[".npy"],
348
+ )
349
+ scale_to_v_input = gr.Checkbox(
350
+ label="Convert prediction from mV to V (divide by 1000)",
351
+ value=SCALE_TO_V,
352
+ )
353
+ run_btn = gr.Button("Run Inference")
354
+
355
+ with gr.Column(scale=2):
356
+ pred_img = gr.Image(label="Predicted IR-drop", type="pil")
357
+ input_img = gr.Image(label="25-channel Input (5×5 grid)", type="pil")
358
+ info_text = gr.Textbox(
359
+ label="Info",
360
+ interactive=False,
361
+ lines=4,
362
+ )
363
+
364
+ run_btn.click(
365
+ fn=infer,
366
+ inputs=[sample_dropdown, uploaded_file, use_uploaded, scale_to_v_input],
367
+ outputs=[input_img, pred_img, info_text],
368
+ )
369
+
370
+ return demo
371
+
372
+
373
+ if __name__ == "__main__":
374
+ demo = build_demo()
375
+ demo.launch()
model.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from timm.layers import DropPath
5
+ import math
6
+ from typing import List, Tuple, Optional, Dict
7
+
8
+
9
+ class WACA_CBAM(nn.Module):
10
+ def __init__(self, channels, reduction=16):
11
+ super(WACA_CBAM, self).__init__()
12
+ self.channels = channels
13
+ if channels < reduction or channels // reduction == 0:
14
+ self.reduced_channels = channels // 2 if channels > 1 else 1
15
+ else:
16
+ self.reduced_channels = channels // reduction
17
+
18
+ self.fc_layers = nn.Sequential(
19
+ nn.Conv2d(self.channels, self.reduced_channels, kernel_size=1, bias=False),
20
+ nn.ReLU(inplace=True),
21
+ nn.Conv2d(self.reduced_channels, self.channels, kernel_size=1, bias=False)
22
+ )
23
+ self.spatial_attn = nn.Sequential(
24
+ nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False),
25
+ nn.Sigmoid()
26
+ )
27
+
28
+ def forward(self, x):
29
+ avg_pool = F.adaptive_avg_pool2d(x, 1)
30
+ max_pool = F.adaptive_max_pool2d(x, 1)
31
+ avg_out = self.fc_layers(avg_pool)
32
+ max_out = self.fc_layers(max_pool)
33
+
34
+ gate_logits = avg_out + max_out
35
+ weakness_scores = torch.sigmoid(-gate_logits)
36
+ attn_scores = torch.sigmoid(gate_logits)
37
+ gated_weak = x * weakness_scores
38
+ squeezed_2_avg = F.adaptive_avg_pool2d(gated_weak, 1)
39
+ squeezed_2_max = F.adaptive_max_pool2d(gated_weak, 1)
40
+ gate_logits_2 = self.fc_layers(squeezed_2_avg+ squeezed_2_max) # current
41
+ # gate_logits_2 = self.fc_layers(squeezed_2_avg) + self.fc_layers(squeezed_2_max) # naive
42
+ attn_scores_2 = torch.sigmoid(gate_logits_2)
43
+ gated_attn = x * (attn_scores + attn_scores_2) * 0.5
44
+
45
+ # Spatial Attention (CBAM)
46
+ avg_out = torch.mean(gated_attn, dim=1, keepdim=True)
47
+ max_out, _ = torch.max(gated_attn, dim=1, keepdim=True)
48
+ sa_input = torch.cat([avg_out, max_out], dim=1)
49
+ sa_weight = self.spatial_attn(sa_input)
50
+ out = gated_attn * sa_weight
51
+
52
+ return out
53
+
54
+
55
+ ##################################################################################
56
+ # copy from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
57
+ class LayerNorm(nn.Module):
58
+ """ LayerNorm that supports two data formats: channels_last (default) or channels_first.
59
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
60
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
61
+ with shape (batch_size, channels, height, width).
62
+ """
63
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
64
+ super().__init__()
65
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
66
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
67
+ self.eps = eps
68
+ self.data_format = data_format
69
+ if self.data_format not in ["channels_last", "channels_first"]:
70
+ raise NotImplementedError
71
+ self.normalized_shape = (normalized_shape, )
72
+
73
+ def forward(self, x):
74
+ if self.data_format == "channels_last":
75
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
76
+ elif self.data_format == "channels_first":
77
+ u = x.mean(1, keepdim=True)
78
+ s = (x - u).pow(2).mean(1, keepdim=True)
79
+ x = (x - u) / torch.sqrt(s + self.eps)
80
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
81
+ return x
82
+
83
+ class GRN(nn.Module):
84
+ """ GRN (Global Response Normalization) layer
85
+ """
86
+ def __init__(self, dim):
87
+ super().__init__()
88
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
89
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
90
+
91
+ def forward(self, x):
92
+ Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
93
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
94
+ return self.gamma * (x * Nx) + self.beta + x
95
+
96
+ #################################################################################################
97
+
98
+ import torch
99
+ import torch.nn as nn
100
+ from torch.nn import functional as F
101
+
102
+ class ConvNeXtV2BlockWACA_Atrous(nn.Module):
103
+ def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilation=3):
104
+ super().__init__()
105
+
106
+ # Atrous (dilated) depthwise convolution
107
+ # dilation을 적용하면서 같은 receptive field를 유지하기 위해 padding 조정
108
+ padding = dilation * 3 # kernel_size=7이므로 (7-1)//2 * dilation
109
+ self.dwconv = nn.Conv2d(
110
+ in_ch, in_ch,
111
+ kernel_size=7,
112
+ padding=padding,
113
+ groups=in_ch,
114
+ dilation=dilation # atrous convolution 적용
115
+ )
116
+
117
+ self.norm = LayerNorm(in_ch, eps=1e-6)
118
+ self.pwconv1 = nn.Linear(in_ch, 4 * in_ch)
119
+ self.act = nn.GELU()
120
+ self.grn = GRN(4 * in_ch)
121
+ self.pwconv2 = nn.Linear(4 * in_ch, out_ch)
122
+ self.fow = WACA_CBAM(out_ch, reduction=reduction)
123
+
124
+ self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)
125
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
126
+
127
+ def forward(self, x):
128
+ input_x = x
129
+ x = self.dwconv(x)
130
+ x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
131
+ x = self.norm(x)
132
+ x = self.pwconv1(x)
133
+ x = self.act(x)
134
+ x = self.grn(x)
135
+ x = self.pwconv2(x)
136
+ x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
137
+ x = self.fow(x)
138
+ x = self.drop_path(x)
139
+
140
+ out = self.proj(input_x) + x
141
+ return out
142
+
143
+
144
+ # Multi-scale atrous convolution을 사용하는 버전
145
+ class ConvNeXtV2BlockWACA_MultiAtrous(nn.Module):
146
+ def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilations=[1, 2, 4]):
147
+ super().__init__()
148
+
149
+ # 여러 dilation rate를 가진 depthwise convolution들
150
+ self.dwconv_branches = nn.ModuleList([
151
+ nn.Conv2d(
152
+ in_ch, in_ch // len(dilations),
153
+ kernel_size=7,
154
+ padding=d * 3, # kernel_size=7에 대한 padding
155
+ groups=in_ch // len(dilations),
156
+ dilation=d
157
+ ) for d in dilations
158
+ ])
159
+
160
+ # 브랜치들을 합친 후 원래 채널 수로 맞추기
161
+ self.combine_conv = nn.Conv2d(in_ch, in_ch, 1)
162
+
163
+ self.norm = LayerNorm(in_ch, eps=1e-6)
164
+ self.pwconv1 = nn.Linear(in_ch, 4 * in_ch)
165
+ self.act = nn.GELU()
166
+ self.grn = GRN(4 * in_ch)
167
+ self.pwconv2 = nn.Linear(4 * in_ch, out_ch)
168
+ self.fow = WACA_CBAM(out_ch, reduction=reduction)
169
+
170
+ self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)
171
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
172
+
173
+ def forward(self, x):
174
+ input_x = x
175
+
176
+ # Multi-scale atrous convolution
177
+ branch_outputs = []
178
+ for i, dwconv in enumerate(self.dwconv_branches):
179
+ # 각 브랜치에 해당하는 채널 선택
180
+ channels_per_branch = x.size(1) // len(self.dwconv_branches)
181
+ start_idx = i * channels_per_branch
182
+ end_idx = (i + 1) * channels_per_branch if i < len(self.dwconv_branches) - 1 else x.size(1)
183
+ branch_input = x[:, start_idx:end_idx, :, :]
184
+ branch_outputs.append(dwconv(branch_input))
185
+
186
+ # 모든 브랜치 출력을 concatenate
187
+ x = torch.cat(branch_outputs, dim=1)
188
+ x = self.combine_conv(x)
189
+
190
+ x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
191
+ x = self.norm(x)
192
+ x = self.pwconv1(x)
193
+ x = self.act(x)
194
+ x = self.grn(x)
195
+ x = self.pwconv2(x)
196
+ x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
197
+ x = self.fow(x)
198
+ x = self.drop_path(x)
199
+
200
+ out = self.proj(input_x) + x
201
+ return out
202
+
203
+
204
+ # ASPP (Atrous Spatial Pyramid Pooling) 스타일의 버전
205
+ class ConvNeXtV2BlockWACA_ASPP(nn.Module):
206
+ def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilations=[1, 6, 12, 18]):
207
+ super().__init__()
208
+
209
+ # ASPP 스타일의 parallel atrous convolutions
210
+ self.aspp_branches = nn.ModuleList()
211
+
212
+ for dilation in dilations:
213
+ if dilation == 1:
214
+ # 첫 번째 브랜치는 일반 convolution
215
+ branch = nn.Conv2d(in_ch, in_ch // len(dilations), 1)
216
+ else:
217
+ # 나머지는 atrous convolution
218
+ branch = nn.Conv2d(
219
+ in_ch, in_ch // len(dilations),
220
+ kernel_size=3,
221
+ padding=dilation,
222
+ dilation=dilation,
223
+ groups=in_ch // len(dilations)
224
+ )
225
+ self.aspp_branches.append(branch)
226
+
227
+ # Global Average Pooling branch
228
+ self.global_avg_pool = nn.Sequential(
229
+ nn.AdaptiveAvgPool2d((1, 1)),
230
+ nn.Conv2d(in_ch, in_ch // len(dilations), 1),
231
+ )
232
+
233
+ # 모든 브랜치를 합치는 convolution
234
+ total_channels = (len(dilations) + 1) * (in_ch // len(dilations))
235
+ self.combine_conv = nn.Conv2d(total_channels, in_ch, 1)
236
+
237
+ self.norm = LayerNorm(in_ch, eps=1e-6)
238
+ self.pwconv1 = nn.Linear(in_ch, 4 * in_ch)
239
+ self.act = nn.GELU()
240
+ self.grn = GRN(4 * in_ch)
241
+ self.pwconv2 = nn.Linear(4 * in_ch, out_ch)
242
+ self.fow = WACA_CBAM(out_ch, reduction=reduction)
243
+
244
+ self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)
245
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
246
+
247
+ def forward(self, x):
248
+ input_x = x
249
+ h, w = x.size()[2:]
250
+
251
+ # ASPP branches
252
+ branch_outputs = []
253
+ for branch in self.aspp_branches:
254
+ branch_outputs.append(branch(x))
255
+
256
+ # Global average pooling branch
257
+ global_feat = self.global_avg_pool(x)
258
+ global_feat = F.interpolate(global_feat, size=(h, w), mode='bilinear', align_corners=False)
259
+ branch_outputs.append(global_feat)
260
+
261
+ # Concatenate all branches
262
+ x = torch.cat(branch_outputs, dim=1)
263
+ x = self.combine_conv(x)
264
+
265
+ x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
266
+ x = self.norm(x)
267
+ x = self.pwconv1(x)
268
+ x = self.act(x)
269
+ x = self.grn(x)
270
+ x = self.pwconv2(x)
271
+ x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
272
+ x = self.fow(x)
273
+ x = self.drop_path(x)
274
+
275
+ out = self.proj(input_x) + x
276
+ return out
277
+
278
+
279
+
280
+
281
+ #################################################################################################
282
+ class ConvNeXtV2BlockWACA(nn.Module):
283
+ def __init__(self, in_ch, out_ch, reduction=16, drop_path=0.,use_grn=True):
284
+ super().__init__()
285
+ self.dwconv = nn.Conv2d(in_ch, in_ch, kernel_size=7, padding=3, groups=in_ch)
286
+ self.norm = LayerNorm(in_ch, eps=1e-6)
287
+ self.pwconv1 = nn.Linear(in_ch, 4 * in_ch)
288
+ self.act = nn.GELU()
289
+ self.grn = GRN(4 * in_ch) if use_grn else nn.Identity()
290
+ self.pwconv2 = nn.Linear(4 * in_ch, out_ch)
291
+ self.fow = WACA_CBAM(out_ch,reduction=reduction)
292
+ self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)
293
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
294
+
295
+ def forward(self, x):
296
+ input_x = x
297
+ x = self.dwconv(x)
298
+ x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
299
+ x = self.norm(x)
300
+ x = self.pwconv1(x)
301
+ x = self.act(x)
302
+ x = self.grn(x)
303
+ x = self.pwconv2(x)
304
+ x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
305
+ x = self.fow(x)
306
+ x = self.drop_path(x)
307
+ out = self.proj(input_x) + x
308
+ return out
309
+
310
+
311
+ class AttentionGate(nn.Module):
312
+ def __init__(self, in_ch_x, in_ch_g, out_ch):
313
+ super().__init__()
314
+ self.act = nn.ReLU(inplace=True)
315
+ self.w_x_g = nn.Conv2d(in_ch_x + in_ch_g, out_ch, kernel_size=1, stride=1, padding=0, bias=False)
316
+ self.attn = nn.Conv2d(out_ch, out_ch, kernel_size=1, padding=0, bias=False)
317
+ def forward(self, x, g):
318
+ res = x
319
+ xg = torch.cat([x, g], dim=1) # B, (x_c+g_c), H, W
320
+ xg = self.w_x_g(xg)
321
+ xg = self.act(xg)
322
+ attn = torch.sigmoid(self.attn(xg))
323
+ out = res * attn
324
+ return out
325
+
326
+
327
+ class WACA_Unet(nn.Module):
328
+ def __init__(self, in_ch=25, out_ch=1, base_ch=64, reduction=16,
329
+ depth=4, drop_path=0.2, block=ConvNeXtV2BlockWACA, **kwargs):
330
+ super().__init__()
331
+ self.depth = depth
332
+ chs = [base_ch * 2**i for i in range(depth+1)]
333
+ self.drop_path = drop_path
334
+
335
+ n_enc_blocks = depth + 1
336
+ n_dec_blocks = depth
337
+ total_blocks = n_enc_blocks + n_dec_blocks
338
+
339
+ drop_path_rates = torch.linspace(0, drop_path, total_blocks).tolist()
340
+ enc_dp_rates = drop_path_rates[:n_enc_blocks]
341
+ dec_dp_rates = drop_path_rates[n_enc_blocks:]
342
+
343
+ # Encoder
344
+ self.enc_blocks = nn.ModuleList([
345
+ block(in_ch, chs[0], reduction, drop_path=enc_dp_rates[0])
346
+ ] + [
347
+ block(chs[i], chs[i+1], reduction, drop_path=enc_dp_rates[i+1])
348
+ for i in range(depth)
349
+ ])
350
+ self.pool = nn.ModuleList([
351
+ nn.Conv2d(chs[i], chs[i], kernel_size=3, stride=2, padding=1, groups=chs[i])
352
+ for i in range(depth)
353
+ ])
354
+
355
+ # Decoder
356
+ self.upconvs = nn.ModuleList([
357
+ nn.ConvTranspose2d(chs[i+1], chs[i], kernel_size=2, stride=2)
358
+ for i in reversed(range(depth))
359
+ ])
360
+ self.dec_blocks = nn.ModuleList([
361
+ block(chs[i]*2, chs[i], reduction, drop_path=dec_dp_rates[i])
362
+ for i in reversed(range(depth))
363
+ ])
364
+ # Attention Gates
365
+ self.attn_gates = nn.ModuleList([
366
+ AttentionGate(chs[i], chs[i], chs[i])
367
+ for i in reversed(range(depth))
368
+ ])
369
+
370
+ self.final_head = nn.Sequential(
371
+ nn.Conv2d(chs[0], out_ch, kernel_size=1)
372
+ )
373
+
374
+ def forward(self, x):
375
+ enc_feats = []
376
+ for i, enc in enumerate(self.enc_blocks):
377
+ x = enc(x)
378
+ enc_feats.append(x)
379
+ if i < self.depth:
380
+ x = self.pool[i](x)
381
+ # Decoder
382
+ for i in range(self.depth):
383
+ x = self.upconvs[i](x)
384
+ enc_feat = enc_feats[self.depth-1-i]
385
+ # AttentionGate: (encoder feature, decoder upconv output)
386
+ attn_enc_feat = self.attn_gates[i](enc_feat, x)
387
+ x = torch.cat([attn_enc_feat, x], dim=1)
388
+ x = self.dec_blocks[i](x)
389
+
390
+ out = self.final_head(x)
391
+ return {
392
+ 'x_recon': out
393
+ }
394
+
395
+ ###############################################################################
396
+ from torch.nn.utils.rnn import pack_padded_sequence
397
+
398
+ class GRUStem(nn.Module):
399
+ """
400
+ Zero-padded variable channels. 각 채널을 공유 인코더(phi)로 임베딩 후,
401
+ 채널 축을 시간축으로 간주해 BiGRU로 통합.
402
+ """
403
+ def __init__(self, out_channels: int = 64, embed_channels: int = 16, small_input: bool = True):
404
+ super().__init__()
405
+ stride = 1 if small_input else 2
406
+ self.phi = nn.Sequential(
407
+ nn.Conv2d(1, embed_channels, kernel_size=3, stride=stride, padding=1, bias=False),
408
+ nn.BatchNorm2d(embed_channels),
409
+ nn.ReLU(inplace=True),
410
+ )
411
+ hidden = out_channels // 2
412
+ assert hidden > 0, "out_channels must be >=2 to support BiGRU"
413
+ self.gru = nn.GRU(input_size=embed_channels, hidden_size=hidden,
414
+ num_layers=1, bidirectional=True)
415
+ self.bn = nn.BatchNorm2d(out_channels)
416
+ self.act = nn.ReLU(inplace=True)
417
+ self.out_channels = out_channels
418
+ self.small_input = small_input
419
+
420
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
421
+ # x: [B, Cmax, H, W] with zero-padded channels
422
+ B, Cmax, H, W = x.shape
423
+
424
+ # non-zero channel lengths
425
+ with torch.no_grad():
426
+ nonzero_ch = (x.abs().sum(dim=(2, 3)) > 0) # [B, Cmax]
427
+ lengths = nonzero_ch.sum(dim=1).clamp(min=1) # [B]
428
+
429
+ # shared encoder φ for each channel
430
+ feat_per_c = [self.phi(x[:, c:c+1, :, :]) for c in range(Cmax)] # list of [B,E,H',W']
431
+ Fstack = torch.stack(feat_per_c, dim=0) # [Cmax, B, E, H', W']
432
+ Cseq, Bsz, E, Hp, Wp = Fstack.shape
433
+
434
+ # sequence for GRU: [T=Cmax, N=B*Hp*Wp, E]
435
+ Fseq = Fstack.permute(0, 1, 3, 4, 2).contiguous().view(Cseq, Bsz * Hp * Wp, E)
436
+ lens = lengths.repeat_interleave(Hp * Wp).cpu() # [N]
437
+ packed = pack_padded_sequence(Fseq, lens, enforce_sorted=False)
438
+ _, h_n = self.gru(packed) # [2, N, hidden]
439
+ h_cat = torch.cat([h_n[0], h_n[1]], dim=-1) # [N, out_ch]
440
+
441
+ out = h_cat.view(Bsz, Hp, Wp, -1).permute(0, 3, 1, 2).contiguous() # [B,out,H',W']
442
+ out = self.act(self.bn(out))
443
+ return out
444
+
445
+ ##################################################################
446
+
447
+ class _PoolDownMixin:
448
+ def __init__(self, small_input: bool):
449
+ self._stride = 1 if small_input else 2
450
+ def _maybe_down(self, y: torch.Tensor) -> torch.Tensor:
451
+ if self._stride == 2:
452
+ return F.avg_pool2d(y, 2)
453
+ return y
454
+
455
+ class FourierStem2D(nn.Module, _PoolDownMixin):
456
+ def __init__(self, out_dim=64, basis="chebyshev", small_input: bool = True):
457
+ nn.Module.__init__(self); _PoolDownMixin.__init__(self, small_input)
458
+ assert basis in ("fourier", "chebyshev")
459
+ self.out_dim = out_dim; self.basis = basis
460
+ self.proj = nn.Linear(out_dim, out_dim)
461
+ self._basis_cache: Dict[Tuple[int, str, torch.device, torch.dtype], torch.Tensor] = {}
462
+
463
+ def _get_basis(self, C, device, dtype):
464
+ key = (C, self.basis, device, dtype)
465
+ if key in self._basis_cache: return self._basis_cache[key]
466
+ idx = torch.linspace(-1, 1, C, device=device, dtype=dtype).unsqueeze(0)
467
+ if self.basis == "fourier":
468
+ B = torch.stack([torch.cos(idx * i * math.pi) for i in range(1, self.out_dim+1)], dim=-1)
469
+ else:
470
+ B = torch.stack([torch.cos(i * torch.acos(idx)) for i in range(1, self.out_dim+1)], dim=-1)
471
+ self._basis_cache[key] = B; return B
472
+
473
+ def forward(self, x: torch.Tensor):
474
+ B, C, H, W = x.shape; device, dtype = x.device, x.dtype
475
+ x_flat = x.permute(0,2,3,1).reshape(-1, C) # [BHW,C]
476
+ basis = self._get_basis(C, device, dtype)[0] # [C,D]
477
+ emb = (x_flat @ basis) # [BHW,D]
478
+ emb = self.proj(emb).view(B, H, W, self.out_dim).permute(0,3,1,2)
479
+ return self._maybe_down(emb)
480
+
481
+
482
+ class WACA_Unet_stem(nn.Module):
483
+ def __init__(self, in_ch=25, out_ch=1, base_ch=64, reduction=16,
484
+ depth=4, drop_path=0.2, block=ConvNeXtV2BlockWACA, **kwargs):
485
+ super().__init__()
486
+ self.depth = depth
487
+ chs = [base_ch * 2**i for i in range(depth+1)]
488
+ self.drop_path = drop_path
489
+
490
+ n_enc_blocks = depth + 1
491
+ n_dec_blocks = depth
492
+ total_blocks = n_enc_blocks + n_dec_blocks
493
+
494
+ drop_path_rates = torch.linspace(0, drop_path, total_blocks).tolist()
495
+ enc_dp_rates = drop_path_rates[:n_enc_blocks]
496
+ dec_dp_rates = drop_path_rates[n_enc_blocks:]
497
+ # self.stem = GRUStem(out_channels=16,embed_channels=16,small_input=True)
498
+ self.stem = FourierStem2D(chs[0])
499
+ # self.up0 = nn.Upsample(scale_factor=2,mode='bicubic')
500
+ # Encoder
501
+ self.enc_blocks = nn.ModuleList([
502
+ block(chs[0], chs[0], reduction, drop_path=enc_dp_rates[0])
503
+ ] + [
504
+ block(chs[i], chs[i+1], reduction, drop_path=enc_dp_rates[i+1])
505
+ for i in range(depth)
506
+ ])
507
+ self.pool = nn.ModuleList([
508
+ nn.Conv2d(chs[i], chs[i], kernel_size=3, stride=2, padding=1, groups=chs[i])
509
+ for i in range(depth)
510
+ ])
511
+
512
+ # Decoder
513
+ self.upconvs = nn.ModuleList([
514
+ nn.ConvTranspose2d(chs[i+1], chs[i], kernel_size=2, stride=2)
515
+ for i in reversed(range(depth))
516
+ ])
517
+ self.dec_blocks = nn.ModuleList([
518
+ block(chs[i]*2, chs[i], reduction, drop_path=dec_dp_rates[i])
519
+ for i in reversed(range(depth))
520
+ ])
521
+ # Attention Gates
522
+ self.attn_gates = nn.ModuleList([
523
+ AttentionGate(chs[i], chs[i], chs[i])
524
+ for i in reversed(range(depth))
525
+ ])
526
+
527
+ self.final_head = nn.Sequential(
528
+ nn.Conv2d(chs[0], out_ch, kernel_size=1)
529
+ )
530
+
531
+ def forward(self, x):
532
+ enc_feats = []
533
+ x = self.stem(x)
534
+ # x = self.up0(x)
535
+ for i, enc in enumerate(self.enc_blocks):
536
+ x = enc(x)
537
+ enc_feats.append(x)
538
+ if i < self.depth:
539
+ x = self.pool[i](x)
540
+ # Decoder
541
+ for i in range(self.depth):
542
+ x = self.upconvs[i](x)
543
+ enc_feat = enc_feats[self.depth-1-i]
544
+ # AttentionGate: (encoder feature, decoder upconv output)
545
+ attn_enc_feat = self.attn_gates[i](enc_feat, x)
546
+ x = torch.cat([attn_enc_feat, x], dim=1)
547
+ x = self.dec_blocks[i](x)
548
+
549
+ out = self.final_head(x)
550
+ return {
551
+ 'x_recon': out
552
+ }
553
+
554
+
555
+
556
+ if __name__ == '__main__':
557
+ for block in [ConvNeXtV2BlockWACA]:
558
+ print(f"Testing block: {block.__name__}")
559
+ model = WACA_Unet_stem(in_ch=25, out_ch=1,block=block, depth=4)
560
+ dummy = torch.randn(2,25, 384, 384)
561
+ out = model(dummy)['x_recon']
562
+ print(f"Input shape: {dummy.shape}")
563
+ print(f"Output shape: {out.shape}")
564
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
565
+ print(f"Total trainable parameters: {total_params:,}")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.0.0
2
+ numpy>=1.23
3
+ matplotlib>=3.7
4
+ pillow>=10.0
5
+ torch>=2.1.0
6
+ torchvision>=0.16.0
7
+ timm>=0.9.5
8
+ einops>=0.8.0
wacaunet_val_f1_331_0.8757.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c20f01f72d3ab9a4efb22d732bad281fc7de53b6a3506e8c0ed626385a77951
3
+ size 72818550