Transcendental-Programmer
commited on
Commit
·
341b6b4
1
Parent(s):
e3af1ef
Add unit tests for core modules: latent explorer, attribute directions, and custom loss
Browse files- tests/test_attribute_directions.py +23 -0
- tests/test_custom_loss.py +28 -0
- tests/test_latent_explorer.py +44 -0
tests/test_attribute_directions.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import numpy as np
|
| 3 |
+
from faceforge_core.attribute_directions import LatentDirectionFinder
|
| 4 |
+
|
| 5 |
+
class TestLatentDirectionFinder(unittest.TestCase):
|
| 6 |
+
def setUp(self):
|
| 7 |
+
# 100 samples, 5D latent
|
| 8 |
+
self.latents = np.random.randn(100, 5)
|
| 9 |
+
self.labels = [0]*50 + [1]*50
|
| 10 |
+
self.finder = LatentDirectionFinder(self.latents)
|
| 11 |
+
|
| 12 |
+
def test_pca_direction(self):
|
| 13 |
+
components, explained = self.finder.pca_direction(n_components=2)
|
| 14 |
+
self.assertEqual(components.shape, (2, 5))
|
| 15 |
+
self.assertEqual(explained.shape, (2,))
|
| 16 |
+
|
| 17 |
+
def test_classifier_direction(self):
|
| 18 |
+
direction = self.finder.classifier_direction(self.labels)
|
| 19 |
+
self.assertEqual(direction.shape, (5,))
|
| 20 |
+
self.assertAlmostEqual(np.linalg.norm(direction), 1.0, places=5)
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
unittest.main()
|
tests/test_custom_loss.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import torch
|
| 3 |
+
from faceforge_core.custom_loss import attribute_preserving_loss
|
| 4 |
+
|
| 5 |
+
class TestAttributePreservingLoss(unittest.TestCase):
|
| 6 |
+
def setUp(self):
|
| 7 |
+
self.generated = torch.ones((2, 3, 4, 4))
|
| 8 |
+
self.original = torch.zeros((2, 3, 4, 4))
|
| 9 |
+
self.y_target = torch.ones((2, 1))
|
| 10 |
+
self.attr_predictor = lambda x: torch.ones((2, 1))
|
| 11 |
+
|
| 12 |
+
def test_loss_value(self):
|
| 13 |
+
loss = attribute_preserving_loss(
|
| 14 |
+
self.generated, self.original, self.attr_predictor, self.y_target, lambda_pred=2.0, lambda_recon=3.0
|
| 15 |
+
)
|
| 16 |
+
# pred_loss = 0, recon_loss = mean((1-0)^2) = 1
|
| 17 |
+
self.assertAlmostEqual(loss.item(), 3.0)
|
| 18 |
+
|
| 19 |
+
def test_loss_with_nonzero_pred(self):
|
| 20 |
+
attr_predictor = lambda x: torch.zeros((2, 1))
|
| 21 |
+
loss = attribute_preserving_loss(
|
| 22 |
+
self.generated, self.original, attr_predictor, self.y_target, lambda_pred=2.0, lambda_recon=3.0
|
| 23 |
+
)
|
| 24 |
+
# pred_loss = mean((0-1)^2) = 1, recon_loss = 1
|
| 25 |
+
self.assertAlmostEqual(loss.item(), 5.0)
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
unittest.main()
|
tests/test_latent_explorer.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import numpy as np
|
| 3 |
+
from faceforge_core.latent_explorer import LatentSpaceExplorer, LatentPoint
|
| 4 |
+
|
| 5 |
+
class TestLatentSpaceExplorer(unittest.TestCase):
|
| 6 |
+
def setUp(self):
|
| 7 |
+
self.explorer = LatentSpaceExplorer()
|
| 8 |
+
self.dummy_encoding = np.array([1.0, 2.0])
|
| 9 |
+
|
| 10 |
+
def test_add_point(self):
|
| 11 |
+
self.explorer.add_point("test", self.dummy_encoding, (0.5, 0.5))
|
| 12 |
+
self.assertEqual(len(self.explorer.points), 1)
|
| 13 |
+
self.assertEqual(self.explorer.points[0].text, "test")
|
| 14 |
+
np.testing.assert_array_equal(self.explorer.points[0].encoding, self.dummy_encoding)
|
| 15 |
+
self.assertEqual(self.explorer.points[0].xy_pos, (0.5, 0.5))
|
| 16 |
+
|
| 17 |
+
def test_delete_point(self):
|
| 18 |
+
self.explorer.add_point("test", self.dummy_encoding)
|
| 19 |
+
self.explorer.delete_point(0)
|
| 20 |
+
self.assertEqual(len(self.explorer.points), 0)
|
| 21 |
+
|
| 22 |
+
def test_modify_point(self):
|
| 23 |
+
self.explorer.add_point("test", self.dummy_encoding)
|
| 24 |
+
new_encoding = np.array([3.0, 4.0])
|
| 25 |
+
self.explorer.modify_point(0, "new", new_encoding)
|
| 26 |
+
self.assertEqual(self.explorer.points[0].text, "new")
|
| 27 |
+
np.testing.assert_array_equal(self.explorer.points[0].encoding, new_encoding)
|
| 28 |
+
|
| 29 |
+
def test_sample_encoding_distance(self):
|
| 30 |
+
self.explorer.add_point("a", np.array([1.0, 0.0]), (0.0, 0.0))
|
| 31 |
+
self.explorer.add_point("b", np.array([0.0, 1.0]), (1.0, 0.0))
|
| 32 |
+
sampled = self.explorer.sample_encoding((0.5, 0.0), mode="distance")
|
| 33 |
+
self.assertIsNotNone(sampled)
|
| 34 |
+
self.assertEqual(sampled.shape, (2,))
|
| 35 |
+
|
| 36 |
+
def test_sample_encoding_circle(self):
|
| 37 |
+
self.explorer.add_point("a", np.array([1.0, 0.0]), (1.0, 0.0))
|
| 38 |
+
self.explorer.add_point("b", np.array([0.0, 1.0]), (0.0, 1.0))
|
| 39 |
+
sampled = self.explorer.sample_encoding((1.0, 1.0), mode="circle")
|
| 40 |
+
self.assertIsNotNone(sampled)
|
| 41 |
+
self.assertEqual(sampled.shape, (2,))
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
unittest.main()
|