| import random |
| import unittest |
|
|
| import torch |
| from einops import repeat |
|
|
| from src.masking import ( |
| MASKING_MODES, |
| MAX_MASKING_STRATEGIES, |
| SPACE_BAND_GROUPS_IDX, |
| SPACE_TIME_BANDS_GROUPS_IDX, |
| STATIC_BAND_GROUPS_IDX, |
| TIME_BAND_GROUPS_IDX, |
| batch_mask_random, |
| batch_mask_space, |
| batch_mask_time, |
| check_modes_for_conflicts, |
| filter_unmasking_mode_candidates, |
| weighted_sample_without_replacement, |
| ) |
|
|
|
|
| class TestMasking(unittest.TestCase): |
| def check_all_values_in_masks( |
| self, space_time_mask, space_mask, time_mask, static_mask, masking_modes, unmasking_modes |
| ): |
| self.assertTrue( |
| (space_time_mask == 2).any() |
| | (space_mask == 2).any() |
| | (time_mask == 2).any() |
| | (static_mask == 2).any(), |
| f"2 check failed for {masking_modes}, {unmasking_modes}", |
| ) |
| self.assertTrue( |
| (space_time_mask == 0).any() |
| | (space_mask == 0).any() |
| | (time_mask == 0).any() |
| | (static_mask == 0).any(), |
| f"0 check failed for {masking_modes}, {unmasking_modes}", |
| ) |
| self.assertTrue( |
| (space_time_mask == 1).any() |
| | (space_mask == 1).any() |
| | (time_mask == 1).any() |
| | (static_mask == 1).any(), |
| f"1 check failed for {masking_modes}, {unmasking_modes}", |
| ) |
|
|
| def test_mask_by_time(self): |
| |
| self._test_mask_by_for_f( |
| batch_mask_time, |
| [ |
| ("static", "LS"), |
| ("static", "location"), |
| ("space", "WC"), |
| ("space_time", "S2_SWIR"), |
| ("space", "DW"), |
| ("space_time", "S2_NIR_20m"), |
| ], |
| [("time", "TC"), ("time", "VIIRS")], |
| ) |
|
|
| for _ in range(100): |
| num_masking_modes = random.choice(list(range(2, MAX_MASKING_STRATEGIES + 1))) |
| num_unmasking_modes = 1 |
|
|
| masking_modes = weighted_sample_without_replacement( |
| MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_masking_modes |
| ) |
| unmasking_modes = weighted_sample_without_replacement( |
| MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_unmasking_modes |
| ) |
| self.assertTrue( |
| len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" |
| ) |
| masking_modes, unmasking_modes = check_modes_for_conflicts( |
| masking_modes, unmasking_modes |
| ) |
| self.assertTrue( |
| len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" |
| ) |
| self.assertTrue(len(masking_modes) >= 1, f"Got {len(masking_modes)}") |
| for m_m in masking_modes: |
| self.assertTrue(m_m not in unmasking_modes, f"{m_m} in {unmasking_modes}") |
| for u_m in unmasking_modes: |
| self.assertTrue(u_m not in masking_modes, f"{u_m} in {masking_modes}") |
| self.assertTrue(len(masking_modes) >= 1) |
| self.assertTrue(len(unmasking_modes) >= 1) |
| self._test_mask_by_for_f(batch_mask_space, masking_modes, unmasking_modes) |
|
|
| def test_mask_by_space(self): |
| for _ in range(100): |
| num_masking_modes = random.choice(list(range(2, MAX_MASKING_STRATEGIES + 1))) |
| num_unmasking_modes = 1 |
|
|
| masking_modes = weighted_sample_without_replacement( |
| MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_masking_modes |
| ) |
| unmasking_modes = weighted_sample_without_replacement( |
| MASKING_MODES, weights=[1] * len(MASKING_MODES), k=num_unmasking_modes |
| ) |
| self.assertTrue( |
| len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" |
| ) |
| masking_modes, unmasking_modes = check_modes_for_conflicts( |
| masking_modes, unmasking_modes |
| ) |
| self.assertTrue( |
| len(unmasking_modes) == num_unmasking_modes, f"Got {len(unmasking_modes)}" |
| ) |
| self.assertTrue(len(masking_modes) >= 1, f"Got {len(masking_modes)}") |
| for m_m in masking_modes: |
| self.assertTrue(m_m not in unmasking_modes, f"{m_m} in {unmasking_modes}") |
| for u_m in unmasking_modes: |
| self.assertTrue(u_m not in masking_modes, f"{u_m} in {masking_modes}") |
| self.assertTrue(len(masking_modes) >= 1) |
| self.assertTrue(len(unmasking_modes) >= 1) |
| self._test_mask_by_for_f(batch_mask_space, masking_modes, unmasking_modes) |
|
|
| def _test_mask_by_for_f(self, f, masking_modes, unmasking_modes): |
| for t in range(4, 8): |
| b, h, w = 2, 16, 16 |
| space_time_input = torch.ones((b, h, w, t, 8)) |
| space_input = torch.ones((b, h, w, 8)) |
| time_input = torch.ones((b, t, 8)) |
| static_input = torch.ones((b, 8)) |
| months = repeat(torch.arange(0, t), "t -> b t", b=b) |
| ratio = 0.25 |
| output = f( |
| space_time_input, |
| space_input, |
| time_input, |
| static_input, |
| months, |
| encode_ratio=ratio, |
| decode_ratio=ratio, |
| mode=masking_modes, |
| decoder_mode=unmasking_modes, |
| patch_size=4, |
| ) |
| self.check_all_values_in_masks( |
| output.space_time_mask, |
| output.space_mask, |
| output.time_mask, |
| output.static_mask, |
| masking_modes, |
| unmasking_modes, |
| ) |
| self.assertEqual( |
| (b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), |
| output.space_time_mask.shape, |
| ) |
| self.assertEqual((b, h, w, len(SPACE_BAND_GROUPS_IDX)), output.space_mask.shape) |
| self.assertEqual((b, t, len(TIME_BAND_GROUPS_IDX)), output.time_mask.shape) |
| self.assertEqual((b, len(STATIC_BAND_GROUPS_IDX)), output.static_mask.shape) |
|
|
| def test_mask_by_random(self): |
| b, t, h, w, p = 2, 8, 16, 16, 4 |
| h_tokens, w_tokens = h / p, w / p |
| space_time_input = torch.ones((b, h, w, t, 8)) |
| space_input = torch.ones((b, h, w, 8)) |
| time_input = torch.ones((b, t, 8)) |
| static_input = torch.ones((b, 8)) |
| months = repeat(torch.arange(0, t), "t -> b t", b=b) |
| ratio = 0.25 |
|
|
| output = batch_mask_random( |
| space_time_input, |
| space_input, |
| time_input, |
| static_input, |
| months, |
| encode_ratio=ratio, |
| decode_ratio=ratio, |
| patch_size=p, |
| ignore_band_groups=["DW", "DW_static"], |
| ) |
| self.check_all_values_in_masks( |
| output.space_time_mask, |
| output.space_mask, |
| output.time_mask, |
| output.static_mask, |
| "random", |
| "random", |
| ) |
| self.assertEqual( |
| (b, h, w, t, len(SPACE_TIME_BANDS_GROUPS_IDX)), output.space_time_mask.shape |
| ) |
| self.assertEqual((b, h, w, len(SPACE_BAND_GROUPS_IDX)), output.space_mask.shape) |
| self.assertEqual((b, t, len(TIME_BAND_GROUPS_IDX)), output.time_mask.shape) |
| self.assertEqual((b, len(STATIC_BAND_GROUPS_IDX)), output.static_mask.shape) |
|
|
| for i in range(1, p): |
| self.assertTrue( |
| torch.equal( |
| output.space_time_mask[:, i::p, i::p], |
| output.space_time_mask[:, i - 1 :: p, i - 1 :: p], |
| ) |
| ) |
| self.assertTrue( |
| torch.equal( |
| output.space_mask[:, i::p, i::p], |
| output.space_mask[:, i - 1 :: p, i - 1 :: p], |
| ) |
| ) |
| space_time_per_token = output.space_time_mask[:, i::p, i::p] |
| space_time_encode_per_instance = ( |
| space_time_per_token[space_time_per_token == 0] + 1 |
| ).sum() |
| space_time_decode_per_instance = space_time_per_token[space_time_per_token == 2].sum() |
| space_per_token = output.space_mask[:, i::p, i::p] |
| space_encode_per_instance = (space_per_token[space_per_token == 0] + 1).sum() |
| space_decode_per_instance = space_per_token[space_per_token == 2].sum() |
| time_per_token = output.time_mask |
| time_encode_per_instance = (time_per_token[time_per_token == 0] + 1).sum() |
| time_decode_per_instance = time_per_token[time_per_token == 2].sum() |
| static_per_token = output.static_mask |
| static_encode_per_instance = (static_per_token[static_per_token == 0] + 1).sum() |
| static_decode_per_instance = static_per_token[static_per_token == 2].sum() |
| total_tokens = ( |
| (h_tokens * w_tokens * t * len(SPACE_TIME_BANDS_GROUPS_IDX)) |
| |
| + (h_tokens * w_tokens * (len(SPACE_BAND_GROUPS_IDX) - 1)) |
| + (t * len(TIME_BAND_GROUPS_IDX)) |
| |
| + (len(STATIC_BAND_GROUPS_IDX) - 1) |
| ) * b |
| |
| self.assertTrue( |
| ( |
| space_time_encode_per_instance |
| + space_encode_per_instance |
| + time_encode_per_instance |
| + static_encode_per_instance |
| <= int(total_tokens * ratio) + 1 |
| ).all() |
| and ( |
| space_time_encode_per_instance |
| + space_encode_per_instance |
| + time_encode_per_instance |
| + static_encode_per_instance |
| >= int(total_tokens * ratio) - 1 |
| ).all() |
| ) |
| self.assertTrue( |
| ( |
| |
| |
| ( |
| space_time_decode_per_instance |
| + space_decode_per_instance |
| + time_decode_per_instance |
| + static_decode_per_instance |
| ) |
| / 2 |
| <= (int(total_tokens * ratio) + 1) |
| ).all() |
| and ( |
| ( |
| space_time_decode_per_instance |
| + space_decode_per_instance |
| + time_decode_per_instance |
| + static_decode_per_instance |
| ) |
| / 2 |
| >= int(total_tokens * ratio) - 1 |
| ).all() |
| ) |
|
|
| |
| expected_mask_index = list(SPACE_BAND_GROUPS_IDX.keys()).index("DW") |
| self.assertTrue((output.space_mask[:, :, :, expected_mask_index] == 1).all()) |
|
|
| expected_mask_index = list(STATIC_BAND_GROUPS_IDX.keys()).index("DW_static") |
| self.assertTrue((output.static_mask[:, expected_mask_index] == 1).all()) |
|
|
| def test_filter_candidates(self): |
| candidates = [ |
| [("space", "SRTM"), ("space_time", "S2_RGB")], |
| [("space", "SRTM"), ("space", "DW")], |
| [("space", "DW")], |
| ] |
| ignore_bands = ["DW"] |
| outputs = filter_unmasking_mode_candidates(candidates, ignore_bands) |
| print(outputs) |
| self.assertEqual(outputs, [[("space", "SRTM"), ("space_time", "S2_RGB")]]) |
|
|