| |
| |
|
|
| """ |
| evaluate_aurora.py (MULTI-DAY TEST EVAL) |
| |
| - ๋ชจ๋ surf/atmos ๋ณ์์ ๋ํด ๋ชจ๋ lead time๋ณ ์๋๊ฐ์ค RMSE/ACC ๊ณ์ฐ |
| - (atmos) ๊ฐ ๋ ๋ฒจ๋ณ(per-level) RMSE/ACC ์ถ๊ฐ ๊ณ์ฐ/๋ก๊น
|
| - ํ
์คํธ ๋ฒ์๋ฅผ ์ฐ/์ผ์๋ก ์ง์ |
| - ํต์ฌ: ๋ฐ์ดํฐ๋ก๋๋ ์งง์ steps๋ก ์
๋ ฅ๋ง ๊ตฌ์ฑํ๊ณ , |
| ์ง์ง GT๋ ๋ฆฌ๋ํ์๋ง๋ค NetCDF์์ '์ฆ์ ๋ก๋'ํ์ฌ ๋ฉํฐ๋ฐ์ด ํ๊ฐ ์ง์ |
| """ |
|
|
| import os |
| import math |
| import argparse |
| import xarray as xr |
| from pathlib import Path |
| from collections import defaultdict |
| from typing import Dict, List, Tuple, Optional |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from torch.utils.tensorboard import SummaryWriter |
| from tqdm.auto import tqdm |
|
|
| |
| from train_aurora_15 import ( |
| setup_model, build_loader, Climatology, |
| lat_weights, rmse_lat_weighted, acc_lat_weighted, rollout, |
| ddp_is_init, init_ddp, get_rank, ar_mean |
| ) |
| |
| from train_aurora_15 import ( |
| find_coord_names, find_time_name, find_level_name, |
| ensure_order, get_var, transpose_lat_lon_last |
| ) |
|
|
| DEFAULT_SURF_VARS = ("2t", "10u", "10v", "msl") |
| DEFAULT_ATMOS_VARS = ("t", "u", "v", "q", "z") |
|
|
| SURF_NAMES_SHORTS = { |
| "2t": (["2m_temperature"], ["t2m", "2t"]), |
| "10u": (["10m_u_component_of_wind"], ["u10", "10u"]), |
| "10v": (["10m_v_component_of_wind"], ["v10", "10v"]), |
| "msl": (["mean_sea_level_pressure"], ["msl"]), |
| } |
| ATMOS_NAMES_SHORTS = { |
| "t": (["temperature"], ["t"]), |
| "u": (["u_component_of_wind"], ["u"]), |
| "v": (["v_component_of_wind"], ["v"]), |
| "q": (["specific_humidity"], ["q", "hus", "q_specific"]), |
| "z": (["geopotential"], ["z", "gh"]), |
| } |
|
|
| def parse_years(s: str) -> List[int]: |
| s = s.strip() |
| if "-" in s: |
| a, b = s.split("-") |
| return list(range(int(a), int(b) + 1)) |
| return [int(x) for x in s.split(",") if x.strip()] |
|
|
| |
| def _open_one_step_from_base(base_dir: Path, day: str, hour: int): |
| """ |
| base_dir/era5_240_{YYYY}_{hour}h.nc ๋ฅผ ์ง์ ์ด์ด์ |
| ํด๋น day T hour ์ ๋จ์ผ ์๊ฐ slice๋ฅผ ๋ฐํ. |
| (Dataset์ self.paths ์ฐ๋ ์ ํ์ ์ฐํ) |
| """ |
| y = int(day[:4]) |
| nc = base_dir / f"era5_240_{y}_{hour}h.nc" |
| if not nc.exists(): |
| raise FileNotFoundError(f"GT file not found for {day} {hour:02d}h: {nc}") |
| ds = xr.open_dataset(str(nc), engine="netcdf4") |
|
|
| tname = find_time_name(ds) |
| tgt = np.datetime64(f"{day}T{hour:02d}:00:00").astype("datetime64[s]") |
| times = ds[tname].values.astype("datetime64[s]") |
| idx = np.where(times == tgt)[0] |
| if idx.size == 0: |
| j = int(np.argmin(np.abs(times - tgt))) |
| ds = ds.isel({tname: j}) |
| else: |
| ds = ds.isel({tname: int(idx[0])}) |
| ds = ds.expand_dims({tname: 1}) |
| return ds |
|
|
|
|
| def _extract_truth_for_time(ds, surf_vars, atmos_vars, target_H: int, target_W: int): |
| """ |
| ds: xr.Dataset (๋จ์ผ ์๊ฐ) |
| ๋ฐํ: |
| gt_surf_step: {var: torch.Tensor(H,W)} |
| gt_atm_step: {var: torch.Tensor(L,H,W)} |
| """ |
| lat, lon = find_coord_names(ds) |
| ds = ensure_order(ds, lat, lon) |
| tname = find_time_name(ds) |
|
|
| gt_surf_step = {} |
| for v in surf_vars: |
| names, shorts = SURF_NAMES_SHORTS.get(v, ([], [])) |
| da = transpose_lat_lon_last(get_var(ds, names=names, shorts=shorts), lat, lon) |
| arr = da.transpose(tname, lat, lon).values |
| gt_surf_step[v] = torch.from_numpy(arr[0, :target_H, :target_W].astype(np.float32)) |
|
|
| gt_atm_step = {} |
| for v in atmos_vars: |
| names, shorts = ATMOS_NAMES_SHORTS.get(v, ([], [])) |
| da = transpose_lat_lon_last(get_var(ds, names=names, shorts=shorts), lat, lon) |
| lvl = find_level_name(ds) |
| arr = da.transpose(tname, lvl, lat, lon).values |
| gt_atm_step[v] = torch.from_numpy(arr[0, :, :target_H, :target_W].astype(np.float32)) |
|
|
| return gt_surf_step, gt_atm_step |
|
|
| def _compute_day_hour_from_valid_time(valid_time: np.datetime64) -> Tuple[str, int]: |
| ts = pd.Timestamp(str(valid_time)) |
| return ts.strftime("%Y-%m-%d"), int(ts.hour) |
|
|
| def main(): |
| ap = argparse.ArgumentParser(description="Aurora TEST evaluation (multi-day GT)") |
| |
| ap.add_argument("--base_dir", type=Path, required=True) |
| ap.add_argument("--test_years", type=str, default="2023") |
| ap.add_argument("--max_test_days", type=int, default=0) |
| ap.add_argument("--input_len", type=int, default=2) |
| ap.add_argument("--steps", type=int, default=28, help="ํ๊ฐ ๋ฆฌ๋ํ์ ๊ฐ์") |
| ap.add_argument("--lead_hour_stride", type=int, default=6, help="๋ฆฌ๋ํ์ ์๊ฐ ๊ฐ๊ฒฉ") |
| ap.add_argument("--lead_offset_mode", type=str, default="first", choices=["first","rolling","list"], |
| help="๋ก๋๋ ์
๋ ฅ๋ง ํ์ํ๋ฏ๋ก first ๊ถ์ฅ") |
| ap.add_argument("--lead_offset_list", type=str, default="") |
| ap.add_argument("--cross_day_extra", type=int, default=2) |
| |
| ap.add_argument("--precision", type=str, default="bf16", choices=["bf16","fp16","fp32"]) |
| ap.add_argument("--batch_size", type=int, default=1) |
| ap.add_argument("--num_workers", type=int, default=2) |
| |
| ap.add_argument("--out_dir", type=Path, default=Path("./out_test")) |
| ap.add_argument("--tb_dir", type=Path, default=Path("./tb_test")) |
| |
| ap.add_argument("--deg_old", type=float, default=0.25) |
| ap.add_argument("--deg_new", type=float, default=1.5) |
| ap.add_argument("--patch_new", type=int, default=1) |
| ap.add_argument("--win_orig", type=str, default="2,6,12") |
| ap.add_argument("--win_min_hw", type=int, default=2) |
| ap.add_argument("--finetune_mode", type=str, default="lora", choices=["lora","hybrid","full"]) |
| ap.add_argument("--lora_mode", type=str, default="single", choices=["single","all","from_second"]) |
| ap.add_argument("--act_ckpt", action="store_true") |
| ap.add_argument("--disable_droppath", action="store_true") |
| ap.add_argument("--local_rank", type=int, default=0) |
| |
| ap.add_argument("--ckpt_source", type=str, default="hf", choices=["hf", "local"]) |
| ap.add_argument("--ckpt_repo", type=str, default="microsoft/aurora") |
| ap.add_argument("--ckpt_name", type=str, default="aurora-0.25-small-pretrained.ckpt") |
| ap.add_argument("--ckpt_revision", type=str, default=None) |
| ap.add_argument("--ckpt_path", type=Path, default=None) |
| ap.add_argument("--patch_old", type=int, default=-1) |
| ap.add_argument("--force_no_adapt", action="store_true") |
| |
| ap.add_argument("--clim_path", type=Path, default=None) |
|
|
| args = ap.parse_args() |
| test_years = parse_years(args.test_years) |
|
|
| |
| ddp = ("RANK" in os.environ and "WORLD_SIZE" in os.environ) |
| if ddp: |
| local_rank = init_ddp(args.local_rank) |
| else: |
| local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank)) |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(local_rank) |
|
|
| torch.backends.cudnn.benchmark = True |
| device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") |
|
|
| |
| model, patch = setup_model(args) |
| model = model.to(device) |
| model.eval() |
|
|
| |
| args.out_dir.mkdir(parents=True, exist_ok=True) |
| writer = SummaryWriter(str(args.tb_dir)) if get_rank() == 0 else None |
|
|
| |
| steps_for_loader = 1 |
| _, test_loader, test_sampler = build_loader( |
| args.base_dir, test_years, steps_for_loader, patch, args.max_test_days, |
| args.batch_size, args.num_workers, shuffle=False, |
| input_len=args.input_len, lead_offset_mode=args.lead_offset_mode, |
| lead_offset_list=args.lead_offset_list, cross_day_extra=args.cross_day_extra |
| ) |
| if ddp and test_sampler is not None: |
| test_sampler.set_epoch(0) |
|
|
| |
| clim = Climatology(args.clim_path) |
|
|
| |
| surf_vars = tuple(getattr(model, "surf_vars", DEFAULT_SURF_VARS)) |
| atmos_vars = tuple(getattr(model, "atmos_vars", DEFAULT_ATMOS_VARS)) |
|
|
| |
| metrics = { |
| "surf": defaultdict(lambda: defaultdict(lambda: {"rmse_sum": 0.0, "rmse_n": 0, |
| "acc_sum": 0.0, "acc_n": 0})), |
| "atmos": defaultdict(lambda: defaultdict(lambda: {"rmse_sum": 0.0, "rmse_n": 0, |
| "acc_sum": 0.0, "acc_n": 0})) |
| } |
| |
| metrics_levels = defaultdict( |
| lambda: defaultdict( |
| lambda: defaultdict( |
| lambda: {"rmse_sum": 0.0, "rmse_n": 0, "acc_sum": 0.0, "acc_n": 0} |
| ) |
| ) |
| ) |
|
|
| rows_overall, rows_levels = [], [] |
|
|
| |
| use_amp = args.precision in ("bf16", "fp16") |
| amp_dtype = (torch.bfloat16 if args.precision == "bf16" |
| else torch.float16 if args.precision == "fp16" else None) |
|
|
| if get_rank() == 0 and len(test_loader) == 0: |
| print("[WARN] ํ
์คํธ ๋ก๋๊ฐ ๋น์์ต๋๋ค. (์
๋ ฅ/์คํ
์กฐํฉ์ ํ์ธํ์ธ์) " |
| "์ด ์คํฌ๋ฆฝํธ๋ GT๋ฅผ ๋ฉํฐ๋ฐ์ด๋ก ์ง์ ๋ก๋ํ๋, steps_for_loader=1๋ก ์ ์งํด๋ ๋ฉ๋๋ค.") |
|
|
| pbar = tqdm(total=len(test_loader), disable=(get_rank() != 0), dynamic_ncols=True, desc="test") |
|
|
| with torch.no_grad(): |
| for batch_list in test_loader: |
| sample = batch_list[0] |
| batch = sample.batch.to(device) |
|
|
| |
| EVAL_STEPS = int(args.steps) |
|
|
| if use_amp: |
| with torch.amp.autocast('cuda', dtype=amp_dtype): |
| preds = list(rollout(model, batch, steps=EVAL_STEPS)) |
| else: |
| preds = list(rollout(model, batch, steps=EVAL_STEPS)) |
|
|
| |
| last_in_time = np.datetime64(batch.metadata.time[0], 's') |
| |
| valid_times = [ |
| last_in_time + np.timedelta64((s + 1) * args.lead_hour_stride, "h") |
| for s in range(EVAL_STEPS) |
| ] |
|
|
| |
| H, W = sample.lat.shape[0], sample.lon.shape[0] |
| w_lat = lat_weights(sample.lat).to(device) |
|
|
| |
| levels_tuple = tuple(getattr(batch.metadata, "atmos_levels", tuple(range(preds[0].atmos_vars[atmos_vars[0]][0,0].shape[0])))) |
|
|
| |
| gt_surf_all: List[Dict[str, torch.Tensor]] = [] |
| gt_atm_all: List[Dict[str, torch.Tensor]] = [] |
| |
| for s, vt in enumerate(valid_times): |
| day, hour = _compute_day_hour_from_valid_time(vt) |
| ds_one = _open_one_step_from_base(args.base_dir, day, hour) |
| gt_surf_s, gt_atm_s = _extract_truth_for_time(ds_one, surf_vars, atmos_vars, H, W) |
| gt_surf_all.append(gt_surf_s) |
| gt_atm_all.append(gt_atm_s) |
|
|
| |
| for s in range(EVAL_STEPS): |
| lead_hours = (s + 1) * args.lead_hour_stride |
| for v in surf_vars: |
| pred = preds[s].surf_vars[v][0, 0] |
| gt = gt_surf_all[s][v].to(pred.device) |
|
|
| rmse = rmse_lat_weighted(pred, gt, w_lat) |
| rmse_r = ar_mean(rmse) |
| rmse_val = float(rmse_r.item()) |
| metrics["surf"][v][s]["rmse_sum"] += rmse_val |
| metrics["surf"][v][s]["rmse_n"] += 1 |
|
|
| acc_val = float("nan") |
| if clim.ok: |
| names, shorts = SURF_NAMES_SHORTS.get(v, ([], [])) |
| clim_da = clim.get(names, shorts, valid_times[s]) |
| if clim_da is not None: |
| clim2 = torch.from_numpy( |
| clim_da.values[:H, :W].astype(np.float32) |
| ).to(pred.device) |
| acc = acc_lat_weighted(pred, gt, clim2, w_lat) |
| if acc is not None and torch.isfinite(acc): |
| acc_r = ar_mean(acc) |
| acc_val = float(acc_r.item()) |
| metrics["surf"][v][s]["acc_sum"] += acc_val |
| metrics["surf"][v][s]["acc_n"] += 1 |
|
|
| if writer: |
| writer.add_scalar(f"test/surf/{v}/rmse/lead_{lead_hours}h", rmse_val) |
| if not math.isnan(acc_val): |
| writer.add_scalar(f"test/surf/{v}/acc/lead_{lead_hours}h", acc_val) |
|
|
| |
| for s in range(EVAL_STEPS): |
| lead_hours = (s + 1) * args.lead_hour_stride |
| for v in atmos_vars: |
| pred3 = preds[s].atmos_vars[v][0, 0] |
| gt3 = gt_atm_all[s][v].to(pred3.device) |
|
|
| |
| rmse_levels = [rmse_lat_weighted(pred3[li], gt3[li], w_lat) |
| for li in range(pred3.shape[0])] |
| if rmse_levels: |
| rmse_mean_r = ar_mean(torch.mean(torch.stack(rmse_levels))) |
| rmse_overall = float(rmse_mean_r.item()) |
| metrics["atmos"][v][s]["rmse_sum"] += rmse_overall |
| metrics["atmos"][v][s]["rmse_n"] += 1 |
| else: |
| rmse_overall = float("nan") |
|
|
| |
| acc_overall = float("nan") |
| if clim.ok: |
| names, shorts = ATMOS_NAMES_SHORTS.get(v, ([], [])) |
| clim_da = clim.get(names, shorts, valid_times[s]) |
| if clim_da is not None and clim_da.values.ndim == 3: |
| clim_vals = clim_da.values |
| L = min(pred3.shape[0], clim_vals.shape[0]) |
| Hc = min(pred3.shape[1], clim_vals.shape[1]) |
| Wc = min(pred3.shape[2], clim_vals.shape[2]) |
| acc_levels = [] |
| for li in range(L): |
| c2 = torch.from_numpy(clim_vals[li, :Hc, :Wc].astype(np.float32)).to(pred3.device) |
| p2 = pred3[li, :Hc, :Wc] |
| g2 = gt3[li, :Hc, :Wc] |
| acc_li = acc_lat_weighted(p2, g2, c2, w_lat[:Hc]) |
| if acc_li is not None and torch.isfinite(acc_li): |
| acc_levels.append(acc_li) |
| if acc_levels: |
| acc_overall = float(ar_mean(torch.mean(torch.stack(acc_levels))).item()) |
| metrics["atmos"][v][s]["acc_sum"] += acc_overall |
| metrics["atmos"][v][s]["acc_n"] += 1 |
|
|
| if writer and not math.isnan(rmse_overall): |
| writer.add_scalar(f"test/atmos/{v}/rmse/lead_{lead_hours}h", rmse_overall) |
| if writer and not math.isnan(acc_overall): |
| writer.add_scalar(f"test/atmos/{v}/acc/lead_{lead_hours}h", acc_overall) |
|
|
| |
| Lp = pred3.shape[0] |
| for li in range(Lp): |
| level_val = int(levels_tuple[li]) if li < len(levels_tuple) else li |
| |
| rmse_li = float(ar_mean(rmse_lat_weighted(pred3[li], gt3[li], w_lat)).item()) |
| agg = metrics_levels[v][level_val][s] |
| agg["rmse_sum"] += rmse_li |
| agg["rmse_n"] += 1 |
| if writer: |
| writer.add_scalar(f"test/atmos/{v}/level_{level_val}/rmse/lead_{lead_hours}h", rmse_li) |
|
|
| |
| if clim.ok: |
| names, shorts = ATMOS_NAMES_SHORTS.get(v, ([], [])) |
| clim_da = clim.get(names, shorts, valid_times[s]) |
| if clim_da is not None and clim_da.values.ndim == 3: |
| clim_vals = clim_da.values |
| Lc = min(pred3.shape[0], clim_vals.shape[0]) |
| Hc = min(pred3.shape[1], clim_vals.shape[1]) |
| Wc = min(pred3.shape[2], clim_vals.shape[2]) |
| for li in range(Lc): |
| level_val = int(levels_tuple[li]) if li < len(levels_tuple) else li |
| c2 = torch.from_numpy(clim_vals[li, :Hc, :Wc].astype(np.float32)).to(pred3.device) |
| p2 = pred3[li, :Hc, :Wc] |
| g2 = gt3[li, :Hc, :Wc] |
| acc_li = acc_lat_weighted(p2, g2, c2, w_lat[:Hc]) |
| if acc_li is not None and torch.isfinite(acc_li): |
| acc_li_val = float(ar_mean(acc_li).item()) |
| agg = metrics_levels[v][level_val][s] |
| agg["acc_sum"] += acc_li_val |
| agg["acc_n"] += 1 |
| if writer: |
| writer.add_scalar(f"test/atmos/{v}/level_{level_val}/acc/lead_{lead_hours}h", acc_li_val) |
|
|
| if get_rank() == 0: |
| pbar.update(1) |
|
|
| if writer: |
| writer.flush() |
|
|
| |
| if get_rank() == 0: |
| from pandas import DataFrame |
| print("\n=== TEST RESULTS (lat-weighted | variable-level averages) ===") |
| print("Domain | Var | Lead(h) | RMSE(avg over samples) | ACC(avg over samples)") |
| print("-------+------+---------+------------------------+----------------------") |
|
|
| rows_overall = [] |
| for domain in ("surf","atmos"): |
| for v, per_lead in metrics[domain].items(): |
| for s, agg in per_lead.items(): |
| lead_hours = (s + 1) * args.lead_hour_stride |
| rmse_avg = (agg["rmse_sum"]/max(1,agg["rmse_n"])) if agg["rmse_n"]>0 else float("nan") |
| acc_avg = (agg["acc_sum"]/max(1,agg["acc_n"])) if agg["acc_n"]>0 else float("nan") |
| print(f"{domain:6s} | {v:4s} | {lead_hours:7d} | {rmse_avg:22.6f} | {acc_avg:20.6f}") |
| rows_overall.append({ |
| "domain": domain, "var": v, |
| "lead_step": s+1, "lead_hours": lead_hours, |
| "rmse": rmse_avg, "acc": acc_avg, |
| "rmse_count": agg["rmse_n"], "acc_count": agg["acc_n"], |
| }) |
|
|
| df_overall = DataFrame(rows_overall) |
| if not df_overall.empty: |
| df_overall = df_overall.sort_values(["domain","var","lead_step"]) |
| csv_path_overall = args.out_dir / "glora_test_metrics.csv" |
| df_overall.to_csv(csv_path_overall, index=False) |
|
|
| print("\n=== TEST RESULTS (lat-weighted | ATMOS PER-LEVEL) ===") |
| print("Var | Level | Lead(h) | RMSE(avg over samples) | ACC(avg over samples)") |
| print("-----+-------+---------+------------------------+----------------------") |
|
|
| rows_levels = [] |
| for v, per_level in metrics_levels.items(): |
| for level_val, per_lead in sorted(per_level.items(), key=lambda kv: kv[0]): |
| for s, agg in per_lead.items(): |
| lead_hours = (s + 1) * args.lead_hour_stride |
| rmse_avg = (agg["rmse_sum"]/max(1,agg["rmse_n"])) if agg["rmse_n"]>0 else float("nan") |
| acc_avg = (agg["acc_sum"]/max(1,agg["acc_n"])) if agg["acc_n"]>0 else float("nan") |
| print(f"{v:4s} | {level_val:5d} | {lead_hours:7d} | {rmse_avg:22.6f} | {acc_avg:20.6f}") |
| rows_levels.append({ |
| "domain":"atmos", "var":v, "level":int(level_val), |
| "lead_step": s+1, "lead_hours": lead_hours, |
| "rmse": rmse_avg, "acc": acc_avg, |
| "rmse_count": agg["rmse_n"], "acc_count": agg["acc_n"], |
| }) |
|
|
| df_levels = pd.DataFrame(rows_levels) |
| if not df_levels.empty: |
| df_levels = df_levels.sort_values(["var","level","lead_step"]) |
| csv_path_levels = args.out_dir / "glora_test_metrics_levels.csv" |
| df_levels.to_csv(csv_path_levels, index=False) |
|
|
| print(f"\nSaved CSV (overall): {csv_path_overall}") |
| print(f"Saved CSV (per-level): {csv_path_levels}") |
|
|
| if writer: |
| writer.close() |
|
|
| if ddp_is_init(): |
| torch.distributed.barrier() |
| torch.distributed.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|