aurora-public / evaluate_aurora.py
bidulki-99's picture
Add files using upload-large-folder tool
a310ddc verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
evaluate_aurora.py (MULTI-DAY TEST EVAL)
- ๋ชจ๋“  surf/atmos ๋ณ€์ˆ˜์— ๋Œ€ํ•ด ๋ชจ๋“  lead time๋ณ„ ์œ„๋„๊ฐ€์ค‘ RMSE/ACC ๊ณ„์‚ฐ
- (atmos) ๊ฐ ๋ ˆ๋ฒจ๋ณ„(per-level) RMSE/ACC ์ถ”๊ฐ€ ๊ณ„์‚ฐ/๋กœ๊น…
- ํ…Œ์ŠคํŠธ ๋ฒ”์œ„๋ฅผ ์—ฐ/์ผ์ˆ˜๋กœ ์ง€์ •
- ํ•ต์‹ฌ: ๋ฐ์ดํ„ฐ๋กœ๋”๋Š” ์งง์€ steps๋กœ ์ž…๋ ฅ๋งŒ ๊ตฌ์„ฑํ•˜๊ณ ,
์ง„์งœ GT๋Š” ๋ฆฌ๋“œํƒ€์ž„๋งˆ๋‹ค NetCDF์—์„œ '์ฆ‰์‹œ ๋กœ๋“œ'ํ•˜์—ฌ ๋ฉ€ํ‹ฐ๋ฐ์ด ํ‰๊ฐ€ ์ง€์›
"""
import os
import math
import argparse
import xarray as xr
from pathlib import Path
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
# ---- ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์˜ ํ•จ์ˆ˜/ํด๋ž˜์Šค ์žฌ์‚ฌ์šฉ ----
from train_aurora_15 import (
setup_model, build_loader, Climatology,
lat_weights, rmse_lat_weighted, acc_lat_weighted, rollout,
ddp_is_init, init_ddp, get_rank, ar_mean
)
# ์ขŒํ‘œ/๋ณ€์ˆ˜ ์œ ํ‹ธ๋„ ์žฌ์‚ฌ์šฉ
from train_aurora_15 import (
find_coord_names, find_time_name, find_level_name,
ensure_order, get_var, transpose_lat_lon_last
)
DEFAULT_SURF_VARS = ("2t", "10u", "10v", "msl")
DEFAULT_ATMOS_VARS = ("t", "u", "v", "q", "z")
SURF_NAMES_SHORTS = {
"2t": (["2m_temperature"], ["t2m", "2t"]),
"10u": (["10m_u_component_of_wind"], ["u10", "10u"]),
"10v": (["10m_v_component_of_wind"], ["v10", "10v"]),
"msl": (["mean_sea_level_pressure"], ["msl"]),
}
ATMOS_NAMES_SHORTS = {
"t": (["temperature"], ["t"]),
"u": (["u_component_of_wind"], ["u"]),
"v": (["v_component_of_wind"], ["v"]),
"q": (["specific_humidity"], ["q", "hus", "q_specific"]),
"z": (["geopotential"], ["z", "gh"]),
}
def parse_years(s: str) -> List[int]:
s = s.strip()
if "-" in s:
a, b = s.split("-")
return list(range(int(a), int(b) + 1))
return [int(x) for x in s.split(",") if x.strip()]
# ---------- ๋ฉ€ํ‹ฐ๋ฐ์ด GT ๋กœ๋” ----------
def _open_one_step_from_base(base_dir: Path, day: str, hour: int):
"""
base_dir/era5_240_{YYYY}_{hour}h.nc ๋ฅผ ์ง์ ‘ ์—ด์–ด์„œ
ํ•ด๋‹น day T hour ์˜ ๋‹จ์ผ ์‹œ๊ฐ slice๋ฅผ ๋ฐ˜ํ™˜.
(Dataset์˜ self.paths ์—ฐ๋„ ์ œํ•œ์„ ์šฐํšŒ)
"""
y = int(day[:4])
nc = base_dir / f"era5_240_{y}_{hour}h.nc"
if not nc.exists():
raise FileNotFoundError(f"GT file not found for {day} {hour:02d}h: {nc}")
ds = xr.open_dataset(str(nc), engine="netcdf4")
tname = find_time_name(ds)
tgt = np.datetime64(f"{day}T{hour:02d}:00:00").astype("datetime64[s]")
times = ds[tname].values.astype("datetime64[s]")
idx = np.where(times == tgt)[0]
if idx.size == 0:
j = int(np.argmin(np.abs(times - tgt)))
ds = ds.isel({tname: j})
else:
ds = ds.isel({tname: int(idx[0])})
ds = ds.expand_dims({tname: 1})
return ds
def _extract_truth_for_time(ds, surf_vars, atmos_vars, target_H: int, target_W: int):
"""
ds: xr.Dataset (๋‹จ์ผ ์‹œ๊ฐ)
๋ฐ˜ํ™˜:
gt_surf_step: {var: torch.Tensor(H,W)}
gt_atm_step: {var: torch.Tensor(L,H,W)}
"""
lat, lon = find_coord_names(ds)
ds = ensure_order(ds, lat, lon)
tname = find_time_name(ds)
gt_surf_step = {}
for v in surf_vars:
names, shorts = SURF_NAMES_SHORTS.get(v, ([], []))
da = transpose_lat_lon_last(get_var(ds, names=names, shorts=shorts), lat, lon)
arr = da.transpose(tname, lat, lon).values # (1,H,W)
gt_surf_step[v] = torch.from_numpy(arr[0, :target_H, :target_W].astype(np.float32))
gt_atm_step = {}
for v in atmos_vars:
names, shorts = ATMOS_NAMES_SHORTS.get(v, ([], []))
da = transpose_lat_lon_last(get_var(ds, names=names, shorts=shorts), lat, lon)
lvl = find_level_name(ds)
arr = da.transpose(tname, lvl, lat, lon).values # (1,L,H,W)
gt_atm_step[v] = torch.from_numpy(arr[0, :, :target_H, :target_W].astype(np.float32))
return gt_surf_step, gt_atm_step
def _compute_day_hour_from_valid_time(valid_time: np.datetime64) -> Tuple[str, int]:
ts = pd.Timestamp(str(valid_time))
return ts.strftime("%Y-%m-%d"), int(ts.hour)
def main():
ap = argparse.ArgumentParser(description="Aurora TEST evaluation (multi-day GT)")
# ๋ฐ์ดํ„ฐ/๋กค์•„์›ƒ
ap.add_argument("--base_dir", type=Path, required=True)
ap.add_argument("--test_years", type=str, default="2023")
ap.add_argument("--max_test_days", type=int, default=0)
ap.add_argument("--input_len", type=int, default=2)
ap.add_argument("--steps", type=int, default=28, help="ํ‰๊ฐ€ ๋ฆฌ๋“œํƒ€์ž„ ๊ฐœ์ˆ˜")
ap.add_argument("--lead_hour_stride", type=int, default=6, help="๋ฆฌ๋“œํƒ€์ž„ ์‹œ๊ฐ„ ๊ฐ„๊ฒฉ")
ap.add_argument("--lead_offset_mode", type=str, default="first", choices=["first","rolling","list"],
help="๋กœ๋”๋Š” ์ž…๋ ฅ๋งŒ ํ•„์š”ํ•˜๋ฏ€๋กœ first ๊ถŒ์žฅ")
ap.add_argument("--lead_offset_list", type=str, default="")
ap.add_argument("--cross_day_extra", type=int, default=2)
# ์ •๋ฐ€๋„/๋ฐฐ์น˜
ap.add_argument("--precision", type=str, default="bf16", choices=["bf16","fp16","fp32"])
ap.add_argument("--batch_size", type=int, default=1)
ap.add_argument("--num_workers", type=int, default=2)
# ์ถœ๋ ฅ/๋กœ๊น…
ap.add_argument("--out_dir", type=Path, default=Path("./out_test"))
ap.add_argument("--tb_dir", type=Path, default=Path("./tb_test"))
# ๋ชจ๋ธ/ํ•ด์ƒ๋„
ap.add_argument("--deg_old", type=float, default=0.25)
ap.add_argument("--deg_new", type=float, default=1.5)
ap.add_argument("--patch_new", type=int, default=1)
ap.add_argument("--win_orig", type=str, default="2,6,12")
ap.add_argument("--win_min_hw", type=int, default=2)
ap.add_argument("--finetune_mode", type=str, default="lora", choices=["lora","hybrid","full"])
ap.add_argument("--lora_mode", type=str, default="single", choices=["single","all","from_second"])
ap.add_argument("--act_ckpt", action="store_true")
ap.add_argument("--disable_droppath", action="store_true")
ap.add_argument("--local_rank", type=int, default=0)
# ckpt
ap.add_argument("--ckpt_source", type=str, default="hf", choices=["hf", "local"])
ap.add_argument("--ckpt_repo", type=str, default="microsoft/aurora")
ap.add_argument("--ckpt_name", type=str, default="aurora-0.25-small-pretrained.ckpt")
ap.add_argument("--ckpt_revision", type=str, default=None)
ap.add_argument("--ckpt_path", type=Path, default=None)
ap.add_argument("--patch_old", type=int, default=-1)
ap.add_argument("--force_no_adapt", action="store_true")
# ํด๋ฆฌ๋งˆ
ap.add_argument("--clim_path", type=Path, default=None)
args = ap.parse_args()
test_years = parse_years(args.test_years)
# DDP
ddp = ("RANK" in os.environ and "WORLD_SIZE" in os.environ)
if ddp:
local_rank = init_ddp(args.local_rank)
else:
local_rank = int(os.environ.get("LOCAL_RANK", args.local_rank))
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
torch.backends.cudnn.benchmark = True
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
# ๋ชจ๋ธ
model, patch = setup_model(args)
model = model.to(device)
model.eval()
# ๋กœ๊น…
args.out_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(str(args.tb_dir)) if get_rank() == 0 else None
# โš ๏ธ ๋กœ๋”๋Š” ์ž…๋ ฅ๋งŒ ํ•„์š” โ†’ steps_for_loader๋ฅผ ์ž‘๊ฒŒ ์žก์•„ ์ƒ˜ํ”Œ์ด ์ƒ์„ฑ๋˜๋„๋ก ํ•จ
steps_for_loader = 1 # ํ™•์‹คํžˆ ์ƒ์„ฑ๋˜๋„๋ก 1 ์ถ”์ฒœ
_, test_loader, test_sampler = build_loader(
args.base_dir, test_years, steps_for_loader, patch, args.max_test_days,
args.batch_size, args.num_workers, shuffle=False,
input_len=args.input_len, lead_offset_mode=args.lead_offset_mode,
lead_offset_list=args.lead_offset_list, cross_day_extra=args.cross_day_extra
)
if ddp and test_sampler is not None:
test_sampler.set_epoch(0)
# ํด๋ฆฌ๋งˆ
clim = Climatology(args.clim_path)
# ๋ณ€์ˆ˜ ๋ชฉ๋ก
surf_vars = tuple(getattr(model, "surf_vars", DEFAULT_SURF_VARS))
atmos_vars = tuple(getattr(model, "atmos_vars", DEFAULT_ATMOS_VARS))
# ์ง‘๊ณ„ ๋ฒ„ํ‚ท (๋ณ€์ˆ˜ ๋‹จ์œ„)
metrics = {
"surf": defaultdict(lambda: defaultdict(lambda: {"rmse_sum": 0.0, "rmse_n": 0,
"acc_sum": 0.0, "acc_n": 0})),
"atmos": defaultdict(lambda: defaultdict(lambda: {"rmse_sum": 0.0, "rmse_n": 0,
"acc_sum": 0.0, "acc_n": 0}))
}
# ๋ ˆ๋ฒจ๋ณ„ ์ง‘๊ณ„
metrics_levels = defaultdict( # var
lambda: defaultdict( # level_value
lambda: defaultdict( # lead
lambda: {"rmse_sum": 0.0, "rmse_n": 0, "acc_sum": 0.0, "acc_n": 0}
)
)
)
rows_overall, rows_levels = [], []
# AMP
use_amp = args.precision in ("bf16", "fp16")
amp_dtype = (torch.bfloat16 if args.precision == "bf16"
else torch.float16 if args.precision == "fp16" else None)
if get_rank() == 0 and len(test_loader) == 0:
print("[WARN] ํ…Œ์ŠคํŠธ ๋กœ๋”๊ฐ€ ๋น„์—ˆ์Šต๋‹ˆ๋‹ค. (์ž…๋ ฅ/์Šคํ… ์กฐํ•ฉ์„ ํ™•์ธํ•˜์„ธ์š”) "
"์ด ์Šคํฌ๋ฆฝํŠธ๋Š” GT๋ฅผ ๋ฉ€ํ‹ฐ๋ฐ์ด๋กœ ์ง์ ‘ ๋กœ๋“œํ•˜๋‹ˆ, steps_for_loader=1๋กœ ์œ ์ง€ํ•ด๋„ ๋ฉ๋‹ˆ๋‹ค.")
pbar = tqdm(total=len(test_loader), disable=(get_rank() != 0), dynamic_ncols=True, desc="test")
with torch.no_grad():
for batch_list in test_loader:
sample = batch_list[0]
batch = sample.batch.to(device)
# ๋ชจ๋ธ rollout์€ '์ง„์งœ ํ‰๊ฐ€ ๋ฆฌ๋“œํƒ€์ž„'์œผ๋กœ ์ˆ˜ํ–‰
EVAL_STEPS = int(args.steps)
if use_amp:
with torch.amp.autocast('cuda', dtype=amp_dtype):
preds = list(rollout(model, batch, steps=EVAL_STEPS))
else:
preds = list(rollout(model, batch, steps=EVAL_STEPS))
# ์ž…๋ ฅ์˜ ๋งˆ์ง€๋ง‰ ์‹œ๊ฐ
last_in_time = np.datetime64(batch.metadata.time[0], 's')
# ํ‰๊ฐ€ ์‹œ ๊ฐ lead์˜ valid_time
valid_times = [
last_in_time + np.timedelta64((s + 1) * args.lead_hour_stride, "h")
for s in range(EVAL_STEPS)
]
# ํ‰๊ฐ€ ๊ฒฉ์ž
H, W = sample.lat.shape[0], sample.lon.shape[0]
w_lat = lat_weights(sample.lat).to(device)
# ๋ ˆ๋ฒจ ๊ฐ’(๊ฐ€๋Šฅ ์‹œ hPa)
levels_tuple = tuple(getattr(batch.metadata, "atmos_levels", tuple(range(preds[0].atmos_vars[atmos_vars[0]][0,0].shape[0]))))
# ---- ๋ชจ๋“  ๋ฆฌ๋“œํƒ€์ž„์— ๋Œ€ํ•ด ์‹ค์ œ GT ๋กœ๋“œ ----
gt_surf_all: List[Dict[str, torch.Tensor]] = []
gt_atm_all: List[Dict[str, torch.Tensor]] = []
# ํ‰๊ฐ€์— ํ•„์š”ํ•œ ์—ฐ/์‹œ ์ถ”์ถœํ•จ์ˆ˜๋กœ NetCDF์—์„œ ํ•œ ์‹œ๊ฐ์”ฉ ๋กœ๋“œ
for s, vt in enumerate(valid_times):
day, hour = _compute_day_hour_from_valid_time(vt)
ds_one = _open_one_step_from_base(args.base_dir, day, hour) # xr.Dataset
gt_surf_s, gt_atm_s = _extract_truth_for_time(ds_one, surf_vars, atmos_vars, H, W)
gt_surf_all.append(gt_surf_s)
gt_atm_all.append(gt_atm_s)
# ---- SURF ๋ฉ”ํŠธ๋ฆญ ----
for s in range(EVAL_STEPS):
lead_hours = (s + 1) * args.lead_hour_stride
for v in surf_vars:
pred = preds[s].surf_vars[v][0, 0] # (H,W)
gt = gt_surf_all[s][v].to(pred.device) # (H,W)
rmse = rmse_lat_weighted(pred, gt, w_lat)
rmse_r = ar_mean(rmse)
rmse_val = float(rmse_r.item())
metrics["surf"][v][s]["rmse_sum"] += rmse_val
metrics["surf"][v][s]["rmse_n"] += 1
acc_val = float("nan")
if clim.ok:
names, shorts = SURF_NAMES_SHORTS.get(v, ([], []))
clim_da = clim.get(names, shorts, valid_times[s])
if clim_da is not None:
clim2 = torch.from_numpy(
clim_da.values[:H, :W].astype(np.float32)
).to(pred.device)
acc = acc_lat_weighted(pred, gt, clim2, w_lat)
if acc is not None and torch.isfinite(acc):
acc_r = ar_mean(acc)
acc_val = float(acc_r.item())
metrics["surf"][v][s]["acc_sum"] += acc_val
metrics["surf"][v][s]["acc_n"] += 1
if writer:
writer.add_scalar(f"test/surf/{v}/rmse/lead_{lead_hours}h", rmse_val)
if not math.isnan(acc_val):
writer.add_scalar(f"test/surf/{v}/acc/lead_{lead_hours}h", acc_val)
# ---- ATMOS ๋ฉ”ํŠธ๋ฆญ (๋ณ€์ˆ˜ ํ‰๊ท  + ๋ ˆ๋ฒจ๋ณ„) ----
for s in range(EVAL_STEPS):
lead_hours = (s + 1) * args.lead_hour_stride
for v in atmos_vars:
pred3 = preds[s].atmos_vars[v][0, 0] # (L,H,W)
gt3 = gt_atm_all[s][v].to(pred3.device) # (L,H,W)
# ๋ณ€์ˆ˜ ํ‰๊ท  RMSE
rmse_levels = [rmse_lat_weighted(pred3[li], gt3[li], w_lat)
for li in range(pred3.shape[0])]
if rmse_levels:
rmse_mean_r = ar_mean(torch.mean(torch.stack(rmse_levels)))
rmse_overall = float(rmse_mean_r.item())
metrics["atmos"][v][s]["rmse_sum"] += rmse_overall
metrics["atmos"][v][s]["rmse_n"] += 1
else:
rmse_overall = float("nan")
# ๋ณ€์ˆ˜ ํ‰๊ท  ACC
acc_overall = float("nan")
if clim.ok:
names, shorts = ATMOS_NAMES_SHORTS.get(v, ([], []))
clim_da = clim.get(names, shorts, valid_times[s])
if clim_da is not None and clim_da.values.ndim == 3:
clim_vals = clim_da.values
L = min(pred3.shape[0], clim_vals.shape[0])
Hc = min(pred3.shape[1], clim_vals.shape[1])
Wc = min(pred3.shape[2], clim_vals.shape[2])
acc_levels = []
for li in range(L):
c2 = torch.from_numpy(clim_vals[li, :Hc, :Wc].astype(np.float32)).to(pred3.device)
p2 = pred3[li, :Hc, :Wc]
g2 = gt3[li, :Hc, :Wc]
acc_li = acc_lat_weighted(p2, g2, c2, w_lat[:Hc])
if acc_li is not None and torch.isfinite(acc_li):
acc_levels.append(acc_li)
if acc_levels:
acc_overall = float(ar_mean(torch.mean(torch.stack(acc_levels))).item())
metrics["atmos"][v][s]["acc_sum"] += acc_overall
metrics["atmos"][v][s]["acc_n"] += 1
if writer and not math.isnan(rmse_overall):
writer.add_scalar(f"test/atmos/{v}/rmse/lead_{lead_hours}h", rmse_overall)
if writer and not math.isnan(acc_overall):
writer.add_scalar(f"test/atmos/{v}/acc/lead_{lead_hours}h", acc_overall)
# (๋ ˆ๋ฒจ๋ณ„) RMSE/ACC
Lp = pred3.shape[0]
for li in range(Lp):
level_val = int(levels_tuple[li]) if li < len(levels_tuple) else li
# RMSE
rmse_li = float(ar_mean(rmse_lat_weighted(pred3[li], gt3[li], w_lat)).item())
agg = metrics_levels[v][level_val][s]
agg["rmse_sum"] += rmse_li
agg["rmse_n"] += 1
if writer:
writer.add_scalar(f"test/atmos/{v}/level_{level_val}/rmse/lead_{lead_hours}h", rmse_li)
# ACC per-level
if clim.ok:
names, shorts = ATMOS_NAMES_SHORTS.get(v, ([], []))
clim_da = clim.get(names, shorts, valid_times[s])
if clim_da is not None and clim_da.values.ndim == 3:
clim_vals = clim_da.values
Lc = min(pred3.shape[0], clim_vals.shape[0])
Hc = min(pred3.shape[1], clim_vals.shape[1])
Wc = min(pred3.shape[2], clim_vals.shape[2])
for li in range(Lc):
level_val = int(levels_tuple[li]) if li < len(levels_tuple) else li
c2 = torch.from_numpy(clim_vals[li, :Hc, :Wc].astype(np.float32)).to(pred3.device)
p2 = pred3[li, :Hc, :Wc]
g2 = gt3[li, :Hc, :Wc]
acc_li = acc_lat_weighted(p2, g2, c2, w_lat[:Hc])
if acc_li is not None and torch.isfinite(acc_li):
acc_li_val = float(ar_mean(acc_li).item())
agg = metrics_levels[v][level_val][s]
agg["acc_sum"] += acc_li_val
agg["acc_n"] += 1
if writer:
writer.add_scalar(f"test/atmos/{v}/level_{level_val}/acc/lead_{lead_hours}h", acc_li_val)
if get_rank() == 0:
pbar.update(1)
if writer:
writer.flush()
# ---- ์ง‘๊ณ„/์ถœ๋ ฅ/CSV ----
if get_rank() == 0:
from pandas import DataFrame
print("\n=== TEST RESULTS (lat-weighted | variable-level averages) ===")
print("Domain | Var | Lead(h) | RMSE(avg over samples) | ACC(avg over samples)")
print("-------+------+---------+------------------------+----------------------")
rows_overall = []
for domain in ("surf","atmos"):
for v, per_lead in metrics[domain].items():
for s, agg in per_lead.items():
lead_hours = (s + 1) * args.lead_hour_stride
rmse_avg = (agg["rmse_sum"]/max(1,agg["rmse_n"])) if agg["rmse_n"]>0 else float("nan")
acc_avg = (agg["acc_sum"]/max(1,agg["acc_n"])) if agg["acc_n"]>0 else float("nan")
print(f"{domain:6s} | {v:4s} | {lead_hours:7d} | {rmse_avg:22.6f} | {acc_avg:20.6f}")
rows_overall.append({
"domain": domain, "var": v,
"lead_step": s+1, "lead_hours": lead_hours,
"rmse": rmse_avg, "acc": acc_avg,
"rmse_count": agg["rmse_n"], "acc_count": agg["acc_n"],
})
df_overall = DataFrame(rows_overall)
if not df_overall.empty:
df_overall = df_overall.sort_values(["domain","var","lead_step"])
csv_path_overall = args.out_dir / "glora_test_metrics.csv"
df_overall.to_csv(csv_path_overall, index=False)
print("\n=== TEST RESULTS (lat-weighted | ATMOS PER-LEVEL) ===")
print("Var | Level | Lead(h) | RMSE(avg over samples) | ACC(avg over samples)")
print("-----+-------+---------+------------------------+----------------------")
rows_levels = []
for v, per_level in metrics_levels.items():
for level_val, per_lead in sorted(per_level.items(), key=lambda kv: kv[0]):
for s, agg in per_lead.items():
lead_hours = (s + 1) * args.lead_hour_stride
rmse_avg = (agg["rmse_sum"]/max(1,agg["rmse_n"])) if agg["rmse_n"]>0 else float("nan")
acc_avg = (agg["acc_sum"]/max(1,agg["acc_n"])) if agg["acc_n"]>0 else float("nan")
print(f"{v:4s} | {level_val:5d} | {lead_hours:7d} | {rmse_avg:22.6f} | {acc_avg:20.6f}")
rows_levels.append({
"domain":"atmos", "var":v, "level":int(level_val),
"lead_step": s+1, "lead_hours": lead_hours,
"rmse": rmse_avg, "acc": acc_avg,
"rmse_count": agg["rmse_n"], "acc_count": agg["acc_n"],
})
df_levels = pd.DataFrame(rows_levels)
if not df_levels.empty:
df_levels = df_levels.sort_values(["var","level","lead_step"])
csv_path_levels = args.out_dir / "glora_test_metrics_levels.csv"
df_levels.to_csv(csv_path_levels, index=False)
print(f"\nSaved CSV (overall): {csv_path_overall}")
print(f"Saved CSV (per-level): {csv_path_levels}")
if writer:
writer.close()
if ddp_is_init():
torch.distributed.barrier()
torch.distributed.destroy_process_group()
if __name__ == "__main__":
main()