| import unittest |
|
|
| import torch |
|
|
| from src.data.utils import ( |
| S2_BANDS, |
| SPACE_TIME_BANDS, |
| SPACE_TIME_BANDS_GROUPS_IDX, |
| construct_galileo_input, |
| ) |
|
|
|
|
| class TestDataUtils(unittest.TestCase): |
| def test_construct_galileo_input_s2(self): |
| t, h, w = 2, 4, 4 |
| s2 = torch.randn((t, h, w, len(S2_BANDS))) |
| for normalize in [True, False]: |
| masked_output = construct_galileo_input(s2=s2, normalize=normalize) |
|
|
| self.assertTrue((masked_output.space_mask == 1).all()) |
| self.assertTrue((masked_output.time_mask == 1).all()) |
| self.assertTrue((masked_output.static_mask == 1).all()) |
|
|
| |
| not_s2 = [ |
| idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" not in key |
| ] |
| self.assertTrue((masked_output.space_time_mask[:, :, :, not_s2] == 1).all()) |
| |
| s2_mask_indices = [ |
| idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S2" in key |
| ] |
| self.assertTrue((masked_output.space_time_mask[:, :, :, s2_mask_indices] == 0).all()) |
|
|
| |
| if not normalize: |
| s2_indices = [idx for idx, val in enumerate(SPACE_TIME_BANDS) if val in S2_BANDS] |
| self.assertTrue(torch.equal(masked_output.space_time_x[:, :, :, s2_indices], s2)) |
|
|