| import tempfile |
| import unittest |
| from pathlib import Path |
|
|
| import h5py |
| import numpy as np |
| import torch |
|
|
| from src.data.dataset import ( |
| SPACE_BANDS, |
| SPACE_TIME_BANDS, |
| STATIC_BANDS, |
| TIME_BANDS, |
| Dataset, |
| Normalizer, |
| to_cartesian, |
| ) |
|
|
| BROKEN_FILE = "min_lat=24.7979_min_lon=-105.1508_max_lat=24.8069_max_lon=-105.141_dates=2022-01-01_2023-12-31.tif" |
| TEST_FILENAMES = [ |
| "min_lat=5.4427_min_lon=101.4016_max_lat=5.4518_max_lon=101.4107_dates=2022-01-01_2023-12-31.tif", |
| "min_lat=-27.6721_min_lon=25.6796_max_lat=-27.663_max_lon=25.6897_dates=2022-01-01_2023-12-31.tif", |
| ] |
| TIFS_FOLDER = Path(__file__).parents[1] / "data/tifs" |
| TEST_FILES = [TIFS_FOLDER / x for x in TEST_FILENAMES] |
|
|
|
|
| class TestDataset(unittest.TestCase): |
| def test_tif_to_array(self): |
| ds = Dataset(TIFS_FOLDER, download=False) |
| for test_file in TEST_FILES: |
| s_t_x, sp_x, t_x, st_x, months = ds._tif_to_array(test_file) |
| self.assertFalse(np.isnan(s_t_x).any()) |
| self.assertFalse(np.isnan(sp_x).any()) |
| self.assertFalse(np.isnan(t_x).any()) |
| self.assertFalse(np.isnan(st_x).any()) |
| self.assertFalse(np.isinf(s_t_x).any()) |
| self.assertFalse(np.isinf(sp_x).any()) |
| self.assertFalse(np.isinf(t_x).any()) |
| self.assertFalse(np.isinf(st_x).any()) |
| self.assertEqual(sp_x.shape[0], s_t_x.shape[0]) |
| self.assertEqual(sp_x.shape[1], s_t_x.shape[1]) |
| self.assertEqual(t_x.shape[0], s_t_x.shape[2]) |
| self.assertEqual(len(SPACE_TIME_BANDS), s_t_x.shape[-1]) |
| self.assertEqual(len(SPACE_BANDS), sp_x.shape[-1]) |
| self.assertEqual(len(TIME_BANDS), t_x.shape[-1]) |
| self.assertEqual(len(STATIC_BANDS), st_x.shape[-1]) |
| self.assertEqual(months[0], 0) |
|
|
| def test_files_are_replaced(self): |
| ds = Dataset(TIFS_FOLDER, download=False) |
| assert TIFS_FOLDER / BROKEN_FILE in ds.tifs |
|
|
| for b in ds: |
| assert len(b) == 5 |
| assert TIFS_FOLDER / BROKEN_FILE not in ds.tifs |
|
|
| def test_normalization(self): |
| ds = Dataset(TIFS_FOLDER, download=False) |
| o = ds.load_normalization_values(path=Path("config/normalization.json")) |
| for t in [len(SPACE_TIME_BANDS), len(SPACE_BANDS), len(STATIC_BANDS), len(TIME_BANDS)]: |
| subdict = o[t] |
| self.assertTrue("mean" in subdict) |
| self.assertTrue("std" in subdict) |
| self.assertTrue(len(subdict["mean"]) == len(subdict["std"])) |
| normalizer = Normalizer(normalizing_dicts=o) |
| ds.normalizer = normalizer |
| for b in ds: |
| for t in b: |
| self.assertFalse(np.isnan(t).any()) |
|
|
| def test_subset_image_with_minimum_size(self): |
| input = np.ones((3, 3, 1)) |
| months = static = np.ones(1) |
| output = Dataset.subset_image(input, input, months, static, months, 3, 1) |
| self.assertTrue(np.equal(input, output[0]).all()) |
| self.assertTrue(np.equal(input, output[1]).all()) |
| self.assertTrue(np.equal(months, output[2]).all()) |
|
|
| def test_subset_with_too_small_image(self): |
| input = np.ones((2, 2, 1)) |
| months = static = np.ones(1) |
| self.assertRaises( |
| AssertionError, Dataset.subset_image, input, input, months, static, months, 3, 1 |
| ) |
|
|
| def test_subset_with_larger_images(self): |
| input = np.ones((5, 5, 1)) |
| months = static = np.ones(1) |
| output = Dataset.subset_image(input, input, months, static, months, 3, 1) |
| self.assertTrue(np.equal(np.ones((3, 3, 1)), output[0]).all()) |
| self.assertTrue(np.equal(np.ones((3, 3, 1)), output[1]).all()) |
| self.assertTrue(np.equal(months, output[2]).all()) |
|
|
| def test_latlon_checks_float(self): |
| |
| _ = to_cartesian( |
| 30.0, |
| 40.0, |
| ) |
| with self.assertRaises(AssertionError): |
| to_cartesian(1000.0, 1000.0) |
|
|
| def test_latlon_checks_np(self): |
| |
| _ = to_cartesian(np.array([30.0]), np.array([40.0])) |
| with self.assertRaises(AssertionError): |
| to_cartesian(np.array([1000.0]), np.array([1000.0])) |
|
|
| def test_latlon_checks_tensor(self): |
| |
| _ = to_cartesian(torch.tensor([30.0]), torch.tensor([40.0])) |
| with self.assertRaises(AssertionError): |
| to_cartesian(torch.tensor([1000.0]), torch.tensor([1000.0])) |
|
|
| def test_process_h5pys(self): |
| with tempfile.TemporaryDirectory() as tempdir_str: |
| tempdir = Path(tempdir_str) |
| dataset = Dataset( |
| TIFS_FOLDER, |
| download=False, |
| h5py_folder=tempdir, |
| h5pys_only=False, |
| ) |
| dataset.process_h5pys() |
|
|
| h5py_files = list(tempdir.glob("*.h5")) |
| self.assertEqual(len(h5py_files), 2) |
| for h5_file in h5py_files: |
| with h5py.File(h5_file, "r") as f: |
| |
| self.assertEqual(f["t_x"].shape[0], 24) |
|
|