| """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" |
|
|
| import dataclasses |
| import os |
| from datetime import timedelta |
| from typing import cast |
|
|
| import numpy as np |
| import pytest |
| import torch |
| import torch.distributed as dist |
|
|
| from tests.conftest import SavedBatch |
|
|
| from aurora import Aurora, AuroraSmallPretrained, Batch |
|
|
|
|
| @pytest.fixture(scope="session") |
| def aurora_small() -> Aurora: |
| model = AuroraSmallPretrained(use_lora=True) |
| model.load_checkpoint(strict=False) |
| model = model.double() |
| model.eval() |
| return model |
|
|
|
|
| def test_aurora_small(aurora_small: Aurora, test_input_output: tuple[Batch, SavedBatch]) -> None: |
| batch, test_output = test_input_output |
|
|
| |
| batch = dataclasses.replace( |
| batch, |
| surf_vars={k: v.repeat(2, 1, 1, 1) for k, v in batch.surf_vars.items()}, |
| atmos_vars={k: v.repeat(2, 1, 1, 1, 1) for k, v in batch.atmos_vars.items()}, |
| ) |
| test_output = cast(SavedBatch, dict(test_output)) |
| test_output["surf_vars"] = {k: v.repeat(2, axis=0) for k, v in test_output["surf_vars"].items()} |
| test_output["atmos_vars"] = { |
| k: v.repeat(2, axis=0) for k, v in test_output["atmos_vars"].items() |
| } |
|
|
| with torch.inference_mode(): |
| pred = aurora_small.forward(batch) |
|
|
| def assert_approx_equality(v_out: np.ndarray, v_ref: np.ndarray, tol: float) -> None: |
| err = np.abs(v_out - v_ref).mean() |
| mag = np.abs(v_ref).mean() |
| assert err / mag <= tol |
|
|
| |
| |
| tolerances = { |
| "2t": 1e-4, |
| "10u": 5e-3, |
| "10v": 5e-3, |
| "msl": 1e-4, |
| "u": 5e-3, |
| "v": 5e-3, |
| "t": 1e-4, |
| "q": 5e-3, |
| } |
|
|
| |
| for k in pred.surf_vars: |
| assert_approx_equality( |
| pred.surf_vars[k].numpy(), |
| test_output["surf_vars"][k], |
| tolerances[k], |
| ) |
| for k in pred.static_vars: |
| assert_approx_equality( |
| pred.static_vars[k].numpy(), |
| batch.static_vars[k].numpy(), |
| 1e-10, |
| ) |
| for k in pred.atmos_vars: |
| assert_approx_equality( |
| pred.atmos_vars[k].numpy(), |
| test_output["atmos_vars"][k], |
| tolerances[k], |
| ) |
|
|
| np.testing.assert_allclose(pred.metadata.lon, test_output["metadata"]["lon"]) |
| np.testing.assert_allclose(pred.metadata.lat, test_output["metadata"]["lat"]) |
| assert pred.metadata.atmos_levels == tuple(test_output["metadata"]["atmos_levels"]) |
| assert pred.metadata.time == tuple(test_output["metadata"]["time"]) |
|
|
|
|
| @pytest.mark.skipif( |
| os.name == "nt", |
| reason=( |
| "DDP test suddenly started failing with error " |
| "`RuntimeError: makeDeviceForHostname(): unsupported gloo device` on Windows." |
| ), |
| ) |
| def test_aurora_small_ddp( |
| aurora_small: Aurora, test_input_output: tuple[Batch, SavedBatch] |
| ) -> None: |
| batch, test_output = test_input_output |
|
|
| if not dist.is_initialized(): |
| os.environ["MASTER_ADDR"] = "localhost" |
| os.environ["MASTER_PORT"] = "12355" |
| dist.init_process_group("gloo", rank=0, world_size=1) |
|
|
| aurora_small = torch.nn.parallel.DistributedDataParallel(aurora_small) |
|
|
| |
| with torch.inference_mode(): |
| aurora_small.forward(batch) |
|
|
|
|
| def test_aurora_small_decoder_init() -> None: |
| aurora_small = AuroraSmallPretrained(use_lora=True) |
|
|
| |
| |
| for layer in [ |
| *aurora_small.decoder.surf_heads.values(), |
| *aurora_small.decoder.atmos_heads.values(), |
| ]: |
| assert not torch.all(layer.weight == 0) |
| assert torch.all(layer.bias == 0) |
|
|
|
|
| def test_aurora_small_lat_lon_matrices( |
| aurora_small: Aurora, test_input_output: tuple[Batch, SavedBatch] |
| ) -> None: |
| batch, test_output = test_input_output |
|
|
| with torch.inference_mode(): |
| pred = aurora_small.forward(batch) |
|
|
| |
| n_lat = len(batch.metadata.lat) |
| n_lon = len(batch.metadata.lon) |
| batch.metadata.lat = batch.metadata.lat[:, None].expand(n_lat, n_lon) |
| batch.metadata.lon = batch.metadata.lon[None, :].expand(n_lat, n_lon) |
|
|
| pred_matrix = aurora_small.forward(batch) |
|
|
| |
| for k in pred.surf_vars: |
| np.testing.assert_allclose( |
| pred.surf_vars[k], |
| pred_matrix.surf_vars[k], |
| rtol=1e-5, |
| ) |
| for k in pred.static_vars: |
| np.testing.assert_allclose( |
| pred.static_vars[k], |
| pred_matrix.static_vars[k], |
| rtol=1e-5, |
| ) |
| for k in pred.atmos_vars: |
| np.testing.assert_allclose( |
| pred.atmos_vars[k], |
| pred_matrix.atmos_vars[k], |
| rtol=1e-5, |
| ) |
|
|
|
|
| def test_aurora_small_flags(test_input_output: tuple[Batch, SavedBatch]) -> None: |
| batch, test_output = test_input_output |
|
|
| flag_collections: list[dict] = [ |
| {}, |
| {"stabilise_level_agg": True}, |
| {"timestep": timedelta(hours=12)}, |
| ] |
|
|
| preds = [] |
| for flags in flag_collections: |
| model = AuroraSmallPretrained(use_lora=True, **flags) |
| model.load_checkpoint( |
| "microsoft/aurora", |
| "aurora-0.25-small-pretrained.ckpt", |
| strict=False, |
| ) |
| model = model.double() |
| model.eval() |
| with torch.inference_mode(): |
| preds.append(model.forward(batch).normalise(model.surf_stats)) |
|
|
| |
| for i, pred1 in enumerate(preds): |
| for pred2 in preds[i + 1 :]: |
| for k in pred1.surf_vars: |
| assert not np.allclose( |
| pred1.surf_vars[k], |
| pred2.surf_vars[k], |
| rtol=5e-2, |
| ) |
| for k in pred1.static_vars: |
| np.testing.assert_allclose( |
| pred1.static_vars[k], |
| pred2.static_vars[k], |
| rtol=1e-5, |
| ) |
| for k in pred1.atmos_vars: |
| assert not np.allclose( |
| pred1.atmos_vars[k], |
| pred2.atmos_vars[k], |
| rtol=5e-2, |
| ) |
|
|