bidulki-99 commited on
Commit
1e7bad2
·
verified ·
1 Parent(s): 844ba07

Resync verified files (1/1)

Browse files
.gitattributes CHANGED
@@ -1733,3 +1733,5 @@ aux_data/era5_tp_6hr_china.nc filter=lfs diff=lfs merge=lfs -text
1733
  datasets/1.5_nc/2m_temperature/2m_temperature_2015_1.5deg.nc filter=lfs diff=lfs merge=lfs -text
1734
  aux_data/climatology/total_precipitation_6hr_seeps_dry_fraction.nc filter=lfs diff=lfs merge=lfs -text
1735
  aux_data/climatology/total_precipitation_6hr_seeps_threshold.nc filter=lfs diff=lfs merge=lfs -text
 
 
 
1733
  datasets/1.5_nc/2m_temperature/2m_temperature_2015_1.5deg.nc filter=lfs diff=lfs merge=lfs -text
1734
  aux_data/climatology/total_precipitation_6hr_seeps_dry_fraction.nc filter=lfs diff=lfs merge=lfs -text
1735
  aux_data/climatology/total_precipitation_6hr_seeps_threshold.nc filter=lfs diff=lfs merge=lfs -text
1736
+ datasets/1.5_nc/2m_temperature/2m_temperature_2012_1.5deg.nc filter=lfs diff=lfs merge=lfs -text
1737
+ datasets/1.5_nc/2m_temperature/2m_temperature_2016_1.5deg.nc filter=lfs diff=lfs merge=lfs -text
datasets/1.5_nc/2m_temperature/2m_temperature_2012_1.5deg.nc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f72d50b681a70895fb3c36b02b65d1c294e4e252f5bea0e4cccc36467c56f9a3
3
+ size 115216361
datasets/1.5_nc/2m_temperature/2m_temperature_2016_1.5deg.nc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92a5a6e02d41079c19dccde429f1071d4e51bd60fa92d1af2670694031551d3c
3
+ size 114847590
downscaling/engine.py CHANGED
@@ -41,40 +41,119 @@ class TemporalCorrectionHead(nn.Module):
41
  hidden_surface: int = 256,
42
  hidden_upper: int = 512,
43
  dropout: float = 0.0,
 
 
44
  ):
45
  super().__init__()
46
  upper_ch = num_upper_vars * num_levels
