| """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" |
|
|
| import os |
| import pickle |
| from datetime import datetime |
| from typing import Generator, TypedDict |
|
|
| import numpy as np |
| import pytest |
| import torch |
| from huggingface_hub import hf_hub_download |
|
|
| from aurora import Batch, Metadata |
| from aurora.batch import interpolate_numpy |
|
|
|
|
| class SavedMetadata(TypedDict): |
| """Type of metadata of a saved test batch.""" |
|
|
| lat: np.ndarray |
| lon: np.ndarray |
| time: list[datetime] |
| atmos_levels: list[int | float] |
|
|
|
|
| class SavedBatch(TypedDict): |
| """Type of a saved test batch.""" |
|
|
| surf_vars: dict[str, np.ndarray] |
| static_vars: dict[str, np.ndarray] |
| atmos_vars: dict[str, np.ndarray] |
| metadata: SavedMetadata |
|
|
|
|
| @pytest.fixture() |
| def test_input_output() -> Generator[tuple[Batch, SavedBatch], None, None]: |
| |
| path = hf_hub_download( |
| repo_id="microsoft/aurora", |
| filename="aurora-0.25-small-pretrained-test-input.pickle", |
| ) |
| with open(path, "rb") as f: |
| test_input: SavedBatch = pickle.load(f) |
|
|
| |
| path = hf_hub_download( |
| repo_id="microsoft/aurora", |
| filename="aurora-0.25-static.pickle", |
| ) |
| with open(path, "rb") as f: |
| static_vars: dict[str, np.ndarray] = pickle.load(f) |
|
|
| |
| path = hf_hub_download( |
| repo_id="microsoft/aurora", |
| filename="aurora-0.25-small-pretrained-test-output.pickle", |
| ) |
| with open(path, "rb") as f: |
| test_output: SavedBatch = pickle.load(f) |
|
|
| |
| |
| if os.name == "nt": |
|
|
| class PatchedDateTime(datetime): |
| def timestamp(self) -> float: |
| |
| return -631134000.0 |
|
|
| test_input["metadata"]["time"] = [PatchedDateTime(1950, 1, 1, 6, 0)] |
|
|
| static_vars = { |
| k: interpolate_numpy( |
| v, |
| np.linspace(90, -90, v.shape[0]), |
| np.linspace(0, 360, v.shape[1], endpoint=False), |
| test_input["metadata"]["lat"], |
| test_input["metadata"]["lon"], |
| ) |
| for k, v in static_vars.items() |
| } |
|
|
| |
| batch = Batch( |
| surf_vars={k: torch.from_numpy(v) for k, v in test_input["surf_vars"].items()}, |
| static_vars={k: torch.from_numpy(v) for k, v in static_vars.items()}, |
| atmos_vars={k: torch.from_numpy(v) for k, v in test_input["atmos_vars"].items()}, |
| metadata=Metadata( |
| lat=torch.from_numpy(test_input["metadata"]["lat"]), |
| lon=torch.from_numpy(test_input["metadata"]["lon"]), |
| atmos_levels=tuple(test_input["metadata"]["atmos_levels"]), |
| time=tuple(test_input["metadata"]["time"]), |
| ), |
| ) |
|
|
| yield batch, test_output |
|
|