Resync verified files (1/1)
Browse files- .gitattributes +2 -0
- datasets/1.5_nc/2m_temperature/2m_temperature_2012_1.5deg.nc +3 -0
- datasets/1.5_nc/2m_temperature/2m_temperature_2016_1.5deg.nc +3 -0
- downscaling/engine.py +522 -77
- downscaling/output/WeatherPEFT/checkpoint-19.pth +3 -0
- downscaling/output/WeatherPEFT/checkpoint-29.pth +3 -0
- downscaling/output/WeatherPEFT/checkpoint-9.pth +3 -0
- downscaling/output/WeatherPEFT/log.txt +0 -0
- downscaling/output/backbone_anchor_smoke/log.txt +4 -0
- downscaling/output/val_freq_smoke/log.txt +4 -0
- downscaling/output/val_speed_after/log.txt +4 -0
- downscaling/output/val_speed_baseline/log.txt +4 -0
- downscaling/run_downscaling.py +99 -14
- downscaling/script_run_downscaling.sh +33 -3
- downscaling/script_smoke_feasibility.sh +1 -0
.gitattributes
CHANGED
|
@@ -1733,3 +1733,5 @@ aux_data/era5_tp_6hr_china.nc filter=lfs diff=lfs merge=lfs -text
|
|
| 1733 |
datasets/1.5_nc/2m_temperature/2m_temperature_2015_1.5deg.nc filter=lfs diff=lfs merge=lfs -text
|
| 1734 |
aux_data/climatology/total_precipitation_6hr_seeps_dry_fraction.nc filter=lfs diff=lfs merge=lfs -text
|
| 1735 |
aux_data/climatology/total_precipitation_6hr_seeps_threshold.nc filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 1733 |
datasets/1.5_nc/2m_temperature/2m_temperature_2015_1.5deg.nc filter=lfs diff=lfs merge=lfs -text
|
| 1734 |
aux_data/climatology/total_precipitation_6hr_seeps_dry_fraction.nc filter=lfs diff=lfs merge=lfs -text
|
| 1735 |
aux_data/climatology/total_precipitation_6hr_seeps_threshold.nc filter=lfs diff=lfs merge=lfs -text
|
| 1736 |
+
datasets/1.5_nc/2m_temperature/2m_temperature_2012_1.5deg.nc filter=lfs diff=lfs merge=lfs -text
|
| 1737 |
+
datasets/1.5_nc/2m_temperature/2m_temperature_2016_1.5deg.nc filter=lfs diff=lfs merge=lfs -text
|
datasets/1.5_nc/2m_temperature/2m_temperature_2012_1.5deg.nc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f72d50b681a70895fb3c36b02b65d1c294e4e252f5bea0e4cccc36467c56f9a3
|
| 3 |
+
size 115216361
|
datasets/1.5_nc/2m_temperature/2m_temperature_2016_1.5deg.nc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:92a5a6e02d41079c19dccde429f1071d4e51bd60fa92d1af2670694031551d3c
|
| 3 |
+
size 114847590
|
downscaling/engine.py
CHANGED
|
@@ -41,40 +41,119 @@ class TemporalCorrectionHead(nn.Module):
|
|
| 41 |
hidden_surface: int = 256,
|
| 42 |
hidden_upper: int = 512,
|
| 43 |
dropout: float = 0.0,
|
|
|
|
|
|
|
| 44 |
):
|
| 45 |
super().__init__()
|
| 46 |
upper_ch = num_upper_vars * num_levels
|
| 47 |
-
|
| 48 |
-
self.
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
return corr_surface, corr_upper
|
| 79 |
|
| 80 |
|
|
@@ -147,23 +226,41 @@ def _default_spatial_interp(surface, upper, target_hw):
|
|
| 147 |
return surface_up, upper_up
|
| 148 |
|
| 149 |
|
| 150 |
-
def _default_temporal_interp(surface_endpoints, upper_endpoints, out_steps):
|
| 151 |
-
"""Linear
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
if surface_endpoints.shape[1] < 2 or upper_endpoints.shape[1] < 2:
|
| 153 |
raise ValueError("Temporal interpolation needs at least two endpoint frames.")
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
).view(1, out_steps, 1, 1, 1)
|
| 163 |
-
alpha_u = alpha_s.view(1, out_steps, 1, 1, 1, 1)
|
| 164 |
|
| 165 |
-
surface_interp =
|
| 166 |
-
upper_interp =
|
| 167 |
return surface_interp, upper_interp
|
| 168 |
|
| 169 |
|
|
@@ -214,26 +311,33 @@ def forward_paths(
|
|
| 214 |
target_hw_6h = batch["y_hr6h_surface"].shape[-2:]
|
| 215 |
target_hw_1h = batch["y_hr1h_surface"].shape[-2:]
|
| 216 |
out_steps_1h = batch["y_hr1h_surface"].shape[1]
|
|
|
|
| 217 |
|
| 218 |
# Spatial path at 6h: S(x) = U_s(x) + C_s(x)
|
|
|
|
|
|
|
| 219 |
s_base_surface, s_base_upper = spatial_interp_fn(x_lr6h_surface, x_lr6h_upper, target_hw_6h)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
pred_s_hr6h_surface, pred_s_hr6h_upper = _apply_correction(
|
| 221 |
spatial_correction_fn,
|
| 222 |
x_lr6h_surface,
|
| 223 |
x_lr6h_upper,
|
| 224 |
-
|
| 225 |
-
|
| 226 |
)
|
| 227 |
-
# Optional Aurora-backbone replacement for S-path last endpoint (enables LoRA gradient path).
|
| 228 |
-
if backbone_s_last_surface is not None:
|
| 229 |
-
pred_s_hr6h_surface = pred_s_hr6h_surface.clone()
|
| 230 |
-
pred_s_hr6h_surface[:, -1] = backbone_s_last_surface
|
| 231 |
-
if backbone_s_last_upper is not None:
|
| 232 |
-
pred_s_hr6h_upper = pred_s_hr6h_upper.clone()
|
| 233 |
-
pred_s_hr6h_upper[:, -1] = backbone_s_last_upper
|
| 234 |
|
| 235 |
# Temporal path at LR: T(x) = U_t(x) + C_t(x)
|
| 236 |
-
t_base_surface, t_base_upper = temporal_interp_fn(
|
|
|
|
|
|
|
| 237 |
pred_t_lr1h_surface, pred_t_lr1h_upper = _apply_correction(
|
| 238 |
temporal_correction_fn,
|
| 239 |
x_lr6h_surface,
|
|
@@ -253,7 +357,9 @@ def forward_paths(
|
|
| 253 |
)
|
| 254 |
|
| 255 |
# ST path: x -> S -> T
|
| 256 |
-
st_base_surface, st_base_upper = temporal_interp_fn(
|
|
|
|
|
|
|
| 257 |
pred_st_hr1h_surface, pred_st_hr1h_upper = _apply_correction(
|
| 258 |
temporal_correction_fn,
|
| 259 |
pred_s_hr6h_surface,
|
|
@@ -265,6 +371,8 @@ def forward_paths(
|
|
| 265 |
return {
|
| 266 |
"pred_s_hr6h_surface": pred_s_hr6h_surface,
|
| 267 |
"pred_s_hr6h_upper": pred_s_hr6h_upper,
|
|
|
|
|
|
|
| 268 |
"pred_t_lr1h_surface": pred_t_lr1h_surface,
|
| 269 |
"pred_t_lr1h_upper": pred_t_lr1h_upper,
|
| 270 |
"pred_ts_hr1h_surface": pred_ts_hr1h_surface,
|
|
@@ -317,6 +425,8 @@ def _downscale_style_term_loss(
|
|
| 317 |
surf_vars=None,
|
| 318 |
upper_vars=None,
|
| 319 |
level=None,
|
|
|
|
|
|
|
| 320 |
):
|
| 321 |
"""
|
| 322 |
Match train_one_epoch_downscale style:
|
|
@@ -329,8 +439,33 @@ def _downscale_style_term_loss(
|
|
| 329 |
pred_upper_n = _normalise_upper_bt_vchw(pred_upper, upper_vars, level)
|
| 330 |
target_upper_n = _normalise_upper_bt_vchw(target_upper, upper_vars, level)
|
| 331 |
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
return 0.25 * loss_surface + loss_upper
|
| 335 |
|
| 336 |
|
|
@@ -341,10 +476,13 @@ def loss_fn(
|
|
| 341 |
lambda_t=1.0,
|
| 342 |
lambda_st=1.0,
|
| 343 |
lambda_comm=1.0,
|
|
|
|
| 344 |
surface_var_weights=None,
|
| 345 |
upper_var_weights=None,
|
| 346 |
upper_level_weights=None,
|
| 347 |
interior_only_1h=False,
|
|
|
|
|
|
|
| 348 |
surf_vars=None,
|
| 349 |
upper_vars=None,
|
| 350 |
level=None,
|
|
@@ -356,8 +494,9 @@ def loss_fn(
|
|
| 356 |
return x[:, 1:-1]
|
| 357 |
return x
|
| 358 |
|
| 359 |
-
#
|
| 360 |
-
|
|
|
|
| 361 |
|
| 362 |
# L_S: S(x) supervised by HR-6h
|
| 363 |
l_s = _downscale_style_term_loss(
|
|
@@ -368,10 +507,14 @@ def loss_fn(
|
|
| 368 |
surf_vars=surf_vars,
|
| 369 |
upper_vars=upper_vars,
|
| 370 |
level=level,
|
|
|
|
|
|
|
| 371 |
)
|
| 372 |
|
| 373 |
-
# L_T:
|
| 374 |
-
|
|
|
|
|
|
|
| 375 |
_trim_time(preds["pred_t_lr1h_surface"]),
|
| 376 |
_trim_time(preds["pred_t_lr1h_upper"]),
|
| 377 |
_trim_time(batch["y_lr1h_surface"]),
|
|
@@ -379,7 +522,33 @@ def loss_fn(
|
|
| 379 |
surf_vars=surf_vars,
|
| 380 |
upper_vars=upper_vars,
|
| 381 |
level=level,
|
|
|
|
|
|
|
| 382 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
# L_ST: ST path supervised by HR-1h
|
| 385 |
l_st = _downscale_style_term_loss(
|
|
@@ -390,9 +559,11 @@ def loss_fn(
|
|
| 390 |
surf_vars=surf_vars,
|
| 391 |
upper_vars=upper_vars,
|
| 392 |
level=level,
|
|
|
|
|
|
|
| 393 |
)
|
| 394 |
|
| 395 |
-
# L_comm: TS and ST consistency
|
| 396 |
l_comm = _downscale_style_term_loss(
|
| 397 |
_trim_time(preds["pred_ts_hr1h_surface"]),
|
| 398 |
_trim_time(preds["pred_ts_hr1h_upper"]),
|
|
@@ -401,15 +572,47 @@ def loss_fn(
|
|
| 401 |
surf_vars=surf_vars,
|
| 402 |
upper_vars=upper_vars,
|
| 403 |
level=level,
|
|
|
|
|
|
|
| 404 |
)
|
| 405 |
|
| 406 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
loss_dict = {
|
| 408 |
"loss_total": total,
|
| 409 |
"loss_s": l_s,
|
| 410 |
"loss_t": l_t,
|
| 411 |
"loss_st": l_st,
|
| 412 |
"loss_comm": l_comm,
|
|
|
|
| 413 |
}
|
| 414 |
return total, loss_dict
|
| 415 |
|
|
@@ -687,6 +890,21 @@ def _trim_to_hw(x, h, w):
|
|
| 687 |
return x[..., :h, :w]
|
| 688 |
|
| 689 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
def _resize_surface_bt_vhw(x, out_hw):
|
| 691 |
b, t, v, h, w = x.shape
|
| 692 |
th, tw = out_hw
|
|
@@ -709,12 +927,123 @@ def _resize_upper_bt_vchw(x, out_hw):
|
|
| 709 |
).reshape(b, t, v, c, th, tw)
|
| 710 |
|
| 711 |
|
| 712 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
h_hr, w_hr = hr_hw
|
| 714 |
h_lr, w_lr = lr_hw
|
| 715 |
|
| 716 |
def _to_device(x, h, w):
|
| 717 |
-
|
|
|
|
| 718 |
|
| 719 |
out = {}
|
| 720 |
out["x_lr6h_surface"] = _to_device(batch["x_lr6h_surface"], h_lr, w_lr)
|
|
@@ -741,8 +1070,30 @@ def _prepare_factorized_batch(batch, hr_hw, lr_hw, device):
|
|
| 741 |
out["y_lr1h_surface"] = _to_device(batch["y_lr1h_surface"], h_lr, w_lr)
|
| 742 |
out["y_lr1h_upper"] = _to_device(batch["y_lr1h_upper"], h_lr, w_lr)
|
| 743 |
else:
|
| 744 |
-
|
| 745 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 746 |
|
| 747 |
out["time_unix"] = batch["time_unix"]
|
| 748 |
return out
|
|
@@ -851,6 +1202,8 @@ def _build_temporal_correction_fn(
|
|
| 851 |
upper_vars=None,
|
| 852 |
level=None,
|
| 853 |
correction_scale=1.0,
|
|
|
|
|
|
|
| 854 |
):
|
| 855 |
if temporal_corrector is None:
|
| 856 |
def _temporal_correction_zeros(input_surface, input_upper, base_surface, base_upper):
|
|
@@ -859,12 +1212,35 @@ def _build_temporal_correction_fn(
|
|
| 859 |
return _temporal_correction_zeros
|
| 860 |
|
| 861 |
def _temporal_correction(input_surface, input_upper, base_surface, base_upper):
|
|
|
|
|
|
|
| 862 |
base_surface_n = _normalise_surface_bt_vhw(base_surface, surf_vars)
|
| 863 |
base_upper_n = _normalise_upper_bt_vchw(base_upper, upper_vars, level)
|
| 864 |
-
corr_surface_n, corr_upper_n = temporal_corrector(
|
|
|
|
|
|
|
| 865 |
|
| 866 |
corr_surface_n = torch.nan_to_num(corr_surface_n, nan=0.0, posinf=0.0, neginf=0.0)
|
| 867 |
corr_upper_n = torch.nan_to_num(corr_upper_n, nan=0.0, posinf=0.0, neginf=0.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
|
| 869 |
s_scale = _surface_scale_tensor(surf_vars, base_surface)
|
| 870 |
u_scale = _upper_scale_tensor(upper_vars, level, base_upper)
|
|
@@ -902,14 +1278,18 @@ def train_one_epoch_factorized(
|
|
| 902 |
lambda_t=1.0,
|
| 903 |
lambda_st=1.0,
|
| 904 |
lambda_comm=1.0,
|
|
|
|
| 905 |
surface_var_weights=None,
|
| 906 |
upper_var_weights=None,
|
| 907 |
upper_level_weights=None,
|
| 908 |
interior_only_1h=False,
|
|
|
|
|
|
|
| 909 |
spatial_corrector=None,
|
| 910 |
temporal_corrector=None,
|
| 911 |
spatial_correction_scale=1.0,
|
| 912 |
temporal_correction_scale=1.0,
|
|
|
|
| 913 |
enable_backbone_lora=False,
|
| 914 |
accum_backward_mode="micro",
|
| 915 |
use_ours=False,
|
|
@@ -937,6 +1317,9 @@ def train_one_epoch_factorized(
|
|
| 937 |
backward_calls = 0
|
| 938 |
pending_loss = None
|
| 939 |
|
|
|
|
|
|
|
|
|
|
| 940 |
for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 941 |
step = data_iter_step // update_freq
|
| 942 |
if step >= num_training_steps_per_epoch:
|
|
@@ -956,6 +1339,10 @@ def train_one_epoch_factorized(
|
|
| 956 |
hr_hw=hr_hw,
|
| 957 |
lr_hw=lr_hw,
|
| 958 |
device=device,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
)
|
| 960 |
|
| 961 |
spatial_correction_fn = _build_spatial_correction_fn(
|
|
@@ -971,6 +1358,8 @@ def train_one_epoch_factorized(
|
|
| 971 |
upper_vars=upper_vars,
|
| 972 |
level=level,
|
| 973 |
correction_scale=temporal_correction_scale,
|
|
|
|
|
|
|
| 974 |
)
|
| 975 |
backbone_last_surface = None
|
| 976 |
backbone_last_upper = None
|
|
@@ -993,6 +1382,20 @@ def train_one_epoch_factorized(
|
|
| 993 |
backbone_s_last_surface=backbone_last_surface,
|
| 994 |
backbone_s_last_upper=backbone_last_upper,
|
| 995 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 996 |
loss, loss_dict = loss_fn(
|
| 997 |
preds=preds,
|
| 998 |
batch=batch_f,
|
|
@@ -1000,18 +1403,29 @@ def train_one_epoch_factorized(
|
|
| 1000 |
lambda_t=lambda_t,
|
| 1001 |
lambda_st=lambda_st,
|
| 1002 |
lambda_comm=lambda_comm,
|
|
|
|
| 1003 |
surface_var_weights=surface_var_weights,
|
| 1004 |
upper_var_weights=upper_var_weights,
|
| 1005 |
upper_level_weights=upper_level_weights,
|
| 1006 |
interior_only_1h=interior_only_1h,
|
|
|
|
|
|
|
| 1007 |
surf_vars=surf_vars,
|
| 1008 |
upper_vars=upper_vars,
|
| 1009 |
level=level,
|
| 1010 |
)
|
| 1011 |
-
|
| 1012 |
loss_value = loss.item()
|
| 1013 |
if not np.isfinite(loss_value):
|
| 1014 |
print(f"[factorized] non-finite loss at iter {data_iter_step}; skip update.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1015 |
optimizer.zero_grad()
|
| 1016 |
pending_loss = None
|
| 1017 |
continue
|
|
@@ -1064,6 +1478,7 @@ def train_one_epoch_factorized(
|
|
| 1064 |
metric_logger.update(loss_t=loss_dict["loss_t"].item())
|
| 1065 |
metric_logger.update(loss_st=loss_dict["loss_st"].item())
|
| 1066 |
metric_logger.update(loss_comm=loss_dict["loss_comm"].item())
|
|
|
|
| 1067 |
metric_logger.update(loss_scale=loss_scale_value)
|
| 1068 |
|
| 1069 |
min_lr = 10.0
|
|
@@ -1090,6 +1505,7 @@ def train_one_epoch_factorized(
|
|
| 1090 |
log_writer.update(loss_t=loss_dict["loss_t"].item(), head="loss")
|
| 1091 |
log_writer.update(loss_st=loss_dict["loss_st"].item(), head="loss")
|
| 1092 |
log_writer.update(loss_comm=loss_dict["loss_comm"].item(), head="loss")
|
|
|
|
| 1093 |
log_writer.update(loss_scale=loss_scale_value, head="opt")
|
| 1094 |
log_writer.update(lr=max_lr, head="opt")
|
| 1095 |
log_writer.update(min_lr=min_lr, head="opt")
|
|
@@ -1126,14 +1542,18 @@ def validation_one_epoch_factorized(
|
|
| 1126 |
lambda_t=1.0,
|
| 1127 |
lambda_st=1.0,
|
| 1128 |
lambda_comm=1.0,
|
|
|
|
| 1129 |
surface_var_weights=None,
|
| 1130 |
upper_var_weights=None,
|
| 1131 |
upper_level_weights=None,
|
| 1132 |
interior_only_1h=False,
|
|
|
|
|
|
|
| 1133 |
spatial_corrector=None,
|
| 1134 |
temporal_corrector=None,
|
| 1135 |
spatial_correction_scale=1.0,
|
| 1136 |
temporal_correction_scale=1.0,
|
|
|
|
| 1137 |
enable_backbone_lora=False,
|
| 1138 |
print_freq=20,
|
| 1139 |
val_max_steps=-1,
|
|
@@ -1142,6 +1562,7 @@ def validation_one_epoch_factorized(
|
|
| 1142 |
metric_logger.add_meter("prep_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
|
| 1143 |
metric_logger.add_meter("model_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
|
| 1144 |
metric_logger.add_meter("post_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
|
|
|
|
| 1145 |
header = "Val:"
|
| 1146 |
|
| 1147 |
model.eval()
|
|
@@ -1159,14 +1580,6 @@ def validation_one_epoch_factorized(
|
|
| 1159 |
level=level,
|
| 1160 |
correction_scale=spatial_correction_scale,
|
| 1161 |
)
|
| 1162 |
-
temporal_correction_fn = _build_temporal_correction_fn(
|
| 1163 |
-
temporal_corrector=temporal_corrector,
|
| 1164 |
-
surf_vars=surf_vars,
|
| 1165 |
-
upper_vars=upper_vars,
|
| 1166 |
-
level=level,
|
| 1167 |
-
correction_scale=temporal_correction_scale,
|
| 1168 |
-
)
|
| 1169 |
-
|
| 1170 |
warned_short_seq = False
|
| 1171 |
for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 1172 |
if val_max_steps is not None and val_max_steps > 0 and data_iter_step >= val_max_steps:
|
|
@@ -1177,12 +1590,25 @@ def validation_one_epoch_factorized(
|
|
| 1177 |
hr_hw=hr_hw,
|
| 1178 |
lr_hw=lr_hw,
|
| 1179 |
device=device,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1180 |
)
|
| 1181 |
if torch.cuda.is_available():
|
| 1182 |
torch.cuda.synchronize()
|
| 1183 |
t1 = time.perf_counter()
|
| 1184 |
|
| 1185 |
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1186 |
backbone_last_surface = None
|
| 1187 |
backbone_last_upper = None
|
| 1188 |
if enable_backbone_lora:
|
|
@@ -1203,6 +1629,20 @@ def validation_one_epoch_factorized(
|
|
| 1203 |
backbone_s_last_surface=backbone_last_surface,
|
| 1204 |
backbone_s_last_upper=backbone_last_upper,
|
| 1205 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1206 |
loss, loss_dict = loss_fn(
|
| 1207 |
preds=preds,
|
| 1208 |
batch=batch_f,
|
|
@@ -1210,10 +1650,13 @@ def validation_one_epoch_factorized(
|
|
| 1210 |
lambda_t=lambda_t,
|
| 1211 |
lambda_st=lambda_st,
|
| 1212 |
lambda_comm=lambda_comm,
|
|
|
|
| 1213 |
surface_var_weights=surface_var_weights,
|
| 1214 |
upper_var_weights=upper_var_weights,
|
| 1215 |
upper_level_weights=upper_level_weights,
|
| 1216 |
interior_only_1h=interior_only_1h,
|
|
|
|
|
|
|
| 1217 |
surf_vars=surf_vars,
|
| 1218 |
upper_vars=upper_vars,
|
| 1219 |
level=level,
|
|
@@ -1227,7 +1670,8 @@ def validation_one_epoch_factorized(
|
|
| 1227 |
target_seq_surface = batch_f["y_hr1h_surface"]
|
| 1228 |
target_seq_upper = batch_f["y_hr1h_upper"]
|
| 1229 |
|
| 1230 |
-
# Default setting is include_endpoints=
|
|
|
|
| 1231 |
# We expose per-delta metrics for t+1..t+6 when available.
|
| 1232 |
seq_len = pred_seq_surface.shape[1]
|
| 1233 |
if seq_len >= 7:
|
|
@@ -1269,6 +1713,7 @@ def validation_one_epoch_factorized(
|
|
| 1269 |
metric_logger.update(loss_t=loss_dict["loss_t"].item())
|
| 1270 |
metric_logger.update(loss_st=loss_dict["loss_st"].item())
|
| 1271 |
metric_logger.update(loss_comm=loss_dict["loss_comm"].item())
|
|
|
|
| 1272 |
metric_logger.update(prep_s=t1 - t0)
|
| 1273 |
metric_logger.update(model_s=t2 - t1)
|
| 1274 |
metric_logger.update(post_s=t3 - t2)
|
|
|
|
| 41 |
hidden_surface: int = 256,
|
| 42 |
hidden_upper: int = 512,
|
| 43 |
dropout: float = 0.0,
|
| 44 |
+
head_mode: str = "shared3d",
|
| 45 |
+
num_time_steps: int = 7,
|
| 46 |
):
|
| 47 |
super().__init__()
|
| 48 |
upper_ch = num_upper_vars * num_levels
|
| 49 |
+
self.head_mode = head_mode
|
| 50 |
+
self.num_time_steps = num_time_steps
|
| 51 |
+
|
| 52 |
+
if self.head_mode == "shared3d":
|
| 53 |
+
# Endpoint-conditioned temporal correction:
|
| 54 |
+
# input channels = [endpoint_{t-6}, base_{t+h}, endpoint_t]
|
| 55 |
+
# for each variable channel group.
|
| 56 |
+
self.surface_net = nn.Sequential(
|
| 57 |
+
nn.Conv3d(num_surface_vars * 3, hidden_surface, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
|
| 58 |
+
nn.GELU(),
|
| 59 |
+
nn.Dropout3d(dropout) if dropout > 0 else nn.Identity(),
|
| 60 |
+
nn.Conv3d(hidden_surface, num_surface_vars, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
|
| 61 |
+
)
|
| 62 |
+
self.upper_net = nn.Sequential(
|
| 63 |
+
nn.Conv3d(upper_ch * 3, hidden_upper, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
|
| 64 |
+
nn.GELU(),
|
| 65 |
+
nn.Dropout3d(dropout) if dropout > 0 else nn.Identity(),
|
| 66 |
+
nn.Conv3d(hidden_upper, upper_ch, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
|
| 67 |
+
)
|
| 68 |
+
# Start from interpolation baseline: initial correction ~= 0.
|
| 69 |
+
nn.init.zeros_(self.surface_net[-1].weight)
|
| 70 |
+
nn.init.zeros_(self.surface_net[-1].bias)
|
| 71 |
+
nn.init.zeros_(self.upper_net[-1].weight)
|
| 72 |
+
nn.init.zeros_(self.upper_net[-1].bias)
|
| 73 |
+
elif self.head_mode == "delta2d":
|
| 74 |
+
# Per-time-step heads with endpoint-conditioning.
|
| 75 |
+
# Each head sees [endpoint_{t-6}, base_{t+h}, endpoint_t] context.
|
| 76 |
+
self.surface_heads = nn.ModuleList([
|
| 77 |
+
nn.Sequential(
|
| 78 |
+
nn.Conv2d(num_surface_vars * 3, hidden_surface, kernel_size=3, padding=1),
|
| 79 |
+
nn.GELU(),
|
| 80 |
+
nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(),
|
| 81 |
+
nn.Conv2d(hidden_surface, num_surface_vars, kernel_size=3, padding=1),
|
| 82 |
+
)
|
| 83 |
+
for _ in range(num_time_steps)
|
| 84 |
+
])
|
| 85 |
+
self.upper_heads = nn.ModuleList([
|
| 86 |
+
nn.Sequential(
|
| 87 |
+
nn.Conv2d(upper_ch * 3, hidden_upper, kernel_size=3, padding=1),
|
| 88 |
+
nn.GELU(),
|
| 89 |
+
nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(),
|
| 90 |
+
nn.Conv2d(hidden_upper, upper_ch, kernel_size=3, padding=1),
|
| 91 |
+
)
|
| 92 |
+
for _ in range(num_time_steps)
|
| 93 |
+
])
|
| 94 |
+
for head in self.surface_heads:
|
| 95 |
+
nn.init.zeros_(head[-1].weight)
|
| 96 |
+
nn.init.zeros_(head[-1].bias)
|
| 97 |
+
for head in self.upper_heads:
|
| 98 |
+
nn.init.zeros_(head[-1].weight)
|
| 99 |
+
nn.init.zeros_(head[-1].bias)
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError(f"Unsupported temporal head_mode: {self.head_mode}")
|
| 102 |
+
|
| 103 |
+
def forward(self, cond_surface_endpoints, cond_upper_endpoints, base_surface, base_upper):
|
| 104 |
+
# cond_surface_endpoints: [B, 2, V, H, W] (t-6, t) on the current grid
|
| 105 |
+
# cond_upper_endpoints: [B, 2, V, C, H, W]
|
| 106 |
+
# base_surface: [B, T, V, H, W] (U_t baseline for target horizons)
|
| 107 |
+
# base_upper: [B, T, V, C, H, W]
|
| 108 |
+
if cond_surface_endpoints is None or cond_upper_endpoints is None:
|
| 109 |
+
raise ValueError("TemporalCorrectionHead requires endpoint-conditioned inputs.")
|
| 110 |
+
if cond_surface_endpoints.shape[1] < 2 or cond_upper_endpoints.shape[1] < 2:
|
| 111 |
+
raise ValueError("TemporalCorrectionHead expects two endpoint frames in cond inputs.")
|
| 112 |
+
|
| 113 |
+
if self.head_mode == "shared3d":
|
| 114 |
+
b, t, v, h, w = base_surface.shape
|
| 115 |
+
s0 = cond_surface_endpoints[:, 0].unsqueeze(1).expand(-1, t, -1, -1, -1)
|
| 116 |
+
s1 = cond_surface_endpoints[:, -1].unsqueeze(1).expand(-1, t, -1, -1, -1)
|
| 117 |
+
x_surface = torch.cat((s0, base_surface, s1), dim=2)
|
| 118 |
+
x_surface = rearrange(x_surface, "b t v h w -> b v t h w")
|
| 119 |
+
corr_surface = self.surface_net(x_surface)
|
| 120 |
+
corr_surface = rearrange(corr_surface, "b v t h w -> b t v h w")
|
| 121 |
+
|
| 122 |
+
_, _, vu, c, hu, wu = base_upper.shape
|
| 123 |
+
u0 = cond_upper_endpoints[:, 0].unsqueeze(1).expand(-1, t, -1, -1, -1, -1)
|
| 124 |
+
u1 = cond_upper_endpoints[:, -1].unsqueeze(1).expand(-1, t, -1, -1, -1, -1)
|
| 125 |
+
x_upper = torch.cat((u0, base_upper, u1), dim=2)
|
| 126 |
+
x_upper = rearrange(x_upper, "b t v c h w -> b (v c) t h w")
|
| 127 |
+
corr_upper = self.upper_net(x_upper)
|
| 128 |
+
corr_upper = rearrange(
|
| 129 |
+
corr_upper, "b (v c) t h w -> b t v c h w", v=base_upper.shape[2], c=base_upper.shape[3]
|
| 130 |
+
)
|
| 131 |
+
return corr_surface, corr_upper
|
| 132 |
+
|
| 133 |
+
# delta2d mode
|
| 134 |
+
b, t, v, h, w = base_surface.shape
|
| 135 |
+
_, _, vu, c, hu, wu = base_upper.shape
|
| 136 |
+
s0 = cond_surface_endpoints[:, 0]
|
| 137 |
+
s1 = cond_surface_endpoints[:, -1]
|
| 138 |
+
u0 = cond_upper_endpoints[:, 0].reshape(b, vu * c, hu, wu)
|
| 139 |
+
u1 = cond_upper_endpoints[:, -1].reshape(b, vu * c, hu, wu)
|
| 140 |
+
|
| 141 |
+
corr_surface_steps = []
|
| 142 |
+
corr_upper_steps = []
|
| 143 |
+
for ti in range(t):
|
| 144 |
+
hi = min(ti, self.num_time_steps - 1)
|
| 145 |
+
s_cur = base_surface[:, ti]
|
| 146 |
+
s_in = torch.cat((s0, s_cur, s1), dim=1)
|
| 147 |
+
s_corr = self.surface_heads[hi](s_in)
|
| 148 |
+
corr_surface_steps.append(s_corr)
|
| 149 |
+
|
| 150 |
+
u_cur = base_upper[:, ti].reshape(b, vu * c, hu, wu)
|
| 151 |
+
u_in = torch.cat((u0, u_cur, u1), dim=1)
|
| 152 |
+
u_corr = self.upper_heads[hi](u_in).reshape(b, vu, c, hu, wu)
|
| 153 |
+
corr_upper_steps.append(u_corr)
|
| 154 |
+
|
| 155 |
+
corr_surface = torch.stack(corr_surface_steps, dim=1)
|
| 156 |
+
corr_upper = torch.stack(corr_upper_steps, dim=1)
|
| 157 |
return corr_surface, corr_upper
|
| 158 |
|
| 159 |
|
|
|
|
| 226 |
return surface_up, upper_up
|
| 227 |
|
| 228 |
|
| 229 |
+
def _default_temporal_interp(surface_endpoints, upper_endpoints, out_steps, offsets_hours=None):
|
| 230 |
+
"""Linear temporal baseline from [t-6, t] to target horizons.
|
| 231 |
+
|
| 232 |
+
By default, this generates a prediction-style baseline for [t+1, ..., t+out_steps]
|
| 233 |
+
using linear extrapolation from the two 6h endpoints.
|
| 234 |
+
"""
|
| 235 |
if surface_endpoints.shape[1] < 2 or upper_endpoints.shape[1] < 2:
|
| 236 |
raise ValueError("Temporal interpolation needs at least two endpoint frames.")
|
| 237 |
|
| 238 |
+
if offsets_hours is None:
|
| 239 |
+
offsets = torch.arange(
|
| 240 |
+
1, out_steps + 1, device=surface_endpoints.device, dtype=surface_endpoints.dtype
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
offsets = offsets_hours
|
| 244 |
+
if not torch.is_tensor(offsets):
|
| 245 |
+
offsets = torch.as_tensor(offsets)
|
| 246 |
+
if offsets.dim() == 2:
|
| 247 |
+
offsets = offsets[0]
|
| 248 |
+
offsets = offsets.to(device=surface_endpoints.device, dtype=surface_endpoints.dtype).reshape(-1)
|
| 249 |
+
if offsets.numel() != out_steps:
|
| 250 |
+
raise ValueError(
|
| 251 |
+
f"offset length ({offsets.numel()}) must equal out_steps ({out_steps})."
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
x_prev_s = surface_endpoints[:, 0].unsqueeze(1) # at t-6
|
| 255 |
+
x_curr_s = surface_endpoints[:, -1].unsqueeze(1) # at t
|
| 256 |
+
x_prev_u = upper_endpoints[:, 0].unsqueeze(1)
|
| 257 |
+
x_curr_u = upper_endpoints[:, -1].unsqueeze(1)
|
| 258 |
|
| 259 |
+
beta_s = (offsets / 6.0).view(1, out_steps, 1, 1, 1)
|
| 260 |
+
beta_u = beta_s.view(1, out_steps, 1, 1, 1, 1)
|
|
|
|
|
|
|
| 261 |
|
| 262 |
+
surface_interp = x_curr_s + beta_s * (x_curr_s - x_prev_s)
|
| 263 |
+
upper_interp = x_curr_u + beta_u * (x_curr_u - x_prev_u)
|
| 264 |
return surface_interp, upper_interp
|
| 265 |
|
| 266 |
|
|
|
|
| 311 |
target_hw_6h = batch["y_hr6h_surface"].shape[-2:]
|
| 312 |
target_hw_1h = batch["y_hr1h_surface"].shape[-2:]
|
| 313 |
out_steps_1h = batch["y_hr1h_surface"].shape[1]
|
| 314 |
+
target_offsets_hours = batch.get("target_offsets_hours", None)
|
| 315 |
|
| 316 |
# Spatial path at 6h: S(x) = U_s(x) + C_s(x)
|
| 317 |
+
# Use pretrained Aurora 6h prior as base signal at current endpoint (centered prior),
|
| 318 |
+
# not a post-hoc overwrite.
|
| 319 |
s_base_surface, s_base_upper = spatial_interp_fn(x_lr6h_surface, x_lr6h_upper, target_hw_6h)
|
| 320 |
+
s_base_prior_surface = s_base_surface
|
| 321 |
+
s_base_prior_upper = s_base_upper
|
| 322 |
+
if backbone_s_last_surface is not None:
|
| 323 |
+
s_base_prior_surface = s_base_prior_surface.clone()
|
| 324 |
+
s_base_prior_surface[:, -1] = backbone_s_last_surface
|
| 325 |
+
if backbone_s_last_upper is not None:
|
| 326 |
+
s_base_prior_upper = s_base_prior_upper.clone()
|
| 327 |
+
s_base_prior_upper[:, -1] = backbone_s_last_upper
|
| 328 |
+
|
| 329 |
pred_s_hr6h_surface, pred_s_hr6h_upper = _apply_correction(
|
| 330 |
spatial_correction_fn,
|
| 331 |
x_lr6h_surface,
|
| 332 |
x_lr6h_upper,
|
| 333 |
+
s_base_prior_surface,
|
| 334 |
+
s_base_prior_upper,
|
| 335 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
# Temporal path at LR: T(x) = U_t(x) + C_t(x)
|
| 338 |
+
t_base_surface, t_base_upper = temporal_interp_fn(
|
| 339 |
+
x_lr6h_surface, x_lr6h_upper, out_steps_1h, offsets_hours=target_offsets_hours
|
| 340 |
+
)
|
| 341 |
pred_t_lr1h_surface, pred_t_lr1h_upper = _apply_correction(
|
| 342 |
temporal_correction_fn,
|
| 343 |
x_lr6h_surface,
|
|
|
|
| 357 |
)
|
| 358 |
|
| 359 |
# ST path: x -> S -> T
|
| 360 |
+
st_base_surface, st_base_upper = temporal_interp_fn(
|
| 361 |
+
pred_s_hr6h_surface, pred_s_hr6h_upper, out_steps_1h, offsets_hours=target_offsets_hours
|
| 362 |
+
)
|
| 363 |
pred_st_hr1h_surface, pred_st_hr1h_upper = _apply_correction(
|
| 364 |
temporal_correction_fn,
|
| 365 |
pred_s_hr6h_surface,
|
|
|
|
| 371 |
return {
|
| 372 |
"pred_s_hr6h_surface": pred_s_hr6h_surface,
|
| 373 |
"pred_s_hr6h_upper": pred_s_hr6h_upper,
|
| 374 |
+
"s_base_prior_surface": s_base_prior_surface,
|
| 375 |
+
"s_base_prior_upper": s_base_prior_upper,
|
| 376 |
"pred_t_lr1h_surface": pred_t_lr1h_surface,
|
| 377 |
"pred_t_lr1h_upper": pred_t_lr1h_upper,
|
| 378 |
"pred_ts_hr1h_surface": pred_ts_hr1h_surface,
|
|
|
|
| 425 |
surf_vars=None,
|
| 426 |
upper_vars=None,
|
| 427 |
level=None,
|
| 428 |
+
surface_var_weights=None,
|
| 429 |
+
upper_var_weights=None,
|
| 430 |
):
|
| 431 |
"""
|
| 432 |
Match train_one_epoch_downscale style:
|
|
|
|
| 439 |
pred_upper_n = _normalise_upper_bt_vchw(pred_upper, upper_vars, level)
|
| 440 |
target_upper_n = _normalise_upper_bt_vchw(target_upper, upper_vars, level)
|
| 441 |
|
| 442 |
+
if surface_var_weights is None:
|
| 443 |
+
loss_surface = F.mse_loss(pred_surface_n, target_surface_n)
|
| 444 |
+
else:
|
| 445 |
+
ws = torch.as_tensor(
|
| 446 |
+
surface_var_weights,
|
| 447 |
+
device=pred_surface_n.device,
|
| 448 |
+
dtype=pred_surface_n.dtype,
|
| 449 |
+
).view(1, 1, -1, 1, 1)
|
| 450 |
+
se_s = (pred_surface_n - target_surface_n) ** 2
|
| 451 |
+
loss_surface = (se_s * ws).sum() / (ws.sum() * se_s.shape[0] * se_s.shape[1] * se_s.shape[3] * se_s.shape[4])
|
| 452 |
+
|
| 453 |
+
if upper_var_weights is None:
|
| 454 |
+
loss_upper = F.mse_loss(pred_upper_n, target_upper_n)
|
| 455 |
+
else:
|
| 456 |
+
wu = torch.as_tensor(
|
| 457 |
+
upper_var_weights,
|
| 458 |
+
device=pred_upper_n.device,
|
| 459 |
+
dtype=pred_upper_n.dtype,
|
| 460 |
+
)
|
| 461 |
+
if wu.dim() == 1:
|
| 462 |
+
wu = wu.view(1, 1, -1, 1, 1, 1)
|
| 463 |
+
else:
|
| 464 |
+
wu = wu.view(1, 1, wu.shape[0], wu.shape[1], 1, 1)
|
| 465 |
+
se_u = (pred_upper_n - target_upper_n) ** 2
|
| 466 |
+
loss_upper = (se_u * wu).sum() / (
|
| 467 |
+
wu.sum() * se_u.shape[0] * se_u.shape[1] * se_u.shape[4] * se_u.shape[5]
|
| 468 |
+
)
|
| 469 |
return 0.25 * loss_surface + loss_upper
|
| 470 |
|
| 471 |
|
|
|
|
| 476 |
lambda_t=1.0,
|
| 477 |
lambda_st=1.0,
|
| 478 |
lambda_comm=1.0,
|
| 479 |
+
lambda_backbone=0.0,
|
| 480 |
surface_var_weights=None,
|
| 481 |
upper_var_weights=None,
|
| 482 |
upper_level_weights=None,
|
| 483 |
interior_only_1h=False,
|
| 484 |
+
lt_supervision_source="lr_path",
|
| 485 |
+
lt_hybrid_alpha=0.5,
|
| 486 |
surf_vars=None,
|
| 487 |
upper_vars=None,
|
| 488 |
level=None,
|
|
|
|
| 494 |
return x[:, 1:-1]
|
| 495 |
return x
|
| 496 |
|
| 497 |
+
# Backward compatibility: if explicit upper-level weights are provided, override upper_var_weights.
|
| 498 |
+
if upper_level_weights is not None:
|
| 499 |
+
upper_var_weights = upper_level_weights
|
| 500 |
|
| 501 |
# L_S: S(x) supervised by HR-6h
|
| 502 |
l_s = _downscale_style_term_loss(
|
|
|
|
| 507 |
surf_vars=surf_vars,
|
| 508 |
upper_vars=upper_vars,
|
| 509 |
level=level,
|
| 510 |
+
surface_var_weights=surface_var_weights,
|
| 511 |
+
upper_var_weights=upper_var_weights,
|
| 512 |
)
|
| 513 |
|
| 514 |
+
# L_T: temporal module supervision on LR 1h targets.
|
| 515 |
+
# "lr_path" is the strict factorized definition.
|
| 516 |
+
# "st_downscaled"/"hybrid" are optional variants.
|
| 517 |
+
l_t_lr = _downscale_style_term_loss(
|
| 518 |
_trim_time(preds["pred_t_lr1h_surface"]),
|
| 519 |
_trim_time(preds["pred_t_lr1h_upper"]),
|
| 520 |
_trim_time(batch["y_lr1h_surface"]),
|
|
|
|
| 522 |
surf_vars=surf_vars,
|
| 523 |
upper_vars=upper_vars,
|
| 524 |
level=level,
|
| 525 |
+
surface_var_weights=surface_var_weights,
|
| 526 |
+
upper_var_weights=upper_var_weights,
|
| 527 |
)
|
| 528 |
+
has_st_lr = ("pred_st_lr1h_surface" in preds) and ("pred_st_lr1h_upper" in preds)
|
| 529 |
+
if has_st_lr:
|
| 530 |
+
l_t_st = _downscale_style_term_loss(
|
| 531 |
+
_trim_time(preds["pred_st_lr1h_surface"]),
|
| 532 |
+
_trim_time(preds["pred_st_lr1h_upper"]),
|
| 533 |
+
_trim_time(batch["y_lr1h_surface"]),
|
| 534 |
+
_trim_time(batch["y_lr1h_upper"]),
|
| 535 |
+
surf_vars=surf_vars,
|
| 536 |
+
upper_vars=upper_vars,
|
| 537 |
+
level=level,
|
| 538 |
+
surface_var_weights=surface_var_weights,
|
| 539 |
+
upper_var_weights=upper_var_weights,
|
| 540 |
+
)
|
| 541 |
+
else:
|
| 542 |
+
l_t_st = l_t_lr
|
| 543 |
+
|
| 544 |
+
if lt_supervision_source == "st_downscaled":
|
| 545 |
+
l_t = l_t_st
|
| 546 |
+
elif lt_supervision_source == "hybrid":
|
| 547 |
+
a = float(lt_hybrid_alpha)
|
| 548 |
+
a = max(0.0, min(1.0, a))
|
| 549 |
+
l_t = a * l_t_lr + (1.0 - a) * l_t_st
|
| 550 |
+
else:
|
| 551 |
+
l_t = l_t_lr
|
| 552 |
|
| 553 |
# L_ST: ST path supervised by HR-1h
|
| 554 |
l_st = _downscale_style_term_loss(
|
|
|
|
| 559 |
surf_vars=surf_vars,
|
| 560 |
upper_vars=upper_vars,
|
| 561 |
level=level,
|
| 562 |
+
surface_var_weights=surface_var_weights,
|
| 563 |
+
upper_var_weights=upper_var_weights,
|
| 564 |
)
|
| 565 |
|
| 566 |
+
# L_comm: TS and ST consistency (commutativity regularizer)
|
| 567 |
l_comm = _downscale_style_term_loss(
|
| 568 |
_trim_time(preds["pred_ts_hr1h_surface"]),
|
| 569 |
_trim_time(preds["pred_ts_hr1h_upper"]),
|
|
|
|
| 572 |
surf_vars=surf_vars,
|
| 573 |
upper_vars=upper_vars,
|
| 574 |
level=level,
|
| 575 |
+
surface_var_weights=surface_var_weights,
|
| 576 |
+
upper_var_weights=upper_var_weights,
|
| 577 |
)
|
| 578 |
|
| 579 |
+
l_backbone = l_comm.new_zeros(())
|
| 580 |
+
if (
|
| 581 |
+
lambda_backbone > 0
|
| 582 |
+
and ("backbone_last_surface" in preds)
|
| 583 |
+
and ("backbone_last_upper" in preds)
|
| 584 |
+
):
|
| 585 |
+
# Optional prior-consistency regularizer:
|
| 586 |
+
# keep S current-endpoint prediction close to pretrained backbone prior.
|
| 587 |
+
bb_s = preds["backbone_last_surface"].detach().unsqueeze(1)
|
| 588 |
+
bb_u = preds["backbone_last_upper"].detach().unsqueeze(1)
|
| 589 |
+
l_backbone_s = _downscale_style_term_loss(
|
| 590 |
+
preds["pred_s_hr6h_surface"][:, -1:].contiguous(),
|
| 591 |
+
preds["pred_s_hr6h_upper"][:, -1:].contiguous(),
|
| 592 |
+
bb_s,
|
| 593 |
+
bb_u,
|
| 594 |
+
surf_vars=surf_vars,
|
| 595 |
+
upper_vars=upper_vars,
|
| 596 |
+
level=level,
|
| 597 |
+
surface_var_weights=surface_var_weights,
|
| 598 |
+
upper_var_weights=upper_var_weights,
|
| 599 |
+
)
|
| 600 |
+
l_backbone = l_backbone_s
|
| 601 |
+
|
| 602 |
+
total = (
|
| 603 |
+
lambda_s * l_s
|
| 604 |
+
+ lambda_t * l_t
|
| 605 |
+
+ lambda_st * l_st
|
| 606 |
+
+ lambda_comm * l_comm
|
| 607 |
+
+ lambda_backbone * l_backbone
|
| 608 |
+
)
|
| 609 |
loss_dict = {
|
| 610 |
"loss_total": total,
|
| 611 |
"loss_s": l_s,
|
| 612 |
"loss_t": l_t,
|
| 613 |
"loss_st": l_st,
|
| 614 |
"loss_comm": l_comm,
|
| 615 |
+
"loss_backbone": l_backbone,
|
| 616 |
}
|
| 617 |
return total, loss_dict
|
| 618 |
|
|
|
|
| 890 |
return x[..., :h, :w]
|
| 891 |
|
| 892 |
|
| 893 |
+
def _sanitize_tensor(x, clip_value=1e6):
|
| 894 |
+
x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
|
| 895 |
+
if clip_value is not None and clip_value > 0:
|
| 896 |
+
x = torch.clamp(x, min=-clip_value, max=clip_value)
|
| 897 |
+
return x
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
def _tensor_health(x):
|
| 901 |
+
finite_mask = torch.isfinite(x)
|
| 902 |
+
finite_ratio = float(finite_mask.float().mean().item())
|
| 903 |
+
x_safe = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
|
| 904 |
+
max_abs = float(x_safe.abs().max().item())
|
| 905 |
+
return finite_ratio, max_abs
|
| 906 |
+
|
| 907 |
+
|
| 908 |
def _resize_surface_bt_vhw(x, out_hw):
|
| 909 |
b, t, v, h, w = x.shape
|
| 910 |
th, tw = out_hw
|
|
|
|
| 927 |
).reshape(b, t, v, c, th, tw)
|
| 928 |
|
| 929 |
|
| 930 |
+
def _cell_edges_from_centers(centers_1d: torch.Tensor) -> torch.Tensor:
|
| 931 |
+
centers_1d = centers_1d.float()
|
| 932 |
+
if centers_1d.numel() < 2:
|
| 933 |
+
c = centers_1d[0]
|
| 934 |
+
return torch.stack((c - 0.5, c + 0.5), dim=0)
|
| 935 |
+
mids = 0.5 * (centers_1d[:-1] + centers_1d[1:])
|
| 936 |
+
first = centers_1d[0] - 0.5 * (centers_1d[1] - centers_1d[0])
|
| 937 |
+
last = centers_1d[-1] + 0.5 * (centers_1d[-1] - centers_1d[-2])
|
| 938 |
+
return torch.cat((first[None], mids, last[None]), dim=0)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
def _interval_overlap_matrix(src_edges: torch.Tensor, tgt_edges: torch.Tensor) -> torch.Tensor:
|
| 942 |
+
# Returns [N_tgt, N_src] overlap lengths for 1D intervals.
|
| 943 |
+
src_lo = src_edges[:-1].unsqueeze(0) # [1, Ns]
|
| 944 |
+
src_hi = src_edges[1:].unsqueeze(0) # [1, Ns]
|
| 945 |
+
tgt_lo = tgt_edges[:-1].unsqueeze(1) # [Nt, 1]
|
| 946 |
+
tgt_hi = tgt_edges[1:].unsqueeze(1) # [Nt, 1]
|
| 947 |
+
return torch.clamp(torch.minimum(src_hi, tgt_hi) - torch.maximum(src_lo, tgt_lo), min=0.0)
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def _maybe_flip_last_dim(x: torch.Tensor, dim: int, need_flip: bool) -> torch.Tensor:
|
| 951 |
+
if need_flip:
|
| 952 |
+
return torch.flip(x, dims=(dim,))
|
| 953 |
+
return x
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
def _conservative_remap_2d(
|
| 957 |
+
x: torch.Tensor,
|
| 958 |
+
lat_src: torch.Tensor,
|
| 959 |
+
lon_src: torch.Tensor,
|
| 960 |
+
lat_tgt: torch.Tensor,
|
| 961 |
+
lon_tgt: torch.Tensor,
|
| 962 |
+
eps: float = 1e-12,
|
| 963 |
+
) -> torch.Tensor:
|
| 964 |
+
"""
|
| 965 |
+
Conservative remap from source grid to target grid for tensors with trailing [..., H, W].
|
| 966 |
+
Uses area-overlap weights in (sin(lat), lon) coordinates (first-order conservative).
|
| 967 |
+
"""
|
| 968 |
+
# Keep target orientation to restore at end.
|
| 969 |
+
src_lat_desc = bool((lat_src[0] > lat_src[-1]).item())
|
| 970 |
+
src_lon_desc = bool((lon_src[0] > lon_src[-1]).item())
|
| 971 |
+
tgt_lat_desc = bool((lat_tgt[0] > lat_tgt[-1]).item())
|
| 972 |
+
tgt_lon_desc = bool((lon_tgt[0] > lon_tgt[-1]).item())
|
| 973 |
+
|
| 974 |
+
x_work = _maybe_flip_last_dim(x, -2, src_lat_desc)
|
| 975 |
+
x_work = _maybe_flip_last_dim(x_work, -1, src_lon_desc)
|
| 976 |
+
lat_src_inc = _maybe_flip_last_dim(lat_src, 0, src_lat_desc)
|
| 977 |
+
lon_src_inc = _maybe_flip_last_dim(lon_src, 0, src_lon_desc)
|
| 978 |
+
lat_tgt_inc = _maybe_flip_last_dim(lat_tgt, 0, tgt_lat_desc)
|
| 979 |
+
lon_tgt_inc = _maybe_flip_last_dim(lon_tgt, 0, tgt_lon_desc)
|
| 980 |
+
|
| 981 |
+
lat_src_edges = _cell_edges_from_centers(lat_src_inc).clamp(-90.0, 90.0)
|
| 982 |
+
lat_tgt_edges = _cell_edges_from_centers(lat_tgt_inc).clamp(-90.0, 90.0)
|
| 983 |
+
lon_src_edges = _cell_edges_from_centers(lon_src_inc)
|
| 984 |
+
lon_tgt_edges = _cell_edges_from_centers(lon_tgt_inc)
|
| 985 |
+
|
| 986 |
+
# Area metric on a sphere is proportional to d(lon) * d(sin(lat)).
|
| 987 |
+
lat_src_m = torch.sin(torch.deg2rad(lat_src_edges))
|
| 988 |
+
lat_tgt_m = torch.sin(torch.deg2rad(lat_tgt_edges))
|
| 989 |
+
lat_overlap = _interval_overlap_matrix(lat_src_m, lat_tgt_m) # [Ht, Hs]
|
| 990 |
+
lon_overlap = _interval_overlap_matrix(lon_src_edges, lon_tgt_edges) # [Wt, Ws]
|
| 991 |
+
|
| 992 |
+
tmp = torch.einsum("...hw,ih->...iw", x_work, lat_overlap)
|
| 993 |
+
num = torch.einsum("...iw,jw->...ij", tmp, lon_overlap)
|
| 994 |
+
|
| 995 |
+
lat_tgt_width = (lat_tgt_m[1:] - lat_tgt_m[:-1]).clamp_min(eps) # [Ht]
|
| 996 |
+
lon_tgt_width = (lon_tgt_edges[1:] - lon_tgt_edges[:-1]).clamp_min(eps) # [Wt]
|
| 997 |
+
den = lat_tgt_width[:, None] * lon_tgt_width[None, :]
|
| 998 |
+
out = num / den
|
| 999 |
+
|
| 1000 |
+
out = _maybe_flip_last_dim(out, -2, tgt_lat_desc)
|
| 1001 |
+
out = _maybe_flip_last_dim(out, -1, tgt_lon_desc)
|
| 1002 |
+
return out
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def _conservative_downscale_surface_bt_vhw(x, lat_hr, lon_hr, lat_lr, lon_lr):
|
| 1006 |
+
# x: [B, T, V, H, W]
|
| 1007 |
+
b, t, v, h, w = x.shape
|
| 1008 |
+
y = _conservative_remap_2d(
|
| 1009 |
+
x.reshape(b * t * v, h, w),
|
| 1010 |
+
lat_src=lat_hr,
|
| 1011 |
+
lon_src=lon_hr,
|
| 1012 |
+
lat_tgt=lat_lr,
|
| 1013 |
+
lon_tgt=lon_lr,
|
| 1014 |
+
)
|
| 1015 |
+
return y.reshape(b, t, v, lat_lr.numel(), lon_lr.numel())
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
def _conservative_downscale_upper_bt_vchw(x, lat_hr, lon_hr, lat_lr, lon_lr):
|
| 1019 |
+
# x: [B, T, V, C, H, W]
|
| 1020 |
+
b, t, v, c, h, w = x.shape
|
| 1021 |
+
y = _conservative_remap_2d(
|
| 1022 |
+
x.reshape(b * t * v * c, h, w),
|
| 1023 |
+
lat_src=lat_hr,
|
| 1024 |
+
lon_src=lon_hr,
|
| 1025 |
+
lat_tgt=lat_lr,
|
| 1026 |
+
lon_tgt=lon_lr,
|
| 1027 |
+
)
|
| 1028 |
+
return y.reshape(b, t, v, c, lat_lr.numel(), lon_lr.numel())
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
def _prepare_factorized_batch(
|
| 1032 |
+
batch,
|
| 1033 |
+
hr_hw,
|
| 1034 |
+
lr_hw,
|
| 1035 |
+
device,
|
| 1036 |
+
lat_hr=None,
|
| 1037 |
+
lon_hr=None,
|
| 1038 |
+
lat_lr=None,
|
| 1039 |
+
lon_lr=None,
|
| 1040 |
+
):
|
| 1041 |
h_hr, w_hr = hr_hw
|
| 1042 |
h_lr, w_lr = lr_hw
|
| 1043 |
|
| 1044 |
def _to_device(x, h, w):
|
| 1045 |
+
x = _trim_to_hw(x.float(), h, w).to(device, non_blocking=True)
|
| 1046 |
+
return _sanitize_tensor(x)
|
| 1047 |
|
| 1048 |
out = {}
|
| 1049 |
out["x_lr6h_surface"] = _to_device(batch["x_lr6h_surface"], h_lr, w_lr)
|
|
|
|
| 1070 |
out["y_lr1h_surface"] = _to_device(batch["y_lr1h_surface"], h_lr, w_lr)
|
| 1071 |
out["y_lr1h_upper"] = _to_device(batch["y_lr1h_upper"], h_lr, w_lr)
|
| 1072 |
else:
|
| 1073 |
+
use_conservative = (
|
| 1074 |
+
(lat_hr is not None) and (lon_hr is not None) and
|
| 1075 |
+
(lat_lr is not None) and (lon_lr is not None)
|
| 1076 |
+
)
|
| 1077 |
+
if use_conservative:
|
| 1078 |
+
lat_hr_d = lat_hr.to(device=device, dtype=out["y_hr1h_surface"].dtype)
|
| 1079 |
+
lon_hr_d = lon_hr.to(device=device, dtype=out["y_hr1h_surface"].dtype)
|
| 1080 |
+
lat_lr_d = lat_lr.to(device=device, dtype=out["y_hr1h_surface"].dtype)
|
| 1081 |
+
lon_lr_d = lon_lr.to(device=device, dtype=out["y_hr1h_surface"].dtype)
|
| 1082 |
+
out["y_lr1h_surface"] = _conservative_downscale_surface_bt_vhw(
|
| 1083 |
+
out["y_hr1h_surface"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
|
| 1084 |
+
)
|
| 1085 |
+
out["y_lr1h_upper"] = _conservative_downscale_upper_bt_vchw(
|
| 1086 |
+
out["y_hr1h_upper"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
|
| 1087 |
+
)
|
| 1088 |
+
else:
|
| 1089 |
+
out["y_lr1h_surface"] = _resize_surface_bt_vhw(out["y_hr1h_surface"], (h_lr, w_lr))
|
| 1090 |
+
out["y_lr1h_upper"] = _resize_upper_bt_vchw(out["y_hr1h_upper"], (h_lr, w_lr))
|
| 1091 |
+
|
| 1092 |
+
if "target_offsets_hours" in batch:
|
| 1093 |
+
offsets = batch["target_offsets_hours"]
|
| 1094 |
+
if not torch.is_tensor(offsets):
|
| 1095 |
+
offsets = torch.as_tensor(offsets)
|
| 1096 |
+
out["target_offsets_hours"] = offsets.to(device=device, non_blocking=True).long()
|
| 1097 |
|
| 1098 |
out["time_unix"] = batch["time_unix"]
|
| 1099 |
return out
|
|
|
|
| 1202 |
upper_vars=None,
|
| 1203 |
level=None,
|
| 1204 |
correction_scale=1.0,
|
| 1205 |
+
fix_endpoints=True,
|
| 1206 |
+
target_offsets_hours=None,
|
| 1207 |
):
|
| 1208 |
if temporal_corrector is None:
|
| 1209 |
def _temporal_correction_zeros(input_surface, input_upper, base_surface, base_upper):
|
|
|
|
| 1212 |
return _temporal_correction_zeros
|
| 1213 |
|
| 1214 |
def _temporal_correction(input_surface, input_upper, base_surface, base_upper):
|
| 1215 |
+
cond_surface_n = _normalise_surface_bt_vhw(input_surface, surf_vars)
|
| 1216 |
+
cond_upper_n = _normalise_upper_bt_vchw(input_upper, upper_vars, level)
|
| 1217 |
base_surface_n = _normalise_surface_bt_vhw(base_surface, surf_vars)
|
| 1218 |
base_upper_n = _normalise_upper_bt_vchw(base_upper, upper_vars, level)
|
| 1219 |
+
corr_surface_n, corr_upper_n = temporal_corrector(
|
| 1220 |
+
cond_surface_n, cond_upper_n, base_surface_n, base_upper_n
|
| 1221 |
+
)
|
| 1222 |
|
| 1223 |
corr_surface_n = torch.nan_to_num(corr_surface_n, nan=0.0, posinf=0.0, neginf=0.0)
|
| 1224 |
corr_upper_n = torch.nan_to_num(corr_upper_n, nan=0.0, posinf=0.0, neginf=0.0)
|
| 1225 |
+
if fix_endpoints and corr_surface_n.shape[1] >= 1 and corr_upper_n.shape[1] >= 1:
|
| 1226 |
+
corr_surface_n = corr_surface_n.clone()
|
| 1227 |
+
corr_upper_n = corr_upper_n.clone()
|
| 1228 |
+
fix_indices = []
|
| 1229 |
+
if target_offsets_hours is None:
|
| 1230 |
+
# Backward-compatibility fallback (legacy setting with endpoint-including sequence).
|
| 1231 |
+
fix_indices = [0, corr_surface_n.shape[1] - 1] if corr_surface_n.shape[1] >= 2 else [0]
|
| 1232 |
+
else:
|
| 1233 |
+
offs = target_offsets_hours
|
| 1234 |
+
if not torch.is_tensor(offs):
|
| 1235 |
+
offs = torch.as_tensor(offs)
|
| 1236 |
+
if offs.dim() == 2:
|
| 1237 |
+
offs = offs[0]
|
| 1238 |
+
offs = offs.to(device=corr_surface_n.device).reshape(-1).long()
|
| 1239 |
+
if offs.numel() == corr_surface_n.shape[1]:
|
| 1240 |
+
fix_indices = torch.nonzero(offs == 0, as_tuple=False).view(-1).tolist()
|
| 1241 |
+
for idx in fix_indices:
|
| 1242 |
+
corr_surface_n[:, idx] = 0.0
|
| 1243 |
+
corr_upper_n[:, idx] = 0.0
|
| 1244 |
|
| 1245 |
s_scale = _surface_scale_tensor(surf_vars, base_surface)
|
| 1246 |
u_scale = _upper_scale_tensor(upper_vars, level, base_upper)
|
|
|
|
| 1278 |
lambda_t=1.0,
|
| 1279 |
lambda_st=1.0,
|
| 1280 |
lambda_comm=1.0,
|
| 1281 |
+
lambda_backbone=0.0,
|
| 1282 |
surface_var_weights=None,
|
| 1283 |
upper_var_weights=None,
|
| 1284 |
upper_level_weights=None,
|
| 1285 |
interior_only_1h=False,
|
| 1286 |
+
lt_supervision_source="lr_path",
|
| 1287 |
+
lt_hybrid_alpha=0.5,
|
| 1288 |
spatial_corrector=None,
|
| 1289 |
temporal_corrector=None,
|
| 1290 |
spatial_correction_scale=1.0,
|
| 1291 |
temporal_correction_scale=1.0,
|
| 1292 |
+
temporal_fix_endpoints=True,
|
| 1293 |
enable_backbone_lora=False,
|
| 1294 |
accum_backward_mode="micro",
|
| 1295 |
use_ours=False,
|
|
|
|
| 1317 |
backward_calls = 0
|
| 1318 |
pending_loss = None
|
| 1319 |
|
| 1320 |
+
if torch.cuda.is_available():
|
| 1321 |
+
torch.cuda.reset_peak_memory_stats()
|
| 1322 |
+
|
| 1323 |
for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 1324 |
step = data_iter_step // update_freq
|
| 1325 |
if step >= num_training_steps_per_epoch:
|
|
|
|
| 1339 |
hr_hw=hr_hw,
|
| 1340 |
lr_hw=lr_hw,
|
| 1341 |
device=device,
|
| 1342 |
+
lat_hr=lat_hr,
|
| 1343 |
+
lon_hr=lon_hr,
|
| 1344 |
+
lat_lr=lat_lr,
|
| 1345 |
+
lon_lr=lon_lr,
|
| 1346 |
)
|
| 1347 |
|
| 1348 |
spatial_correction_fn = _build_spatial_correction_fn(
|
|
|
|
| 1358 |
upper_vars=upper_vars,
|
| 1359 |
level=level,
|
| 1360 |
correction_scale=temporal_correction_scale,
|
| 1361 |
+
fix_endpoints=temporal_fix_endpoints,
|
| 1362 |
+
target_offsets_hours=batch_f.get("target_offsets_hours", None),
|
| 1363 |
)
|
| 1364 |
backbone_last_surface = None
|
| 1365 |
backbone_last_upper = None
|
|
|
|
| 1382 |
backbone_s_last_surface=backbone_last_surface,
|
| 1383 |
backbone_s_last_upper=backbone_last_upper,
|
| 1384 |
)
|
| 1385 |
+
if lt_supervision_source in {"st_downscaled", "hybrid"}:
|
| 1386 |
+
lat_hr_d = lat_hr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
|
| 1387 |
+
lon_hr_d = lon_hr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
|
| 1388 |
+
lat_lr_d = lat_lr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
|
| 1389 |
+
lon_lr_d = lon_lr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
|
| 1390 |
+
preds["pred_st_lr1h_surface"] = _conservative_downscale_surface_bt_vhw(
|
| 1391 |
+
preds["pred_st_hr1h_surface"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
|
| 1392 |
+
)
|
| 1393 |
+
preds["pred_st_lr1h_upper"] = _conservative_downscale_upper_bt_vchw(
|
| 1394 |
+
preds["pred_st_hr1h_upper"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
|
| 1395 |
+
)
|
| 1396 |
+
if backbone_last_surface is not None and backbone_last_upper is not None:
|
| 1397 |
+
preds["backbone_last_surface"] = backbone_last_surface
|
| 1398 |
+
preds["backbone_last_upper"] = backbone_last_upper
|
| 1399 |
loss, loss_dict = loss_fn(
|
| 1400 |
preds=preds,
|
| 1401 |
batch=batch_f,
|
|
|
|
| 1403 |
lambda_t=lambda_t,
|
| 1404 |
lambda_st=lambda_st,
|
| 1405 |
lambda_comm=lambda_comm,
|
| 1406 |
+
lambda_backbone=lambda_backbone,
|
| 1407 |
surface_var_weights=surface_var_weights,
|
| 1408 |
upper_var_weights=upper_var_weights,
|
| 1409 |
upper_level_weights=upper_level_weights,
|
| 1410 |
interior_only_1h=interior_only_1h,
|
| 1411 |
+
lt_supervision_source=lt_supervision_source,
|
| 1412 |
+
lt_hybrid_alpha=lt_hybrid_alpha,
|
| 1413 |
surf_vars=surf_vars,
|
| 1414 |
upper_vars=upper_vars,
|
| 1415 |
level=level,
|
| 1416 |
)
|
|
|
|
| 1417 |
loss_value = loss.item()
|
| 1418 |
if not np.isfinite(loss_value):
|
| 1419 |
print(f"[factorized] non-finite loss at iter {data_iter_step}; skip update.")
|
| 1420 |
+
try:
|
| 1421 |
+
for k in ("x_lr6h_surface", "x_lr6h_upper", "y_hr1h_surface", "y_hr1h_upper"):
|
| 1422 |
+
fr, ma = _tensor_health(batch_f[k])
|
| 1423 |
+
print(f"[factorized][health] {k}: finite_ratio={fr:.6f}, max_abs={ma:.3e}")
|
| 1424 |
+
for k in ("pred_s_hr6h_surface", "pred_t_lr1h_surface", "pred_st_hr1h_surface"):
|
| 1425 |
+
fr, ma = _tensor_health(preds[k])
|
| 1426 |
+
print(f"[factorized][health] {k}: finite_ratio={fr:.6f}, max_abs={ma:.3e}")
|
| 1427 |
+
except Exception as health_err:
|
| 1428 |
+
print(f"[factorized][health] failed to collect tensor health: {health_err}")
|
| 1429 |
optimizer.zero_grad()
|
| 1430 |
pending_loss = None
|
| 1431 |
continue
|
|
|
|
| 1478 |
metric_logger.update(loss_t=loss_dict["loss_t"].item())
|
| 1479 |
metric_logger.update(loss_st=loss_dict["loss_st"].item())
|
| 1480 |
metric_logger.update(loss_comm=loss_dict["loss_comm"].item())
|
| 1481 |
+
metric_logger.update(loss_backbone=loss_dict["loss_backbone"].item())
|
| 1482 |
metric_logger.update(loss_scale=loss_scale_value)
|
| 1483 |
|
| 1484 |
min_lr = 10.0
|
|
|
|
| 1505 |
log_writer.update(loss_t=loss_dict["loss_t"].item(), head="loss")
|
| 1506 |
log_writer.update(loss_st=loss_dict["loss_st"].item(), head="loss")
|
| 1507 |
log_writer.update(loss_comm=loss_dict["loss_comm"].item(), head="loss")
|
| 1508 |
+
log_writer.update(loss_backbone=loss_dict["loss_backbone"].item(), head="loss")
|
| 1509 |
log_writer.update(loss_scale=loss_scale_value, head="opt")
|
| 1510 |
log_writer.update(lr=max_lr, head="opt")
|
| 1511 |
log_writer.update(min_lr=min_lr, head="opt")
|
|
|
|
| 1542 |
lambda_t=1.0,
|
| 1543 |
lambda_st=1.0,
|
| 1544 |
lambda_comm=1.0,
|
| 1545 |
+
lambda_backbone=0.0,
|
| 1546 |
surface_var_weights=None,
|
| 1547 |
upper_var_weights=None,
|
| 1548 |
upper_level_weights=None,
|
| 1549 |
interior_only_1h=False,
|
| 1550 |
+
lt_supervision_source="lr_path",
|
| 1551 |
+
lt_hybrid_alpha=0.5,
|
| 1552 |
spatial_corrector=None,
|
| 1553 |
temporal_corrector=None,
|
| 1554 |
spatial_correction_scale=1.0,
|
| 1555 |
temporal_correction_scale=1.0,
|
| 1556 |
+
temporal_fix_endpoints=True,
|
| 1557 |
enable_backbone_lora=False,
|
| 1558 |
print_freq=20,
|
| 1559 |
val_max_steps=-1,
|
|
|
|
| 1562 |
metric_logger.add_meter("prep_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
|
| 1563 |
metric_logger.add_meter("model_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
|
| 1564 |
metric_logger.add_meter("post_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
|
| 1565 |
+
metric_logger.add_meter("loss_backbone", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
|
| 1566 |
header = "Val:"
|
| 1567 |
|
| 1568 |
model.eval()
|
|
|
|
| 1580 |
level=level,
|
| 1581 |
correction_scale=spatial_correction_scale,
|
| 1582 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1583 |
warned_short_seq = False
|
| 1584 |
for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 1585 |
if val_max_steps is not None and val_max_steps > 0 and data_iter_step >= val_max_steps:
|
|
|
|
| 1590 |
hr_hw=hr_hw,
|
| 1591 |
lr_hw=lr_hw,
|
| 1592 |
device=device,
|
| 1593 |
+
lat_hr=lat_hr,
|
| 1594 |
+
lon_hr=lon_hr,
|
| 1595 |
+
lat_lr=lat_lr,
|
| 1596 |
+
lon_lr=lon_lr,
|
| 1597 |
)
|
| 1598 |
if torch.cuda.is_available():
|
| 1599 |
torch.cuda.synchronize()
|
| 1600 |
t1 = time.perf_counter()
|
| 1601 |
|
| 1602 |
with torch.inference_mode():
|
| 1603 |
+
temporal_correction_fn = _build_temporal_correction_fn(
|
| 1604 |
+
temporal_corrector=temporal_corrector,
|
| 1605 |
+
surf_vars=surf_vars,
|
| 1606 |
+
upper_vars=upper_vars,
|
| 1607 |
+
level=level,
|
| 1608 |
+
correction_scale=temporal_correction_scale,
|
| 1609 |
+
fix_endpoints=temporal_fix_endpoints,
|
| 1610 |
+
target_offsets_hours=batch_f.get("target_offsets_hours", None),
|
| 1611 |
+
)
|
| 1612 |
backbone_last_surface = None
|
| 1613 |
backbone_last_upper = None
|
| 1614 |
if enable_backbone_lora:
|
|
|
|
| 1629 |
backbone_s_last_surface=backbone_last_surface,
|
| 1630 |
backbone_s_last_upper=backbone_last_upper,
|
| 1631 |
)
|
| 1632 |
+
if lt_supervision_source in {"st_downscaled", "hybrid"}:
|
| 1633 |
+
lat_hr_d = lat_hr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
|
| 1634 |
+
lon_hr_d = lon_hr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
|
| 1635 |
+
lat_lr_d = lat_lr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
|
| 1636 |
+
lon_lr_d = lon_lr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
|
| 1637 |
+
preds["pred_st_lr1h_surface"] = _conservative_downscale_surface_bt_vhw(
|
| 1638 |
+
preds["pred_st_hr1h_surface"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
|
| 1639 |
+
)
|
| 1640 |
+
preds["pred_st_lr1h_upper"] = _conservative_downscale_upper_bt_vchw(
|
| 1641 |
+
preds["pred_st_hr1h_upper"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
|
| 1642 |
+
)
|
| 1643 |
+
if backbone_last_surface is not None and backbone_last_upper is not None:
|
| 1644 |
+
preds["backbone_last_surface"] = backbone_last_surface
|
| 1645 |
+
preds["backbone_last_upper"] = backbone_last_upper
|
| 1646 |
loss, loss_dict = loss_fn(
|
| 1647 |
preds=preds,
|
| 1648 |
batch=batch_f,
|
|
|
|
| 1650 |
lambda_t=lambda_t,
|
| 1651 |
lambda_st=lambda_st,
|
| 1652 |
lambda_comm=lambda_comm,
|
| 1653 |
+
lambda_backbone=lambda_backbone,
|
| 1654 |
surface_var_weights=surface_var_weights,
|
| 1655 |
upper_var_weights=upper_var_weights,
|
| 1656 |
upper_level_weights=upper_level_weights,
|
| 1657 |
interior_only_1h=interior_only_1h,
|
| 1658 |
+
lt_supervision_source=lt_supervision_source,
|
| 1659 |
+
lt_hybrid_alpha=lt_hybrid_alpha,
|
| 1660 |
surf_vars=surf_vars,
|
| 1661 |
upper_vars=upper_vars,
|
| 1662 |
level=level,
|
|
|
|
| 1670 |
target_seq_surface = batch_f["y_hr1h_surface"]
|
| 1671 |
target_seq_upper = batch_f["y_hr1h_upper"]
|
| 1672 |
|
| 1673 |
+
# Default setting is include_endpoints=False -> sequence is [t+1, ..., t+6] (len=6).
|
| 1674 |
+
# With include_endpoints=True, sequence is [t+0, t+1, ..., t+6] (len=7).
|
| 1675 |
# We expose per-delta metrics for t+1..t+6 when available.
|
| 1676 |
seq_len = pred_seq_surface.shape[1]
|
| 1677 |
if seq_len >= 7:
|
|
|
|
| 1713 |
metric_logger.update(loss_t=loss_dict["loss_t"].item())
|
| 1714 |
metric_logger.update(loss_st=loss_dict["loss_st"].item())
|
| 1715 |
metric_logger.update(loss_comm=loss_dict["loss_comm"].item())
|
| 1716 |
+
metric_logger.update(loss_backbone=loss_dict["loss_backbone"].item())
|
| 1717 |
metric_logger.update(prep_s=t1 - t0)
|
| 1718 |
metric_logger.update(model_s=t2 - t1)
|
| 1719 |
metric_logger.update(post_s=t3 - t2)
|
downscaling/output/WeatherPEFT/checkpoint-19.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4df5855ef1cd3b158eb5b2724fb4e3486ee2e543fe0cb799f5b88f72be3cfca1
|
| 3 |
+
size 5062400227
|
downscaling/output/WeatherPEFT/checkpoint-29.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1f90edeba59a2bc62285624d7fb5c532e535adca27c59d756c22234660a8230c
|
| 3 |
+
size 5062400227
|
downscaling/output/WeatherPEFT/checkpoint-9.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78f21d453f4d85be909ec43cb4876e021acafd757e3668fc0617732a2eef4bd5
|
| 3 |
+
size 5062398726
|
downscaling/output/WeatherPEFT/log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
downscaling/output/backbone_anchor_smoke/log.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"Eval:"
|
| 2 |
+
{"val_prep_s": 0.07330790814012289, "val_model_s": 0.5823051929473877, "val_post_s": 0.010655493009835482, "val_valid_loss": 0.6934626698493958, "val_loss_s": 0.2821418046951294, "val_loss_t": 0.05917414277791977, "val_loss_st": 0.238005131483078, "val_loss_comm": 0.11414159834384918, "val_loss_backbone": 0.0, "val_rmse_2t_tplus1h": 2.2092418670654297, "val_mean_bias_2t_tplus1h": 0.01837158203125, "val_rmse_10u_tplus1h": 2.0027098655700684, "val_mean_bias_10u_tplus1h": -0.045052409172058105, "val_rmse_10v_tplus1h": 2.0557847023010254, "val_mean_bias_10v_tplus1h": 0.08185440301895142, "val_rmse_t850_tplus1h": 1.5737048387527466, "val_mean_bias_t850_tplus1h": 0.1566162109375, "val_rmse_z500_tplus1h": 235.0081024169922, "val_mean_bias_z500_tplus1h": 13.34765625, "val_rmse_2t_tplus2h": 2.311091661453247, "val_mean_bias_2t_tplus2h": -0.047149658203125, "val_rmse_10u_tplus2h": 2.2796661853790283, "val_mean_bias_10u_tplus2h": -0.10876637697219849, "val_rmse_10v_tplus2h": 2.4048244953155518, "val_mean_bias_10v_tplus2h": 0.15996196866035461, "val_rmse_t850_tplus2h": 1.8633790016174316, "val_mean_bias_t850_tplus2h": 0.289154052734375, "val_rmse_z500_tplus2h": 294.5515441894531, "val_mean_bias_z500_tplus2h": 25.08203125, "val_rmse_2t_tplus3h": 2.514113426208496, "val_mean_bias_2t_tplus3h": -0.05084228515625, "val_rmse_10u_tplus3h": 2.639458417892456, "val_mean_bias_10u_tplus3h": -0.17488084733486176, "val_rmse_10v_tplus3h": 2.870990514755249, "val_mean_bias_10v_tplus3h": 0.233941450715065, "val_rmse_t850_tplus3h": 2.23453950881958, "val_mean_bias_t850_tplus3h": 0.416259765625, "val_rmse_z500_tplus3h": 366.6177673339844, "val_mean_bias_z500_tplus3h": 35.8828125, "val_rmse_2t_tplus4h": 2.748657464981079, "val_mean_bias_2t_tplus4h": -0.073028564453125, "val_rmse_10u_tplus4h": 3.0420165061950684, "val_mean_bias_10u_tplus4h": -0.24148303270339966, "val_rmse_10v_tplus4h": 3.3935253620147705, "val_mean_bias_10v_tplus4h": 0.29964739084243774, "val_rmse_t850_tplus4h": 2.6447532176971436, "val_mean_bias_t850_tplus4h": 0.5400390625, "val_rmse_z500_tplus4h": 444.1191101074219, "val_mean_bias_z500_tplus4h": 46.0546875, "val_rmse_2t_tplus5h": 2.8865745067596436, "val_mean_bias_2t_tplus5h": -0.24896240234375, "val_rmse_10u_tplus5h": 3.4593324661254883, "val_mean_bias_10u_tplus5h": -0.30440056324005127, "val_rmse_10v_tplus5h": 3.940408706665039, "val_mean_bias_10v_tplus5h": 0.3687131106853485, "val_rmse_t850_tplus5h": 3.0733869075775146, "val_mean_bias_t850_tplus5h": 0.66290283203125, "val_rmse_z500_tplus5h": 522.9864501953125, "val_mean_bias_z500_tplus5h": 55.27734375, "val_rmse_2t_tplus6h": 3.113255262374878, "val_mean_bias_2t_tplus6h": -0.284454345703125, "val_rmse_10u_tplus6h": 3.8806846141815186, "val_mean_bias_10u_tplus6h": -0.3701193332672119, "val_rmse_10v_tplus6h": 4.501622200012207, "val_mean_bias_10v_tplus6h": 0.4297124147415161, "val_rmse_t850_tplus6h": 3.5143191814422607, "val_mean_bias_t850_tplus6h": 0.78155517578125, "val_rmse_z500_tplus6h": 600.989013671875, "val_mean_bias_z500_tplus6h": 63.31640625, "val_rmse_2t": 3.113255262374878, "val_mean_bias_2t": -0.284454345703125, "val_rmse_10u": 3.8806846141815186, "val_mean_bias_10u": -0.3701193332672119, "val_rmse_10v": 4.501622200012207, "val_mean_bias_10v": 0.4297124147415161, "val_rmse_t850": 3.5143191814422607, "val_mean_bias_t850": 0.78155517578125, "val_rmse_z500": 600.989013671875, "val_mean_bias_z500": 63.31640625}
|
| 3 |
+
" &0.073 &0.582 &0.011 &0.282 &0.059 &0.238 &0.114 &0.0 &2.209 &0.018 &2.003 &-0.045 &2.056 &0.082 &1.574 &0.157 &235.008 &13.348 &2.311 &-0.047 &2.28 &-0.109 &2.405 &0.16 &1.863 &0.289 &294.552 &25.082 &2.514 &-0.051 &2.639 &-0.175 &2.871 &0.234 &2.235 &0.416 &366.618 &35.883 &2.749 &-0.073 &3.042 &-0.241 &3.394 &0.3 &2.645 &0.54 &444.119 &46.055 &2.887 &-0.249 &3.459 &-0.304 &3.94 &0.369 &3.073 &0.663 &522.986 &55.277 &3.113 &-0.284 &3.881 &-0.37 &4.502 &0.43 &3.514 &0.782 &600.989 &63.316 &3.113 &-0.284 &3.881 &-0.37 &4.502 &0.43 &3.514 &0.782 &600.989 &63.316"
|
| 4 |
+
|
downscaling/output/val_freq_smoke/log.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"Eval:"
|
| 2 |
+
{"val_prep_s": 0.05949366791173816, "val_model_s": 0.3861960484646261, "val_post_s": 0.005747961113229394, "val_valid_loss": 0.6520180404186249, "val_loss_s": 0.2656427472829819, "val_loss_t": 0.05454788729548454, "val_loss_st": 0.22411399334669113, "val_loss_comm": 0.10771343111991882, "val_rmse_2t_tplus1h": 2.1595544815063477, "val_mean_bias_2t_tplus1h": 0.0153045654296875, "val_rmse_10u_tplus1h": 1.9929717779159546, "val_mean_bias_10u_tplus1h": -0.06720291078090668, "val_rmse_10v_tplus1h": 2.0678257942199707, "val_mean_bias_10v_tplus1h": 0.04859817773103714, "val_rmse_t850_tplus1h": 1.5690162181854248, "val_mean_bias_t850_tplus1h": 0.1692962646484375, "val_rmse_z500_tplus1h": 229.7540283203125, "val_mean_bias_z500_tplus1h": 11.56640625, "val_rmse_2t_tplus2h": 2.325566053390503, "val_mean_bias_2t_tplus2h": -0.111175537109375, "val_rmse_10u_tplus2h": 2.259079933166504, "val_mean_bias_10u_tplus2h": -0.12979388236999512, "val_rmse_10v_tplus2h": 2.4088897705078125, "val_mean_bias_10v_tplus2h": 0.0989474430680275, "val_rmse_t850_tplus2h": 1.8684790134429932, "val_mean_bias_t850_tplus2h": 0.3101043701171875, "val_rmse_z500_tplus2h": 285.2650146484375, "val_mean_bias_z500_tplus2h": 23.6953125, "val_rmse_2t_tplus3h": 2.559150457382202, "val_mean_bias_2t_tplus3h": -0.1783599853515625, "val_rmse_10u_tplus3h": 2.5984840393066406, "val_mean_bias_10u_tplus3h": -0.1920362114906311, "val_rmse_10v_tplus3h": 2.8546905517578125, "val_mean_bias_10v_tplus3h": 0.15115559101104736, "val_rmse_t850_tplus3h": 2.2511355876922607, "val_mean_bias_t850_tplus3h": 0.44720458984375, "val_rmse_z500_tplus3h": 353.3114013671875, "val_mean_bias_z500_tplus3h": 34.533203125, "val_rmse_2t_tplus4h": 2.829373598098755, "val_mean_bias_2t_tplus4h": -0.268890380859375, "val_rmse_10u_tplus4h": 2.954393148422241, "val_mean_bias_10u_tplus4h": -0.2777942419052124, "val_rmse_10v_tplus4h": 3.3426637649536133, "val_mean_bias_10v_tplus4h": 0.18345022201538086, "val_rmse_t850_tplus4h": 2.6910040378570557, "val_mean_bias_t850_tplus4h": 0.591400146484375, "val_rmse_z500_tplus4h": 428.23114013671875, "val_mean_bias_z500_tplus4h": 46.494140625, "val_rmse_2t_tplus5h": 2.988680839538574, "val_mean_bias_2t_tplus5h": -0.3951263427734375, "val_rmse_10u_tplus5h": 3.360450506210327, "val_mean_bias_10u_tplus5h": -0.33785462379455566, "val_rmse_10v_tplus5h": 3.8764631748199463, "val_mean_bias_10v_tplus5h": 0.23951280117034912, "val_rmse_t850_tplus5h": 3.136207103729248, "val_mean_bias_t850_tplus5h": 0.7298583984375, "val_rmse_z500_tplus5h": 503.41949462890625, "val_mean_bias_z500_tplus5h": 58.720703125, "val_rmse_2t_tplus6h": 3.1790788173675537, "val_mean_bias_2t_tplus6h": -0.4730224609375, "val_rmse_10u_tplus6h": 3.7709269523620605, "val_mean_bias_10u_tplus6h": -0.39494484663009644, "val_rmse_10v_tplus6h": 4.418558120727539, "val_mean_bias_10v_tplus6h": 0.29149335622787476, "val_rmse_t850_tplus6h": 3.5930073261260986, "val_mean_bias_t850_tplus6h": 0.868896484375, "val_rmse_z500_tplus6h": 578.9581298828125, "val_mean_bias_z500_tplus6h": 69.849609375, "val_rmse_2t": 3.1790788173675537, "val_mean_bias_2t": -0.4730224609375, "val_rmse_10u": 3.7709269523620605, "val_mean_bias_10u": -0.39494484663009644, "val_rmse_10v": 4.418558120727539, "val_mean_bias_10v": 0.29149335622787476, "val_rmse_t850": 3.5930073261260986, "val_mean_bias_t850": 0.868896484375, "val_rmse_z500": 578.9581298828125, "val_mean_bias_z500": 69.849609375}
|
| 3 |
+
" &0.059 &0.386 &0.006 &0.266 &0.055 &0.224 &0.108 &2.16 &0.015 &1.993 &-0.067 &2.068 &0.049 &1.569 &0.169 &229.754 &11.566 &2.326 &-0.111 &2.259 &-0.13 &2.409 &0.099 &1.868 &0.31 &285.265 &23.695 &2.559 &-0.178 &2.598 &-0.192 &2.855 &0.151 &2.251 &0.447 &353.311 &34.533 &2.829 &-0.269 &2.954 &-0.278 &3.343 &0.183 &2.691 &0.591 &428.231 &46.494 &2.989 &-0.395 &3.36 &-0.338 &3.876 &0.24 &3.136 &0.73 &503.419 &58.721 &3.179 &-0.473 &3.771 &-0.395 &4.419 &0.291 &3.593 &0.869 &578.958 &69.85 &3.179 &-0.473 &3.771 &-0.395 &4.419 &0.291 &3.593 &0.869 &578.958 &69.85"
|
| 4 |
+
|
downscaling/output/val_speed_after/log.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"Eval:"
|
| 2 |
+
{"val_prep_s": 0.084631056097957, "val_model_s": 0.6615325791062787, "val_post_s": 0.0011346005291367571, "val_sample_s": 0.7511531277947748, "val_valid_loss": 0.6753831406434377, "val_loss_s": 0.27440425256888074, "val_loss_t": 0.05877547711133957, "val_loss_st": 0.23192234834035239, "val_loss_comm": 0.11028106758991878, "val_rmse_2t_tplus1h": 2.120826482772827, "val_mean_bias_2t_tplus1h": -0.08579571545124054, "val_rmse_10u_tplus1h": 2.0343947410583496, "val_mean_bias_10u_tplus1h": -0.044072579592466354, "val_rmse_10v_tplus1h": 2.0728132724761963, "val_mean_bias_10v_tplus1h": 0.05038035660982132, "val_rmse_t850_tplus1h": 1.610738754272461, "val_mean_bias_t850_tplus1h": 0.1519871950149536, "val_rmse_z500_tplus1h": 237.20074462890625, "val_mean_bias_z500_tplus1h": 9.716571807861328, "val_rmse_2t_tplus2h": 2.293344020843506, "val_mean_bias_2t_tplus2h": -0.18076401948928833, "val_rmse_10u_tplus2h": 2.3025588989257812, "val_mean_bias_10u_tplus2h": -0.08629271388053894, "val_rmse_10v_tplus2h": 2.409632444381714, "val_mean_bias_10v_tplus2h": 0.102354496717453, "val_rmse_t850_tplus2h": 1.9087270498275757, "val_mean_bias_t850_tplus2h": 0.2832726240158081, "val_rmse_z500_tplus2h": 296.6208801269531, "val_mean_bias_z500_tplus2h": 20.60210418701172, "val_rmse_2t_tplus3h": 2.481692314147949, "val_mean_bias_2t_tplus3h": -0.2545323371887207, "val_rmse_10u_tplus3h": 2.6446375846862793, "val_mean_bias_10u_tplus3h": -0.12955179810523987, "val_rmse_10v_tplus3h": 2.8543546199798584, "val_mean_bias_10v_tplus3h": 0.15386676788330078, "val_rmse_t850_tplus3h": 2.282665491104126, "val_mean_bias_t850_tplus3h": 0.4141359329223633, "val_rmse_z500_tplus3h": 367.32757568359375, "val_mean_bias_z500_tplus3h": 30.99903678894043, "val_rmse_2t_tplus4h": 2.7042994499206543, "val_mean_bias_2t_tplus4h": -0.3467860221862793, "val_rmse_10u_tplus4h": 2.998748540878296, "val_mean_bias_10u_tplus4h": -0.19702103734016418, "val_rmse_10v_tplus4h": 3.3468921184539795, "val_mean_bias_10v_tplus4h": 0.20130321383476257, "val_rmse_t850_tplus4h": 2.7033567428588867, "val_mean_bias_t850_tplus4h": 0.5463706254959106, "val_rmse_z500_tplus4h": 444.17047119140625, "val_mean_bias_z500_tplus4h": 42.49504089355469, "val_rmse_2t_tplus5h": 2.890420913696289, "val_mean_bias_2t_tplus5h": -0.4551726281642914, "val_rmse_10u_tplus5h": 3.40653395652771, "val_mean_bias_10u_tplus5h": -0.23875072598457336, "val_rmse_10v_tplus5h": 3.876429319381714, "val_mean_bias_10v_tplus5h": 0.25390106439590454, "val_rmse_t850_tplus5h": 3.1348283290863037, "val_mean_bias_t850_tplus5h": 0.6817998886108398, "val_rmse_z500_tplus5h": 521.3743286132812, "val_mean_bias_z500_tplus5h": 54.1654052734375, "val_rmse_2t_tplus6h": 3.1032655239105225, "val_mean_bias_2t_tplus6h": -0.538508415222168, "val_rmse_10u_tplus6h": 3.821709394454956, "val_mean_bias_10u_tplus6h": -0.27899816632270813, "val_rmse_10v_tplus6h": 4.415755271911621, "val_mean_bias_10v_tplus6h": 0.3063257932662964, "val_rmse_t850_tplus6h": 3.5775928497314453, "val_mean_bias_t850_tplus6h": 0.8159565329551697, "val_rmse_z500_tplus6h": 599.5579833984375, "val_mean_bias_z500_tplus6h": 65.35942077636719, "val_rmse_2t": 3.1032655239105225, "val_mean_bias_2t": -0.538508415222168, "val_rmse_10u": 3.821709394454956, "val_mean_bias_10u": -0.27899816632270813, "val_rmse_10v": 4.415755271911621, "val_mean_bias_10v": 0.3063257932662964, "val_rmse_t850": 3.5775928497314453, "val_mean_bias_t850": 0.8159565329551697, "val_rmse_z500": 599.5579833984375, "val_mean_bias_z500": 65.35942077636719}
|
| 3 |
+
" &0.085 &0.662 &0.001 &0.751 &0.274 &0.059 &0.232 &0.11 &2.121 &-0.086 &2.034 &-0.044 &2.073 &0.05 &1.611 &0.152 &237.201 &9.717 &2.293 &-0.181 &2.303 &-0.086 &2.41 &0.102 &1.909 &0.283 &296.621 &20.602 &2.482 &-0.255 &2.645 &-0.13 &2.854 &0.154 &2.283 &0.414 &367.328 &30.999 &2.704 &-0.347 &2.999 &-0.197 &3.347 &0.201 &2.703 &0.546 &444.17 &42.495 &2.89 &-0.455 &3.407 &-0.239 &3.876 &0.254 &3.135 &0.682 &521.374 &54.165 &3.103 &-0.539 &3.822 &-0.279 &4.416 &0.306 &3.578 &0.816 &599.558 &65.359 &3.103 &-0.539 &3.822 &-0.279 &4.416 &0.306 &3.578 &0.816 &599.558 &65.359"
|
| 4 |
+
|
downscaling/output/val_speed_baseline/log.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"Eval:"
|
| 2 |
+
{"val_prep_s": 0.05412517949783554, "val_model_s": 0.16398887091781944, "val_post_s": 0.003232622635550797, "val_valid_loss": 0.6753831406434377, "val_loss_s": 0.27440425256888074, "val_loss_t": 0.05877547711133957, "val_loss_st": 0.23192234834035239, "val_loss_comm": 0.11028106758991878, "val_rmse_2t_tplus1h": 2.120826482772827, "val_mean_bias_2t_tplus1h": -0.0857950896024704, "val_rmse_10u_tplus1h": 2.0343947410583496, "val_mean_bias_10u_tplus1h": -0.044072579592466354, "val_rmse_10v_tplus1h": 2.0728132724761963, "val_mean_bias_10v_tplus1h": 0.05038034915924072, "val_rmse_t850_tplus1h": 1.610738754272461, "val_mean_bias_t850_tplus1h": 0.1519927978515625, "val_rmse_z500_tplus1h": 237.20074462890625, "val_mean_bias_z500_tplus1h": 9.715169906616211, "val_rmse_2t_tplus2h": 2.293344020843506, "val_mean_bias_2t_tplus2h": -0.18076324462890625, "val_rmse_10u_tplus2h": 2.3025588989257812, "val_mean_bias_10u_tplus2h": -0.08629272878170013, "val_rmse_10v_tplus2h": 2.409632444381714, "val_mean_bias_10v_tplus2h": 0.1023545116186142, "val_rmse_t850_tplus2h": 1.9087270498275757, "val_mean_bias_t850_tplus2h": 0.28327688574790955, "val_rmse_z500_tplus2h": 296.6208801269531, "val_mean_bias_z500_tplus2h": 20.602214813232422, "val_rmse_2t_tplus3h": 2.481692314147949, "val_mean_bias_2t_tplus3h": -0.2545318603515625, "val_rmse_10u_tplus3h": 2.6446375846862793, "val_mean_bias_10u_tplus3h": -0.12955179810523987, "val_rmse_10v_tplus3h": 2.8543546199798584, "val_mean_bias_10v_tplus3h": 0.15386676788330078, "val_rmse_t850_tplus3h": 2.282665491104126, "val_mean_bias_t850_tplus3h": 0.4141286313533783, "val_rmse_z500_tplus3h": 367.32757568359375, "val_mean_bias_z500_tplus3h": 30.99837303161621, "val_rmse_2t_tplus4h": 2.7042994499206543, "val_mean_bias_2t_tplus4h": -0.34678396582603455, "val_rmse_10u_tplus4h": 2.998748540878296, "val_mean_bias_10u_tplus4h": -0.19702103734016418, "val_rmse_10v_tplus4h": 3.3468921184539795, "val_mean_bias_10v_tplus4h": 0.20130321383476257, "val_rmse_t850_tplus4h": 2.7033567428588867, "val_mean_bias_t850_tplus4h": 0.5463689565658569, "val_rmse_z500_tplus4h": 444.17047119140625, "val_mean_bias_z500_tplus4h": 42.49609375, "val_rmse_2t_tplus5h": 2.890420913696289, "val_mean_bias_2t_tplus5h": -0.45517224073410034, "val_rmse_10u_tplus5h": 3.40653395652771, "val_mean_bias_10u_tplus5h": -0.23875072598457336, "val_rmse_10v_tplus5h": 3.876429319381714, "val_mean_bias_10v_tplus5h": 0.25390106439590454, "val_rmse_t850_tplus5h": 3.1348283290863037, "val_mean_bias_t850_tplus5h": 0.6818059682846069, "val_rmse_z500_tplus5h": 521.3743286132812, "val_mean_bias_z500_tplus5h": 54.1650390625, "val_rmse_2t_tplus6h": 3.1032655239105225, "val_mean_bias_2t_tplus6h": -0.5385081171989441, "val_rmse_10u_tplus6h": 3.821709394454956, "val_mean_bias_10u_tplus6h": -0.2789981961250305, "val_rmse_10v_tplus6h": 4.415755271911621, "val_mean_bias_10v_tplus6h": 0.3063257932662964, "val_rmse_t850_tplus6h": 3.5775928497314453, "val_mean_bias_t850_tplus6h": 0.8159612417221069, "val_rmse_z500_tplus6h": 599.5579833984375, "val_mean_bias_z500_tplus6h": 65.3583984375, "val_rmse_2t": 3.1032655239105225, "val_mean_bias_2t": -0.5385081171989441, "val_rmse_10u": 3.821709394454956, "val_mean_bias_10u": -0.2789981961250305, "val_rmse_10v": 4.415755271911621, "val_mean_bias_10v": 0.3063257932662964, "val_rmse_t850": 3.5775928497314453, "val_mean_bias_t850": 0.8159612417221069, "val_rmse_z500": 599.5579833984375, "val_mean_bias_z500": 65.3583984375}
|
| 3 |
+
" &0.054 &0.164 &0.003 &0.274 &0.059 &0.232 &0.11 &2.121 &-0.086 &2.034 &-0.044 &2.073 &0.05 &1.611 &0.152 &237.201 &9.715 &2.293 &-0.181 &2.303 &-0.086 &2.41 &0.102 &1.909 &0.283 &296.621 &20.602 &2.482 &-0.255 &2.645 &-0.13 &2.854 &0.154 &2.283 &0.414 &367.328 &30.998 &2.704 &-0.347 &2.999 &-0.197 &3.347 &0.201 &2.703 &0.546 &444.17 &42.496 &2.89 &-0.455 &3.407 &-0.239 &3.876 &0.254 &3.135 &0.682 &521.374 &54.165 &3.103 &-0.539 &3.822 &-0.279 &4.416 &0.306 &3.578 &0.816 &599.558 &65.358 &3.103 &-0.539 &3.822 &-0.279 &4.416 &0.306 &3.578 &0.816 &599.558 &65.358"
|
| 4 |
+
|
downscaling/run_downscaling.py
CHANGED
|
@@ -87,6 +87,7 @@ def _resolve_pretrained_ckpt(path_str: str) -> Path:
|
|
| 87 |
|
| 88 |
def _attach_factorized_heads(model, surf_vars, upper_vars, level, args):
|
| 89 |
"""Attach factorized temporal/spatial correction heads."""
|
|
|
|
| 90 |
model.temporal_corrector = TemporalCorrectionHead(
|
| 91 |
num_surface_vars=len(surf_vars),
|
| 92 |
num_upper_vars=len(upper_vars),
|
|
@@ -94,6 +95,8 @@ def _attach_factorized_heads(model, surf_vars, upper_vars, level, args):
|
|
| 94 |
hidden_surface=args.temporal_hidden_surface,
|
| 95 |
hidden_upper=args.temporal_hidden_upper,
|
| 96 |
dropout=args.temporal_dropout,
|
|
|
|
|
|
|
| 97 |
)
|
| 98 |
model.spatial_corrector = SpatialCorrectionHead(
|
| 99 |
num_surface_vars=len(surf_vars),
|
|
@@ -327,6 +330,8 @@ def get_args():
|
|
| 327 |
help='validation progress print frequency (iterations)')
|
| 328 |
parser.add_argument('--val_max_steps', default=-1, type=int,
|
| 329 |
help='if >0, limit validation iterations per epoch for faster feedback')
|
|
|
|
|
|
|
| 330 |
parser.add_argument('--persistent_workers', action='store_true', default=True,
|
| 331 |
help='keep dataloader workers persistent across epochs')
|
| 332 |
parser.add_argument('--no_persistent_workers', action='store_false', dest='persistent_workers')
|
|
@@ -366,6 +371,8 @@ def get_args():
|
|
| 366 |
parser.add_argument('--lora_mode', default='single', choices=['single', 'all'])
|
| 367 |
parser.add_argument('--disable_backbone_lora_in_factorized', action='store_true',
|
| 368 |
help='If set in finetune_mode=lora, freeze LoRA adapters and train only correction heads.')
|
|
|
|
|
|
|
| 369 |
parser.add_argument('--lr_data_dir', default='', type=str,
|
| 370 |
help='LR dataset folder under datasets (auto if empty), e.g. 5.625')
|
| 371 |
parser.add_argument('--hr_data_dir', default='', type=str,
|
|
@@ -386,12 +393,19 @@ def get_args():
|
|
| 386 |
parser.add_argument('--derive_lr1h_from_hr1h', action='store_true', default=True,
|
| 387 |
help='Build y_lr1h from y_hr1h in training loop to reduce dataset IO.')
|
| 388 |
parser.add_argument('--no_derive_lr1h_from_hr1h', action='store_false', dest='derive_lr1h_from_hr1h')
|
| 389 |
-
parser.add_argument('--include_endpoints', action='store_true', default=
|
| 390 |
-
help='
|
|
|
|
| 391 |
parser.add_argument('--lambda_s', default=1.0, type=float)
|
| 392 |
parser.add_argument('--lambda_t', default=1.0, type=float)
|
| 393 |
parser.add_argument('--lambda_st', default=1.0, type=float)
|
| 394 |
parser.add_argument('--lambda_comm', default=1.0, type=float)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
parser.add_argument('--interior_only_1h', action='store_true',
|
| 396 |
help='If set, compute 1h losses on interior frames only (exclude endpoints)')
|
| 397 |
parser.add_argument('--ours_prompt_length', default=0, type=int,
|
|
@@ -402,6 +416,8 @@ def get_args():
|
|
| 402 |
help='Hidden channels for temporal correction head (upper vars x levels).')
|
| 403 |
parser.add_argument('--temporal_dropout', default=0.0, type=float,
|
| 404 |
help='Dropout for temporal correction head.')
|
|
|
|
|
|
|
| 405 |
parser.add_argument('--spatial_hidden_surface', default=256, type=int,
|
| 406 |
help='Hidden channels for spatial correction head (surface vars).')
|
| 407 |
parser.add_argument('--spatial_hidden_upper', default=512, type=int,
|
|
@@ -412,6 +428,8 @@ def get_args():
|
|
| 412 |
help='Residual scale multiplier for spatial correction outputs.')
|
| 413 |
parser.add_argument('--temporal_correction_scale', default=1e-3, type=float,
|
| 414 |
help='Residual scale multiplier for temporal correction outputs.')
|
|
|
|
|
|
|
| 415 |
parser.add_argument('--disable_factorized_lr_scale', action='store_true',
|
| 416 |
help='Disable LR scaling by total_batch_size/256 for factorized training.')
|
| 417 |
parser.add_argument('--disable_scale_balanced_loss', action='store_true',
|
|
@@ -563,6 +581,10 @@ def main(args, ds_init):
|
|
| 563 |
)
|
| 564 |
|
| 565 |
print("Sampler_train = %s" % str(sampler_train))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
if args.dist_eval:
|
| 567 |
if len(dataset_val) % num_tasks != 0:
|
| 568 |
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
|
@@ -607,6 +629,8 @@ def main(args, ds_init):
|
|
| 607 |
f"[factorized-data] derive_hr6h_from_hr1h={args.derive_hr6h_from_hr1h}, "
|
| 608 |
f"derive_lr1h_from_hr1h={args.derive_lr1h_from_hr1h}"
|
| 609 |
)
|
|
|
|
|
|
|
| 610 |
|
| 611 |
if dataset_val is not None:
|
| 612 |
val_loader_kwargs = dict(
|
|
@@ -627,6 +651,10 @@ def main(args, ds_init):
|
|
| 627 |
f"prefetch_factor={args.prefetch_factor if args.num_workers > 0 else 'n/a'}, "
|
| 628 |
f"pin_mem={args.pin_mem}"
|
| 629 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
else:
|
| 631 |
data_loader_val = None
|
| 632 |
|
|
@@ -666,14 +694,28 @@ def main(args, ds_init):
|
|
| 666 |
|
| 667 |
use_ours = finetune_mode == "ours"
|
| 668 |
use_lora = finetune_mode == "lora"
|
| 669 |
-
|
|
|
|
| 670 |
if finetune_mode == "lora":
|
| 671 |
if args.lora_mode == "all" and args.lora_steps > 1:
|
| 672 |
print(
|
| 673 |
"[lora mode] NOTE: factorized path uses single-step Aurora forward (rollout_step=0), "
|
| 674 |
"so only step-0 LoRA branch is used even when lora_mode=all."
|
| 675 |
)
|
| 676 |
-
print(f"[lora mode]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
|
| 678 |
model = Aurora(
|
| 679 |
use_ours=use_ours,
|
|
@@ -701,13 +743,14 @@ def main(args, ds_init):
|
|
| 701 |
print(
|
| 702 |
"[factorized] Attached correction heads "
|
| 703 |
f"(temporal: {args.temporal_hidden_surface}/{args.temporal_hidden_upper}, "
|
|
|
|
| 704 |
f"spatial: {args.spatial_hidden_surface}/{args.spatial_hidden_upper})."
|
| 705 |
)
|
| 706 |
|
| 707 |
_set_trainable_parameters(
|
| 708 |
model,
|
| 709 |
finetune_mode,
|
| 710 |
-
train_lora_adapters=
|
| 711 |
)
|
| 712 |
total_params, trainable_params = _count_parameters(model)
|
| 713 |
print(f"[mode={finetune_mode}] trainable params: {trainable_params:,} / {total_params:,}")
|
|
@@ -875,15 +918,19 @@ def main(args, ds_init):
|
|
| 875 |
lambda_t=args.lambda_t,
|
| 876 |
lambda_st=args.lambda_st,
|
| 877 |
lambda_comm=args.lambda_comm,
|
|
|
|
| 878 |
surface_var_weights=surface_var_weights,
|
| 879 |
upper_var_weights=upper_var_level_weights,
|
| 880 |
upper_level_weights=None,
|
| 881 |
interior_only_1h=args.interior_only_1h,
|
|
|
|
|
|
|
| 882 |
spatial_corrector=_get_spatial_corrector(model),
|
| 883 |
temporal_corrector=_get_temporal_corrector(model),
|
| 884 |
spatial_correction_scale=args.spatial_correction_scale,
|
| 885 |
temporal_correction_scale=args.temporal_correction_scale,
|
| 886 |
-
|
|
|
|
| 887 |
print_freq=args.val_print_freq,
|
| 888 |
val_max_steps=args.val_max_steps,
|
| 889 |
)
|
|
@@ -908,6 +955,10 @@ def main(args, ds_init):
|
|
| 908 |
f"[factorized] correction scales: spatial={args.spatial_correction_scale:.2e}, "
|
| 909 |
f"temporal={args.temporal_correction_scale:.2e}"
|
| 910 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 911 |
start_time = time.time()
|
| 912 |
train_time_only = 0
|
| 913 |
total_step = args.epochs * num_training_steps_per_epoch
|
|
@@ -929,15 +980,19 @@ def main(args, ds_init):
|
|
| 929 |
surf_vars=surf_vars, upper_vars=upper_vars,
|
| 930 |
lambda_s=args.lambda_s, lambda_t=args.lambda_t,
|
| 931 |
lambda_st=args.lambda_st, lambda_comm=args.lambda_comm,
|
|
|
|
| 932 |
surface_var_weights=surface_var_weights,
|
| 933 |
upper_var_weights=upper_var_level_weights,
|
| 934 |
upper_level_weights=None,
|
| 935 |
interior_only_1h=args.interior_only_1h,
|
|
|
|
|
|
|
| 936 |
spatial_corrector=_get_spatial_corrector(model),
|
| 937 |
temporal_corrector=_get_temporal_corrector(model),
|
| 938 |
spatial_correction_scale=args.spatial_correction_scale,
|
| 939 |
temporal_correction_scale=args.temporal_correction_scale,
|
| 940 |
-
|
|
|
|
| 941 |
accum_backward_mode=args.accum_backward_mode,
|
| 942 |
use_ours=use_ours, total_step=total_step
|
| 943 |
)
|
|
@@ -948,7 +1003,19 @@ def main(args, ds_init):
|
|
| 948 |
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
| 949 |
loss_scaler=loss_scaler, epoch=epoch)
|
| 950 |
|
| 951 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
test_stats = validation_one_epoch_factorized(
|
| 953 |
data_loader_val,
|
| 954 |
model,
|
|
@@ -966,31 +1033,49 @@ def main(args, ds_init):
|
|
| 966 |
lambda_t=args.lambda_t,
|
| 967 |
lambda_st=args.lambda_st,
|
| 968 |
lambda_comm=args.lambda_comm,
|
|
|
|
| 969 |
surface_var_weights=surface_var_weights,
|
| 970 |
upper_var_weights=upper_var_level_weights,
|
| 971 |
upper_level_weights=None,
|
| 972 |
interior_only_1h=args.interior_only_1h,
|
|
|
|
|
|
|
| 973 |
spatial_corrector=_get_spatial_corrector(model),
|
| 974 |
temporal_corrector=_get_temporal_corrector(model),
|
| 975 |
spatial_correction_scale=args.spatial_correction_scale,
|
| 976 |
temporal_correction_scale=args.temporal_correction_scale,
|
| 977 |
-
|
|
|
|
| 978 |
print_freq=args.val_print_freq,
|
| 979 |
val_max_steps=args.val_max_steps,
|
| 980 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
log_stats = {'epoch': epoch,
|
| 982 |
**{f'train_{k}': v for k, v in train_stats.items()}}
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
|
|
|
|
|
|
|
|
|
| 986 |
|
| 987 |
if args.output_dir and utils.is_main_process():
|
| 988 |
if log_writer is not None:
|
| 989 |
log_writer.flush()
|
| 990 |
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
| 991 |
f.write(json.dumps(log_stats) + "\n")
|
| 992 |
-
|
| 993 |
-
|
|
|
|
|
|
|
| 994 |
|
| 995 |
total_time = time.time() - start_time
|
| 996 |
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
|
|
| 87 |
|
| 88 |
def _attach_factorized_heads(model, surf_vars, upper_vars, level, args):
|
| 89 |
"""Attach factorized temporal/spatial correction heads."""
|
| 90 |
+
temporal_num_steps = 7 if args.include_endpoints else 6
|
| 91 |
model.temporal_corrector = TemporalCorrectionHead(
|
| 92 |
num_surface_vars=len(surf_vars),
|
| 93 |
num_upper_vars=len(upper_vars),
|
|
|
|
| 95 |
hidden_surface=args.temporal_hidden_surface,
|
| 96 |
hidden_upper=args.temporal_hidden_upper,
|
| 97 |
dropout=args.temporal_dropout,
|
| 98 |
+
head_mode=args.temporal_head_mode,
|
| 99 |
+
num_time_steps=temporal_num_steps,
|
| 100 |
)
|
| 101 |
model.spatial_corrector = SpatialCorrectionHead(
|
| 102 |
num_surface_vars=len(surf_vars),
|
|
|
|
| 330 |
help='validation progress print frequency (iterations)')
|
| 331 |
parser.add_argument('--val_max_steps', default=-1, type=int,
|
| 332 |
help='if >0, limit validation iterations per epoch for faster feedback')
|
| 333 |
+
parser.add_argument('--val_freq', default=1, type=int,
|
| 334 |
+
help='run validation every N epochs (1 means every epoch)')
|
| 335 |
parser.add_argument('--persistent_workers', action='store_true', default=True,
|
| 336 |
help='keep dataloader workers persistent across epochs')
|
| 337 |
parser.add_argument('--no_persistent_workers', action='store_false', dest='persistent_workers')
|
|
|
|
| 371 |
parser.add_argument('--lora_mode', default='single', choices=['single', 'all'])
|
| 372 |
parser.add_argument('--disable_backbone_lora_in_factorized', action='store_true',
|
| 373 |
help='If set in finetune_mode=lora, freeze LoRA adapters and train only correction heads.')
|
| 374 |
+
parser.add_argument('--disable_backbone_prior_in_factorized', action='store_true',
|
| 375 |
+
help='If set, do not use Aurora pretrained 6h endpoint prediction as factorized prior/anchor.')
|
| 376 |
parser.add_argument('--lr_data_dir', default='', type=str,
|
| 377 |
help='LR dataset folder under datasets (auto if empty), e.g. 5.625')
|
| 378 |
parser.add_argument('--hr_data_dir', default='', type=str,
|
|
|
|
| 393 |
parser.add_argument('--derive_lr1h_from_hr1h', action='store_true', default=True,
|
| 394 |
help='Build y_lr1h from y_hr1h in training loop to reduce dataset IO.')
|
| 395 |
parser.add_argument('--no_derive_lr1h_from_hr1h', action='store_false', dest='derive_lr1h_from_hr1h')
|
| 396 |
+
parser.add_argument('--include_endpoints', action='store_true', default=False,
|
| 397 |
+
help='If set, use 1h targets [t+0..t+6]. Default uses [t+1..t+6].')
|
| 398 |
+
parser.add_argument('--no_include_endpoints', action='store_false', dest='include_endpoints')
|
| 399 |
parser.add_argument('--lambda_s', default=1.0, type=float)
|
| 400 |
parser.add_argument('--lambda_t', default=1.0, type=float)
|
| 401 |
parser.add_argument('--lambda_st', default=1.0, type=float)
|
| 402 |
parser.add_argument('--lambda_comm', default=1.0, type=float)
|
| 403 |
+
parser.add_argument('--lambda_backbone', default=0.1, type=float,
|
| 404 |
+
help='Weight for Aurora-6h anchor loss on unanchored S-module 6h endpoint prediction.')
|
| 405 |
+
parser.add_argument('--lt_supervision_source', default='lr_path', choices=['lr_path', 'st_downscaled', 'hybrid'],
|
| 406 |
+
help='Source for L_t supervision: direct LR temporal path or ST(HR) downscaled to LR.')
|
| 407 |
+
parser.add_argument('--lt_hybrid_alpha', default=0.5, type=float,
|
| 408 |
+
help='When lt_supervision_source=hybrid, weight on direct lr_path term (0..1).')
|
| 409 |
parser.add_argument('--interior_only_1h', action='store_true',
|
| 410 |
help='If set, compute 1h losses on interior frames only (exclude endpoints)')
|
| 411 |
parser.add_argument('--ours_prompt_length', default=0, type=int,
|
|
|
|
| 416 |
help='Hidden channels for temporal correction head (upper vars x levels).')
|
| 417 |
parser.add_argument('--temporal_dropout', default=0.0, type=float,
|
| 418 |
help='Dropout for temporal correction head.')
|
| 419 |
+
parser.add_argument('--temporal_head_mode', default='shared3d', choices=['shared3d', 'delta2d'],
|
| 420 |
+
help='Temporal correction head type: shared 3D conv or per-delta 2D heads.')
|
| 421 |
parser.add_argument('--spatial_hidden_surface', default=256, type=int,
|
| 422 |
help='Hidden channels for spatial correction head (surface vars).')
|
| 423 |
parser.add_argument('--spatial_hidden_upper', default=512, type=int,
|
|
|
|
| 428 |
help='Residual scale multiplier for spatial correction outputs.')
|
| 429 |
parser.add_argument('--temporal_correction_scale', default=1e-3, type=float,
|
| 430 |
help='Residual scale multiplier for temporal correction outputs.')
|
| 431 |
+
parser.add_argument('--allow_temporal_endpoint_correction', action='store_true',
|
| 432 |
+
help='Allow temporal correction head on offset=0 frame when include_endpoints is used.')
|
| 433 |
parser.add_argument('--disable_factorized_lr_scale', action='store_true',
|
| 434 |
help='Disable LR scaling by total_batch_size/256 for factorized training.')
|
| 435 |
parser.add_argument('--disable_scale_balanced_loss', action='store_true',
|
|
|
|
| 581 |
)
|
| 582 |
|
| 583 |
print("Sampler_train = %s" % str(sampler_train))
|
| 584 |
+
if args.distributed and (dataset_val is not None) and (not args.dist_eval):
|
| 585 |
+
args.dist_eval = True
|
| 586 |
+
print("[validation] Distributed training detected; enabling --dist_eval to shard validation across ranks.")
|
| 587 |
+
|
| 588 |
if args.dist_eval:
|
| 589 |
if len(dataset_val) % num_tasks != 0:
|
| 590 |
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
|
|
|
| 629 |
f"[factorized-data] derive_hr6h_from_hr1h={args.derive_hr6h_from_hr1h}, "
|
| 630 |
f"derive_lr1h_from_hr1h={args.derive_lr1h_from_hr1h}"
|
| 631 |
)
|
| 632 |
+
if args.derive_lr1h_from_hr1h:
|
| 633 |
+
print("[factorized-data] y_lr1h generation: conservative coarsen from y_hr1h (engine-side)")
|
| 634 |
|
| 635 |
if dataset_val is not None:
|
| 636 |
val_loader_kwargs = dict(
|
|
|
|
| 651 |
f"prefetch_factor={args.prefetch_factor if args.num_workers > 0 else 'n/a'}, "
|
| 652 |
f"pin_mem={args.pin_mem}"
|
| 653 |
)
|
| 654 |
+
print(
|
| 655 |
+
f"[validation] samples={len(dataset_val)}, batches={len(data_loader_val)}, "
|
| 656 |
+
f"val_freq={args.val_freq}, val_max_steps={args.val_max_steps}"
|
| 657 |
+
)
|
| 658 |
else:
|
| 659 |
data_loader_val = None
|
| 660 |
|
|
|
|
| 694 |
|
| 695 |
use_ours = finetune_mode == "ours"
|
| 696 |
use_lora = finetune_mode == "lora"
|
| 697 |
+
train_lora_adapters = finetune_mode == "lora" and (not args.disable_backbone_lora_in_factorized)
|
| 698 |
+
use_backbone_prior = not args.disable_backbone_prior_in_factorized
|
| 699 |
if finetune_mode == "lora":
|
| 700 |
if args.lora_mode == "all" and args.lora_steps > 1:
|
| 701 |
print(
|
| 702 |
"[lora mode] NOTE: factorized path uses single-step Aurora forward (rollout_step=0), "
|
| 703 |
"so only step-0 LoRA branch is used even when lora_mode=all."
|
| 704 |
)
|
| 705 |
+
print(f"[lora mode] train LoRA adapters: {train_lora_adapters}")
|
| 706 |
+
print(f"[factorized] use backbone 6h prior: {use_backbone_prior} (lambda_backbone={args.lambda_backbone})")
|
| 707 |
+
print(
|
| 708 |
+
"[factorized] loss lambdas: "
|
| 709 |
+
f"lambda_s={args.lambda_s}, "
|
| 710 |
+
f"lambda_t={args.lambda_t}, "
|
| 711 |
+
f"lambda_st={args.lambda_st}, "
|
| 712 |
+
f"lambda_comm={args.lambda_comm}, "
|
| 713 |
+
f"lambda_backbone={args.lambda_backbone}"
|
| 714 |
+
)
|
| 715 |
+
print(
|
| 716 |
+
f"[factorized] L_t supervision source: {args.lt_supervision_source} "
|
| 717 |
+
f"(hybrid alpha={args.lt_hybrid_alpha:.2f})"
|
| 718 |
+
)
|
| 719 |
|
| 720 |
model = Aurora(
|
| 721 |
use_ours=use_ours,
|
|
|
|
| 743 |
print(
|
| 744 |
"[factorized] Attached correction heads "
|
| 745 |
f"(temporal: {args.temporal_hidden_surface}/{args.temporal_hidden_upper}, "
|
| 746 |
+
f"temporal_mode={args.temporal_head_mode}, "
|
| 747 |
f"spatial: {args.spatial_hidden_surface}/{args.spatial_hidden_upper})."
|
| 748 |
)
|
| 749 |
|
| 750 |
_set_trainable_parameters(
|
| 751 |
model,
|
| 752 |
finetune_mode,
|
| 753 |
+
train_lora_adapters=train_lora_adapters,
|
| 754 |
)
|
| 755 |
total_params, trainable_params = _count_parameters(model)
|
| 756 |
print(f"[mode={finetune_mode}] trainable params: {trainable_params:,} / {total_params:,}")
|
|
|
|
| 918 |
lambda_t=args.lambda_t,
|
| 919 |
lambda_st=args.lambda_st,
|
| 920 |
lambda_comm=args.lambda_comm,
|
| 921 |
+
lambda_backbone=args.lambda_backbone,
|
| 922 |
surface_var_weights=surface_var_weights,
|
| 923 |
upper_var_weights=upper_var_level_weights,
|
| 924 |
upper_level_weights=None,
|
| 925 |
interior_only_1h=args.interior_only_1h,
|
| 926 |
+
lt_supervision_source=args.lt_supervision_source,
|
| 927 |
+
lt_hybrid_alpha=args.lt_hybrid_alpha,
|
| 928 |
spatial_corrector=_get_spatial_corrector(model),
|
| 929 |
temporal_corrector=_get_temporal_corrector(model),
|
| 930 |
spatial_correction_scale=args.spatial_correction_scale,
|
| 931 |
temporal_correction_scale=args.temporal_correction_scale,
|
| 932 |
+
temporal_fix_endpoints=(not args.allow_temporal_endpoint_correction),
|
| 933 |
+
enable_backbone_lora=use_backbone_prior,
|
| 934 |
print_freq=args.val_print_freq,
|
| 935 |
val_max_steps=args.val_max_steps,
|
| 936 |
)
|
|
|
|
| 955 |
f"[factorized] correction scales: spatial={args.spatial_correction_scale:.2e}, "
|
| 956 |
f"temporal={args.temporal_correction_scale:.2e}"
|
| 957 |
)
|
| 958 |
+
print(
|
| 959 |
+
f"[factorized] temporal endpoint correction enabled: "
|
| 960 |
+
f"{args.allow_temporal_endpoint_correction}"
|
| 961 |
+
)
|
| 962 |
start_time = time.time()
|
| 963 |
train_time_only = 0
|
| 964 |
total_step = args.epochs * num_training_steps_per_epoch
|
|
|
|
| 980 |
surf_vars=surf_vars, upper_vars=upper_vars,
|
| 981 |
lambda_s=args.lambda_s, lambda_t=args.lambda_t,
|
| 982 |
lambda_st=args.lambda_st, lambda_comm=args.lambda_comm,
|
| 983 |
+
lambda_backbone=args.lambda_backbone,
|
| 984 |
surface_var_weights=surface_var_weights,
|
| 985 |
upper_var_weights=upper_var_level_weights,
|
| 986 |
upper_level_weights=None,
|
| 987 |
interior_only_1h=args.interior_only_1h,
|
| 988 |
+
lt_supervision_source=args.lt_supervision_source,
|
| 989 |
+
lt_hybrid_alpha=args.lt_hybrid_alpha,
|
| 990 |
spatial_corrector=_get_spatial_corrector(model),
|
| 991 |
temporal_corrector=_get_temporal_corrector(model),
|
| 992 |
spatial_correction_scale=args.spatial_correction_scale,
|
| 993 |
temporal_correction_scale=args.temporal_correction_scale,
|
| 994 |
+
temporal_fix_endpoints=(not args.allow_temporal_endpoint_correction),
|
| 995 |
+
enable_backbone_lora=use_backbone_prior,
|
| 996 |
accum_backward_mode=args.accum_backward_mode,
|
| 997 |
use_ours=use_ours, total_step=total_step
|
| 998 |
)
|
|
|
|
| 1003 |
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
| 1004 |
loss_scaler=loss_scaler, epoch=epoch)
|
| 1005 |
|
| 1006 |
+
should_validate = (
|
| 1007 |
+
data_loader_val is not None
|
| 1008 |
+
and (args.val_freq > 0)
|
| 1009 |
+
and (((epoch + 1) % args.val_freq == 0) or (epoch + 1 == args.epochs))
|
| 1010 |
+
)
|
| 1011 |
+
test_stats = None
|
| 1012 |
+
if should_validate:
|
| 1013 |
+
val_start_time = time.time()
|
| 1014 |
+
val_rank_samples = len(data_loader_val.sampler) if hasattr(data_loader_val, "sampler") else len(data_loader_val.dataset)
|
| 1015 |
+
print(
|
| 1016 |
+
f"[validation] epoch={epoch} start: rank_samples={val_rank_samples}, "
|
| 1017 |
+
f"batches={len(data_loader_val)}"
|
| 1018 |
+
)
|
| 1019 |
test_stats = validation_one_epoch_factorized(
|
| 1020 |
data_loader_val,
|
| 1021 |
model,
|
|
|
|
| 1033 |
lambda_t=args.lambda_t,
|
| 1034 |
lambda_st=args.lambda_st,
|
| 1035 |
lambda_comm=args.lambda_comm,
|
| 1036 |
+
lambda_backbone=args.lambda_backbone,
|
| 1037 |
surface_var_weights=surface_var_weights,
|
| 1038 |
upper_var_weights=upper_var_level_weights,
|
| 1039 |
upper_level_weights=None,
|
| 1040 |
interior_only_1h=args.interior_only_1h,
|
| 1041 |
+
lt_supervision_source=args.lt_supervision_source,
|
| 1042 |
+
lt_hybrid_alpha=args.lt_hybrid_alpha,
|
| 1043 |
spatial_corrector=_get_spatial_corrector(model),
|
| 1044 |
temporal_corrector=_get_temporal_corrector(model),
|
| 1045 |
spatial_correction_scale=args.spatial_correction_scale,
|
| 1046 |
temporal_correction_scale=args.temporal_correction_scale,
|
| 1047 |
+
temporal_fix_endpoints=(not args.allow_temporal_endpoint_correction),
|
| 1048 |
+
enable_backbone_lora=use_backbone_prior,
|
| 1049 |
print_freq=args.val_print_freq,
|
| 1050 |
val_max_steps=args.val_max_steps,
|
| 1051 |
)
|
| 1052 |
+
val_elapsed = time.time() - val_start_time
|
| 1053 |
+
val_steps_done = min(len(data_loader_val), args.val_max_steps) if args.val_max_steps > 0 else len(data_loader_val)
|
| 1054 |
+
print(
|
| 1055 |
+
f"[validation] epoch={epoch} done in {val_elapsed:.1f}s "
|
| 1056 |
+
f"(~{val_elapsed / max(val_rank_samples, 1):.3f}s/sample on this rank, "
|
| 1057 |
+
f"{val_steps_done} batches)."
|
| 1058 |
+
)
|
| 1059 |
+
elif data_loader_val is not None:
|
| 1060 |
+
print(f"[validation] epoch={epoch} skipped (val_freq={args.val_freq}).")
|
| 1061 |
log_stats = {'epoch': epoch,
|
| 1062 |
**{f'train_{k}': v for k, v in train_stats.items()}}
|
| 1063 |
+
if test_stats is not None:
|
| 1064 |
+
test_log = {**{f'val_{k}': _to_float(v) for k, v in test_stats.items()}}
|
| 1065 |
+
copy_log = " &"+" &".join([str(round(_to_float(v), 3)) for k, v in test_stats.items() if k!="valid_loss"])
|
| 1066 |
+
else:
|
| 1067 |
+
test_log = None
|
| 1068 |
+
copy_log = None
|
| 1069 |
|
| 1070 |
if args.output_dir and utils.is_main_process():
|
| 1071 |
if log_writer is not None:
|
| 1072 |
log_writer.flush()
|
| 1073 |
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
| 1074 |
f.write(json.dumps(log_stats) + "\n")
|
| 1075 |
+
if test_log is not None:
|
| 1076 |
+
f.write(json.dumps(test_log) + "\n")
|
| 1077 |
+
f.write(json.dumps(copy_log) + "\n")
|
| 1078 |
+
f.write("\n")
|
| 1079 |
|
| 1080 |
total_time = time.time() - start_time
|
| 1081 |
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
downscaling/script_run_downscaling.sh
CHANGED
|
@@ -10,6 +10,22 @@ VAL_END_DATE="2018"
|
|
| 10 |
OUTPUT_DIR="output/WeatherPEFT"
|
| 11 |
GRAD_ACCUM_STEPS=8
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
OMP_NUM_THREADS=1 python -m torch.distributed.run --nproc_per_node=2 \
|
| 14 |
--master_port 12326 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" \
|
| 15 |
run_downscaling.py \
|
|
@@ -21,18 +37,19 @@ OMP_NUM_THREADS=1 python -m torch.distributed.run --nproc_per_node=2 \
|
|
| 21 |
--output_dir ${OUTPUT_DIR} \
|
| 22 |
--batch_size 2 \
|
| 23 |
--grad_accum_steps ${GRAD_ACCUM_STEPS} \
|
| 24 |
-
--val_batch_size
|
| 25 |
--num_workers 8 \
|
| 26 |
--persistent_workers \
|
| 27 |
--prefetch_factor 4 \
|
| 28 |
--save_ckpt_freq 10 \
|
| 29 |
--opt adamw \
|
| 30 |
-
--lr
|
| 31 |
--opt_betas 0.9 0.999 \
|
| 32 |
--weight_decay 0.05 \
|
| 33 |
--clip_grad 1.0 \
|
| 34 |
--warmup_epochs 3 \
|
| 35 |
--epochs 30 \
|
|
|
|
| 36 |
--dist_eval \
|
| 37 |
--mode lora \
|
| 38 |
--finetune_mode lora \
|
|
@@ -44,9 +61,22 @@ OMP_NUM_THREADS=1 python -m torch.distributed.run --nproc_per_node=2 \
|
|
| 44 |
--hr1h_data_dir 1.5_1h_nc \
|
| 45 |
--lr_degree_tag 5.625 \
|
| 46 |
--hr_degree_tag 1.5 \
|
|
|
|
| 47 |
--norm_mean_path ../aux_data/normalize_mean_wb2.npz \
|
| 48 |
--norm_std_path ../aux_data/normalize_std_wb2.npz \
|
| 49 |
-
--norm_std_floor 1e-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
--no_auto_resume \
|
| 51 |
--spatial_correction_scale 1.0 \
|
| 52 |
--temporal_correction_scale 1.0 \
|
|
|
|
| 10 |
OUTPUT_DIR="output/WeatherPEFT"
|
| 11 |
GRAD_ACCUM_STEPS=8
|
| 12 |
|
| 13 |
+
# Preserve-first recipe:
|
| 14 |
+
# 1) keep pretrained 6h behavior strong at start
|
| 15 |
+
# 2) learn 1h~5h corrections gradually
|
| 16 |
+
LAMBDA_BACKBONE=1.0
|
| 17 |
+
LAMBDA_S=1.0
|
| 18 |
+
LAMBDA_T=1.0
|
| 19 |
+
LAMBDA_ST=1.0
|
| 20 |
+
LAMBDA_COMM=1.0
|
| 21 |
+
LT_SUPERVISION_SOURCE="lr_path"
|
| 22 |
+
SPATIAL_HIDDEN_SURFACE=64
|
| 23 |
+
SPATIAL_HIDDEN_UPPER=128
|
| 24 |
+
TEMPORAL_HIDDEN_SURFACE=64
|
| 25 |
+
TEMPORAL_HIDDEN_UPPER=128
|
| 26 |
+
TEMPORAL_HEAD_MODE="delta2d"
|
| 27 |
+
# backbone prior is ON by default (do not pass --disable_backbone_prior_in_factorized)
|
| 28 |
+
|
| 29 |
OMP_NUM_THREADS=1 python -m torch.distributed.run --nproc_per_node=2 \
|
| 30 |
--master_port 12326 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" \
|
| 31 |
run_downscaling.py \
|
|
|
|
| 37 |
--output_dir ${OUTPUT_DIR} \
|
| 38 |
--batch_size 2 \
|
| 39 |
--grad_accum_steps ${GRAD_ACCUM_STEPS} \
|
| 40 |
+
--val_batch_size 2 \
|
| 41 |
--num_workers 8 \
|
| 42 |
--persistent_workers \
|
| 43 |
--prefetch_factor 4 \
|
| 44 |
--save_ckpt_freq 10 \
|
| 45 |
--opt adamw \
|
| 46 |
+
--lr 5e-4 \
|
| 47 |
--opt_betas 0.9 0.999 \
|
| 48 |
--weight_decay 0.05 \
|
| 49 |
--clip_grad 1.0 \
|
| 50 |
--warmup_epochs 3 \
|
| 51 |
--epochs 30 \
|
| 52 |
+
--val_freq 1 \
|
| 53 |
--dist_eval \
|
| 54 |
--mode lora \
|
| 55 |
--finetune_mode lora \
|
|
|
|
| 61 |
--hr1h_data_dir 1.5_1h_nc \
|
| 62 |
--lr_degree_tag 5.625 \
|
| 63 |
--hr_degree_tag 1.5 \
|
| 64 |
+
--no_include_endpoints \
|
| 65 |
--norm_mean_path ../aux_data/normalize_mean_wb2.npz \
|
| 66 |
--norm_std_path ../aux_data/normalize_std_wb2.npz \
|
| 67 |
+
--norm_std_floor 1e-8 \
|
| 68 |
+
--lambda_s ${LAMBDA_S} \
|
| 69 |
+
--lambda_t ${LAMBDA_T} \
|
| 70 |
+
--lambda_st ${LAMBDA_ST} \
|
| 71 |
+
--lambda_comm ${LAMBDA_COMM} \
|
| 72 |
+
--lambda_backbone ${LAMBDA_BACKBONE} \
|
| 73 |
+
--lt_supervision_source ${LT_SUPERVISION_SOURCE} \
|
| 74 |
+
--disable_scale_balanced_loss \
|
| 75 |
+
--spatial_hidden_surface ${SPATIAL_HIDDEN_SURFACE} \
|
| 76 |
+
--spatial_hidden_upper ${SPATIAL_HIDDEN_UPPER} \
|
| 77 |
+
--temporal_hidden_surface ${TEMPORAL_HIDDEN_SURFACE} \
|
| 78 |
+
--temporal_hidden_upper ${TEMPORAL_HIDDEN_UPPER} \
|
| 79 |
+
--temporal_head_mode ${TEMPORAL_HEAD_MODE} \
|
| 80 |
--no_auto_resume \
|
| 81 |
--spatial_correction_scale 1.0 \
|
| 82 |
--temporal_correction_scale 1.0 \
|
downscaling/script_smoke_feasibility.sh
CHANGED
|
@@ -48,6 +48,7 @@ for MODE in "${MODES[@]}"; do
|
|
| 48 |
--hr1h_data_dir 1.5_1h_nc \
|
| 49 |
--lr_degree_tag 5.625 \
|
| 50 |
--hr_degree_tag 1.5 \
|
|
|
|
| 51 |
--norm_mean_path ../aux_data/normalize_mean_wb2.npz \
|
| 52 |
--norm_std_path ../aux_data/normalize_std_wb2.npz \
|
| 53 |
--norm_std_floor 1e-8 \
|
|
|
|
| 48 |
--hr1h_data_dir 1.5_1h_nc \
|
| 49 |
--lr_degree_tag 5.625 \
|
| 50 |
--hr_degree_tag 1.5 \
|
| 51 |
+
--no_include_endpoints \
|
| 52 |
--norm_mean_path ../aux_data/normalize_mean_wb2.npz \
|
| 53 |
--norm_std_path ../aux_data/normalize_std_wb2.npz \
|
| 54 |
--norm_std_floor 1e-8 \
|