47
-
48
- self.surface_net = nn.Sequential(
49
- nn.Conv3d(num_surface_vars, hidden_surface, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
50
- nn.GELU(),
51
- nn.Dropout3d(dropout) if dropout > 0 else nn.Identity(),
52
- nn.Conv3d(hidden_surface, num_surface_vars, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
53
- )
54
- self.upper_net = nn.Sequential(
55
- nn.Conv3d(upper_ch, hidden_upper, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
56
- nn.GELU(),
57
- nn.Dropout3d(dropout) if dropout > 0 else nn.Identity(),
58
- nn.Conv3d(hidden_upper, upper_ch, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
59
- )
60
-
61
- # Start from interpolation baseline: initial correction ~= 0.
62
- nn.init.zeros_(self.surface_net[-1].weight)
63
- nn.init.zeros_(self.surface_net[-1].bias)
64
- nn.init.zeros_(self.upper_net[-1].weight)
65
- nn.init.zeros_(self.upper_net[-1].bias)
66
-
67
- def forward(self, base_surface, base_upper):
68
- # base_surface: [B, T, V, H, W]
69
- # base_upper: [B, T, V, C, H, W]
70
- x_surface = rearrange(base_surface, "b t v h w -> b v t h w")
71
- corr_surface = self.surface_net(x_surface)
72
- corr_surface = rearrange(corr_surface, "b v t h w -> b t v h w")
73
-
74
- x_upper = rearrange(base_upper, "b t v c h w -> b (v c) t h w")
75
- corr_upper = self.upper_net(x_upper)
76
- corr_upper = rearrange(corr_upper, "b (v c) t h w -> b t v c h w", v=base_upper.shape[2], c=base_upper.shape[3])
77
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  return corr_surface, corr_upper
79
 
80
 
@@ -147,23 +226,41 @@ def _default_spatial_interp(surface, upper, target_hw):
147
  return surface_up, upper_up
148
 
149
 
150
- def _default_temporal_interp(surface_endpoints, upper_endpoints, out_steps):
151
- """Linear interpolation along time between the two endpoint frames."""
 
 
 
 
152
  if surface_endpoints.shape[1] < 2 or upper_endpoints.shape[1] < 2:
153
  raise ValueError("Temporal interpolation needs at least two endpoint frames.")
154
 
155
- s0 = surface_endpoints[:, 0].unsqueeze(1)
156
- s1 = surface_endpoints[:, -1].unsqueeze(1)
157
- u0 = upper_endpoints[:, 0].unsqueeze(1)
158
- u1 = upper_endpoints[:, -1].unsqueeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- alpha_s = torch.linspace(
161
- 0.0, 1.0, out_steps, device=surface_endpoints.device, dtype=surface_endpoints.dtype
162
- ).view(1, out_steps, 1, 1, 1)
163
- alpha_u = alpha_s.view(1, out_steps, 1, 1, 1, 1)
164
 
165
- surface_interp = (1.0 - alpha_s) * s0 + alpha_s * s1
166
- upper_interp = (1.0 - alpha_u) * u0 + alpha_u * u1
167
  return surface_interp, upper_interp
168
 
169
 
@@ -214,26 +311,33 @@ def forward_paths(
214
  target_hw_6h = batch["y_hr6h_surface"].shape[-2:]
215
  target_hw_1h = batch["y_hr1h_surface"].shape[-2:]
216
  out_steps_1h = batch["y_hr1h_surface"].shape[1]
 
217
 
218
  # Spatial path at 6h: S(x) = U_s(x) + C_s(x)
 
 
219
  s_base_surface, s_base_upper = spatial_interp_fn(x_lr6h_surface, x_lr6h_upper, target_hw_6h)
 
 
 
 
 
 
 
 
 
220
  pred_s_hr6h_surface, pred_s_hr6h_upper = _apply_correction(
221
  spatial_correction_fn,
222
  x_lr6h_surface,
223
  x_lr6h_upper,
224
- s_base_surface,
225
- s_base_upper,
226
  )
227
- # Optional Aurora-backbone replacement for S-path last endpoint (enables LoRA gradient path).
228
- if backbone_s_last_surface is not None:
229
- pred_s_hr6h_surface = pred_s_hr6h_surface.clone()
230
- pred_s_hr6h_surface[:, -1] = backbone_s_last_surface
231
- if backbone_s_last_upper is not None:
232
- pred_s_hr6h_upper = pred_s_hr6h_upper.clone()
233
- pred_s_hr6h_upper[:, -1] = backbone_s_last_upper
234
 
235
  # Temporal path at LR: T(x) = U_t(x) + C_t(x)
236
- t_base_surface, t_base_upper = temporal_interp_fn(x_lr6h_surface, x_lr6h_upper, out_steps_1h)
 
 
237
  pred_t_lr1h_surface, pred_t_lr1h_upper = _apply_correction(
238
  temporal_correction_fn,
239
  x_lr6h_surface,
@@ -253,7 +357,9 @@ def forward_paths(
253
  )
254
 
255
  # ST path: x -> S -> T
256
- st_base_surface, st_base_upper = temporal_interp_fn(pred_s_hr6h_surface, pred_s_hr6h_upper, out_steps_1h)
 
 
257
  pred_st_hr1h_surface, pred_st_hr1h_upper = _apply_correction(
258
  temporal_correction_fn,
259
  pred_s_hr6h_surface,
@@ -265,6 +371,8 @@ def forward_paths(
265
  return {
266
  "pred_s_hr6h_surface": pred_s_hr6h_surface,
267
  "pred_s_hr6h_upper": pred_s_hr6h_upper,
 
 
268
  "pred_t_lr1h_surface": pred_t_lr1h_surface,
269
  "pred_t_lr1h_upper": pred_t_lr1h_upper,
270
  "pred_ts_hr1h_surface": pred_ts_hr1h_surface,
@@ -317,6 +425,8 @@ def _downscale_style_term_loss(
317
  surf_vars=None,
318
  upper_vars=None,
319
  level=None,
 
 
320
  ):
321
  """
322
  Match train_one_epoch_downscale style:
@@ -329,8 +439,33 @@ def _downscale_style_term_loss(
329
  pred_upper_n = _normalise_upper_bt_vchw(pred_upper, upper_vars, level)
330
  target_upper_n = _normalise_upper_bt_vchw(target_upper, upper_vars, level)
331
 
332
- loss_surface = F.mse_loss(pred_surface_n, target_surface_n)
333
- loss_upper = F.mse_loss(pred_upper_n, target_upper_n)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  return 0.25 * loss_surface + loss_upper
335
 
336
 
@@ -341,10 +476,13 @@ def loss_fn(
341
  lambda_t=1.0,
342
  lambda_st=1.0,
343
  lambda_comm=1.0,
 
344
  surface_var_weights=None,
345
  upper_var_weights=None,
346
  upper_level_weights=None,
347
  interior_only_1h=False,
 
 
348
  surf_vars=None,
349
  upper_vars=None,
350
  level=None,
@@ -356,8 +494,9 @@ def loss_fn(
356
  return x[:, 1:-1]
357
  return x
358
 
359
- # Keep signature compatibility; old downscale style does not use these explicit weights.
360
- _ = surface_var_weights, upper_var_weights, upper_level_weights
 
361
 
362
  # L_S: S(x) supervised by HR-6h
363
  l_s = _downscale_style_term_loss(
@@ -368,10 +507,14 @@ def loss_fn(
368
  surf_vars=surf_vars,
369
  upper_vars=upper_vars,
370
  level=level,
 
 
371
  )
372
 
373
- # L_T: T(x) supervised by LR-1h
374
- l_t = _downscale_style_term_loss(
 
 
375
  _trim_time(preds["pred_t_lr1h_surface"]),
376
  _trim_time(preds["pred_t_lr1h_upper"]),
377
  _trim_time(batch["y_lr1h_surface"]),
@@ -379,7 +522,33 @@ def loss_fn(
379
  surf_vars=surf_vars,
380
  upper_vars=upper_vars,
381
  level=level,
 
 
382
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  # L_ST: ST path supervised by HR-1h
385
  l_st = _downscale_style_term_loss(
@@ -390,9 +559,11 @@ def loss_fn(
390
  surf_vars=surf_vars,
391
  upper_vars=upper_vars,
392
  level=level,
 
 
393
  )
394
 
395
- # L_comm: TS and ST consistency
396
  l_comm = _downscale_style_term_loss(
397
  _trim_time(preds["pred_ts_hr1h_surface"]),
398
  _trim_time(preds["pred_ts_hr1h_upper"]),
@@ -401,15 +572,47 @@ def loss_fn(
401
  surf_vars=surf_vars,
402
  upper_vars=upper_vars,
403
  level=level,
 
 
404
  )
405
 
406
- total = lambda_s * l_s + lambda_t * l_t + lambda_st * l_st + lambda_comm * l_comm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  loss_dict = {
408
  "loss_total": total,
409
  "loss_s": l_s,
410
  "loss_t": l_t,
411
  "loss_st": l_st,
412
  "loss_comm": l_comm,
 
413
  }
414
  return total, loss_dict
415
 
@@ -687,6 +890,21 @@ def _trim_to_hw(x, h, w):
687
  return x[..., :h, :w]
688
 
689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
  def _resize_surface_bt_vhw(x, out_hw):
691
  b, t, v, h, w = x.shape
692
  th, tw = out_hw
@@ -709,12 +927,123 @@ def _resize_upper_bt_vchw(x, out_hw):
709
  ).reshape(b, t, v, c, th, tw)
710
 
711
 
712
- def _prepare_factorized_batch(batch, hr_hw, lr_hw, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  h_hr, w_hr = hr_hw
714
  h_lr, w_lr = lr_hw
715
 
716
  def _to_device(x, h, w):
717
- return _trim_to_hw(x.float(), h, w).to(device, non_blocking=True)
 
718
 
719
  out = {}
720
  out["x_lr6h_surface"] = _to_device(batch["x_lr6h_surface"], h_lr, w_lr)
@@ -741,8 +1070,30 @@ def _prepare_factorized_batch(batch, hr_hw, lr_hw, device):
741
  out["y_lr1h_surface"] = _to_device(batch["y_lr1h_surface"], h_lr, w_lr)
742
  out["y_lr1h_upper"] = _to_device(batch["y_lr1h_upper"], h_lr, w_lr)
743
  else:
744
- out["y_lr1h_surface"] = _resize_surface_bt_vhw(out["y_hr1h_surface"], (h_lr, w_lr))
745
- out["y_lr1h_upper"] = _resize_upper_bt_vchw(out["y_hr1h_upper"], (h_lr, w_lr))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
 
747
  out["time_unix"] = batch["time_unix"]
748
  return out
@@ -851,6 +1202,8 @@ def _build_temporal_correction_fn(
851
  upper_vars=None,
852
  level=None,
853
  correction_scale=1.0,
 
 
854
  ):
855
  if temporal_corrector is None:
856
  def _temporal_correction_zeros(input_surface, input_upper, base_surface, base_upper):
@@ -859,12 +1212,35 @@ def _build_temporal_correction_fn(
859
  return _temporal_correction_zeros
860
 
861
  def _temporal_correction(input_surface, input_upper, base_surface, base_upper):
 
 
862
  base_surface_n = _normalise_surface_bt_vhw(base_surface, surf_vars)
863
  base_upper_n = _normalise_upper_bt_vchw(base_upper, upper_vars, level)
864
- corr_surface_n, corr_upper_n = temporal_corrector(base_surface_n, base_upper_n)
 
 
865
 
866
  corr_surface_n = torch.nan_to_num(corr_surface_n, nan=0.0, posinf=0.0, neginf=0.0)
867
  corr_upper_n = torch.nan_to_num(corr_upper_n, nan=0.0, posinf=0.0, neginf=0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
  s_scale = _surface_scale_tensor(surf_vars, base_surface)
870
  u_scale = _upper_scale_tensor(upper_vars, level, base_upper)
@@ -902,14 +1278,18 @@ def train_one_epoch_factorized(
902
  lambda_t=1.0,
903
  lambda_st=1.0,
904
  lambda_comm=1.0,
 
905
  surface_var_weights=None,
906
  upper_var_weights=None,
907
  upper_level_weights=None,
908
  interior_only_1h=False,
 
 
909
  spatial_corrector=None,
910
  temporal_corrector=None,
911
  spatial_correction_scale=1.0,
912
  temporal_correction_scale=1.0,
 
913
  enable_backbone_lora=False,
914
  accum_backward_mode="micro",
915
  use_ours=False,
@@ -937,6 +1317,9 @@ def train_one_epoch_factorized(
937
  backward_calls = 0
938
  pending_loss = None
939
 
 
 
 
940
  for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
941
  step = data_iter_step // update_freq
942
  if step >= num_training_steps_per_epoch:
@@ -956,6 +1339,10 @@ def train_one_epoch_factorized(
956
  hr_hw=hr_hw,
957
  lr_hw=lr_hw,
958
  device=device,
 
 
 
 
959
  )
960
 
961
  spatial_correction_fn = _build_spatial_correction_fn(
@@ -971,6 +1358,8 @@ def train_one_epoch_factorized(
971
  upper_vars=upper_vars,
972
  level=level,
973
  correction_scale=temporal_correction_scale,
 
 
974
  )
975
  backbone_last_surface = None
976
  backbone_last_upper = None
@@ -993,6 +1382,20 @@ def train_one_epoch_factorized(
993
  backbone_s_last_surface=backbone_last_surface,
994
  backbone_s_last_upper=backbone_last_upper,
995
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
996
  loss, loss_dict = loss_fn(
997
  preds=preds,
998
  batch=batch_f,
@@ -1000,18 +1403,29 @@ def train_one_epoch_factorized(
1000
  lambda_t=lambda_t,
1001
  lambda_st=lambda_st,
1002
  lambda_comm=lambda_comm,
 
1003
  surface_var_weights=surface_var_weights,
1004
  upper_var_weights=upper_var_weights,
1005
  upper_level_weights=upper_level_weights,
1006
  interior_only_1h=interior_only_1h,
 
 
1007
  surf_vars=surf_vars,
1008
  upper_vars=upper_vars,
1009
  level=level,
1010
  )
1011
-
1012
  loss_value = loss.item()
1013
  if not np.isfinite(loss_value):
1014
  print(f"[factorized] non-finite loss at iter {data_iter_step}; skip update.")
 
 
 
 
 
 
 
 
 
1015
  optimizer.zero_grad()
1016
  pending_loss = None
1017
  continue
@@ -1064,6 +1478,7 @@ def train_one_epoch_factorized(
1064
  metric_logger.update(loss_t=loss_dict["loss_t"].item())
1065
  metric_logger.update(loss_st=loss_dict["loss_st"].item())
1066
  metric_logger.update(loss_comm=loss_dict["loss_comm"].item())
 
1067
  metric_logger.update(loss_scale=loss_scale_value)
1068
 
1069
  min_lr = 10.0
@@ -1090,6 +1505,7 @@ def train_one_epoch_factorized(
1090
  log_writer.update(loss_t=loss_dict["loss_t"].item(), head="loss")
1091
  log_writer.update(loss_st=loss_dict["loss_st"].item(), head="loss")
1092
  log_writer.update(loss_comm=loss_dict["loss_comm"].item(), head="loss")
 
1093
  log_writer.update(loss_scale=loss_scale_value, head="opt")
1094
  log_writer.update(lr=max_lr, head="opt")
1095
  log_writer.update(min_lr=min_lr, head="opt")
@@ -1126,14 +1542,18 @@ def validation_one_epoch_factorized(
1126
  lambda_t=1.0,
1127
  lambda_st=1.0,
1128
  lambda_comm=1.0,
 
1129
  surface_var_weights=None,
1130
  upper_var_weights=None,
1131
  upper_level_weights=None,
1132
  interior_only_1h=False,
 
 
1133
  spatial_corrector=None,
1134
  temporal_corrector=None,
1135
  spatial_correction_scale=1.0,
1136
  temporal_correction_scale=1.0,
 
1137
  enable_backbone_lora=False,
1138
  print_freq=20,
1139
  val_max_steps=-1,
@@ -1142,6 +1562,7 @@ def validation_one_epoch_factorized(
1142
  metric_logger.add_meter("prep_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
1143
  metric_logger.add_meter("model_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
1144
  metric_logger.add_meter("post_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
 
1145
  header = "Val:"
1146
 
1147
  model.eval()
@@ -1159,14 +1580,6 @@ def validation_one_epoch_factorized(
1159
  level=level,
1160
  correction_scale=spatial_correction_scale,
1161
  )
1162
- temporal_correction_fn = _build_temporal_correction_fn(
1163
- temporal_corrector=temporal_corrector,
1164
- surf_vars=surf_vars,
1165
- upper_vars=upper_vars,
1166
- level=level,
1167
- correction_scale=temporal_correction_scale,
1168
- )
1169
-
1170
  warned_short_seq = False
1171
  for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
1172
  if val_max_steps is not None and val_max_steps > 0 and data_iter_step >= val_max_steps:
@@ -1177,12 +1590,25 @@ def validation_one_epoch_factorized(
1177
  hr_hw=hr_hw,
1178
  lr_hw=lr_hw,
1179
  device=device,
 
 
 
 
1180
  )
1181
  if torch.cuda.is_available():
1182
  torch.cuda.synchronize()
1183
  t1 = time.perf_counter()
1184
 
1185
  with torch.inference_mode():
 
 
 
 
 
 
 
 
 
1186
  backbone_last_surface = None
1187
  backbone_last_upper = None
1188
  if enable_backbone_lora:
@@ -1203,6 +1629,20 @@ def validation_one_epoch_factorized(
1203
  backbone_s_last_surface=backbone_last_surface,
1204
  backbone_s_last_upper=backbone_last_upper,
1205
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1206
  loss, loss_dict = loss_fn(
1207
  preds=preds,
1208
  batch=batch_f,
@@ -1210,10 +1650,13 @@ def validation_one_epoch_factorized(
1210
  lambda_t=lambda_t,
1211
  lambda_st=lambda_st,
1212
  lambda_comm=lambda_comm,
 
1213
  surface_var_weights=surface_var_weights,
1214
  upper_var_weights=upper_var_weights,
1215
  upper_level_weights=upper_level_weights,
1216
  interior_only_1h=interior_only_1h,
 
 
1217
  surf_vars=surf_vars,
1218
  upper_vars=upper_vars,
1219
  level=level,
@@ -1227,7 +1670,8 @@ def validation_one_epoch_factorized(
1227
  target_seq_surface = batch_f["y_hr1h_surface"]
1228
  target_seq_upper = batch_f["y_hr1h_upper"]
1229
 
1230
- # Default setting is include_endpoints=True -> sequence is [t0, t+1, ..., t+6] (len=7).
 
1231
  # We expose per-delta metrics for t+1..t+6 when available.
1232
  seq_len = pred_seq_surface.shape[1]
1233
  if seq_len >= 7:
@@ -1269,6 +1713,7 @@ def validation_one_epoch_factorized(
1269
  metric_logger.update(loss_t=loss_dict["loss_t"].item())
1270
  metric_logger.update(loss_st=loss_dict["loss_st"].item())
1271
  metric_logger.update(loss_comm=loss_dict["loss_comm"].item())
 
1272
  metric_logger.update(prep_s=t1 - t0)
1273
  metric_logger.update(model_s=t2 - t1)
1274
  metric_logger.update(post_s=t3 - t2)
 
41
  hidden_surface: int = 256,
42
  hidden_upper: int = 512,
43
  dropout: float = 0.0,
44
+ head_mode: str = "shared3d",
45
+ num_time_steps: int = 7,
46
  ):
47
  super().__init__()
48
  upper_ch = num_upper_vars * num_levels
49
+ self.head_mode = head_mode
50
+ self.num_time_steps = num_time_steps
51
+
52
+ if self.head_mode == "shared3d":
53
+ # Endpoint-conditioned temporal correction:
54
+ # input channels = [endpoint_{t-6}, base_{t+h}, endpoint_t]
55
+ # for each variable channel group.
56
+ self.surface_net = nn.Sequential(
57
+ nn.Conv3d(num_surface_vars * 3, hidden_surface, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
58
+ nn.GELU(),
59
+ nn.Dropout3d(dropout) if dropout > 0 else nn.Identity(),
60
+ nn.Conv3d(hidden_surface, num_surface_vars, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
61
+ )
62
+ self.upper_net = nn.Sequential(
63
+ nn.Conv3d(upper_ch * 3, hidden_upper, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
64
+ nn.GELU(),
65
+ nn.Dropout3d(dropout) if dropout > 0 else nn.Identity(),
66
+ nn.Conv3d(hidden_upper, upper_ch, kernel_size=(3, 1, 1), padding=(1, 0, 0)),
67
+ )
68
+ # Start from interpolation baseline: initial correction ~= 0.
69
+ nn.init.zeros_(self.surface_net[-1].weight)
70
+ nn.init.zeros_(self.surface_net[-1].bias)
71
+ nn.init.zeros_(self.upper_net[-1].weight)
72
+ nn.init.zeros_(self.upper_net[-1].bias)
73
+ elif self.head_mode == "delta2d":
74
+ # Per-time-step heads with endpoint-conditioning.
75
+ # Each head sees [endpoint_{t-6}, base_{t+h}, endpoint_t] context.
76
+ self.surface_heads = nn.ModuleList([
77
+ nn.Sequential(
78
+ nn.Conv2d(num_surface_vars * 3, hidden_surface, kernel_size=3, padding=1),
79
+ nn.GELU(),
80
+ nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(),
81
+ nn.Conv2d(hidden_surface, num_surface_vars, kernel_size=3, padding=1),
82
+ )
83
+ for _ in range(num_time_steps)
84
+ ])
85
+ self.upper_heads = nn.ModuleList([
86
+ nn.Sequential(
87
+ nn.Conv2d(upper_ch * 3, hidden_upper, kernel_size=3, padding=1),
88
+ nn.GELU(),
89
+ nn.Dropout2d(dropout) if dropout > 0 else nn.Identity(),
90
+ nn.Conv2d(hidden_upper, upper_ch, kernel_size=3, padding=1),
91
+ )
92
+ for _ in range(num_time_steps)
93
+ ])
94
+ for head in self.surface_heads:
95
+ nn.init.zeros_(head[-1].weight)
96
+ nn.init.zeros_(head[-1].bias)
97
+ for head in self.upper_heads:
98
+ nn.init.zeros_(head[-1].weight)
99
+ nn.init.zeros_(head[-1].bias)
100
+ else:
101
+ raise ValueError(f"Unsupported temporal head_mode: {self.head_mode}")
102
+
103
+ def forward(self, cond_surface_endpoints, cond_upper_endpoints, base_surface, base_upper):
104
+ # cond_surface_endpoints: [B, 2, V, H, W] (t-6, t) on the current grid
105
+ # cond_upper_endpoints: [B, 2, V, C, H, W]
106
+ # base_surface: [B, T, V, H, W] (U_t baseline for target horizons)
107
+ # base_upper: [B, T, V, C, H, W]
108
+ if cond_surface_endpoints is None or cond_upper_endpoints is None:
109
+ raise ValueError("TemporalCorrectionHead requires endpoint-conditioned inputs.")
110
+ if cond_surface_endpoints.shape[1] < 2 or cond_upper_endpoints.shape[1] < 2:
111
+ raise ValueError("TemporalCorrectionHead expects two endpoint frames in cond inputs.")
112
+
113
+ if self.head_mode == "shared3d":
114
+ b, t, v, h, w = base_surface.shape
115
+ s0 = cond_surface_endpoints[:, 0].unsqueeze(1).expand(-1, t, -1, -1, -1)
116
+ s1 = cond_surface_endpoints[:, -1].unsqueeze(1).expand(-1, t, -1, -1, -1)
117
+ x_surface = torch.cat((s0, base_surface, s1), dim=2)
118
+ x_surface = rearrange(x_surface, "b t v h w -> b v t h w")
119
+ corr_surface = self.surface_net(x_surface)
120
+ corr_surface = rearrange(corr_surface, "b v t h w -> b t v h w")
121
+
122
+ _, _, vu, c, hu, wu = base_upper.shape
123
+ u0 = cond_upper_endpoints[:, 0].unsqueeze(1).expand(-1, t, -1, -1, -1, -1)
124
+ u1 = cond_upper_endpoints[:, -1].unsqueeze(1).expand(-1, t, -1, -1, -1, -1)
125
+ x_upper = torch.cat((u0, base_upper, u1), dim=2)
126
+ x_upper = rearrange(x_upper, "b t v c h w -> b (v c) t h w")
127
+ corr_upper = self.upper_net(x_upper)
128
+ corr_upper = rearrange(
129
+ corr_upper, "b (v c) t h w -> b t v c h w", v=base_upper.shape[2], c=base_upper.shape[3]
130
+ )
131
+ return corr_surface, corr_upper
132
+
133
+ # delta2d mode
134
+ b, t, v, h, w = base_surface.shape
135
+ _, _, vu, c, hu, wu = base_upper.shape
136
+ s0 = cond_surface_endpoints[:, 0]
137
+ s1 = cond_surface_endpoints[:, -1]
138
+ u0 = cond_upper_endpoints[:, 0].reshape(b, vu * c, hu, wu)
139
+ u1 = cond_upper_endpoints[:, -1].reshape(b, vu * c, hu, wu)
140
+
141
+ corr_surface_steps = []
142
+ corr_upper_steps = []
143
+ for ti in range(t):
144
+ hi = min(ti, self.num_time_steps - 1)
145
+ s_cur = base_surface[:, ti]
146
+ s_in = torch.cat((s0, s_cur, s1), dim=1)
147
+ s_corr = self.surface_heads[hi](s_in)
148
+ corr_surface_steps.append(s_corr)
149
+
150
+ u_cur = base_upper[:, ti].reshape(b, vu * c, hu, wu)
151
+ u_in = torch.cat((u0, u_cur, u1), dim=1)
152
+ u_corr = self.upper_heads[hi](u_in).reshape(b, vu, c, hu, wu)
153
+ corr_upper_steps.append(u_corr)
154
+
155
+ corr_surface = torch.stack(corr_surface_steps, dim=1)
156
+ corr_upper = torch.stack(corr_upper_steps, dim=1)
157
  return corr_surface, corr_upper
158
 
159
 
 
226
  return surface_up, upper_up
227
 
228
 
229
+ def _default_temporal_interp(surface_endpoints, upper_endpoints, out_steps, offsets_hours=None):
230
+ """Linear temporal baseline from [t-6, t] to target horizons.
231
+
232
+ By default, this generates a prediction-style baseline for [t+1, ..., t+out_steps]
233
+ using linear extrapolation from the two 6h endpoints.
234
+ """
235
  if surface_endpoints.shape[1] < 2 or upper_endpoints.shape[1] < 2:
236
  raise ValueError("Temporal interpolation needs at least two endpoint frames.")
237
 
238
+ if offsets_hours is None:
239
+ offsets = torch.arange(
240
+ 1, out_steps + 1, device=surface_endpoints.device, dtype=surface_endpoints.dtype
241
+ )
242
+ else:
243
+ offsets = offsets_hours
244
+ if not torch.is_tensor(offsets):
245
+ offsets = torch.as_tensor(offsets)
246
+ if offsets.dim() == 2:
247
+ offsets = offsets[0]
248
+ offsets = offsets.to(device=surface_endpoints.device, dtype=surface_endpoints.dtype).reshape(-1)
249
+ if offsets.numel() != out_steps:
250
+ raise ValueError(
251
+ f"offset length ({offsets.numel()}) must equal out_steps ({out_steps})."
252
+ )
253
+
254
+ x_prev_s = surface_endpoints[:, 0].unsqueeze(1) # at t-6
255
+ x_curr_s = surface_endpoints[:, -1].unsqueeze(1) # at t
256
+ x_prev_u = upper_endpoints[:, 0].unsqueeze(1)
257
+ x_curr_u = upper_endpoints[:, -1].unsqueeze(1)
258
 
259
+ beta_s = (offsets / 6.0).view(1, out_steps, 1, 1, 1)
260
+ beta_u = beta_s.view(1, out_steps, 1, 1, 1, 1)
 
 
261
 
262
+ surface_interp = x_curr_s + beta_s * (x_curr_s - x_prev_s)
263
+ upper_interp = x_curr_u + beta_u * (x_curr_u - x_prev_u)
264
  return surface_interp, upper_interp
265
 
266
 
 
311
  target_hw_6h = batch["y_hr6h_surface"].shape[-2:]
312
  target_hw_1h = batch["y_hr1h_surface"].shape[-2:]
313
  out_steps_1h = batch["y_hr1h_surface"].shape[1]
314
+ target_offsets_hours = batch.get("target_offsets_hours", None)
315
 
316
  # Spatial path at 6h: S(x) = U_s(x) + C_s(x)
317
+ # Use pretrained Aurora 6h prior as base signal at current endpoint (centered prior),
318
+ # not a post-hoc overwrite.
319
  s_base_surface, s_base_upper = spatial_interp_fn(x_lr6h_surface, x_lr6h_upper, target_hw_6h)
320
+ s_base_prior_surface = s_base_surface
321
+ s_base_prior_upper = s_base_upper
322
+ if backbone_s_last_surface is not None:
323
+ s_base_prior_surface = s_base_prior_surface.clone()
324
+ s_base_prior_surface[:, -1] = backbone_s_last_surface
325
+ if backbone_s_last_upper is not None:
326
+ s_base_prior_upper = s_base_prior_upper.clone()
327
+ s_base_prior_upper[:, -1] = backbone_s_last_upper
328
+
329
  pred_s_hr6h_surface, pred_s_hr6h_upper = _apply_correction(
330
  spatial_correction_fn,
331
  x_lr6h_surface,
332
  x_lr6h_upper,
333
+ s_base_prior_surface,
334
+ s_base_prior_upper,
335
  )
 
 
 
 
 
 
 
336
 
337
  # Temporal path at LR: T(x) = U_t(x) + C_t(x)
338
+ t_base_surface, t_base_upper = temporal_interp_fn(
339
+ x_lr6h_surface, x_lr6h_upper, out_steps_1h, offsets_hours=target_offsets_hours
340
+ )
341
  pred_t_lr1h_surface, pred_t_lr1h_upper = _apply_correction(
342
  temporal_correction_fn,
343
  x_lr6h_surface,
 
357
  )
358
 
359
  # ST path: x -> S -> T
360
+ st_base_surface, st_base_upper = temporal_interp_fn(
361
+ pred_s_hr6h_surface, pred_s_hr6h_upper, out_steps_1h, offsets_hours=target_offsets_hours
362
+ )
363
  pred_st_hr1h_surface, pred_st_hr1h_upper = _apply_correction(
364
  temporal_correction_fn,
365
  pred_s_hr6h_surface,
 
371
  return {
372
  "pred_s_hr6h_surface": pred_s_hr6h_surface,
373
  "pred_s_hr6h_upper": pred_s_hr6h_upper,
374
+ "s_base_prior_surface": s_base_prior_surface,
375
+ "s_base_prior_upper": s_base_prior_upper,
376
  "pred_t_lr1h_surface": pred_t_lr1h_surface,
377
  "pred_t_lr1h_upper": pred_t_lr1h_upper,
378
  "pred_ts_hr1h_surface": pred_ts_hr1h_surface,
 
425
  surf_vars=None,
426
  upper_vars=None,
427
  level=None,
428
+ surface_var_weights=None,
429
+ upper_var_weights=None,
430
  ):
431
  """
432
  Match train_one_epoch_downscale style:
 
439
  pred_upper_n = _normalise_upper_bt_vchw(pred_upper, upper_vars, level)
440
  target_upper_n = _normalise_upper_bt_vchw(target_upper, upper_vars, level)
441
 
442
+ if surface_var_weights is None:
443
+ loss_surface = F.mse_loss(pred_surface_n, target_surface_n)
444
+ else:
445
+ ws = torch.as_tensor(
446
+ surface_var_weights,
447
+ device=pred_surface_n.device,
448
+ dtype=pred_surface_n.dtype,
449
+ ).view(1, 1, -1, 1, 1)
450
+ se_s = (pred_surface_n - target_surface_n) ** 2
451
+ loss_surface = (se_s * ws).sum() / (ws.sum() * se_s.shape[0] * se_s.shape[1] * se_s.shape[3] * se_s.shape[4])
452
+
453
+ if upper_var_weights is None:
454
+ loss_upper = F.mse_loss(pred_upper_n, target_upper_n)
455
+ else:
456
+ wu = torch.as_tensor(
457
+ upper_var_weights,
458
+ device=pred_upper_n.device,
459
+ dtype=pred_upper_n.dtype,
460
+ )
461
+ if wu.dim() == 1:
462
+ wu = wu.view(1, 1, -1, 1, 1, 1)
463
+ else:
464
+ wu = wu.view(1, 1, wu.shape[0], wu.shape[1], 1, 1)
465
+ se_u = (pred_upper_n - target_upper_n) ** 2
466
+ loss_upper = (se_u * wu).sum() / (
467
+ wu.sum() * se_u.shape[0] * se_u.shape[1] * se_u.shape[4] * se_u.shape[5]
468
+ )
469
  return 0.25 * loss_surface + loss_upper
470
 
471
 
 
476
  lambda_t=1.0,
477
  lambda_st=1.0,
478
  lambda_comm=1.0,
479
+ lambda_backbone=0.0,
480
  surface_var_weights=None,
481
  upper_var_weights=None,
482
  upper_level_weights=None,
483
  interior_only_1h=False,
484
+ lt_supervision_source="lr_path",
485
+ lt_hybrid_alpha=0.5,
486
  surf_vars=None,
487
  upper_vars=None,
488
  level=None,
 
494
  return x[:, 1:-1]
495
  return x
496
 
497
+ # Backward compatibility: if explicit upper-level weights are provided, override upper_var_weights.
498
+ if upper_level_weights is not None:
499
+ upper_var_weights = upper_level_weights
500
 
501
  # L_S: S(x) supervised by HR-6h
502
  l_s = _downscale_style_term_loss(
 
507
  surf_vars=surf_vars,
508
  upper_vars=upper_vars,
509
  level=level,
510
+ surface_var_weights=surface_var_weights,
511
+ upper_var_weights=upper_var_weights,
512
  )
513
 
514
+ # L_T: temporal module supervision on LR 1h targets.
515
+ # "lr_path" is the strict factorized definition.
516
+ # "st_downscaled"/"hybrid" are optional variants.
517
+ l_t_lr = _downscale_style_term_loss(
518
  _trim_time(preds["pred_t_lr1h_surface"]),
519
  _trim_time(preds["pred_t_lr1h_upper"]),
520
  _trim_time(batch["y_lr1h_surface"]),
 
522
  surf_vars=surf_vars,
523
  upper_vars=upper_vars,
524
  level=level,
525
+ surface_var_weights=surface_var_weights,
526
+ upper_var_weights=upper_var_weights,
527
  )
528
+ has_st_lr = ("pred_st_lr1h_surface" in preds) and ("pred_st_lr1h_upper" in preds)
529
+ if has_st_lr:
530
+ l_t_st = _downscale_style_term_loss(
531
+ _trim_time(preds["pred_st_lr1h_surface"]),
532
+ _trim_time(preds["pred_st_lr1h_upper"]),
533
+ _trim_time(batch["y_lr1h_surface"]),
534
+ _trim_time(batch["y_lr1h_upper"]),
535
+ surf_vars=surf_vars,
536
+ upper_vars=upper_vars,
537
+ level=level,
538
+ surface_var_weights=surface_var_weights,
539
+ upper_var_weights=upper_var_weights,
540
+ )
541
+ else:
542
+ l_t_st = l_t_lr
543
+
544
+ if lt_supervision_source == "st_downscaled":
545
+ l_t = l_t_st
546
+ elif lt_supervision_source == "hybrid":
547
+ a = float(lt_hybrid_alpha)
548
+ a = max(0.0, min(1.0, a))
549
+ l_t = a * l_t_lr + (1.0 - a) * l_t_st
550
+ else:
551
+ l_t = l_t_lr
552
 
553
  # L_ST: ST path supervised by HR-1h
554
  l_st = _downscale_style_term_loss(
 
559
  surf_vars=surf_vars,
560
  upper_vars=upper_vars,
561
  level=level,
562
+ surface_var_weights=surface_var_weights,
563
+ upper_var_weights=upper_var_weights,
564
  )
565
 
566
+ # L_comm: TS and ST consistency (commutativity regularizer)
567
  l_comm = _downscale_style_term_loss(
568
  _trim_time(preds["pred_ts_hr1h_surface"]),
569
  _trim_time(preds["pred_ts_hr1h_upper"]),
 
572
  surf_vars=surf_vars,
573
  upper_vars=upper_vars,
574
  level=level,
575
+ surface_var_weights=surface_var_weights,
576
+ upper_var_weights=upper_var_weights,
577
  )
578
 
579
+ l_backbone = l_comm.new_zeros(())
580
+ if (
581
+ lambda_backbone > 0
582
+ and ("backbone_last_surface" in preds)
583
+ and ("backbone_last_upper" in preds)
584
+ ):
585
+ # Optional prior-consistency regularizer:
586
+ # keep S current-endpoint prediction close to pretrained backbone prior.
587
+ bb_s = preds["backbone_last_surface"].detach().unsqueeze(1)
588
+ bb_u = preds["backbone_last_upper"].detach().unsqueeze(1)
589
+ l_backbone_s = _downscale_style_term_loss(
590
+ preds["pred_s_hr6h_surface"][:, -1:].contiguous(),
591
+ preds["pred_s_hr6h_upper"][:, -1:].contiguous(),
592
+ bb_s,
593
+ bb_u,
594
+ surf_vars=surf_vars,
595
+ upper_vars=upper_vars,
596
+ level=level,
597
+ surface_var_weights=surface_var_weights,
598
+ upper_var_weights=upper_var_weights,
599
+ )
600
+ l_backbone = l_backbone_s
601
+
602
+ total = (
603
+ lambda_s * l_s
604
+ + lambda_t * l_t
605
+ + lambda_st * l_st
606
+ + lambda_comm * l_comm
607
+ + lambda_backbone * l_backbone
608
+ )
609
  loss_dict = {
610
  "loss_total": total,
611
  "loss_s": l_s,
612
  "loss_t": l_t,
613
  "loss_st": l_st,
614
  "loss_comm": l_comm,
615
+ "loss_backbone": l_backbone,
616
  }
617
  return total, loss_dict
618
 
 
890
  return x[..., :h, :w]
891
 
892
 
893
+ def _sanitize_tensor(x, clip_value=1e6):
894
+ x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
895
+ if clip_value is not None and clip_value > 0:
896
+ x = torch.clamp(x, min=-clip_value, max=clip_value)
897
+ return x
898
+
899
+
900
+ def _tensor_health(x):
901
+ finite_mask = torch.isfinite(x)
902
+ finite_ratio = float(finite_mask.float().mean().item())
903
+ x_safe = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
904
+ max_abs = float(x_safe.abs().max().item())
905
+ return finite_ratio, max_abs
906
+
907
+
908
  def _resize_surface_bt_vhw(x, out_hw):
909
  b, t, v, h, w = x.shape
910
  th, tw = out_hw
 
927
  ).reshape(b, t, v, c, th, tw)
928
 
929
 
930
+ def _cell_edges_from_centers(centers_1d: torch.Tensor) -> torch.Tensor:
931
+ centers_1d = centers_1d.float()
932
+ if centers_1d.numel() < 2:
933
+ c = centers_1d[0]
934
+ return torch.stack((c - 0.5, c + 0.5), dim=0)
935
+ mids = 0.5 * (centers_1d[:-1] + centers_1d[1:])
936
+ first = centers_1d[0] - 0.5 * (centers_1d[1] - centers_1d[0])
937
+ last = centers_1d[-1] + 0.5 * (centers_1d[-1] - centers_1d[-2])
938
+ return torch.cat((first[None], mids, last[None]), dim=0)
939
+
940
+
941
+ def _interval_overlap_matrix(src_edges: torch.Tensor, tgt_edges: torch.Tensor) -> torch.Tensor:
942
+ # Returns [N_tgt, N_src] overlap lengths for 1D intervals.
943
+ src_lo = src_edges[:-1].unsqueeze(0) # [1, Ns]
944
+ src_hi = src_edges[1:].unsqueeze(0) # [1, Ns]
945
+ tgt_lo = tgt_edges[:-1].unsqueeze(1) # [Nt, 1]
946
+ tgt_hi = tgt_edges[1:].unsqueeze(1) # [Nt, 1]
947
+ return torch.clamp(torch.minimum(src_hi, tgt_hi) - torch.maximum(src_lo, tgt_lo), min=0.0)
948
+
949
+
950
+ def _maybe_flip_last_dim(x: torch.Tensor, dim: int, need_flip: bool) -> torch.Tensor:
951
+ if need_flip:
952
+ return torch.flip(x, dims=(dim,))
953
+ return x
954
+
955
+
956
+ def _conservative_remap_2d(
957
+ x: torch.Tensor,
958
+ lat_src: torch.Tensor,
959
+ lon_src: torch.Tensor,
960
+ lat_tgt: torch.Tensor,
961
+ lon_tgt: torch.Tensor,
962
+ eps: float = 1e-12,
963
+ ) -> torch.Tensor:
964
+ """
965
+ Conservative remap from source grid to target grid for tensors with trailing [..., H, W].
966
+ Uses area-overlap weights in (sin(lat), lon) coordinates (first-order conservative).
967
+ """
968
+ # Keep target orientation to restore at end.
969
+ src_lat_desc = bool((lat_src[0] > lat_src[-1]).item())
970
+ src_lon_desc = bool((lon_src[0] > lon_src[-1]).item())
971
+ tgt_lat_desc = bool((lat_tgt[0] > lat_tgt[-1]).item())
972
+ tgt_lon_desc = bool((lon_tgt[0] > lon_tgt[-1]).item())
973
+
974
+ x_work = _maybe_flip_last_dim(x, -2, src_lat_desc)
975
+ x_work = _maybe_flip_last_dim(x_work, -1, src_lon_desc)
976
+ lat_src_inc = _maybe_flip_last_dim(lat_src, 0, src_lat_desc)
977
+ lon_src_inc = _maybe_flip_last_dim(lon_src, 0, src_lon_desc)
978
+ lat_tgt_inc = _maybe_flip_last_dim(lat_tgt, 0, tgt_lat_desc)
979
+ lon_tgt_inc = _maybe_flip_last_dim(lon_tgt, 0, tgt_lon_desc)
980
+
981
+ lat_src_edges = _cell_edges_from_centers(lat_src_inc).clamp(-90.0, 90.0)
982
+ lat_tgt_edges = _cell_edges_from_centers(lat_tgt_inc).clamp(-90.0, 90.0)
983
+ lon_src_edges = _cell_edges_from_centers(lon_src_inc)
984
+ lon_tgt_edges = _cell_edges_from_centers(lon_tgt_inc)
985
+
986
+ # Area metric on a sphere is proportional to d(lon) * d(sin(lat)).
987
+ lat_src_m = torch.sin(torch.deg2rad(lat_src_edges))
988
+ lat_tgt_m = torch.sin(torch.deg2rad(lat_tgt_edges))
989
+ lat_overlap = _interval_overlap_matrix(lat_src_m, lat_tgt_m) # [Ht, Hs]
990
+ lon_overlap = _interval_overlap_matrix(lon_src_edges, lon_tgt_edges) # [Wt, Ws]
991
+
992
+ tmp = torch.einsum("...hw,ih->...iw", x_work, lat_overlap)
993
+ num = torch.einsum("...iw,jw->...ij", tmp, lon_overlap)
994
+
995
+ lat_tgt_width = (lat_tgt_m[1:] - lat_tgt_m[:-1]).clamp_min(eps) # [Ht]
996
+ lon_tgt_width = (lon_tgt_edges[1:] - lon_tgt_edges[:-1]).clamp_min(eps) # [Wt]
997
+ den = lat_tgt_width[:, None] * lon_tgt_width[None, :]
998
+ out = num / den
999
+
1000
+ out = _maybe_flip_last_dim(out, -2, tgt_lat_desc)
1001
+ out = _maybe_flip_last_dim(out, -1, tgt_lon_desc)
1002
+ return out
1003
+
1004
+
1005
+ def _conservative_downscale_surface_bt_vhw(x, lat_hr, lon_hr, lat_lr, lon_lr):
1006
+ # x: [B, T, V, H, W]
1007
+ b, t, v, h, w = x.shape
1008
+ y = _conservative_remap_2d(
1009
+ x.reshape(b * t * v, h, w),
1010
+ lat_src=lat_hr,
1011
+ lon_src=lon_hr,
1012
+ lat_tgt=lat_lr,
1013
+ lon_tgt=lon_lr,
1014
+ )
1015
+ return y.reshape(b, t, v, lat_lr.numel(), lon_lr.numel())
1016
+
1017
+
1018
+ def _conservative_downscale_upper_bt_vchw(x, lat_hr, lon_hr, lat_lr, lon_lr):
1019
+ # x: [B, T, V, C, H, W]
1020
+ b, t, v, c, h, w = x.shape
1021
+ y = _conservative_remap_2d(
1022
+ x.reshape(b * t * v * c, h, w),
1023
+ lat_src=lat_hr,
1024
+ lon_src=lon_hr,
1025
+ lat_tgt=lat_lr,
1026
+ lon_tgt=lon_lr,
1027
+ )
1028
+ return y.reshape(b, t, v, c, lat_lr.numel(), lon_lr.numel())
1029
+
1030
+
1031
+ def _prepare_factorized_batch(
1032
+ batch,
1033
+ hr_hw,
1034
+ lr_hw,
1035
+ device,
1036
+ lat_hr=None,
1037
+ lon_hr=None,
1038
+ lat_lr=None,
1039
+ lon_lr=None,
1040
+ ):
1041
  h_hr, w_hr = hr_hw
1042
  h_lr, w_lr = lr_hw
1043
 
1044
  def _to_device(x, h, w):
1045
+ x = _trim_to_hw(x.float(), h, w).to(device, non_blocking=True)
1046
+ return _sanitize_tensor(x)
1047
 
1048
  out = {}
1049
  out["x_lr6h_surface"] = _to_device(batch["x_lr6h_surface"], h_lr, w_lr)
 
1070
  out["y_lr1h_surface"] = _to_device(batch["y_lr1h_surface"], h_lr, w_lr)
1071
  out["y_lr1h_upper"] = _to_device(batch["y_lr1h_upper"], h_lr, w_lr)
1072
  else:
1073
+ use_conservative = (
1074
+ (lat_hr is not None) and (lon_hr is not None) and
1075
+ (lat_lr is not None) and (lon_lr is not None)
1076
+ )
1077
+ if use_conservative:
1078
+ lat_hr_d = lat_hr.to(device=device, dtype=out["y_hr1h_surface"].dtype)
1079
+ lon_hr_d = lon_hr.to(device=device, dtype=out["y_hr1h_surface"].dtype)
1080
+ lat_lr_d = lat_lr.to(device=device, dtype=out["y_hr1h_surface"].dtype)
1081
+ lon_lr_d = lon_lr.to(device=device, dtype=out["y_hr1h_surface"].dtype)
1082
+ out["y_lr1h_surface"] = _conservative_downscale_surface_bt_vhw(
1083
+ out["y_hr1h_surface"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
1084
+ )
1085
+ out["y_lr1h_upper"] = _conservative_downscale_upper_bt_vchw(
1086
+ out["y_hr1h_upper"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
1087
+ )
1088
+ else:
1089
+ out["y_lr1h_surface"] = _resize_surface_bt_vhw(out["y_hr1h_surface"], (h_lr, w_lr))
1090
+ out["y_lr1h_upper"] = _resize_upper_bt_vchw(out["y_hr1h_upper"], (h_lr, w_lr))
1091
+
1092
+ if "target_offsets_hours" in batch:
1093
+ offsets = batch["target_offsets_hours"]
1094
+ if not torch.is_tensor(offsets):
1095
+ offsets = torch.as_tensor(offsets)
1096
+ out["target_offsets_hours"] = offsets.to(device=device, non_blocking=True).long()
1097
 
1098
  out["time_unix"] = batch["time_unix"]
1099
  return out
 
1202
  upper_vars=None,
1203
  level=None,
1204
  correction_scale=1.0,
1205
+ fix_endpoints=True,
1206
+ target_offsets_hours=None,
1207
  ):
1208
  if temporal_corrector is None:
1209
  def _temporal_correction_zeros(input_surface, input_upper, base_surface, base_upper):
 
1212
  return _temporal_correction_zeros
1213
 
1214
  def _temporal_correction(input_surface, input_upper, base_surface, base_upper):
1215
+ cond_surface_n = _normalise_surface_bt_vhw(input_surface, surf_vars)
1216
+ cond_upper_n = _normalise_upper_bt_vchw(input_upper, upper_vars, level)
1217
  base_surface_n = _normalise_surface_bt_vhw(base_surface, surf_vars)
1218
  base_upper_n = _normalise_upper_bt_vchw(base_upper, upper_vars, level)
1219
+ corr_surface_n, corr_upper_n = temporal_corrector(
1220
+ cond_surface_n, cond_upper_n, base_surface_n, base_upper_n
1221
+ )
1222
 
1223
  corr_surface_n = torch.nan_to_num(corr_surface_n, nan=0.0, posinf=0.0, neginf=0.0)
1224
  corr_upper_n = torch.nan_to_num(corr_upper_n, nan=0.0, posinf=0.0, neginf=0.0)
1225
+ if fix_endpoints and corr_surface_n.shape[1] >= 1 and corr_upper_n.shape[1] >= 1:
1226
+ corr_surface_n = corr_surface_n.clone()
1227
+ corr_upper_n = corr_upper_n.clone()
1228
+ fix_indices = []
1229
+ if target_offsets_hours is None:
1230
+ # Backward-compatibility fallback (legacy setting with endpoint-including sequence).
1231
+ fix_indices = [0, corr_surface_n.shape[1] - 1] if corr_surface_n.shape[1] >= 2 else [0]
1232
+ else:
1233
+ offs = target_offsets_hours
1234
+ if not torch.is_tensor(offs):
1235
+ offs = torch.as_tensor(offs)
1236
+ if offs.dim() == 2:
1237
+ offs = offs[0]
1238
+ offs = offs.to(device=corr_surface_n.device).reshape(-1).long()
1239
+ if offs.numel() == corr_surface_n.shape[1]:
1240
+ fix_indices = torch.nonzero(offs == 0, as_tuple=False).view(-1).tolist()
1241
+ for idx in fix_indices:
1242
+ corr_surface_n[:, idx] = 0.0
1243
+ corr_upper_n[:, idx] = 0.0
1244
 
1245
  s_scale = _surface_scale_tensor(surf_vars, base_surface)
1246
  u_scale = _upper_scale_tensor(upper_vars, level, base_upper)
 
1278
  lambda_t=1.0,
1279
  lambda_st=1.0,
1280
  lambda_comm=1.0,
1281
+ lambda_backbone=0.0,
1282
  surface_var_weights=None,
1283
  upper_var_weights=None,
1284
  upper_level_weights=None,
1285
  interior_only_1h=False,
1286
+ lt_supervision_source="lr_path",
1287
+ lt_hybrid_alpha=0.5,
1288
  spatial_corrector=None,
1289
  temporal_corrector=None,
1290
  spatial_correction_scale=1.0,
1291
  temporal_correction_scale=1.0,
1292
+ temporal_fix_endpoints=True,
1293
  enable_backbone_lora=False,
1294
  accum_backward_mode="micro",
1295
  use_ours=False,
 
1317
  backward_calls = 0
1318
  pending_loss = None
1319
 
1320
+ if torch.cuda.is_available():
1321
+ torch.cuda.reset_peak_memory_stats()
1322
+
1323
  for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
1324
  step = data_iter_step // update_freq
1325
  if step >= num_training_steps_per_epoch:
 
1339
  hr_hw=hr_hw,
1340
  lr_hw=lr_hw,
1341
  device=device,
1342
+ lat_hr=lat_hr,
1343
+ lon_hr=lon_hr,
1344
+ lat_lr=lat_lr,
1345
+ lon_lr=lon_lr,
1346
  )
1347
 
1348
  spatial_correction_fn = _build_spatial_correction_fn(
 
1358
  upper_vars=upper_vars,
1359
  level=level,
1360
  correction_scale=temporal_correction_scale,
1361
+ fix_endpoints=temporal_fix_endpoints,
1362
+ target_offsets_hours=batch_f.get("target_offsets_hours", None),
1363
  )
1364
  backbone_last_surface = None
1365
  backbone_last_upper = None
 
1382
  backbone_s_last_surface=backbone_last_surface,
1383
  backbone_s_last_upper=backbone_last_upper,
1384
  )
1385
+ if lt_supervision_source in {"st_downscaled", "hybrid"}:
1386
+ lat_hr_d = lat_hr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
1387
+ lon_hr_d = lon_hr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
1388
+ lat_lr_d = lat_lr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
1389
+ lon_lr_d = lon_lr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
1390
+ preds["pred_st_lr1h_surface"] = _conservative_downscale_surface_bt_vhw(
1391
+ preds["pred_st_hr1h_surface"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
1392
+ )
1393
+ preds["pred_st_lr1h_upper"] = _conservative_downscale_upper_bt_vchw(
1394
+ preds["pred_st_hr1h_upper"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
1395
+ )
1396
+ if backbone_last_surface is not None and backbone_last_upper is not None:
1397
+ preds["backbone_last_surface"] = backbone_last_surface
1398
+ preds["backbone_last_upper"] = backbone_last_upper
1399
  loss, loss_dict = loss_fn(
1400
  preds=preds,
1401
  batch=batch_f,
 
1403
  lambda_t=lambda_t,
1404
  lambda_st=lambda_st,
1405
  lambda_comm=lambda_comm,
1406
+ lambda_backbone=lambda_backbone,
1407
  surface_var_weights=surface_var_weights,
1408
  upper_var_weights=upper_var_weights,
1409
  upper_level_weights=upper_level_weights,
1410
  interior_only_1h=interior_only_1h,
1411
+ lt_supervision_source=lt_supervision_source,
1412
+ lt_hybrid_alpha=lt_hybrid_alpha,
1413
  surf_vars=surf_vars,
1414
  upper_vars=upper_vars,
1415
  level=level,
1416
  )
 
1417
  loss_value = loss.item()
1418
  if not np.isfinite(loss_value):
1419
  print(f"[factorized] non-finite loss at iter {data_iter_step}; skip update.")
1420
+ try:
1421
+ for k in ("x_lr6h_surface", "x_lr6h_upper", "y_hr1h_surface", "y_hr1h_upper"):
1422
+ fr, ma = _tensor_health(batch_f[k])
1423
+ print(f"[factorized][health] {k}: finite_ratio={fr:.6f}, max_abs={ma:.3e}")
1424
+ for k in ("pred_s_hr6h_surface", "pred_t_lr1h_surface", "pred_st_hr1h_surface"):
1425
+ fr, ma = _tensor_health(preds[k])
1426
+ print(f"[factorized][health] {k}: finite_ratio={fr:.6f}, max_abs={ma:.3e}")
1427
+ except Exception as health_err:
1428
+ print(f"[factorized][health] failed to collect tensor health: {health_err}")
1429
  optimizer.zero_grad()
1430
  pending_loss = None
1431
  continue
 
1478
  metric_logger.update(loss_t=loss_dict["loss_t"].item())
1479
  metric_logger.update(loss_st=loss_dict["loss_st"].item())
1480
  metric_logger.update(loss_comm=loss_dict["loss_comm"].item())
1481
+ metric_logger.update(loss_backbone=loss_dict["loss_backbone"].item())
1482
  metric_logger.update(loss_scale=loss_scale_value)
1483
 
1484
  min_lr = 10.0
 
1505
  log_writer.update(loss_t=loss_dict["loss_t"].item(), head="loss")
1506
  log_writer.update(loss_st=loss_dict["loss_st"].item(), head="loss")
1507
  log_writer.update(loss_comm=loss_dict["loss_comm"].item(), head="loss")
1508
+ log_writer.update(loss_backbone=loss_dict["loss_backbone"].item(), head="loss")
1509
  log_writer.update(loss_scale=loss_scale_value, head="opt")
1510
  log_writer.update(lr=max_lr, head="opt")
1511
  log_writer.update(min_lr=min_lr, head="opt")
 
1542
  lambda_t=1.0,
1543
  lambda_st=1.0,
1544
  lambda_comm=1.0,
1545
+ lambda_backbone=0.0,
1546
  surface_var_weights=None,
1547
  upper_var_weights=None,
1548
  upper_level_weights=None,
1549
  interior_only_1h=False,
1550
+ lt_supervision_source="lr_path",
1551
+ lt_hybrid_alpha=0.5,
1552
  spatial_corrector=None,
1553
  temporal_corrector=None,
1554
  spatial_correction_scale=1.0,
1555
  temporal_correction_scale=1.0,
1556
+ temporal_fix_endpoints=True,
1557
  enable_backbone_lora=False,
1558
  print_freq=20,
1559
  val_max_steps=-1,
 
1562
  metric_logger.add_meter("prep_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
1563
  metric_logger.add_meter("model_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
1564
  metric_logger.add_meter("post_s", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
1565
+ metric_logger.add_meter("loss_backbone", utils.SmoothedValue(window_size=20, fmt="{avg:.3f}"))
1566
  header = "Val:"
1567
 
1568
  model.eval()
 
1580
  level=level,
1581
  correction_scale=spatial_correction_scale,
1582
  )
 
 
 
 
 
 
 
 
1583
  warned_short_seq = False
1584
  for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
1585
  if val_max_steps is not None and val_max_steps > 0 and data_iter_step >= val_max_steps:
 
1590
  hr_hw=hr_hw,
1591
  lr_hw=lr_hw,
1592
  device=device,
1593
+ lat_hr=lat_hr,
1594
+ lon_hr=lon_hr,
1595
+ lat_lr=lat_lr,
1596
+ lon_lr=lon_lr,
1597
  )
1598
  if torch.cuda.is_available():
1599
  torch.cuda.synchronize()
1600
  t1 = time.perf_counter()
1601
 
1602
  with torch.inference_mode():
1603
+ temporal_correction_fn = _build_temporal_correction_fn(
1604
+ temporal_corrector=temporal_corrector,
1605
+ surf_vars=surf_vars,
1606
+ upper_vars=upper_vars,
1607
+ level=level,
1608
+ correction_scale=temporal_correction_scale,
1609
+ fix_endpoints=temporal_fix_endpoints,
1610
+ target_offsets_hours=batch_f.get("target_offsets_hours", None),
1611
+ )
1612
  backbone_last_surface = None
1613
  backbone_last_upper = None
1614
  if enable_backbone_lora:
 
1629
  backbone_s_last_surface=backbone_last_surface,
1630
  backbone_s_last_upper=backbone_last_upper,
1631
  )
1632
+ if lt_supervision_source in {"st_downscaled", "hybrid"}:
1633
+ lat_hr_d = lat_hr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
1634
+ lon_hr_d = lon_hr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
1635
+ lat_lr_d = lat_lr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
1636
+ lon_lr_d = lon_lr.to(device=preds["pred_st_hr1h_surface"].device, dtype=preds["pred_st_hr1h_surface"].dtype)
1637
+ preds["pred_st_lr1h_surface"] = _conservative_downscale_surface_bt_vhw(
1638
+ preds["pred_st_hr1h_surface"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
1639
+ )
1640
+ preds["pred_st_lr1h_upper"] = _conservative_downscale_upper_bt_vchw(
1641
+ preds["pred_st_hr1h_upper"], lat_hr_d, lon_hr_d, lat_lr_d, lon_lr_d
1642
+ )
1643
+ if backbone_last_surface is not None and backbone_last_upper is not None:
1644
+ preds["backbone_last_surface"] = backbone_last_surface
1645
+ preds["backbone_last_upper"] = backbone_last_upper
1646
  loss, loss_dict = loss_fn(
1647
  preds=preds,
1648
  batch=batch_f,
 
1650
  lambda_t=lambda_t,
1651
  lambda_st=lambda_st,
1652
  lambda_comm=lambda_comm,
1653
+ lambda_backbone=lambda_backbone,
1654
  surface_var_weights=surface_var_weights,
1655
  upper_var_weights=upper_var_weights,
1656
  upper_level_weights=upper_level_weights,
1657
  interior_only_1h=interior_only_1h,
1658
+ lt_supervision_source=lt_supervision_source,
1659
+ lt_hybrid_alpha=lt_hybrid_alpha,
1660
  surf_vars=surf_vars,
1661
  upper_vars=upper_vars,
1662
  level=level,
 
1670
  target_seq_surface = batch_f["y_hr1h_surface"]
1671
  target_seq_upper = batch_f["y_hr1h_upper"]
1672
 
1673
+ # Default setting is include_endpoints=False -> sequence is [t+1, ..., t+6] (len=6).
1674
+ # With include_endpoints=True, sequence is [t+0, t+1, ..., t+6] (len=7).
1675
  # We expose per-delta metrics for t+1..t+6 when available.
1676
  seq_len = pred_seq_surface.shape[1]
1677
  if seq_len >= 7:
 
1713
  metric_logger.update(loss_t=loss_dict["loss_t"].item())
1714
  metric_logger.update(loss_st=loss_dict["loss_st"].item())
1715
  metric_logger.update(loss_comm=loss_dict["loss_comm"].item())
1716
+ metric_logger.update(loss_backbone=loss_dict["loss_backbone"].item())
1717
  metric_logger.update(prep_s=t1 - t0)
1718
  metric_logger.update(model_s=t2 - t1)
1719
  metric_logger.update(post_s=t3 - t2)
downscaling/output/WeatherPEFT/checkpoint-19.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4df5855ef1cd3b158eb5b2724fb4e3486ee2e543fe0cb799f5b88f72be3cfca1
3
+ size 5062400227
downscaling/output/WeatherPEFT/checkpoint-29.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f90edeba59a2bc62285624d7fb5c532e535adca27c59d756c22234660a8230c
3
+ size 5062400227
downscaling/output/WeatherPEFT/checkpoint-9.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78f21d453f4d85be909ec43cb4876e021acafd757e3668fc0617732a2eef4bd5
3
+ size 5062398726
downscaling/output/WeatherPEFT/log.txt ADDED
The diff for this file is too large to render. See raw diff
 
downscaling/output/backbone_anchor_smoke/log.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ "Eval:"
2
+ {"val_prep_s": 0.07330790814012289, "val_model_s": 0.5823051929473877, "val_post_s": 0.010655493009835482, "val_valid_loss": 0.6934626698493958, "val_loss_s": 0.2821418046951294, "val_loss_t": 0.05917414277791977, "val_loss_st": 0.238005131483078, "val_loss_comm": 0.11414159834384918, "val_loss_backbone": 0.0, "val_rmse_2t_tplus1h": 2.2092418670654297, "val_mean_bias_2t_tplus1h": 0.01837158203125, "val_rmse_10u_tplus1h": 2.0027098655700684, "val_mean_bias_10u_tplus1h": -0.045052409172058105, "val_rmse_10v_tplus1h": 2.0557847023010254, "val_mean_bias_10v_tplus1h": 0.08185440301895142, "val_rmse_t850_tplus1h": 1.5737048387527466, "val_mean_bias_t850_tplus1h": 0.1566162109375, "val_rmse_z500_tplus1h": 235.0081024169922, "val_mean_bias_z500_tplus1h": 13.34765625, "val_rmse_2t_tplus2h": 2.311091661453247, "val_mean_bias_2t_tplus2h": -0.047149658203125, "val_rmse_10u_tplus2h": 2.2796661853790283, "val_mean_bias_10u_tplus2h": -0.10876637697219849, "val_rmse_10v_tplus2h": 2.4048244953155518, "val_mean_bias_10v_tplus2h": 0.15996196866035461, "val_rmse_t850_tplus2h": 1.8633790016174316, "val_mean_bias_t850_tplus2h": 0.289154052734375, "val_rmse_z500_tplus2h": 294.5515441894531, "val_mean_bias_z500_tplus2h": 25.08203125, "val_rmse_2t_tplus3h": 2.514113426208496, "val_mean_bias_2t_tplus3h": -0.05084228515625, "val_rmse_10u_tplus3h": 2.639458417892456, "val_mean_bias_10u_tplus3h": -0.17488084733486176, "val_rmse_10v_tplus3h": 2.870990514755249, "val_mean_bias_10v_tplus3h": 0.233941450715065, "val_rmse_t850_tplus3h": 2.23453950881958, "val_mean_bias_t850_tplus3h": 0.416259765625, "val_rmse_z500_tplus3h": 366.6177673339844, "val_mean_bias_z500_tplus3h": 35.8828125, "val_rmse_2t_tplus4h": 2.748657464981079, "val_mean_bias_2t_tplus4h": -0.073028564453125, "val_rmse_10u_tplus4h": 3.0420165061950684, "val_mean_bias_10u_tplus4h": -0.24148303270339966, "val_rmse_10v_tplus4h": 3.3935253620147705, "val_mean_bias_10v_tplus4h": 0.29964739084243774, "val_rmse_t850_tplus4h": 2.6447532176971436, "val_mean_bias_t850_tplus4h": 0.5400390625, "val_rmse_z500_tplus4h": 444.1191101074219, "val_mean_bias_z500_tplus4h": 46.0546875, "val_rmse_2t_tplus5h": 2.8865745067596436, "val_mean_bias_2t_tplus5h": -0.24896240234375, "val_rmse_10u_tplus5h": 3.4593324661254883, "val_mean_bias_10u_tplus5h": -0.30440056324005127, "val_rmse_10v_tplus5h": 3.940408706665039, "val_mean_bias_10v_tplus5h": 0.3687131106853485, "val_rmse_t850_tplus5h": 3.0733869075775146, "val_mean_bias_t850_tplus5h": 0.66290283203125, "val_rmse_z500_tplus5h": 522.9864501953125, "val_mean_bias_z500_tplus5h": 55.27734375, "val_rmse_2t_tplus6h": 3.113255262374878, "val_mean_bias_2t_tplus6h": -0.284454345703125, "val_rmse_10u_tplus6h": 3.8806846141815186, "val_mean_bias_10u_tplus6h": -0.3701193332672119, "val_rmse_10v_tplus6h": 4.501622200012207, "val_mean_bias_10v_tplus6h": 0.4297124147415161, "val_rmse_t850_tplus6h": 3.5143191814422607, "val_mean_bias_t850_tplus6h": 0.78155517578125, "val_rmse_z500_tplus6h": 600.989013671875, "val_mean_bias_z500_tplus6h": 63.31640625, "val_rmse_2t": 3.113255262374878, "val_mean_bias_2t": -0.284454345703125, "val_rmse_10u": 3.8806846141815186, "val_mean_bias_10u": -0.3701193332672119, "val_rmse_10v": 4.501622200012207, "val_mean_bias_10v": 0.4297124147415161, "val_rmse_t850": 3.5143191814422607, "val_mean_bias_t850": 0.78155517578125, "val_rmse_z500": 600.989013671875, "val_mean_bias_z500": 63.31640625}
3
+ " &0.073 &0.582 &0.011 &0.282 &0.059 &0.238 &0.114 &0.0 &2.209 &0.018 &2.003 &-0.045 &2.056 &0.082 &1.574 &0.157 &235.008 &13.348 &2.311 &-0.047 &2.28 &-0.109 &2.405 &0.16 &1.863 &0.289 &294.552 &25.082 &2.514 &-0.051 &2.639 &-0.175 &2.871 &0.234 &2.235 &0.416 &366.618 &35.883 &2.749 &-0.073 &3.042 &-0.241 &3.394 &0.3 &2.645 &0.54 &444.119 &46.055 &2.887 &-0.249 &3.459 &-0.304 &3.94 &0.369 &3.073 &0.663 &522.986 &55.277 &3.113 &-0.284 &3.881 &-0.37 &4.502 &0.43 &3.514 &0.782 &600.989 &63.316 &3.113 &-0.284 &3.881 &-0.37 &4.502 &0.43 &3.514 &0.782 &600.989 &63.316"
4
+
downscaling/output/val_freq_smoke/log.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ "Eval:"
2
+ {"val_prep_s": 0.05949366791173816, "val_model_s": 0.3861960484646261, "val_post_s": 0.005747961113229394, "val_valid_loss": 0.6520180404186249, "val_loss_s": 0.2656427472829819, "val_loss_t": 0.05454788729548454, "val_loss_st": 0.22411399334669113, "val_loss_comm": 0.10771343111991882, "val_rmse_2t_tplus1h": 2.1595544815063477, "val_mean_bias_2t_tplus1h": 0.0153045654296875, "val_rmse_10u_tplus1h": 1.9929717779159546, "val_mean_bias_10u_tplus1h": -0.06720291078090668, "val_rmse_10v_tplus1h": 2.0678257942199707, "val_mean_bias_10v_tplus1h": 0.04859817773103714, "val_rmse_t850_tplus1h": 1.5690162181854248, "val_mean_bias_t850_tplus1h": 0.1692962646484375, "val_rmse_z500_tplus1h": 229.7540283203125, "val_mean_bias_z500_tplus1h": 11.56640625, "val_rmse_2t_tplus2h": 2.325566053390503, "val_mean_bias_2t_tplus2h": -0.111175537109375, "val_rmse_10u_tplus2h": 2.259079933166504, "val_mean_bias_10u_tplus2h": -0.12979388236999512, "val_rmse_10v_tplus2h": 2.4088897705078125, "val_mean_bias_10v_tplus2h": 0.0989474430680275, "val_rmse_t850_tplus2h": 1.8684790134429932, "val_mean_bias_t850_tplus2h": 0.3101043701171875, "val_rmse_z500_tplus2h": 285.2650146484375, "val_mean_bias_z500_tplus2h": 23.6953125, "val_rmse_2t_tplus3h": 2.559150457382202, "val_mean_bias_2t_tplus3h": -0.1783599853515625, "val_rmse_10u_tplus3h": 2.5984840393066406, "val_mean_bias_10u_tplus3h": -0.1920362114906311, "val_rmse_10v_tplus3h": 2.8546905517578125, "val_mean_bias_10v_tplus3h": 0.15115559101104736, "val_rmse_t850_tplus3h": 2.2511355876922607, "val_mean_bias_t850_tplus3h": 0.44720458984375, "val_rmse_z500_tplus3h": 353.3114013671875, "val_mean_bias_z500_tplus3h": 34.533203125, "val_rmse_2t_tplus4h": 2.829373598098755, "val_mean_bias_2t_tplus4h": -0.268890380859375, "val_rmse_10u_tplus4h": 2.954393148422241, "val_mean_bias_10u_tplus4h": -0.2777942419052124, "val_rmse_10v_tplus4h": 3.3426637649536133, "val_mean_bias_10v_tplus4h": 0.18345022201538086, "val_rmse_t850_tplus4h": 2.6910040378570557, "val_mean_bias_t850_tplus4h": 0.591400146484375, "val_rmse_z500_tplus4h": 428.23114013671875, "val_mean_bias_z500_tplus4h": 46.494140625, "val_rmse_2t_tplus5h": 2.988680839538574, "val_mean_bias_2t_tplus5h": -0.3951263427734375, "val_rmse_10u_tplus5h": 3.360450506210327, "val_mean_bias_10u_tplus5h": -0.33785462379455566, "val_rmse_10v_tplus5h": 3.8764631748199463, "val_mean_bias_10v_tplus5h": 0.23951280117034912, "val_rmse_t850_tplus5h": 3.136207103729248, "val_mean_bias_t850_tplus5h": 0.7298583984375, "val_rmse_z500_tplus5h": 503.41949462890625, "val_mean_bias_z500_tplus5h": 58.720703125, "val_rmse_2t_tplus6h": 3.1790788173675537, "val_mean_bias_2t_tplus6h": -0.4730224609375, "val_rmse_10u_tplus6h": 3.7709269523620605, "val_mean_bias_10u_tplus6h": -0.39494484663009644, "val_rmse_10v_tplus6h": 4.418558120727539, "val_mean_bias_10v_tplus6h": 0.29149335622787476, "val_rmse_t850_tplus6h": 3.5930073261260986, "val_mean_bias_t850_tplus6h": 0.868896484375, "val_rmse_z500_tplus6h": 578.9581298828125, "val_mean_bias_z500_tplus6h": 69.849609375, "val_rmse_2t": 3.1790788173675537, "val_mean_bias_2t": -0.4730224609375, "val_rmse_10u": 3.7709269523620605, "val_mean_bias_10u": -0.39494484663009644, "val_rmse_10v": 4.418558120727539, "val_mean_bias_10v": 0.29149335622787476, "val_rmse_t850": 3.5930073261260986, "val_mean_bias_t850": 0.868896484375, "val_rmse_z500": 578.9581298828125, "val_mean_bias_z500": 69.849609375}
3
+ " &0.059 &0.386 &0.006 &0.266 &0.055 &0.224 &0.108 &2.16 &0.015 &1.993 &-0.067 &2.068 &0.049 &1.569 &0.169 &229.754 &11.566 &2.326 &-0.111 &2.259 &-0.13 &2.409 &0.099 &1.868 &0.31 &285.265 &23.695 &2.559 &-0.178 &2.598 &-0.192 &2.855 &0.151 &2.251 &0.447 &353.311 &34.533 &2.829 &-0.269 &2.954 &-0.278 &3.343 &0.183 &2.691 &0.591 &428.231 &46.494 &2.989 &-0.395 &3.36 &-0.338 &3.876 &0.24 &3.136 &0.73 &503.419 &58.721 &3.179 &-0.473 &3.771 &-0.395 &4.419 &0.291 &3.593 &0.869 &578.958 &69.85 &3.179 &-0.473 &3.771 &-0.395 &4.419 &0.291 &3.593 &0.869 &578.958 &69.85"
4
+
downscaling/output/val_speed_after/log.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ "Eval:"
2
+ {"val_prep_s": 0.084631056097957, "val_model_s": 0.6615325791062787, "val_post_s": 0.0011346005291367571, "val_sample_s": 0.7511531277947748, "val_valid_loss": 0.6753831406434377, "val_loss_s": 0.27440425256888074, "val_loss_t": 0.05877547711133957, "val_loss_st": 0.23192234834035239, "val_loss_comm": 0.11028106758991878, "val_rmse_2t_tplus1h": 2.120826482772827, "val_mean_bias_2t_tplus1h": -0.08579571545124054, "val_rmse_10u_tplus1h": 2.0343947410583496, "val_mean_bias_10u_tplus1h": -0.044072579592466354, "val_rmse_10v_tplus1h": 2.0728132724761963, "val_mean_bias_10v_tplus1h": 0.05038035660982132, "val_rmse_t850_tplus1h": 1.610738754272461, "val_mean_bias_t850_tplus1h": 0.1519871950149536, "val_rmse_z500_tplus1h": 237.20074462890625, "val_mean_bias_z500_tplus1h": 9.716571807861328, "val_rmse_2t_tplus2h": 2.293344020843506, "val_mean_bias_2t_tplus2h": -0.18076401948928833, "val_rmse_10u_tplus2h": 2.3025588989257812, "val_mean_bias_10u_tplus2h": -0.08629271388053894, "val_rmse_10v_tplus2h": 2.409632444381714, "val_mean_bias_10v_tplus2h": 0.102354496717453, "val_rmse_t850_tplus2h": 1.9087270498275757, "val_mean_bias_t850_tplus2h": 0.2832726240158081, "val_rmse_z500_tplus2h": 296.6208801269531, "val_mean_bias_z500_tplus2h": 20.60210418701172, "val_rmse_2t_tplus3h": 2.481692314147949, "val_mean_bias_2t_tplus3h": -0.2545323371887207, "val_rmse_10u_tplus3h": 2.6446375846862793, "val_mean_bias_10u_tplus3h": -0.12955179810523987, "val_rmse_10v_tplus3h": 2.8543546199798584, "val_mean_bias_10v_tplus3h": 0.15386676788330078, "val_rmse_t850_tplus3h": 2.282665491104126, "val_mean_bias_t850_tplus3h": 0.4141359329223633, "val_rmse_z500_tplus3h": 367.32757568359375, "val_mean_bias_z500_tplus3h": 30.99903678894043, "val_rmse_2t_tplus4h": 2.7042994499206543, "val_mean_bias_2t_tplus4h": -0.3467860221862793, "val_rmse_10u_tplus4h": 2.998748540878296, "val_mean_bias_10u_tplus4h": -0.19702103734016418, "val_rmse_10v_tplus4h": 3.3468921184539795, "val_mean_bias_10v_tplus4h": 0.20130321383476257, "val_rmse_t850_tplus4h": 2.7033567428588867, "val_mean_bias_t850_tplus4h": 0.5463706254959106, "val_rmse_z500_tplus4h": 444.17047119140625, "val_mean_bias_z500_tplus4h": 42.49504089355469, "val_rmse_2t_tplus5h": 2.890420913696289, "val_mean_bias_2t_tplus5h": -0.4551726281642914, "val_rmse_10u_tplus5h": 3.40653395652771, "val_mean_bias_10u_tplus5h": -0.23875072598457336, "val_rmse_10v_tplus5h": 3.876429319381714, "val_mean_bias_10v_tplus5h": 0.25390106439590454, "val_rmse_t850_tplus5h": 3.1348283290863037, "val_mean_bias_t850_tplus5h": 0.6817998886108398, "val_rmse_z500_tplus5h": 521.3743286132812, "val_mean_bias_z500_tplus5h": 54.1654052734375, "val_rmse_2t_tplus6h": 3.1032655239105225, "val_mean_bias_2t_tplus6h": -0.538508415222168, "val_rmse_10u_tplus6h": 3.821709394454956, "val_mean_bias_10u_tplus6h": -0.27899816632270813, "val_rmse_10v_tplus6h": 4.415755271911621, "val_mean_bias_10v_tplus6h": 0.3063257932662964, "val_rmse_t850_tplus6h": 3.5775928497314453, "val_mean_bias_t850_tplus6h": 0.8159565329551697, "val_rmse_z500_tplus6h": 599.5579833984375, "val_mean_bias_z500_tplus6h": 65.35942077636719, "val_rmse_2t": 3.1032655239105225, "val_mean_bias_2t": -0.538508415222168, "val_rmse_10u": 3.821709394454956, "val_mean_bias_10u": -0.27899816632270813, "val_rmse_10v": 4.415755271911621, "val_mean_bias_10v": 0.3063257932662964, "val_rmse_t850": 3.5775928497314453, "val_mean_bias_t850": 0.8159565329551697, "val_rmse_z500": 599.5579833984375, "val_mean_bias_z500": 65.35942077636719}
3
+ " &0.085 &0.662 &0.001 &0.751 &0.274 &0.059 &0.232 &0.11 &2.121 &-0.086 &2.034 &-0.044 &2.073 &0.05 &1.611 &0.152 &237.201 &9.717 &2.293 &-0.181 &2.303 &-0.086 &2.41 &0.102 &1.909 &0.283 &296.621 &20.602 &2.482 &-0.255 &2.645 &-0.13 &2.854 &0.154 &2.283 &0.414 &367.328 &30.999 &2.704 &-0.347 &2.999 &-0.197 &3.347 &0.201 &2.703 &0.546 &444.17 &42.495 &2.89 &-0.455 &3.407 &-0.239 &3.876 &0.254 &3.135 &0.682 &521.374 &54.165 &3.103 &-0.539 &3.822 &-0.279 &4.416 &0.306 &3.578 &0.816 &599.558 &65.359 &3.103 &-0.539 &3.822 &-0.279 &4.416 &0.306 &3.578 &0.816 &599.558 &65.359"
4
+
downscaling/output/val_speed_baseline/log.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ "Eval:"
2
+ {"val_prep_s": 0.05412517949783554, "val_model_s": 0.16398887091781944, "val_post_s": 0.003232622635550797, "val_valid_loss": 0.6753831406434377, "val_loss_s": 0.27440425256888074, "val_loss_t": 0.05877547711133957, "val_loss_st": 0.23192234834035239, "val_loss_comm": 0.11028106758991878, "val_rmse_2t_tplus1h": 2.120826482772827, "val_mean_bias_2t_tplus1h": -0.0857950896024704, "val_rmse_10u_tplus1h": 2.0343947410583496, "val_mean_bias_10u_tplus1h": -0.044072579592466354, "val_rmse_10v_tplus1h": 2.0728132724761963, "val_mean_bias_10v_tplus1h": 0.05038034915924072, "val_rmse_t850_tplus1h": 1.610738754272461, "val_mean_bias_t850_tplus1h": 0.1519927978515625, "val_rmse_z500_tplus1h": 237.20074462890625, "val_mean_bias_z500_tplus1h": 9.715169906616211, "val_rmse_2t_tplus2h": 2.293344020843506, "val_mean_bias_2t_tplus2h": -0.18076324462890625, "val_rmse_10u_tplus2h": 2.3025588989257812, "val_mean_bias_10u_tplus2h": -0.08629272878170013, "val_rmse_10v_tplus2h": 2.409632444381714, "val_mean_bias_10v_tplus2h": 0.1023545116186142, "val_rmse_t850_tplus2h": 1.9087270498275757, "val_mean_bias_t850_tplus2h": 0.28327688574790955, "val_rmse_z500_tplus2h": 296.6208801269531, "val_mean_bias_z500_tplus2h": 20.602214813232422, "val_rmse_2t_tplus3h": 2.481692314147949, "val_mean_bias_2t_tplus3h": -0.2545318603515625, "val_rmse_10u_tplus3h": 2.6446375846862793, "val_mean_bias_10u_tplus3h": -0.12955179810523987, "val_rmse_10v_tplus3h": 2.8543546199798584, "val_mean_bias_10v_tplus3h": 0.15386676788330078, "val_rmse_t850_tplus3h": 2.282665491104126, "val_mean_bias_t850_tplus3h": 0.4141286313533783, "val_rmse_z500_tplus3h": 367.32757568359375, "val_mean_bias_z500_tplus3h": 30.99837303161621, "val_rmse_2t_tplus4h": 2.7042994499206543, "val_mean_bias_2t_tplus4h": -0.34678396582603455, "val_rmse_10u_tplus4h": 2.998748540878296, "val_mean_bias_10u_tplus4h": -0.19702103734016418, "val_rmse_10v_tplus4h": 3.3468921184539795, "val_mean_bias_10v_tplus4h": 0.20130321383476257, "val_rmse_t850_tplus4h": 2.7033567428588867, "val_mean_bias_t850_tplus4h": 0.5463689565658569, "val_rmse_z500_tplus4h": 444.17047119140625, "val_mean_bias_z500_tplus4h": 42.49609375, "val_rmse_2t_tplus5h": 2.890420913696289, "val_mean_bias_2t_tplus5h": -0.45517224073410034, "val_rmse_10u_tplus5h": 3.40653395652771, "val_mean_bias_10u_tplus5h": -0.23875072598457336, "val_rmse_10v_tplus5h": 3.876429319381714, "val_mean_bias_10v_tplus5h": 0.25390106439590454, "val_rmse_t850_tplus5h": 3.1348283290863037, "val_mean_bias_t850_tplus5h": 0.6818059682846069, "val_rmse_z500_tplus5h": 521.3743286132812, "val_mean_bias_z500_tplus5h": 54.1650390625, "val_rmse_2t_tplus6h": 3.1032655239105225, "val_mean_bias_2t_tplus6h": -0.5385081171989441, "val_rmse_10u_tplus6h": 3.821709394454956, "val_mean_bias_10u_tplus6h": -0.2789981961250305, "val_rmse_10v_tplus6h": 4.415755271911621, "val_mean_bias_10v_tplus6h": 0.3063257932662964, "val_rmse_t850_tplus6h": 3.5775928497314453, "val_mean_bias_t850_tplus6h": 0.8159612417221069, "val_rmse_z500_tplus6h": 599.5579833984375, "val_mean_bias_z500_tplus6h": 65.3583984375, "val_rmse_2t": 3.1032655239105225, "val_mean_bias_2t": -0.5385081171989441, "val_rmse_10u": 3.821709394454956, "val_mean_bias_10u": -0.2789981961250305, "val_rmse_10v": 4.415755271911621, "val_mean_bias_10v": 0.3063257932662964, "val_rmse_t850": 3.5775928497314453, "val_mean_bias_t850": 0.8159612417221069, "val_rmse_z500": 599.5579833984375, "val_mean_bias_z500": 65.3583984375}
3
+ " &0.054 &0.164 &0.003 &0.274 &0.059 &0.232 &0.11 &2.121 &-0.086 &2.034 &-0.044 &2.073 &0.05 &1.611 &0.152 &237.201 &9.715 &2.293 &-0.181 &2.303 &-0.086 &2.41 &0.102 &1.909 &0.283 &296.621 &20.602 &2.482 &-0.255 &2.645 &-0.13 &2.854 &0.154 &2.283 &0.414 &367.328 &30.998 &2.704 &-0.347 &2.999 &-0.197 &3.347 &0.201 &2.703 &0.546 &444.17 &42.496 &2.89 &-0.455 &3.407 &-0.239 &3.876 &0.254 &3.135 &0.682 &521.374 &54.165 &3.103 &-0.539 &3.822 &-0.279 &4.416 &0.306 &3.578 &0.816 &599.558 &65.358 &3.103 &-0.539 &3.822 &-0.279 &4.416 &0.306 &3.578 &0.816 &599.558 &65.358"
4
+
downscaling/run_downscaling.py CHANGED
@@ -87,6 +87,7 @@ def _resolve_pretrained_ckpt(path_str: str) -> Path:
87
 
88
  def _attach_factorized_heads(model, surf_vars, upper_vars, level, args):
89
  """Attach factorized temporal/spatial correction heads."""
 
90
  model.temporal_corrector = TemporalCorrectionHead(
91
  num_surface_vars=len(surf_vars),
92
  num_upper_vars=len(upper_vars),
@@ -94,6 +95,8 @@ def _attach_factorized_heads(model, surf_vars, upper_vars, level, args):
94
  hidden_surface=args.temporal_hidden_surface,
95
  hidden_upper=args.temporal_hidden_upper,
96
  dropout=args.temporal_dropout,
 
 
97
  )
98
  model.spatial_corrector = SpatialCorrectionHead(
99
  num_surface_vars=len(surf_vars),
@@ -327,6 +330,8 @@ def get_args():
327
  help='validation progress print frequency (iterations)')
328
  parser.add_argument('--val_max_steps', default=-1, type=int,
329
  help='if >0, limit validation iterations per epoch for faster feedback')
 
 
330
  parser.add_argument('--persistent_workers', action='store_true', default=True,
331
  help='keep dataloader workers persistent across epochs')
332
  parser.add_argument('--no_persistent_workers', action='store_false', dest='persistent_workers')
@@ -366,6 +371,8 @@ def get_args():
366
  parser.add_argument('--lora_mode', default='single', choices=['single', 'all'])
367
  parser.add_argument('--disable_backbone_lora_in_factorized', action='store_true',
368
  help='If set in finetune_mode=lora, freeze LoRA adapters and train only correction heads.')
 
 
369
  parser.add_argument('--lr_data_dir', default='', type=str,
370
  help='LR dataset folder under datasets (auto if empty), e.g. 5.625')
371
  parser.add_argument('--hr_data_dir', default='', type=str,
@@ -386,12 +393,19 @@ def get_args():
386
  parser.add_argument('--derive_lr1h_from_hr1h', action='store_true', default=True,
387
  help='Build y_lr1h from y_hr1h in training loop to reduce dataset IO.')
388
  parser.add_argument('--no_derive_lr1h_from_hr1h', action='store_false', dest='derive_lr1h_from_hr1h')
389
- parser.add_argument('--include_endpoints', action='store_true', default=True,
390
- help='Use 1h targets including endpoints t0 and t0+6h')
 
391
  parser.add_argument('--lambda_s', default=1.0, type=float)
392
  parser.add_argument('--lambda_t', default=1.0, type=float)
393
  parser.add_argument('--lambda_st', default=1.0, type=float)
394
  parser.add_argument('--lambda_comm', default=1.0, type=float)
 
 
 
 
 
 
395
  parser.add_argument('--interior_only_1h', action='store_true',
396
  help='If set, compute 1h losses on interior frames only (exclude endpoints)')
397
  parser.add_argument('--ours_prompt_length', default=0, type=int,
@@ -402,6 +416,8 @@ def get_args():
402
  help='Hidden channels for temporal correction head (upper vars x levels).')
403
  parser.add_argument('--temporal_dropout', default=0.0, type=float,
404
  help='Dropout for temporal correction head.')
 
 
405
  parser.add_argument('--spatial_hidden_surface', default=256, type=int,
406
  help='Hidden channels for spatial correction head (surface vars).')
407
  parser.add_argument('--spatial_hidden_upper', default=512, type=int,
@@ -412,6 +428,8 @@ def get_args():
412
  help='Residual scale multiplier for spatial correction outputs.')
413
  parser.add_argument('--temporal_correction_scale', default=1e-3, type=float,
414
  help='Residual scale multiplier for temporal correction outputs.')
 
 
415
  parser.add_argument('--disable_factorized_lr_scale', action='store_true',
416
  help='Disable LR scaling by total_batch_size/256 for factorized training.')
417
  parser.add_argument('--disable_scale_balanced_loss', action='store_true',
@@ -563,6 +581,10 @@ def main(args, ds_init):
563
  )
564
 
565
  print("Sampler_train = %s" % str(sampler_train))
 
 
 
 
566
  if args.dist_eval:
567
  if len(dataset_val) % num_tasks != 0:
568
  print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
@@ -607,6 +629,8 @@ def main(args, ds_init):
607
  f"[factorized-data] derive_hr6h_from_hr1h={args.derive_hr6h_from_hr1h}, "
608
  f"derive_lr1h_from_hr1h={args.derive_lr1h_from_hr1h}"
609
  )
 
 
610
 
611
  if dataset_val is not None:
612
  val_loader_kwargs = dict(
@@ -627,6 +651,10 @@ def main(args, ds_init):
627
  f"prefetch_factor={args.prefetch_factor if args.num_workers > 0 else 'n/a'}, "
628
  f"pin_mem={args.pin_mem}"
629
  )
 
 
 
 
630
  else:
631
  data_loader_val = None
632
 
@@ -666,14 +694,28 @@ def main(args, ds_init):
666
 
667
  use_ours = finetune_mode == "ours"
668
  use_lora = finetune_mode == "lora"
669
- enable_backbone_lora = finetune_mode == "lora" and (not args.disable_backbone_lora_in_factorized)
 
670
  if finetune_mode == "lora":
671
  if args.lora_mode == "all" and args.lora_steps > 1:
672
  print(
673
  "[lora mode] NOTE: factorized path uses single-step Aurora forward (rollout_step=0), "
674
  "so only step-0 LoRA branch is used even when lora_mode=all."
675
  )
676
- print(f"[lora mode] backbone LoRA enabled in factorized loop: {enable_backbone_lora}")
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
678
  model = Aurora(
679
  use_ours=use_ours,
@@ -701,13 +743,14 @@ def main(args, ds_init):
701
  print(
702
  "[factorized] Attached correction heads "
703
  f"(temporal: {args.temporal_hidden_surface}/{args.temporal_hidden_upper}, "
 
704
  f"spatial: {args.spatial_hidden_surface}/{args.spatial_hidden_upper})."
705
  )
706
 
707
  _set_trainable_parameters(
708
  model,
709
  finetune_mode,
710
- train_lora_adapters=enable_backbone_lora,
711
  )
712
  total_params, trainable_params = _count_parameters(model)
713
  print(f"[mode={finetune_mode}] trainable params: {trainable_params:,} / {total_params:,}")
@@ -875,15 +918,19 @@ def main(args, ds_init):
875
  lambda_t=args.lambda_t,
876
  lambda_st=args.lambda_st,
877
  lambda_comm=args.lambda_comm,
 
878
  surface_var_weights=surface_var_weights,
879
  upper_var_weights=upper_var_level_weights,
880
  upper_level_weights=None,
881
  interior_only_1h=args.interior_only_1h,
 
 
882
  spatial_corrector=_get_spatial_corrector(model),
883
  temporal_corrector=_get_temporal_corrector(model),
884
  spatial_correction_scale=args.spatial_correction_scale,
885
  temporal_correction_scale=args.temporal_correction_scale,
886
- enable_backbone_lora=enable_backbone_lora,
 
887
  print_freq=args.val_print_freq,
888
  val_max_steps=args.val_max_steps,
889
  )
@@ -908,6 +955,10 @@ def main(args, ds_init):
908
  f"[factorized] correction scales: spatial={args.spatial_correction_scale:.2e}, "
909
  f"temporal={args.temporal_correction_scale:.2e}"
910
  )
 
 
 
 
911
  start_time = time.time()
912
  train_time_only = 0
913
  total_step = args.epochs * num_training_steps_per_epoch
@@ -929,15 +980,19 @@ def main(args, ds_init):
929
  surf_vars=surf_vars, upper_vars=upper_vars,
930
  lambda_s=args.lambda_s, lambda_t=args.lambda_t,
931
  lambda_st=args.lambda_st, lambda_comm=args.lambda_comm,
 
932
  surface_var_weights=surface_var_weights,
933
  upper_var_weights=upper_var_level_weights,
934
  upper_level_weights=None,
935
  interior_only_1h=args.interior_only_1h,
 
 
936
  spatial_corrector=_get_spatial_corrector(model),
937
  temporal_corrector=_get_temporal_corrector(model),
938
  spatial_correction_scale=args.spatial_correction_scale,
939
  temporal_correction_scale=args.temporal_correction_scale,
940
- enable_backbone_lora=enable_backbone_lora,
 
941
  accum_backward_mode=args.accum_backward_mode,
942
  use_ours=use_ours, total_step=total_step
943
  )
@@ -948,7 +1003,19 @@ def main(args, ds_init):
948
  args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
949
  loss_scaler=loss_scaler, epoch=epoch)
950
 
951
- if data_loader_val is not None:
 
 
 
 
 
 
 
 
 
 
 
 
952
  test_stats = validation_one_epoch_factorized(
953
  data_loader_val,
954
  model,
@@ -966,31 +1033,49 @@ def main(args, ds_init):
966
  lambda_t=args.lambda_t,
967
  lambda_st=args.lambda_st,
968
  lambda_comm=args.lambda_comm,
 
969
  surface_var_weights=surface_var_weights,
970
  upper_var_weights=upper_var_level_weights,
971
  upper_level_weights=None,
972
  interior_only_1h=args.interior_only_1h,
 
 
973
  spatial_corrector=_get_spatial_corrector(model),
974
  temporal_corrector=_get_temporal_corrector(model),
975
  spatial_correction_scale=args.spatial_correction_scale,
976
  temporal_correction_scale=args.temporal_correction_scale,
977
- enable_backbone_lora=enable_backbone_lora,
 
978
  print_freq=args.val_print_freq,
979
  val_max_steps=args.val_max_steps,
980
  )
 
 
 
 
 
 
 
 
 
981
  log_stats = {'epoch': epoch,
982
  **{f'train_{k}': v for k, v in train_stats.items()}}
983
- test_log = {**{f'val_{k}': _to_float(v) for k, v in test_stats.items()}}
984
-
985
- copy_log = " &"+" &".join([str(round(_to_float(v), 3)) for k, v in test_stats.items() if k!="valid_loss"])
 
 
 
986
 
987
  if args.output_dir and utils.is_main_process():
988
  if log_writer is not None:
989
  log_writer.flush()
990
  with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
991
  f.write(json.dumps(log_stats) + "\n")
992
- f.write(json.dumps(test_log) + "\n")
993
- f.write(json.dumps(copy_log) + "\n" + "\n")
 
 
994
 
995
  total_time = time.time() - start_time
996
  total_time_str = str(datetime.timedelta(seconds=int(total_time)))
 
87
 
88
  def _attach_factorized_heads(model, surf_vars, upper_vars, level, args):
89
  """Attach factorized temporal/spatial correction heads."""
90
+ temporal_num_steps = 7 if args.include_endpoints else 6
91
  model.temporal_corrector = TemporalCorrectionHead(
92
  num_surface_vars=len(surf_vars),
93
  num_upper_vars=len(upper_vars),
 
95
  hidden_surface=args.temporal_hidden_surface,
96
  hidden_upper=args.temporal_hidden_upper,
97
  dropout=args.temporal_dropout,
98
+ head_mode=args.temporal_head_mode,
99
+ num_time_steps=temporal_num_steps,
100
  )
101
  model.spatial_corrector = SpatialCorrectionHead(
102
  num_surface_vars=len(surf_vars),
 
330
  help='validation progress print frequency (iterations)')
331
  parser.add_argument('--val_max_steps', default=-1, type=int,
332
  help='if >0, limit validation iterations per epoch for faster feedback')
333
+ parser.add_argument('--val_freq', default=1, type=int,
334
+ help='run validation every N epochs (1 means every epoch)')
335
  parser.add_argument('--persistent_workers', action='store_true', default=True,
336
  help='keep dataloader workers persistent across epochs')
337
  parser.add_argument('--no_persistent_workers', action='store_false', dest='persistent_workers')
 
371
  parser.add_argument('--lora_mode', default='single', choices=['single', 'all'])
372
  parser.add_argument('--disable_backbone_lora_in_factorized', action='store_true',
373
  help='If set in finetune_mode=lora, freeze LoRA adapters and train only correction heads.')
374
+ parser.add_argument('--disable_backbone_prior_in_factorized', action='store_true',
375
+ help='If set, do not use Aurora pretrained 6h endpoint prediction as factorized prior/anchor.')
376
  parser.add_argument('--lr_data_dir', default='', type=str,
377
  help='LR dataset folder under datasets (auto if empty), e.g. 5.625')
378
  parser.add_argument('--hr_data_dir', default='', type=str,
 
393
  parser.add_argument('--derive_lr1h_from_hr1h', action='store_true', default=True,
394
  help='Build y_lr1h from y_hr1h in training loop to reduce dataset IO.')
395
  parser.add_argument('--no_derive_lr1h_from_hr1h', action='store_false', dest='derive_lr1h_from_hr1h')
396
+ parser.add_argument('--include_endpoints', action='store_true', default=False,
397
+ help='If set, use 1h targets [t+0..t+6]. Default uses [t+1..t+6].')
398
+ parser.add_argument('--no_include_endpoints', action='store_false', dest='include_endpoints')
399
  parser.add_argument('--lambda_s', default=1.0, type=float)
400
  parser.add_argument('--lambda_t', default=1.0, type=float)
401
  parser.add_argument('--lambda_st', default=1.0, type=float)
402
  parser.add_argument('--lambda_comm', default=1.0, type=float)
403
+ parser.add_argument('--lambda_backbone', default=0.1, type=float,
404
+ help='Weight for Aurora-6h anchor loss on unanchored S-module 6h endpoint prediction.')
405
+ parser.add_argument('--lt_supervision_source', default='lr_path', choices=['lr_path', 'st_downscaled', 'hybrid'],
406
+ help='Source for L_t supervision: direct LR temporal path or ST(HR) downscaled to LR.')
407
+ parser.add_argument('--lt_hybrid_alpha', default=0.5, type=float,
408
+ help='When lt_supervision_source=hybrid, weight on direct lr_path term (0..1).')
409
  parser.add_argument('--interior_only_1h', action='store_true',
410
  help='If set, compute 1h losses on interior frames only (exclude endpoints)')
411
  parser.add_argument('--ours_prompt_length', default=0, type=int,
 
416
  help='Hidden channels for temporal correction head (upper vars x levels).')
417
  parser.add_argument('--temporal_dropout', default=0.0, type=float,
418
  help='Dropout for temporal correction head.')
419
+ parser.add_argument('--temporal_head_mode', default='shared3d', choices=['shared3d', 'delta2d'],
420
+ help='Temporal correction head type: shared 3D conv or per-delta 2D heads.')
421
  parser.add_argument('--spatial_hidden_surface', default=256, type=int,
422
  help='Hidden channels for spatial correction head (surface vars).')
423
  parser.add_argument('--spatial_hidden_upper', default=512, type=int,
 
428
  help='Residual scale multiplier for spatial correction outputs.')
429
  parser.add_argument('--temporal_correction_scale', default=1e-3, type=float,
430
  help='Residual scale multiplier for temporal correction outputs.')
431
+ parser.add_argument('--allow_temporal_endpoint_correction', action='store_true',
432
+ help='Allow temporal correction head on offset=0 frame when include_endpoints is used.')
433
  parser.add_argument('--disable_factorized_lr_scale', action='store_true',
434
  help='Disable LR scaling by total_batch_size/256 for factorized training.')
435
  parser.add_argument('--disable_scale_balanced_loss', action='store_true',
 
581
  )
582
 
583
  print("Sampler_train = %s" % str(sampler_train))
584
+ if args.distributed and (dataset_val is not None) and (not args.dist_eval):
585
+ args.dist_eval = True
586
+ print("[validation] Distributed training detected; enabling --dist_eval to shard validation across ranks.")
587
+
588
  if args.dist_eval:
589
  if len(dataset_val) % num_tasks != 0:
590
  print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
 
629
  f"[factorized-data] derive_hr6h_from_hr1h={args.derive_hr6h_from_hr1h}, "
630
  f"derive_lr1h_from_hr1h={args.derive_lr1h_from_hr1h}"
631
  )
632
+ if args.derive_lr1h_from_hr1h:
633
+ print("[factorized-data] y_lr1h generation: conservative coarsen from y_hr1h (engine-side)")
634
 
635
  if dataset_val is not None:
636
  val_loader_kwargs = dict(
 
651
  f"prefetch_factor={args.prefetch_factor if args.num_workers > 0 else 'n/a'}, "
652
  f"pin_mem={args.pin_mem}"
653
  )
654
+ print(
655
+ f"[validation] samples={len(dataset_val)}, batches={len(data_loader_val)}, "
656
+ f"val_freq={args.val_freq}, val_max_steps={args.val_max_steps}"
657
+ )
658
  else:
659
  data_loader_val = None
660
 
 
694
 
695
  use_ours = finetune_mode == "ours"
696
  use_lora = finetune_mode == "lora"
697
+ train_lora_adapters = finetune_mode == "lora" and (not args.disable_backbone_lora_in_factorized)
698
+ use_backbone_prior = not args.disable_backbone_prior_in_factorized
699
  if finetune_mode == "lora":
700
  if args.lora_mode == "all" and args.lora_steps > 1:
701
  print(
702
  "[lora mode] NOTE: factorized path uses single-step Aurora forward (rollout_step=0), "
703
  "so only step-0 LoRA branch is used even when lora_mode=all."
704
  )
705
+ print(f"[lora mode] train LoRA adapters: {train_lora_adapters}")
706
+ print(f"[factorized] use backbone 6h prior: {use_backbone_prior} (lambda_backbone={args.lambda_backbone})")
707
+ print(
708
+ "[factorized] loss lambdas: "
709
+ f"lambda_s={args.lambda_s}, "
710
+ f"lambda_t={args.lambda_t}, "
711
+ f"lambda_st={args.lambda_st}, "
712
+ f"lambda_comm={args.lambda_comm}, "
713
+ f"lambda_backbone={args.lambda_backbone}"
714
+ )
715
+ print(
716
+ f"[factorized] L_t supervision source: {args.lt_supervision_source} "
717
+ f"(hybrid alpha={args.lt_hybrid_alpha:.2f})"
718
+ )
719
 
720
  model = Aurora(
721
  use_ours=use_ours,
 
743
  print(
744
  "[factorized] Attached correction heads "
745
  f"(temporal: {args.temporal_hidden_surface}/{args.temporal_hidden_upper}, "
746
+ f"temporal_mode={args.temporal_head_mode}, "
747
  f"spatial: {args.spatial_hidden_surface}/{args.spatial_hidden_upper})."
748
  )
749
 
750
  _set_trainable_parameters(
751
  model,
752
  finetune_mode,
753
+ train_lora_adapters=train_lora_adapters,
754
  )
755
  total_params, trainable_params = _count_parameters(model)
756
  print(f"[mode={finetune_mode}] trainable params: {trainable_params:,} / {total_params:,}")
 
918
  lambda_t=args.lambda_t,
919
  lambda_st=args.lambda_st,
920
  lambda_comm=args.lambda_comm,
921
+ lambda_backbone=args.lambda_backbone,
922
  surface_var_weights=surface_var_weights,
923
  upper_var_weights=upper_var_level_weights,
924
  upper_level_weights=None,
925
  interior_only_1h=args.interior_only_1h,
926
+ lt_supervision_source=args.lt_supervision_source,
927
+ lt_hybrid_alpha=args.lt_hybrid_alpha,
928
  spatial_corrector=_get_spatial_corrector(model),
929
  temporal_corrector=_get_temporal_corrector(model),
930
  spatial_correction_scale=args.spatial_correction_scale,
931
  temporal_correction_scale=args.temporal_correction_scale,
932
+ temporal_fix_endpoints=(not args.allow_temporal_endpoint_correction),
933
+ enable_backbone_lora=use_backbone_prior,
934
  print_freq=args.val_print_freq,
935
  val_max_steps=args.val_max_steps,
936
  )
 
955
  f"[factorized] correction scales: spatial={args.spatial_correction_scale:.2e}, "
956
  f"temporal={args.temporal_correction_scale:.2e}"
957
  )
958
+ print(
959
+ f"[factorized] temporal endpoint correction enabled: "
960
+ f"{args.allow_temporal_endpoint_correction}"
961
+ )
962
  start_time = time.time()
963
  train_time_only = 0
964
  total_step = args.epochs * num_training_steps_per_epoch
 
980
  surf_vars=surf_vars, upper_vars=upper_vars,
981
  lambda_s=args.lambda_s, lambda_t=args.lambda_t,
982
  lambda_st=args.lambda_st, lambda_comm=args.lambda_comm,
983
+ lambda_backbone=args.lambda_backbone,
984
  surface_var_weights=surface_var_weights,
985
  upper_var_weights=upper_var_level_weights,
986
  upper_level_weights=None,
987
  interior_only_1h=args.interior_only_1h,
988
+ lt_supervision_source=args.lt_supervision_source,
989
+ lt_hybrid_alpha=args.lt_hybrid_alpha,
990
  spatial_corrector=_get_spatial_corrector(model),
991
  temporal_corrector=_get_temporal_corrector(model),
992
  spatial_correction_scale=args.spatial_correction_scale,
993
  temporal_correction_scale=args.temporal_correction_scale,
994
+ temporal_fix_endpoints=(not args.allow_temporal_endpoint_correction),
995
+ enable_backbone_lora=use_backbone_prior,
996
  accum_backward_mode=args.accum_backward_mode,
997
  use_ours=use_ours, total_step=total_step
998
  )
 
1003
  args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
1004
  loss_scaler=loss_scaler, epoch=epoch)
1005
 
1006
+ should_validate = (
1007
+ data_loader_val is not None
1008
+ and (args.val_freq > 0)
1009
+ and (((epoch + 1) % args.val_freq == 0) or (epoch + 1 == args.epochs))
1010
+ )
1011
+ test_stats = None
1012
+ if should_validate:
1013
+ val_start_time = time.time()
1014
+ val_rank_samples = len(data_loader_val.sampler) if hasattr(data_loader_val, "sampler") else len(data_loader_val.dataset)
1015
+ print(
1016
+ f"[validation] epoch={epoch} start: rank_samples={val_rank_samples}, "
1017
+ f"batches={len(data_loader_val)}"
1018
+ )
1019
  test_stats = validation_one_epoch_factorized(
1020
  data_loader_val,
1021
  model,
 
1033
  lambda_t=args.lambda_t,
1034
  lambda_st=args.lambda_st,
1035
  lambda_comm=args.lambda_comm,
1036
+ lambda_backbone=args.lambda_backbone,
1037
  surface_var_weights=surface_var_weights,
1038
  upper_var_weights=upper_var_level_weights,
1039
  upper_level_weights=None,
1040
  interior_only_1h=args.interior_only_1h,
1041
+ lt_supervision_source=args.lt_supervision_source,
1042
+ lt_hybrid_alpha=args.lt_hybrid_alpha,
1043
  spatial_corrector=_get_spatial_corrector(model),
1044
  temporal_corrector=_get_temporal_corrector(model),
1045
  spatial_correction_scale=args.spatial_correction_scale,
1046
  temporal_correction_scale=args.temporal_correction_scale,
1047
+ temporal_fix_endpoints=(not args.allow_temporal_endpoint_correction),
1048
+ enable_backbone_lora=use_backbone_prior,
1049
  print_freq=args.val_print_freq,
1050
  val_max_steps=args.val_max_steps,
1051
  )
1052
+ val_elapsed = time.time() - val_start_time
1053
+ val_steps_done = min(len(data_loader_val), args.val_max_steps) if args.val_max_steps > 0 else len(data_loader_val)
1054
+ print(
1055
+ f"[validation] epoch={epoch} done in {val_elapsed:.1f}s "
1056
+ f"(~{val_elapsed / max(val_rank_samples, 1):.3f}s/sample on this rank, "
1057
+ f"{val_steps_done} batches)."
1058
+ )
1059
+ elif data_loader_val is not None:
1060
+ print(f"[validation] epoch={epoch} skipped (val_freq={args.val_freq}).")
1061
  log_stats = {'epoch': epoch,
1062
  **{f'train_{k}': v for k, v in train_stats.items()}}
1063
+ if test_stats is not None:
1064
+ test_log = {**{f'val_{k}': _to_float(v) for k, v in test_stats.items()}}
1065
+ copy_log = " &"+" &".join([str(round(_to_float(v), 3)) for k, v in test_stats.items() if k!="valid_loss"])
1066
+ else:
1067
+ test_log = None
1068
+ copy_log = None
1069
 
1070
  if args.output_dir and utils.is_main_process():
1071
  if log_writer is not None:
1072
  log_writer.flush()
1073
  with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
1074
  f.write(json.dumps(log_stats) + "\n")
1075
+ if test_log is not None:
1076
+ f.write(json.dumps(test_log) + "\n")
1077
+ f.write(json.dumps(copy_log) + "\n")
1078
+ f.write("\n")
1079
 
1080
  total_time = time.time() - start_time
1081
  total_time_str = str(datetime.timedelta(seconds=int(total_time)))
downscaling/script_run_downscaling.sh CHANGED
@@ -10,6 +10,22 @@ VAL_END_DATE="2018"
10
  OUTPUT_DIR="output/WeatherPEFT"
11
  GRAD_ACCUM_STEPS=8
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  OMP_NUM_THREADS=1 python -m torch.distributed.run --nproc_per_node=2 \
14
  --master_port 12326 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" \
15
  run_downscaling.py \
@@ -21,18 +37,19 @@ OMP_NUM_THREADS=1 python -m torch.distributed.run --nproc_per_node=2 \
21
  --output_dir ${OUTPUT_DIR} \
22
  --batch_size 2 \
23
  --grad_accum_steps ${GRAD_ACCUM_STEPS} \
24
- --val_batch_size 5 \
25
  --num_workers 8 \
26
  --persistent_workers \
27
  --prefetch_factor 4 \
28
  --save_ckpt_freq 10 \
29
  --opt adamw \
30
- --lr 7e-4 \
31
  --opt_betas 0.9 0.999 \
32
  --weight_decay 0.05 \
33
  --clip_grad 1.0 \
34
  --warmup_epochs 3 \
35
  --epochs 30 \
 
36
  --dist_eval \
37
  --mode lora \
38
  --finetune_mode lora \
@@ -44,9 +61,22 @@ OMP_NUM_THREADS=1 python -m torch.distributed.run --nproc_per_node=2 \
44
  --hr1h_data_dir 1.5_1h_nc \
45
  --lr_degree_tag 5.625 \
46
  --hr_degree_tag 1.5 \
 
47
  --norm_mean_path ../aux_data/normalize_mean_wb2.npz \
48
  --norm_std_path ../aux_data/normalize_std_wb2.npz \
49
- --norm_std_floor 1e-6 \
 
 
 
 
 
 
 
 
 
 
 
 
50
  --no_auto_resume \
51
  --spatial_correction_scale 1.0 \
52
  --temporal_correction_scale 1.0 \
 
10
  OUTPUT_DIR="output/WeatherPEFT"
11
  GRAD_ACCUM_STEPS=8
12
 
13
+ # Preserve-first recipe:
14
+ # 1) keep pretrained 6h behavior strong at start
15
+ # 2) learn 1h~5h corrections gradually
16
+ LAMBDA_BACKBONE=1.0
17
+ LAMBDA_S=1.0
18
+ LAMBDA_T=1.0
19
+ LAMBDA_ST=1.0
20
+ LAMBDA_COMM=1.0
21
+ LT_SUPERVISION_SOURCE="lr_path"
22
+ SPATIAL_HIDDEN_SURFACE=64
23
+ SPATIAL_HIDDEN_UPPER=128
24
+ TEMPORAL_HIDDEN_SURFACE=64
25
+ TEMPORAL_HIDDEN_UPPER=128
26
+ TEMPORAL_HEAD_MODE="delta2d"
27
+ # backbone prior is ON by default (do not pass --disable_backbone_prior_in_factorized)
28
+
29
  OMP_NUM_THREADS=1 python -m torch.distributed.run --nproc_per_node=2 \
30
  --master_port 12326 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" \
31
  run_downscaling.py \
 
37
  --output_dir ${OUTPUT_DIR} \
38
  --batch_size 2 \
39
  --grad_accum_steps ${GRAD_ACCUM_STEPS} \
40
+ --val_batch_size 2 \
41
  --num_workers 8 \
42
  --persistent_workers \
43
  --prefetch_factor 4 \
44
  --save_ckpt_freq 10 \
45
  --opt adamw \
46
+ --lr 5e-4 \
47
  --opt_betas 0.9 0.999 \
48
  --weight_decay 0.05 \
49
  --clip_grad 1.0 \
50
  --warmup_epochs 3 \
51
  --epochs 30 \
52
+ --val_freq 1 \
53
  --dist_eval \
54
  --mode lora \
55
  --finetune_mode lora \
 
61
  --hr1h_data_dir 1.5_1h_nc \
62
  --lr_degree_tag 5.625 \
63
  --hr_degree_tag 1.5 \
64
+ --no_include_endpoints \
65
  --norm_mean_path ../aux_data/normalize_mean_wb2.npz \
66
  --norm_std_path ../aux_data/normalize_std_wb2.npz \
67
+ --norm_std_floor 1e-8 \
68
+ --lambda_s ${LAMBDA_S} \
69
+ --lambda_t ${LAMBDA_T} \
70
+ --lambda_st ${LAMBDA_ST} \
71
+ --lambda_comm ${LAMBDA_COMM} \
72
+ --lambda_backbone ${LAMBDA_BACKBONE} \
73
+ --lt_supervision_source ${LT_SUPERVISION_SOURCE} \
74
+ --disable_scale_balanced_loss \
75
+ --spatial_hidden_surface ${SPATIAL_HIDDEN_SURFACE} \
76
+ --spatial_hidden_upper ${SPATIAL_HIDDEN_UPPER} \
77
+ --temporal_hidden_surface ${TEMPORAL_HIDDEN_SURFACE} \
78
+ --temporal_hidden_upper ${TEMPORAL_HIDDEN_UPPER} \
79
+ --temporal_head_mode ${TEMPORAL_HEAD_MODE} \
80
  --no_auto_resume \
81
  --spatial_correction_scale 1.0 \
82
  --temporal_correction_scale 1.0 \
downscaling/script_smoke_feasibility.sh CHANGED
@@ -48,6 +48,7 @@ for MODE in "${MODES[@]}"; do
48
  --hr1h_data_dir 1.5_1h_nc \
49
  --lr_degree_tag 5.625 \
50
  --hr_degree_tag 1.5 \
 
51
  --norm_mean_path ../aux_data/normalize_mean_wb2.npz \
52
  --norm_std_path ../aux_data/normalize_std_wb2.npz \
53
  --norm_std_floor 1e-8 \
 
48
  --hr1h_data_dir 1.5_1h_nc \
49
  --lr_degree_tag 5.625 \
50
  --hr_degree_tag 1.5 \
51
+ --no_include_endpoints \
52
  --norm_mean_path ../aux_data/normalize_mean_wb2.npz \
53
  --norm_std_path ../aux_data/normalize_std_wb2.npz \
54
  --norm_std_floor 1e-8 \