| import sys |
| import unittest |
| from pathlib import Path |
|
|
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) |
|
|
| from iconoclast.direction import ( |
| compute_benign_subspace_basis, |
| project_directions_out_of_subspace, |
| ) |
|
|
|
|
| class DirectionTests(unittest.TestCase): |
| def test_benign_subspace_projection_removes_principal_good_direction(self): |
| good_residuals = torch.tensor( |
| [ |
| [[3.0, 0.0, 0.0]], |
| [[1.0, 0.0, 0.0]], |
| [[-1.0, 0.0, 0.0]], |
| [[-3.0, 0.0, 0.0]], |
| ] |
| ) |
| basis = compute_benign_subspace_basis(good_residuals, rank=1) |
| directions = torch.tensor([[1.0, 1.0, 0.0]]) |
|
|
| projected = project_directions_out_of_subspace(directions, basis) |
|
|
| self.assertAlmostEqual(projected[0, 0].item(), 0.0, places=5) |
| self.assertAlmostEqual(projected[0, 1].item(), 1.0, places=5) |
| self.assertAlmostEqual(projected[0, 2].item(), 0.0, places=5) |
|
|
| def test_zero_rank_benign_subspace_is_disabled(self): |
| good_residuals = torch.tensor( |
| [ |
| [[1.0, 0.0]], |
| [[-1.0, 0.0]], |
| ] |
| ) |
| basis = compute_benign_subspace_basis(good_residuals, rank=0) |
| directions = torch.tensor([[0.6, 0.8]]) |
|
|
| projected = project_directions_out_of_subspace(directions, basis) |
|
|
| self.assertTrue(torch.equal(projected, directions)) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|