| import unittest |
| import pytest |
|
|
| pytest.importorskip("torch") |
| pytest.importorskip("tensorly") |
| import torch |
| import numpy as np |
| import tensorly as tl |
| from tensorly.cp_tensor import cp_to_tensor |
| from tensorly.tucker_tensor import tucker_to_tensor |
| from tensorly.tt_tensor import tt_to_tensor |
| from tensorly.tr_tensor import tr_to_tensor |
| from typing import List, Tuple, Union, Dict |
| import sys |
| import os |
|
|
| try: |
| import htensor |
| HTENSOR_AVAILABLE = True |
| except ImportError: |
| HTENSOR_AVAILABLE = False |
|
|
| from scipy.fft import fft, ifft |
|
|
| |
| |
|
|
| from tensorus.tensor_ops import TensorOps |
| from tensorus.tensor_decompositions import TensorDecompositionOps |
|
|
| class TestTensorOps(unittest.TestCase): |
|
|
| |
| def test_add_tensor_tensor(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| t2 = torch.tensor([[5., 6.], [7., 8.]]) |
| expected = torch.tensor([[6., 8.], [10., 12.]]) |
| result = TensorOps.add(t1, t2) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_add_tensor_scalar(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| scalar = 10. |
| expected = torch.tensor([[11., 12.], [13., 14.]]) |
| result = TensorOps.add(t1, scalar) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_add_type_error(self): |
| t1 = torch.tensor([1., 2.]) |
| with self.assertRaises(TypeError): |
| TensorOps.add(t1, "not_a_tensor_or_scalar") |
|
|
| def test_subtract_tensor_tensor(self): |
| t1 = torch.tensor([[5., 6.], [7., 8.]]) |
| t2 = torch.tensor([[1., 2.], [3., 4.]]) |
| expected = torch.tensor([[4., 4.], [4., 4.]]) |
| result = TensorOps.subtract(t1, t2) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_subtract_tensor_scalar(self): |
| t1 = torch.tensor([[5., 6.], [7., 8.]]) |
| expected = torch.tensor([[4., 5.], [6., 7.]]) |
| result = TensorOps.subtract(t1, 1.0) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_multiply_tensor_tensor(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| t2 = torch.tensor([[2., 2.], [2., 2.]]) |
| expected = torch.tensor([[2., 4.], [6., 8.]]) |
| result = TensorOps.multiply(t1, t2) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_multiply_tensor_scalar(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| expected = torch.tensor([[2., 4.], [6., 8.]]) |
| result = TensorOps.multiply(t1, 2.0) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_divide_tensor_tensor(self): |
| t1 = torch.tensor([[10., 20.], [30., 40.]]) |
| t2 = torch.tensor([[2., 5.], [3., 4.]]) |
| expected = torch.tensor([[5., 4.], [10., 10.]]) |
| result = TensorOps.divide(t1, t2) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_divide_tensor_scalar(self): |
| t1 = torch.tensor([[10., 20.], [30., 40.]]) |
| scalar = 10. |
| expected = torch.tensor([[1., 2.], [3., 4.]]) |
| result = TensorOps.divide(t1, scalar) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_divide_by_zero_scalar(self): |
| t1 = torch.tensor([[10., 20.], [30., 40.]]) |
| scalar_zero = 0 |
| with self.assertRaises(ValueError): |
| TensorOps.divide(t1, scalar_zero) |
|
|
| def test_divide_by_zero_tensor(self): |
| t1 = torch.tensor([[10., 20.], [30., 40.]]) |
| t_zero = torch.tensor([[1., 0.], [3., 1.]]) |
| |
| |
| expected_output = torch.divide(t1, t_zero) |
| result = TensorOps.divide(t1, t_zero) |
| self.assertTrue(torch.equal(result, expected_output)) |
| |
|
|
| |
| def test_matmul_valid(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| t2 = torch.tensor([[5., 6.], [7., 8.]]) |
| expected = torch.matmul(t1, t2) |
| result = TensorOps.matmul(t1, t2) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_matmul_shape_mismatch(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| t2_wrong_shape = torch.tensor([[5., 6., 7.], [8., 9., 10.]]) |
| |
| |
| t2_fail = torch.tensor([[1.,2.,3.],[4.,5.,6.],[7.,8.,9.]]) |
| with self.assertRaises(ValueError): |
| TensorOps.matmul(t1, t2_fail) |
|
|
| def test_matmul_ndim_error(self): |
| t1 = torch.tensor(1.) |
| t2 = torch.tensor([1.,2.]) |
| with self.assertRaises(ValueError): |
| TensorOps.matmul(t1,t2) |
|
|
| def test_outer_valid(self): |
| t1 = torch.tensor([1., 2.]) |
| t2 = torch.tensor([3., 4., 5.]) |
| expected = torch.outer(t1, t2) |
| result = TensorOps.outer(t1, t2) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_outer_invalid_rank(self): |
| t1 = torch.tensor([[1., 2.]]) |
| t2 = torch.tensor([1., 2.]) |
| with self.assertRaises(ValueError): |
| TensorOps.outer(t1, t2) |
|
|
| def test_cross_valid(self): |
| t1 = torch.tensor([1., 0., 0.]) |
| t2 = torch.tensor([0., 1., 0.]) |
| expected = torch.cross(t1, t2, dim=0) |
| result = TensorOps.cross(t1, t2) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_cross_invalid_dim_size(self): |
| t1 = torch.tensor([1., 2., 3., 4.]) |
| t2 = torch.tensor([4., 5., 6., 7.]) |
| with self.assertRaises(ValueError): |
| TensorOps.cross(t1, t2, dim=0) |
|
|
| def test_dot_valid(self): |
| t1 = torch.tensor([1., 2., 3.]) |
| t2 = torch.tensor([4., 5., 6.]) |
| expected = torch.dot(t1, t2) |
| result = TensorOps.dot(t1, t2) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_dot_shape_mismatch(self): |
| t1 = torch.tensor([1., 2., 3.]) |
| t2 = torch.tensor([1., 2.]) |
| with self.assertRaises(ValueError): |
| TensorOps.dot(t1, t2) |
|
|
| def test_dot_invalid_rank(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| t2 = torch.tensor([1., 2.]) |
| with self.assertRaises(ValueError): |
| TensorOps.dot(t1, t2) |
|
|
|
|
| |
| def test_sum_all_elements(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| expected = torch.tensor(10.) |
| result = TensorOps.sum(t1) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_sum_along_dimension(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| |
| expected_dim0 = torch.tensor([4., 6.]) |
| result_dim0 = TensorOps.sum(t1, dim=0) |
| self.assertTrue(torch.equal(result_dim0, expected_dim0)) |
| |
| expected_dim1 = torch.tensor([3., 7.]) |
| result_dim1 = TensorOps.sum(t1, dim=1) |
| self.assertTrue(torch.equal(result_dim1, expected_dim1)) |
|
|
| def test_sum_keepdim(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| expected_dim0_keepdim = torch.tensor([[4., 6.]]) |
| result_dim0_keepdim = TensorOps.sum(t1, dim=0, keepdim=True) |
| self.assertTrue(torch.equal(result_dim0_keepdim, expected_dim0_keepdim)) |
|
|
| def test_mean_operations(self): |
| t = torch.tensor([[1., 2.], [3., 4.]]) |
| self.assertTrue(torch.allclose(TensorOps.mean(t), torch.mean(t))) |
| self.assertTrue(torch.allclose(TensorOps.mean(t, dim=0), torch.mean(t, dim=0))) |
|
|
| def test_min_and_max(self): |
| t = torch.tensor([[1., 3.], [2., 0.]]) |
| val, idx = TensorOps.min(t, dim=1) |
| expected_val, expected_idx = torch.min(t, dim=1) |
| self.assertTrue(torch.equal(val, expected_val)) |
| self.assertTrue(torch.equal(idx, expected_idx)) |
|
|
| val, idx = TensorOps.max(t, dim=0) |
| expected_val, expected_idx = torch.max(t, dim=0) |
| self.assertTrue(torch.equal(val, expected_val)) |
| self.assertTrue(torch.equal(idx, expected_idx)) |
|
|
| |
| def test_power_scalar_exponent(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| exponent = 2.0 |
| expected = torch.tensor([[1., 4.], [9., 16.]]) |
| result = TensorOps.power(t1, exponent) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| t2 = torch.tensor([1, 2, 3]) |
| exponent_int = 3 |
| expected_int = torch.tensor([1, 8, 27]) |
| result_int = TensorOps.power(t2, exponent_int) |
| self.assertTrue(torch.equal(result_int, expected_int.float())) |
|
|
| def test_power_tensor_exponent(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| t_exponent = torch.tensor([[2., 3.], [1., 2.]]) |
| expected = torch.tensor([[1., 8.], [3., 16.]]) |
| result = TensorOps.power(t1, t_exponent) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_power_type_error(self): |
| t1 = torch.tensor([1., 2.]) |
| with self.assertRaises(TypeError): |
| TensorOps.power(t1, "not_a_number_or_tensor") |
|
|
| with self.assertRaises(TypeError): |
| TensorOps.power("not_a_tensor", 2.0) |
|
|
| def test_power_runtime_error_shape_mismatch(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| t_exponent_wrong_shape = torch.tensor([2., 3.]) |
| expected = torch.pow(t1, t_exponent_wrong_shape) |
| result = TensorOps.power(t1, t_exponent_wrong_shape) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_log_valid_inputs(self): |
| t1 = torch.tensor([[1., 2.], [3., 4.]]) |
| expected = torch.log(t1) |
| result = TensorOps.log(t1) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| t2 = torch.tensor([10., 20., 30.]) |
| expected2 = torch.log(t2) |
| result2 = TensorOps.log(t2) |
| self.assertTrue(torch.equal(result2, expected2)) |
|
|
| def test_log_non_positive_inputs(self): |
| t_with_zero = torch.tensor([1., 0., 3.]) |
| |
| |
| |
| expected_zero = torch.log(t_with_zero) |
| |
| |
| |
| |
| result_zero = TensorOps.log(t_with_zero) |
| self.assertTrue(torch.allclose(result_zero, expected_zero, equal_nan=True)) |
|
|
| t_with_negative = torch.tensor([1., -2., 3.]) |
| expected_negative = torch.log(t_with_negative) |
| result_negative = TensorOps.log(t_with_negative) |
| self.assertTrue(torch.allclose(result_negative, expected_negative, equal_nan=True)) |
|
|
| def test_log_type_error(self): |
| with self.assertRaises(TypeError): |
| TensorOps.log("not_a_tensor") |
|
|
| |
|
|
| def test_compute_gradient(self): |
| x = torch.tensor(2.0, requires_grad=True) |
| def f(t): |
| return t * t |
| grad = TensorOps.compute_gradient(f, x) |
| self.assertTrue(torch.allclose(grad, torch.tensor(4.0))) |
|
|
| def test_compute_jacobian(self): |
| x = torch.tensor([1.0, 2.0]) |
| def f(t): |
| return torch.stack([t[0] + t[1], t[0] * t[1]]) |
| jac = TensorOps.compute_jacobian(f, x) |
| expected = torch.tensor([[1., 1.], [2.0, 1.0]]) |
| self.assertTrue(torch.allclose(jac, expected)) |
|
|
| def test_matrix_eigendecomposition(self): |
| A = torch.tensor([[2., 0.], [0., 3.]]) |
| vals, vecs = TensorOps.matrix_eigendecomposition(A) |
| self.assertTrue(torch.allclose(torch.sort(vals.real).values, torch.tensor([2., 3.]))) |
| self.assertTrue(torch.allclose(torch.abs(vecs), torch.eye(2))) |
|
|
| def test_matrix_trace_and_tensor_trace(self): |
| A = torch.tensor([[1., 2.], [3., 4.]]) |
| self.assertEqual(TensorOps.matrix_trace(A).item(), 5.0) |
|
|
| T = torch.arange(24.).reshape(2, 3, 4) |
| with self.assertRaises(ValueError): |
| TensorOps.tensor_trace(T, axis1=0, axis2=1) |
|
|
| def test_tensor_trace_valid(self): |
| T = torch.arange(27.).reshape(3, 3, 3).float() |
| diag_sum0 = T.diagonal(dim1=0, dim2=1).sum(-1) |
| result = TensorOps.tensor_trace(T, axis1=0, axis2=1) |
| self.assertTrue(torch.equal(result, diag_sum0)) |
|
|
| def test_svd_reconstruction(self): |
| A = torch.tensor([[3., 1.], [1., 3.]], dtype=torch.float32) |
| U, S, Vh = TensorOps.svd(A) |
| reconstructed = U @ torch.diag(S) @ Vh |
| self.assertTrue(torch.allclose(reconstructed, A)) |
|
|
| def test_qr_reconstruction(self): |
| A = torch.randn(4, 3) |
| Q, R = TensorOps.qr_decomposition(A) |
| self.assertTrue(torch.allclose(Q @ R, A, atol=1e-5, rtol=1e-5)) |
|
|
| def test_lu_decomposition(self): |
| A = torch.tensor([[4., 3.], [6., 3.]], dtype=torch.float32) |
| P, L, U = TensorOps.lu_decomposition(A) |
| self.assertTrue(torch.allclose(P @ A, L @ U)) |
|
|
| def test_cholesky_valid(self): |
| B = torch.tensor([[2., 0.], [1., 1.]], dtype=torch.float32) |
| A = B @ B.t() |
| L = TensorOps.cholesky_decomposition(A) |
| self.assertTrue(torch.allclose(L @ L.t(), A)) |
|
|
| def test_cholesky_non_symmetric_error(self): |
| A = torch.tensor([[1., 2.], [3., 4.]], dtype=torch.float32) |
| with self.assertRaises(ValueError): |
| TensorOps.cholesky_decomposition(A) |
|
|
| def test_matrix_inverse(self): |
| A = torch.tensor([[4., 7.], [2., 6.]], dtype=torch.float32) |
| inv = TensorOps.matrix_inverse(A) |
| expected_identity = torch.eye(2, dtype=torch.float32) |
| actual_result = A @ inv |
| self.assertEqual(inv.dtype, A.dtype) |
| self.assertTrue(torch.allclose(actual_result, expected_identity)) |
|
|
| def test_matrix_inverse_non_square_error(self): |
| A = torch.randn(2, 3) |
| with self.assertRaises(ValueError): |
| TensorOps.matrix_inverse(A) |
|
|
| def test_matrix_determinant_and_rank(self): |
| A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32) |
| det = TensorOps.matrix_determinant(A) |
| rank = TensorOps.matrix_rank(A) |
| self.assertEqual(det.item(), 0.0) |
| self.assertEqual(rank.item(), 1) |
|
|
| def test_convolutions(self): |
| sig = torch.tensor([1., 2., 3.]) |
| ker = torch.tensor([1., 1.]) |
| conv_valid = TensorOps.convolve_1d(sig, ker, mode="valid") |
| self.assertTrue(torch.allclose(conv_valid, torch.tensor([3., 5.]))) |
|
|
| img = torch.tensor([[1., 2.], [3., 4.]]) |
| k = torch.tensor([[1., 0.], [0., 1.]]) |
| conv2d_same = TensorOps.convolve_2d(img, k, mode="same") |
| self.assertEqual(conv2d_same.shape, torch.Size([3, 3])) |
|
|
| def test_convolve_3d(self): |
| vol = torch.arange(27.).reshape(3, 3, 3) |
| ker = torch.ones((2, 2, 2)) |
| expected = torch.nn.functional.conv3d( |
| vol.unsqueeze(0).unsqueeze(0), |
| ker.flip(0, 1, 2).unsqueeze(0).unsqueeze(0), |
| ).squeeze(0).squeeze(0) |
| result = TensorOps.convolve_3d(vol, ker, mode="valid") |
| self.assertTrue(torch.allclose(result, expected)) |
|
|
| ker_same = torch.ones((3, 3, 3)) |
| conv_same = TensorOps.convolve_3d(vol, ker_same, mode="same") |
| self.assertEqual(conv_same.shape, vol.shape) |
|
|
| def test_statistics(self): |
| t = torch.tensor([[1., 2.], [3., 4.]]) |
| self.assertTrue(torch.allclose(TensorOps.variance(t), torch.var(t, unbiased=False))) |
| cov = TensorOps.covariance(t) |
| import numpy as np |
| expected_cov = torch.from_numpy(np.cov(t.numpy(), rowvar=True, bias=False)).float() |
| self.assertTrue(torch.allclose(cov, expected_cov)) |
| corr = TensorOps.correlation(t) |
| expected_corr = torch.from_numpy(np.corrcoef(t.numpy(), rowvar=True)).float() |
| self.assertTrue(torch.allclose(corr, expected_corr)) |
| self.assertTrue(torch.allclose(TensorOps.frobenius_norm(t), torch.linalg.norm(t, "fro"))) |
| self.assertTrue(torch.allclose(TensorOps.l1_norm(t), torch.sum(torch.abs(t)))) |
| self.assertTrue(torch.allclose(TensorOps.l2_norm(t), torch.linalg.norm(t, 2))) |
| self.assertTrue(torch.allclose(TensorOps.p_norm(t, 2), torch.linalg.norm(t, 2))) |
| m = torch.tensor([[1., 2.], [3., 4.]]) |
| self.assertTrue(torch.allclose(TensorOps.nuclear_norm(m), torch.linalg.matrix_norm(m, ord="nuc"))) |
| with self.assertRaises(ValueError): |
| TensorOps.nuclear_norm(torch.tensor([1., 2., 3.])) |
|
|
| def test_std_default(self): |
| t = torch.tensor([[1., 2.], [3., 4.]]) |
| expected = torch.std(t, unbiased=False) |
| result = TensorOps.std(t) |
| self.assertTrue(torch.allclose(result, expected)) |
|
|
| def test_std_dim_unbiased_keepdim(self): |
| t = torch.tensor([[1., 2.], [3., 4.]]) |
| expected = torch.std(t, dim=0, unbiased=True, keepdim=True) |
| result = TensorOps.std(t, dim=0, unbiased=True, keepdim=True) |
| self.assertTrue(torch.allclose(result, expected)) |
|
|
| def test_std_type_error(self): |
| with self.assertRaises(TypeError): |
| TensorOps.std("not_a_tensor") |
|
|
| |
| def test_flatten_default(self): |
| t = torch.arange(6).reshape(2, 3) |
| expected = torch.flatten(t) |
| result = TensorOps.flatten(t) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| def test_flatten_start_end(self): |
| t = torch.arange(24).reshape(2, 3, 4) |
| expected = torch.flatten(t, start_dim=1, end_dim=2) |
| result = TensorOps.flatten(t, start_dim=1, end_dim=2) |
| self.assertTrue(torch.equal(result, expected)) |
| self.assertEqual(result.shape, (2, 12)) |
|
|
| def test_squeeze_default(self): |
| t = torch.zeros(1, 3, 1, 4) |
| expected = torch.squeeze(t) |
| result = TensorOps.squeeze(t) |
| self.assertTrue(torch.equal(result, expected)) |
| self.assertEqual(result.shape, (3, 4)) |
|
|
| def test_squeeze_dim(self): |
| t = torch.zeros(1, 3, 1, 4) |
| expected = torch.squeeze(t, dim=2) |
| result = TensorOps.squeeze(t, dim=2) |
| self.assertTrue(torch.equal(result, expected)) |
| self.assertEqual(result.shape, (1, 3, 4)) |
|
|
| def test_unsqueeze(self): |
| t = torch.randn(3, 4) |
| expected = torch.unsqueeze(t, dim=0) |
| result = TensorOps.unsqueeze(t, dim=0) |
| self.assertTrue(torch.equal(result, expected)) |
| self.assertEqual(result.shape, (1, 3, 4)) |
|
|
| def test_reshape_and_transpose(self): |
| t = torch.arange(6) |
| reshaped = TensorOps.reshape(t, (2, 3)) |
| self.assertTrue(torch.equal(reshaped, t.reshape(2, 3))) |
| with self.assertRaises(ValueError): |
| TensorOps.reshape(t, (4, 2)) |
|
|
| transposed = TensorOps.transpose(reshaped, 0, 1) |
| self.assertTrue(torch.equal(transposed, reshaped.t())) |
|
|
| def test_permute(self): |
| t = torch.arange(24).reshape(2, 3, 4) |
| permuted = TensorOps.permute(t, (1, 0, 2)) |
| self.assertTrue(torch.equal(permuted, t.permute(1, 0, 2))) |
| with self.assertRaises(ValueError): |
| TensorOps.permute(t, (0, 1)) |
|
|
| def test_concatenate_and_stack(self): |
| t1 = torch.ones(2, 2) |
| t2 = torch.zeros(2, 2) |
| cat_expected = torch.cat([t1, t2], dim=0) |
| cat_res = TensorOps.concatenate([t1, t2], dim=0) |
| self.assertTrue(torch.equal(cat_res, cat_expected)) |
|
|
| stack_expected = torch.stack([t1, t2], dim=0) |
| stack_res = TensorOps.stack([t1, t2], dim=0) |
| self.assertTrue(torch.equal(stack_res, stack_expected)) |
|
|
| def test_einsum(self): |
| a = torch.tensor([[1., 2.], [3., 4.]]) |
| b = torch.tensor([[5., 6.], [7., 8.]]) |
| expected = torch.einsum('ij,jk->ik', a, b) |
| result = TensorOps.einsum('ij,jk->ik', a, b) |
| self.assertTrue(torch.equal(result, expected)) |
|
|
| |
|
|
| def test_cp_decomposition_valid_low_rank(self): |
| """Test CP decomposition with a known low-rank tensor.""" |
| shape = (3, 4, 5) |
| rank = 2 |
|
|
| |
| true_weights_np = np.random.rand(rank).astype(np.float32) |
| true_factors_np = [np.random.rand(s, rank).astype(np.float32) for s in shape] |
|
|
| |
| |
| |
|
|
| low_rank_tensor_tl = tl.cp_to_tensor((true_weights_np, true_factors_np)) |
| low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float() |
|
|
| weights, factors = TensorDecompositionOps.cp_decomposition(low_rank_tensor_torch, rank) |
|
|
| self.assertIsInstance(weights, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| self.assertEqual(weights.ndim, 1) |
| self.assertEqual(weights.size(0), rank) |
| self.assertEqual(len(factors), low_rank_tensor_torch.ndim) |
| for i in range(low_rank_tensor_torch.ndim): |
| self.assertEqual(factors[i].shape, (low_rank_tensor_torch.shape[i], rank)) |
|
|
| |
| np_weights_res = weights.detach().cpu().numpy() |
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
|
|
| reconstructed_tl_tensor = tl.cp_to_tensor((np_weights_res, np_factors_res)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
|
|
| |
| error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch) |
| self.assertAlmostEqual(error.item(), 0.0, delta=3e-2) |
|
|
| def test_cp_decomposition_random_tensor(self): |
| """Test CP decomposition with a random tensor.""" |
| sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32) |
| rank = 3 |
|
|
| weights, factors = TensorDecompositionOps.cp_decomposition(sample_tensor, rank) |
|
|
| self.assertIsInstance(weights, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| self.assertEqual(weights.ndim, 1) |
| self.assertEqual(weights.size(0), rank) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| for i in range(sample_tensor.ndim): |
| self.assertEqual(factors[i].shape, (sample_tensor.shape[i], rank)) |
|
|
| |
| np_weights = weights.detach().cpu().numpy() |
| np_factors = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.cp_to_tensor((np_weights, np_factors)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
|
|
| error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor) |
| |
| |
| self.assertLess(error.item(), 0.8) |
|
|
| def test_cp_decomposition_matrix(self): |
| """Test CP decomposition on a 2D tensor (matrix).""" |
| matrix_data = torch.rand(6, 7, dtype=torch.float32) |
| rank = 2 |
|
|
| weights, factors = TensorDecompositionOps.cp_decomposition(matrix_data, rank) |
|
|
| self.assertIsInstance(weights, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| self.assertEqual(weights.ndim, 1) |
| self.assertEqual(weights.size(0), rank) |
| self.assertEqual(len(factors), matrix_data.ndim) |
| self.assertEqual(factors[0].shape, (matrix_data.shape[0], rank)) |
| self.assertEqual(factors[1].shape, (matrix_data.shape[1], rank)) |
|
|
| |
| np_weights = weights.detach().cpu().numpy() |
| np_factors = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.cp_to_tensor((np_weights, np_factors)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
| error = torch.norm(matrix_data - reconstructed_torch_tensor) / torch.norm(matrix_data) |
| self.assertLess(error.item(), 0.8) |
|
|
| def test_cp_decomposition_invalid_rank(self): |
| """Test CP decomposition with invalid ranks.""" |
| sample_tensor = torch.rand(2, 2, 2, dtype=torch.float32) |
| with self.assertRaisesRegex(ValueError, "Rank must be a positive integer"): |
| TensorDecompositionOps.cp_decomposition(sample_tensor, 0) |
| with self.assertRaisesRegex(ValueError, "Rank must be a positive integer"): |
| TensorDecompositionOps.cp_decomposition(sample_tensor, -1) |
| with self.assertRaisesRegex(ValueError, "Rank must be a positive integer"): |
| TensorDecompositionOps.cp_decomposition(sample_tensor, 1.5) |
|
|
| def test_cp_decomposition_invalid_tensor_ndim(self): |
| """Test CP decomposition with tensor of invalid number of dimensions.""" |
| one_d_tensor = torch.rand(5, dtype=torch.float32) |
| with self.assertRaisesRegex(ValueError, "CP decomposition requires a tensor with at least 2 dimensions"): |
| TensorDecompositionOps.cp_decomposition(one_d_tensor, 2) |
|
|
| def test_cp_decomposition_type_error(self): |
| """Test CP decomposition with non-tensor input.""" |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.cp_decomposition("not a tensor", 2) |
|
|
| |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.cp_decomposition([1,2,3], 2) |
|
|
| |
|
|
| def test_tucker_decomposition_valid_low_rank(self): |
| """Test Tucker decomposition with a known low-rank tensor.""" |
| shape = (4, 5, 6) |
| ranks = [2, 3, 3] |
|
|
| |
| true_core_np = np.random.rand(*ranks).astype(np.float32) |
| true_factors_np = [np.random.rand(shape[i], ranks[i]).astype(np.float32) for i in range(len(shape))] |
|
|
| low_rank_tensor_tl = tl.tucker_to_tensor((true_core_np, true_factors_np)) |
| low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float() |
|
|
| core, factors = TensorDecompositionOps.tucker_decomposition(low_rank_tensor_torch, ranks) |
|
|
| self.assertIsInstance(core, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| self.assertEqual(core.shape, tuple(ranks)) |
| self.assertEqual(len(factors), low_rank_tensor_torch.ndim) |
| for i in range(low_rank_tensor_torch.ndim): |
| self.assertEqual(factors[i].shape, (low_rank_tensor_torch.shape[i], ranks[i])) |
|
|
| |
| np_core_res = core.detach().cpu().numpy() |
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
|
|
| reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
|
|
| error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch) |
| self.assertAlmostEqual(error.item(), 0.0, delta=1e-5) |
|
|
| def test_tucker_decomposition_random_tensor(self): |
| """Test Tucker decomposition with a random tensor.""" |
| sample_tensor = torch.rand(4, 5, 6, dtype=torch.float32) |
| ranks = [2, 3, 3] |
|
|
| core, factors = TensorDecompositionOps.tucker_decomposition(sample_tensor, ranks) |
|
|
| self.assertIsInstance(core, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| self.assertEqual(core.shape, tuple(ranks)) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| for i in range(sample_tensor.ndim): |
| self.assertEqual(factors[i].shape, (sample_tensor.shape[i], ranks[i])) |
|
|
| np_core_res = core.detach().cpu().numpy() |
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
|
|
| error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor) |
| self.assertLess(error.item(), 0.7) |
|
|
| def test_tucker_decomposition_matrix(self): |
| """Test Tucker decomposition on a 2D tensor (matrix) using known low-rank data.""" |
| shape = (5, 6) |
| ranks = [2, 3] |
|
|
| true_core_np = np.random.rand(*ranks).astype(np.float32) |
| true_factors_np = [np.random.rand(shape[i], ranks[i]).astype(np.float32) for i in range(len(shape))] |
|
|
| low_rank_matrix_tl = tl.tucker_to_tensor((true_core_np, true_factors_np)) |
| low_rank_matrix_torch = torch.from_numpy(low_rank_matrix_tl).float() |
|
|
| core, factors = TensorDecompositionOps.tucker_decomposition(low_rank_matrix_torch, ranks) |
|
|
| self.assertIsInstance(core, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| self.assertEqual(core.shape, tuple(ranks)) |
| self.assertEqual(len(factors), low_rank_matrix_torch.ndim) |
| for i in range(low_rank_matrix_torch.ndim): |
| self.assertEqual(factors[i].shape, (low_rank_matrix_torch.shape[i], ranks[i])) |
|
|
| np_core_res = core.detach().cpu().numpy() |
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
| error = torch.norm(low_rank_matrix_torch - reconstructed_torch_tensor) / torch.norm(low_rank_matrix_torch) |
| self.assertAlmostEqual(error.item(), 0.0, delta=1e-5) |
|
|
| def test_tucker_decomposition_invalid_ranks_list_length(self): |
| """Test Tucker decomposition with incorrect length of ranks list.""" |
| sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32) |
| invalid_ranks = [2, 2] |
| with self.assertRaisesRegex(ValueError, "Length of ranks list .* must match tensor dimensionality"): |
| TensorDecompositionOps.tucker_decomposition(sample_tensor, invalid_ranks) |
|
|
| def test_tucker_decomposition_invalid_rank_value_type(self): |
| """Test Tucker decomposition with non-integer rank in list.""" |
| sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32) |
| invalid_ranks = [2, 2.5, 2] |
| with self.assertRaisesRegex(ValueError, "Ranks must be a list of positive integers"): |
| TensorDecompositionOps.tucker_decomposition(sample_tensor, invalid_ranks) |
|
|
|
|
| def test_tucker_decomposition_invalid_rank_value_zero(self): |
| """Test Tucker decomposition with a zero rank.""" |
| sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32) |
| invalid_ranks = [2, 0, 2] |
| with self.assertRaisesRegex(ValueError, "Ranks must be a list of positive integers"): |
| TensorDecompositionOps.tucker_decomposition(sample_tensor, invalid_ranks) |
|
|
| def test_tucker_decomposition_invalid_rank_value_exceeds_dim(self): |
| """Test Tucker decomposition with a rank value exceeding tensor dimension.""" |
| sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32) |
| invalid_ranks = [2, 5, 2] |
| with self.assertRaisesRegex(ValueError, "Rank for mode 1 .* is out of valid range"): |
| TensorDecompositionOps.tucker_decomposition(sample_tensor, invalid_ranks) |
|
|
| def test_tucker_decomposition_type_error(self): |
| """Test Tucker decomposition with non-tensor input.""" |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.tucker_decomposition("not a tensor", [2,2]) |
|
|
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.tucker_decomposition([1,2,3], [1]) |
|
|
| |
|
|
| def test_hosvd_valid_3d(self): |
| """Test HOSVD on a 3D tensor.""" |
| sample_tensor = torch.rand(3, 4, 2, dtype=torch.float32) |
|
|
| core, factors = TensorDecompositionOps.hosvd(sample_tensor) |
|
|
| self.assertIsInstance(core, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| self.assertEqual(core.shape, sample_tensor.shape) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| for i in range(sample_tensor.ndim): |
| self.assertEqual(factors[i].shape, (sample_tensor.shape[i], sample_tensor.shape[i])) |
| |
| eye = torch.eye(factors[i].shape[1], dtype=factors[i].dtype, device=factors[i].device) |
| self.assertTrue(torch.allclose(torch.matmul(factors[i].T, factors[i]), eye, atol=1e-5)) |
|
|
| |
| np_core_res = core.detach().cpu().numpy() |
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
|
|
| error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor) |
| self.assertAlmostEqual(error.item(), 0.0, delta=1e-5) |
|
|
| def test_hosvd_valid_matrix(self): |
| """Test HOSVD on a 2D tensor (matrix).""" |
| sample_tensor = torch.rand(5, 3, dtype=torch.float32) |
|
|
| core, factors = TensorDecompositionOps.hosvd(sample_tensor) |
|
|
| self.assertIsInstance(core, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| self.assertEqual(core.shape, sample_tensor.shape) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| for i in range(sample_tensor.ndim): |
| self.assertEqual(factors[i].shape, (sample_tensor.shape[i], sample_tensor.shape[i])) |
| |
| eye = torch.eye(factors[i].shape[1], dtype=factors[i].dtype, device=factors[i].device) |
| self.assertTrue(torch.allclose(torch.matmul(factors[i].T, factors[i]), eye, atol=1e-5)) |
|
|
| |
| np_core_res = core.detach().cpu().numpy() |
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
|
|
| error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor) |
| self.assertAlmostEqual(error.item(), 0.0, delta=1e-5) |
|
|
| def test_hosvd_type_error(self): |
| """Test HOSVD with non-tensor input.""" |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.hosvd("not a tensor") |
|
|
| def test_hosvd_input_tensor_constraints(self): |
| """Test HOSVD with 0-dim (scalar) and 1-dim (vector) tensors.""" |
| scalar_tensor = torch.tensor(5.0).float() |
| vector_tensor = torch.rand(7, dtype=torch.float32) |
|
|
| with self.assertRaisesRegex(ValueError, "HOSVD requires a tensor with at least 2 dimensions"): |
| TensorDecompositionOps.hosvd(scalar_tensor) |
|
|
| with self.assertRaisesRegex(ValueError, "HOSVD requires a tensor with at least 2 dimensions"): |
| TensorDecompositionOps.hosvd(vector_tensor) |
|
|
| |
|
|
| def test_tt_decomposition_valid_3d_list_rank(self): |
| """Test TT decomposition on 3D tensor with list of internal ranks.""" |
| shape = (3, 4, 5) |
| internal_ranks = [2, 3] |
| full_ranks_for_check = [1] + internal_ranks + [1] |
|
|
| |
| |
| true_factors_np = [ |
| np.random.rand(full_ranks_for_check[0], shape[0], full_ranks_for_check[1]).astype(np.float32), |
| np.random.rand(full_ranks_for_check[1], shape[1], full_ranks_for_check[2]).astype(np.float32), |
| np.random.rand(full_ranks_for_check[2], shape[2], full_ranks_for_check[3]).astype(np.float32), |
| ] |
| low_rank_tensor_tl = tl.tt_to_tensor(true_factors_np) |
| low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float() |
|
|
| factors = TensorDecompositionOps.tt_decomposition(low_rank_tensor_torch, rank=internal_ranks) |
|
|
| self.assertIsInstance(factors, list) |
| self.assertEqual(len(factors), low_rank_tensor_torch.ndim) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| for i in range(len(factors)): |
| expected_shape = (full_ranks_for_check[i], shape[i], full_ranks_for_check[i+1]) |
| self.assertEqual(factors[i].shape, expected_shape) |
|
|
| |
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tt_to_tensor(np_factors_res) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
|
|
| error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch) |
| self.assertAlmostEqual(error.item(), 0.0, delta=1e-4) |
|
|
| def test_tt_decomposition_valid_3d_int_rank(self): |
| """Test TT decomposition on 3D tensor with integer max rank.""" |
| sample_tensor = torch.rand(3, 4, 2, dtype=torch.float32) |
| max_rank = 2 |
|
|
| factors = TensorDecompositionOps.tt_decomposition(sample_tensor, rank=max_rank) |
|
|
| self.assertIsInstance(factors, list) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| |
| self.assertEqual(factors[0].shape[0], 1) |
| self.assertEqual(factors[-1].shape[2], 1) |
| for i in range(len(factors)): |
| self.assertEqual(factors[i].shape[1], sample_tensor.shape[i]) |
| if i < len(factors) -1: |
| self.assertLessEqual(factors[i].shape[2], max_rank) |
| if i > 0: |
| self.assertLessEqual(factors[i].shape[0], max_rank) |
|
|
| |
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tt_to_tensor(np_factors_res) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
|
|
| error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor) |
| |
| self.assertLess(error.item(), 0.8) |
|
|
| def test_tt_decomposition_valid_matrix_list_rank(self): |
| """Test TT decomposition on a 2D matrix with a list rank.""" |
| shape = (5, 6) |
| internal_ranks = [3] |
| full_ranks_for_check = [1] + internal_ranks + [1] |
|
|
| true_factors_np = [ |
| np.random.rand(full_ranks_for_check[0], shape[0], full_ranks_for_check[1]).astype(np.float32), |
| np.random.rand(full_ranks_for_check[1], shape[1], full_ranks_for_check[2]).astype(np.float32), |
| ] |
| low_rank_tensor_tl = tl.tt_to_tensor(true_factors_np) |
| low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float() |
|
|
| factors = TensorDecompositionOps.tt_decomposition(low_rank_tensor_torch, rank=internal_ranks) |
|
|
| self.assertIsInstance(factors, list) |
| self.assertEqual(len(factors), low_rank_tensor_torch.ndim) |
| for i in range(len(factors)): |
| expected_shape = (full_ranks_for_check[i], shape[i], full_ranks_for_check[i+1]) |
| self.assertEqual(factors[i].shape, expected_shape) |
|
|
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tt_to_tensor(np_factors_res) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
| error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch) |
| self.assertAlmostEqual(error.item(), 0.0, delta=1e-4) |
|
|
| def test_tt_decomposition_1d_tensor_runtime_error(self): |
| """Test TT decomposition for 1D tensor, expecting RuntimeError due to TensorLy issue.""" |
| tensor_1d = torch.rand(10).float() |
| |
| |
| with self.assertRaisesRegex(RuntimeError, "TT decomposition failed"): |
| TensorDecompositionOps.tt_decomposition(tensor_1d, rank=1) |
|
|
| |
| with self.assertRaisesRegex(RuntimeError, "TT decomposition failed"): |
| TensorDecompositionOps.tt_decomposition(tensor_1d, rank=[]) |
|
|
|
|
| def test_tt_decomposition_invalid_rank_type(self): |
| """Test TT decomposition with invalid rank type.""" |
| sample_tensor = torch.rand(3,4,5).float() |
| with self.assertRaisesRegex(TypeError, "Rank must be an int or a list of ints"): |
| TensorDecompositionOps.tt_decomposition(sample_tensor, rank="invalid_rank_type") |
|
|
| def test_tt_decomposition_invalid_rank_list_length(self): |
| """Test TT decomposition with incorrect length of rank list for N>1D tensor.""" |
| sample_tensor = torch.rand(3,4,5).float() |
| invalid_ranks_list = [2,3,4] |
| with self.assertRaisesRegex(ValueError, "Rank list length must be tensor.ndim - 1"): |
| TensorDecompositionOps.tt_decomposition(sample_tensor, rank=invalid_ranks_list) |
|
|
| |
| tensor_1d = torch.rand(5).float() |
| invalid_ranks_for_1d = [1] |
| with self.assertRaisesRegex(ValueError, "For a 1D tensor, rank list must be empty for user input"): |
| TensorDecompositionOps.tt_decomposition(tensor_1d, rank=invalid_ranks_for_1d) |
|
|
|
|
| def test_tt_decomposition_invalid_rank_list_values(self): |
| """Test TT decomposition with non-positive values in rank list.""" |
| sample_tensor = torch.rand(3,4,5).float() |
| invalid_ranks_list = [2, 0] |
| with self.assertRaisesRegex(ValueError, "All ranks in the list must be positive integers"): |
| TensorDecompositionOps.tt_decomposition(sample_tensor, rank=invalid_ranks_list) |
|
|
| def test_tt_decomposition_invalid_rank_int_value(self): |
| """Test TT decomposition with non-positive integer rank.""" |
| sample_tensor = torch.rand(3,4,5).float() |
| invalid_rank_int = 0 |
| with self.assertRaisesRegex(ValueError, "If rank is an integer, it must be positive"): |
| TensorDecompositionOps.tt_decomposition(sample_tensor, rank=invalid_rank_int) |
|
|
| def test_tt_decomposition_invalid_tensor_ndim0(self): |
| """Test TT decomposition with a 0-dimensional (scalar) tensor.""" |
| scalar_tensor = torch.tensor(1.0).float() |
| with self.assertRaisesRegex(ValueError, "TT decomposition requires a tensor with at least 1 dimension"): |
| TensorDecompositionOps.tt_decomposition(scalar_tensor, rank=1) |
|
|
| def test_tt_decomposition_type_error_tensor(self): |
| """Test TT decomposition with non-tensor input.""" |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.tt_decomposition("not a tensor", rank=1) |
|
|
| |
|
|
| def test_tr_decomposition_valid_3d_list_rank(self): |
| """Test TR decomposition on 3D tensor with list of ranks.""" |
| shape = (3, 4, 5) |
| |
| |
| ranks_tr = [1, 2, 2] |
|
|
| |
| true_factors_np = [ |
| np.random.rand(ranks_tr[0], shape[0], ranks_tr[1]).astype(np.float32), |
| np.random.rand(ranks_tr[1], shape[1], ranks_tr[2]).astype(np.float32), |
| np.random.rand(ranks_tr[2], shape[2], ranks_tr[0]).astype(np.float32), |
| ] |
| low_rank_tensor_tl = tl.tr_to_tensor(true_factors_np) |
| low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float() |
|
|
| factors = TensorDecompositionOps.tr_decomposition(low_rank_tensor_torch, rank=ranks_tr) |
|
|
| self.assertIsInstance(factors, list) |
| self.assertEqual(len(factors), low_rank_tensor_torch.ndim) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| |
| self.assertEqual(factors[0].shape, (ranks_tr[0], shape[0], ranks_tr[1])) |
| self.assertEqual(factors[1].shape, (ranks_tr[1], shape[1], ranks_tr[2])) |
| self.assertEqual(factors[2].shape, (ranks_tr[2], shape[2], ranks_tr[0])) |
|
|
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tr_to_tensor(np_factors_res) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
| error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch) |
| self.assertAlmostEqual(error.item(), 0.0, delta=1e-4) |
|
|
| def test_tr_decomposition_valid_3d_int_rank(self): |
| """Test TR decomposition on 3D tensor with integer max rank.""" |
| sample_tensor = torch.rand(3, 4, 2, dtype=torch.float32) |
| |
| max_rank = 1 |
|
|
| factors = TensorDecompositionOps.tr_decomposition(sample_tensor, rank=max_rank) |
|
|
| self.assertIsInstance(factors, list) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| |
| for i in range(sample_tensor.ndim): |
| self.assertEqual(factors[i].shape[1], sample_tensor.shape[i]) |
| self.assertLessEqual(factors[i].shape[0], max_rank) |
| self.assertLessEqual(factors[i].shape[2], max_rank) |
| |
| self.assertEqual(factors[-1].shape[2], factors[0].shape[0]) |
|
|
|
|
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tr_to_tensor(np_factors_res) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
| error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor) |
| self.assertLess(error.item(), 0.8) |
|
|
| def test_tr_decomposition_valid_matrix_list_rank(self): |
| """Test TR decomposition on a 2D matrix with a list rank.""" |
| shape = (5, 6) |
| |
| ranks_tr = [1, 2] |
|
|
| |
| true_factors_np = [ |
| np.random.rand(ranks_tr[0], shape[0], ranks_tr[1]).astype(np.float32), |
| np.random.rand(ranks_tr[1], shape[1], ranks_tr[0]).astype(np.float32), |
| ] |
| low_rank_tensor_tl = tl.tr_to_tensor(true_factors_np) |
| low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float() |
|
|
| factors = TensorDecompositionOps.tr_decomposition(low_rank_tensor_torch, rank=ranks_tr) |
|
|
| self.assertIsInstance(factors, list) |
| self.assertEqual(len(factors), low_rank_tensor_torch.ndim) |
| self.assertEqual(factors[0].shape, (ranks_tr[0], shape[0], ranks_tr[1])) |
| self.assertEqual(factors[1].shape, (ranks_tr[1], shape[1], ranks_tr[0])) |
|
|
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed_tl_tensor = tl.tr_to_tensor(np_factors_res) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
| error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch) |
| self.assertAlmostEqual(error.item(), 0.0, delta=1e-4) |
|
|
| def test_tr_decomposition_invalid_rank_type(self): |
| sample_tensor = torch.rand(3,4,5).float() |
| with self.assertRaisesRegex(TypeError, "Rank must be an int or a list of ints"): |
| TensorDecompositionOps.tr_decomposition(sample_tensor, rank="invalid_type") |
|
|
| def test_tr_decomposition_invalid_rank_list_length(self): |
| sample_tensor = torch.rand(3,4,5).float() |
| invalid_ranks = [2,3] |
| with self.assertRaisesRegex(ValueError, "If rank is a list, its length must be equal to tensor.ndim"): |
| TensorDecompositionOps.tr_decomposition(sample_tensor, rank=invalid_ranks) |
|
|
| def test_tr_decomposition_invalid_rank_list_values(self): |
| sample_tensor = torch.rand(3,4,5).float() |
| invalid_ranks = [2, 0, 2] |
| with self.assertRaisesRegex(ValueError, "All ranks in the list must be positive integers"): |
| TensorDecompositionOps.tr_decomposition(sample_tensor, rank=invalid_ranks) |
|
|
| def test_tr_decomposition_invalid_rank_int_value(self): |
| sample_tensor = torch.rand(3,4,5).float() |
| invalid_rank = 0 |
| with self.assertRaisesRegex(ValueError, "If rank is an integer, it must be positive"): |
| TensorDecompositionOps.tr_decomposition(sample_tensor, rank=invalid_rank) |
|
|
| def test_tr_decomposition_invalid_tensor_ndim0(self): |
| scalar_tensor = torch.tensor(1.0).float() |
| with self.assertRaisesRegex(ValueError, "TR decomposition requires a tensor with at least 1 dimension, but got 0."): |
| TensorDecompositionOps.tr_decomposition(scalar_tensor, rank=1) |
|
|
| def test_tr_decomposition_type_error_tensor(self): |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.tr_decomposition("not a tensor", rank=1) |
|
|
| |
|
|
| @unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available") |
| def test_ht_decomposition_valid_4d(self): |
| """Test HT decomposition on a 4D tensor.""" |
| shape = (2, 3, 2, 2) |
| ndim = len(shape) |
| sample_tensor = torch.rand(shape).float() |
|
|
| dim_tree = htensor.DimensionTree(ndim) |
| |
| |
| ht_ranks = {node_id: 2 for node_id in range(1, dim_tree.max_node_id + 1)} |
|
|
| ht_object = TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree, ht_ranks) |
| self.assertIsInstance(ht_object, htensor.HTensor) |
|
|
| reconstructed_np = ht_object.to_tensor() |
| reconstructed_torch = torch.from_numpy(reconstructed_np).type(sample_tensor.dtype) |
| error = torch.norm(sample_tensor - reconstructed_torch) / torch.norm(sample_tensor) |
| self.assertLess(error.item(), 0.8) |
|
|
| @unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available") |
| def test_ht_decomposition_invalid_dim_tree_mismatch(self): |
| """Test HT decomposition with mismatched tensor and dimension tree.""" |
| sample_tensor = torch.rand(2,2,2,2).float() |
| dim_tree_wrong = htensor.DimensionTree(3) |
| ht_ranks = {node_id: 2 for node_id in range(1, dim_tree_wrong.max_node_id + 1)} |
| with self.assertRaisesRegex(ValueError, "Dimension tree number of dimensions .* must match tensor dimensionality"): |
| TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree_wrong, ht_ranks) |
|
|
| @unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available") |
| def test_ht_decomposition_invalid_dim_tree_type(self): |
| """Test HT decomposition with invalid dim_tree type.""" |
| sample_tensor = torch.rand(2,2).float() |
| invalid_dim_tree = "not_a_dim_tree" |
| |
| ht_ranks = {1:1, 2:1, 3:1} |
| with self.assertRaisesRegex(TypeError, "dim_tree must be an htensor.DimensionTree"): |
| TensorDecompositionOps.ht_decomposition(sample_tensor, invalid_dim_tree, ht_ranks) |
|
|
| @unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available") |
| def test_ht_decomposition_invalid_ranks_type(self): |
| """Test HT decomposition with invalid ranks type.""" |
| sample_tensor = torch.rand(2,2).float() |
| dim_tree = htensor.DimensionTree(2) |
| invalid_ranks = "not_a_dict" |
| with self.assertRaisesRegex(TypeError, "ranks must be a dict"): |
| TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree, invalid_ranks) |
|
|
| @unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available") |
| def test_ht_decomposition_invalid_ranks_content_type(self): |
| """Test HT decomposition with invalid content type in ranks dict.""" |
| sample_tensor = torch.rand(2,2).float() |
| dim_tree = htensor.DimensionTree(2) |
| invalid_ranks = {1: 2, 2: "not_an_int", 3: 2} |
| with self.assertRaisesRegex(ValueError, "ranks dictionary must have integer keys and positive integer values"): |
| TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree, invalid_ranks) |
|
|
| @unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available") |
| def test_ht_decomposition_invalid_ranks_content_value(self): |
| """Test HT decomposition with non-positive rank value in ranks dict.""" |
| sample_tensor = torch.rand(2,2).float() |
| dim_tree = htensor.DimensionTree(2) |
| invalid_ranks = {1: 2, 2: 0, 3: 2} |
| with self.assertRaisesRegex(ValueError, "ranks dictionary must have integer keys and positive integer values"): |
| TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree, invalid_ranks) |
|
|
| @unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available") |
| def test_ht_decomposition_invalid_tensor_ndim0(self): |
| """Test HT decomposition with a 0-dimensional tensor.""" |
| scalar_tensor = torch.tensor(1.0).float() |
| |
| dim_tree = htensor.DimensionTree(1) |
| ht_ranks = {1:1} |
| with self.assertRaisesRegex(ValueError, "HT decomposition requires a tensor with at least 1 dimension"): |
| TensorDecompositionOps.ht_decomposition(scalar_tensor, dim_tree, ht_ranks) |
|
|
| @unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available") |
| def test_ht_decomposition_type_error_tensor(self): |
| """Test HT decomposition with non-tensor input.""" |
| dim_tree = htensor.DimensionTree(2) |
| ht_ranks = {1:1, 2:1, 3:1} |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.ht_decomposition("not a tensor", dim_tree, ht_ranks) |
|
|
| |
|
|
| def test_btd_decomposition_valid_structure(self): |
| """Test BTD decomposition returns cores and factors with expected shapes.""" |
| sample_tensor = torch.rand(6, 7, 8).float() |
| ranks_per_term = [(2, 2, 2), (1, 3, 2)] |
|
|
| terms = TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term) |
|
|
| self.assertIsInstance(terms, list) |
| self.assertEqual(len(terms), len(ranks_per_term)) |
|
|
| for term, ranks in zip(terms, ranks_per_term): |
| core, factors = term |
| self.assertIsInstance(core, torch.Tensor) |
| self.assertEqual(core.shape, ranks) |
| self.assertIsInstance(factors, list) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
| self.assertEqual(factors[0].shape, (sample_tensor.shape[0], ranks[0])) |
| self.assertEqual(factors[1].shape, (sample_tensor.shape[1], ranks[1])) |
| self.assertEqual(factors[2].shape, (sample_tensor.shape[2], ranks[2])) |
|
|
| |
| reconstructed = torch.zeros_like(sample_tensor) |
| for core, factors in terms: |
| np_core = core.numpy() |
| np_factors = [f.numpy() for f in factors] |
| reconstructed += torch.from_numpy(tucker_to_tensor((np_core, np_factors))).float() |
| error = torch.norm(sample_tensor - reconstructed) / torch.norm(sample_tensor) |
| self.assertLess(error.item(), 0.9) |
|
|
| def test_btd_decomposition_invalid_tensor_ndim(self): |
| """Test BTD with non-3D tensor.""" |
| sample_tensor_2d = torch.rand(6, 7).float() |
| sample_tensor_4d = torch.rand(3,4,5,6).float() |
| ranks_per_term = [(2, 2, 2)] |
| with self.assertRaisesRegex(ValueError, "BTD as sum of Tucker-1 terms is typically for 3-way tensors"): |
| TensorDecompositionOps.btd_decomposition(sample_tensor_2d, ranks_per_term) |
| with self.assertRaisesRegex(ValueError, "BTD as sum of Tucker-1 terms is typically for 3-way tensors"): |
| TensorDecompositionOps.btd_decomposition(sample_tensor_4d, ranks_per_term) |
|
|
| def test_btd_decomposition_invalid_ranks_type(self): |
| """Test BTD with invalid type for ranks_per_term.""" |
| sample_tensor = torch.rand(6, 7, 8).float() |
| with self.assertRaisesRegex(TypeError, "ranks_per_term must be a list of tuples"): |
| TensorDecompositionOps.btd_decomposition(sample_tensor, "not_a_list") |
|
|
| def test_btd_decomposition_empty_ranks_list(self): |
| """Test BTD with empty ranks_per_term list.""" |
| sample_tensor = torch.rand(6, 7, 8).float() |
| with self.assertRaisesRegex(ValueError, "ranks_per_term list cannot be empty"): |
| TensorDecompositionOps.btd_decomposition(sample_tensor, []) |
|
|
| def test_btd_decomposition_invalid_term_rank_type(self): |
| """Test BTD with invalid type for a term's rank tuple.""" |
| sample_tensor = torch.rand(6, 7, 8).float() |
| ranks_per_term = [(2,2,2), "not_a_tuple"] |
| with self.assertRaisesRegex(ValueError, "Each element in ranks_per_term must be a tuple of 3 positive integers"): |
| TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term) |
|
|
| def test_btd_decomposition_invalid_term_rank_length(self): |
| """Test BTD with incorrect number of ranks in a term's tuple.""" |
| sample_tensor = torch.rand(6, 7, 8).float() |
| ranks_per_term = [(2,2,2), (3,3)] |
| with self.assertRaisesRegex(ValueError, "Each element in ranks_per_term must be a tuple of 3 positive integers"): |
| TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term) |
|
|
| def test_btd_decomposition_invalid_term_rank_value(self): |
| """Test BTD with non-positive rank in a term's tuple.""" |
| sample_tensor = torch.rand(6, 7, 8).float() |
| ranks_per_term = [(2,2,2), (3,0,3)] |
| with self.assertRaisesRegex(ValueError, "Each element in ranks_per_term must be a tuple of 3 positive integers"): |
| TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term) |
|
|
| def test_btd_decomposition_rank_exceeds_dim(self): |
| """Test BTD with term rank exceeding tensor dimension.""" |
| sample_tensor = torch.rand(3, 4, 5).float() |
| ranks_per_term = [(2,2,2), (4,3,3)] |
| with self.assertRaisesRegex(ValueError, "Ranks for term .* exceed tensor dimensions"): |
| TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term) |
|
|
| def test_btd_decomposition_type_error_tensor(self): |
| """Test BTD with non-tensor input.""" |
| ranks_per_term = [(2,2,2)] |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.btd_decomposition("not a tensor", ranks_per_term) |
|
|
| |
|
|
| def test_ntf_cp_decomposition_valid(self): |
| """Test NTF-CP decomposition with a random non-negative tensor.""" |
| sample_tensor = torch.rand(3, 4, 5).float() |
| rank = 2 |
| weights, factors = TensorDecompositionOps.ntf_cp_decomposition(sample_tensor, rank) |
|
|
| self.assertIsInstance(weights, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors)) |
|
|
| self.assertEqual(weights.ndim, 1) |
| self.assertEqual(weights.size(0), rank) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| for i in range(sample_tensor.ndim): |
| self.assertEqual(factors[i].shape, (sample_tensor.shape[i], rank)) |
|
|
| self.assertTrue(torch.all(weights >= -1e-6)) |
| for f in factors: |
| self.assertTrue(torch.all(f >= -1e-6)) |
|
|
| np_weights = weights.numpy() |
| np_factors = [f.numpy() for f in factors] |
| reconstructed_tl_tensor = tl.cp_to_tensor((np_weights, np_factors)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
| error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor) |
| self.assertLess(error.item(), 0.8) |
|
|
| def test_ntf_cp_decomposition_known_non_negative_low_rank(self): |
| """Test NTF-CP with a known low-rank non-negative tensor.""" |
| true_rank = 2 |
| shape = (3,4,5) |
| true_weights_np = np.random.rand(true_rank).astype(np.float32) |
| true_factors_np = [np.abs(np.random.rand(s, true_rank).astype(np.float32)) for s in shape] |
|
|
| |
| low_rank_nn_tensor_tl = tl.cp_to_tensor((true_weights_np, true_factors_np)) |
| low_rank_nn_tensor_torch = torch.from_numpy(low_rank_nn_tensor_tl).float().abs() |
|
|
| weights, factors = TensorDecompositionOps.ntf_cp_decomposition(low_rank_nn_tensor_torch, true_rank) |
|
|
| self.assertTrue(torch.all(weights >= -1e-6)) |
| for f in factors: |
| self.assertTrue(torch.all(f >= -1e-6)) |
|
|
| np_weights = weights.numpy() |
| np_factors = [f.numpy() for f in factors] |
| reconstructed_tl_tensor = tl.cp_to_tensor((np_weights, np_factors)) |
| reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float() |
| error = torch.norm(low_rank_nn_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_nn_tensor_torch) |
| self.assertLess(error.item(), 0.3) |
|
|
| def test_ntf_cp_decomposition_input_has_negative_values(self): |
| """Test NTF-CP with a tensor containing negative values.""" |
| negative_tensor = torch.tensor([[[1.0, -0.1, 2.0]]], dtype=torch.float32) |
| rank = 1 |
| with self.assertRaisesRegex(ValueError, "Input tensor for NTF-CP must be non-negative"): |
| TensorDecompositionOps.ntf_cp_decomposition(negative_tensor, rank) |
|
|
| def test_ntf_cp_decomposition_invalid_rank(self): |
| """Test NTF-CP with invalid rank.""" |
| sample_tensor = torch.rand(2,2,2).float() |
| with self.assertRaisesRegex(ValueError, "Rank must be a positive integer"): |
| TensorDecompositionOps.ntf_cp_decomposition(sample_tensor, 0) |
|
|
| def test_ntf_cp_decomposition_invalid_tensor_ndim(self): |
| """Test NTF-CP with tensor of invalid number of dimensions.""" |
| one_d_tensor = torch.rand(5).float() |
| with self.assertRaisesRegex(ValueError, "NTF-CP decomposition requires a tensor with at least 2 dimensions"): |
| TensorDecompositionOps.ntf_cp_decomposition(one_d_tensor, 2) |
|
|
| def test_ntf_cp_decomposition_type_error_tensor(self): |
| """Test NTF-CP with non-tensor input.""" |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.ntf_cp_decomposition("not a tensor", 2) |
|
|
| |
|
|
| def test_non_negative_tucker_valid(self): |
| sample_tensor = torch.rand(3, 4, 5).float() |
| ranks = [2, 3, 2] |
| core, factors = TensorDecompositionOps.non_negative_tucker(sample_tensor, ranks) |
|
|
| self.assertIsInstance(core, torch.Tensor) |
| self.assertIsInstance(factors, list) |
| self.assertEqual(core.shape, tuple(ranks)) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| for i in range(sample_tensor.ndim): |
| self.assertEqual(factors[i].shape, (sample_tensor.shape[i], ranks[i])) |
| self.assertTrue(torch.all(factors[i] >= -1e-6)) |
| self.assertTrue(torch.all(core >= -1e-6)) |
|
|
| np_core = core.numpy() |
| np_factors = [f.numpy() for f in factors] |
| reconstructed = tl.tucker_to_tensor((np_core, np_factors)) |
| recon_torch = torch.from_numpy(reconstructed).float() |
| error = torch.norm(sample_tensor - recon_torch) / torch.norm(sample_tensor) |
| self.assertLess(error.item(), 0.8) |
|
|
| def test_non_negative_tucker_negative_input(self): |
| tensor = torch.tensor([[[-1.0, 0.5]]]) |
| with self.assertRaisesRegex(ValueError, "non-negative"): |
| TensorDecompositionOps.non_negative_tucker(tensor, [1,1,1]) |
|
|
| |
|
|
| def test_partial_tucker_valid(self): |
| sample_tensor = torch.rand(3, 4, 5).float() |
| ranks = [2, 3, 2] |
| core, factors = TensorDecompositionOps.partial_tucker(sample_tensor, ranks) |
|
|
| self.assertEqual(core.shape, tuple(ranks)) |
| self.assertEqual(len(factors), sample_tensor.ndim) |
| for i in range(sample_tensor.ndim): |
| self.assertEqual(factors[i].shape, (sample_tensor.shape[i], ranks[i])) |
|
|
| np_core = core.numpy() |
| np_factors = [f.numpy() for f in factors] |
| reconstructed = tl.tucker_to_tensor((np_core, np_factors)) |
| recon_torch = torch.from_numpy(reconstructed).float() |
| error = torch.norm(sample_tensor - recon_torch) / torch.norm(sample_tensor) |
| self.assertLess(error.item(), 0.7) |
|
|
| def test_partial_tucker_invalid_rank_length(self): |
| tensor = torch.rand(3, 4, 5).float() |
| with self.assertRaisesRegex(ValueError, "Length of ranks list"): |
| TensorDecompositionOps.partial_tucker(tensor, [2, 2]) |
|
|
| |
|
|
| def test_tt_svd_valid_low_rank(self): |
| shape = (3, 4, 5) |
| internal_ranks = [2, 3] |
| full_ranks = [1] + internal_ranks + [1] |
| true_factors_np = [ |
| np.random.rand(full_ranks[0], shape[0], full_ranks[1]).astype(np.float32), |
| np.random.rand(full_ranks[1], shape[1], full_ranks[2]).astype(np.float32), |
| np.random.rand(full_ranks[2], shape[2], full_ranks[3]).astype(np.float32), |
| ] |
| tensor = torch.from_numpy(tl.tt_to_tensor(true_factors_np)).float() |
|
|
| factors = TensorDecompositionOps.tt_svd(tensor, internal_ranks) |
|
|
| self.assertEqual(len(factors), tensor.ndim) |
| for i in range(len(factors)): |
| self.assertEqual(factors[i].shape, (full_ranks[i], shape[i], full_ranks[i+1])) |
|
|
| np_factors_res = [f.detach().cpu().numpy() for f in factors] |
| reconstructed = tl.tt_to_tensor(np_factors_res) |
| recon_torch = torch.from_numpy(reconstructed).float() |
| error = torch.norm(tensor - recon_torch) / torch.norm(tensor) |
| self.assertAlmostEqual(error.item(), 0.0, delta=1e-4) |
|
|
| def test_tt_svd_invalid_rank_type(self): |
| tensor = torch.rand(3, 4, 5).float() |
| with self.assertRaisesRegex(TypeError, "Rank must be an int or a list of ints"): |
| TensorDecompositionOps.tt_svd(tensor, "bad") |
|
|
| |
|
|
| def test_t_product_valid(self): |
| """Test _t_product with valid 3-way tensors.""" |
| A_torch = torch.rand(3, 2, 4).float() |
| B_torch = torch.rand(2, 3, 4).float() |
| C_torch = TensorDecompositionOps._t_product(A_torch, B_torch) |
|
|
| self.assertIsInstance(C_torch, torch.Tensor) |
| self.assertEqual(C_torch.shape, (A_torch.shape[0], B_torch.shape[1], A_torch.shape[2])) |
| self.assertEqual(C_torch.dtype, A_torch.dtype) |
|
|
| |
| A_np = A_torch.numpy() |
| B_np = B_torch.numpy() |
| C_np = C_torch.numpy() |
|
|
| A_fft_slice0 = fft(A_np, axis=2)[:,:,0] |
| B_fft_slice0 = fft(B_np, axis=2)[:,:,0] |
| C_fft_expected_slice0 = A_fft_slice0 @ B_fft_slice0 |
|
|
| C_fft_actual_slice0 = fft(C_np, axis=2)[:,:,0] |
| self.assertTrue(np.allclose(C_fft_actual_slice0, C_fft_expected_slice0, atol=1e-5)) |
|
|
| def test_t_product_invalid_ndim(self): |
| """Test _t_product with non-3-way tensors.""" |
| A_2d = torch.rand(3,2).float() |
| B_3d = torch.rand(2,3,4).float() |
| with self.assertRaisesRegex(ValueError, "t-product is defined for 3-way tensors"): |
| TensorDecompositionOps._t_product(A_2d, B_3d) |
| with self.assertRaisesRegex(ValueError, "t-product is defined for 3-way tensors"): |
| TensorDecompositionOps._t_product(B_3d, A_2d) |
|
|
| def test_t_product_shape_mismatch(self): |
| """Test _t_product with incompatible inner dimensions.""" |
| A = torch.rand(3,2,4).float() |
| B_wrong_shape = torch.rand(3,3,4).float() |
| |
| with self.assertRaises(ValueError): |
| TensorDecompositionOps._t_product(A, B_wrong_shape) |
|
|
| def test_t_product_tube_shape_mismatch(self): |
| """Test _t_product with mismatched third dimensions (tubes).""" |
| A = torch.rand(3,2,4).float() |
| B_wrong_tubes = torch.rand(2,3,5).float() |
| with self.assertRaisesRegex(ValueError, "Third dimensions .* for t-product must match"): |
| TensorDecompositionOps._t_product(A, B_wrong_tubes) |
|
|
| def test_t_svd_valid_reconstruction(self): |
| """Test t-SVD decomposition and reconstruction.""" |
| X_torch = torch.rand(5, 4, 3).float() |
|
|
| U_torch, S_torch, V_torch = TensorDecompositionOps.t_svd(X_torch) |
|
|
| self.assertIsInstance(U_torch, torch.Tensor) |
| self.assertIsInstance(S_torch, torch.Tensor) |
| self.assertIsInstance(V_torch, torch.Tensor) |
| self.assertEqual(U_torch.dtype, X_torch.dtype) |
| self.assertEqual(S_torch.dtype, X_torch.dtype) |
| self.assertEqual(V_torch.dtype, X_torch.dtype) |
|
|
| |
| n1, n2, n3 = X_torch.shape |
| self.assertEqual(U_torch.shape, (n1, n1, n3)) |
| self.assertEqual(S_torch.shape, (n1, n2, n3)) |
| self.assertEqual(V_torch.shape, (n2, n2, n3)) |
|
|
| |
| |
| |
| Vh_torch = torch.permute(V_torch, (1, 0, 2)) |
|
|
| temp = TensorDecompositionOps._t_product(U_torch, S_torch) |
| X_reconstructed = TensorDecompositionOps._t_product(temp, Vh_torch) |
|
|
| error = torch.norm(X_torch - X_reconstructed) / torch.norm(X_torch) |
| self.assertLess(error.item(), 0.8) |
|
|
| def test_t_svd_properties(self): |
| """Test properties of t-SVD factors (orthogonality, f-diagonal).""" |
| X_torch = torch.rand(5, 4, 3).float() |
| U_torch, S_torch, V_torch = TensorDecompositionOps.t_svd(X_torch) |
|
|
| |
| Uh_torch = torch.permute(U_torch, (1,0,2)) |
| UUh = TensorDecompositionOps._t_product(Uh_torch, U_torch) |
| I_U_expected = torch.zeros_like(UUh) |
| for k in range(UUh.shape[2]): |
| I_U_expected[:,:,k] = torch.eye(UUh.shape[0], dtype=UUh.dtype) |
| self.assertTrue(torch.allclose(UUh, I_U_expected, atol=2.0)) |
|
|
| |
| Vh_torch = torch.permute(V_torch, (1,0,2)) |
| VVh = TensorDecompositionOps._t_product(Vh_torch, V_torch) |
| I_V_expected = torch.zeros_like(VVh) |
| for k in range(VVh.shape[2]): |
| I_V_expected[:,:,k] = torch.eye(VVh.shape[0], dtype=VVh.dtype) |
| self.assertTrue(torch.allclose(VVh, I_V_expected, atol=2.0)) |
|
|
| for k in range(S_torch.shape[2]): |
| S_slice = S_torch[:, :, k] |
| min_dim = min(S_slice.shape) |
| diag_S_slice = torch.diag(torch.diag(S_slice)[:min_dim]) |
| self.assertTrue(torch.allclose(S_slice[:min_dim, :min_dim], diag_S_slice, atol=2.0)) |
|
|
| def test_t_svd_invalid_ndim(self): |
| """Test t-SVD with non-3-way tensor.""" |
| X_2d = torch.rand(3,2).float() |
| with self.assertRaisesRegex(ValueError, "t-SVD is defined for 3-way tensors"): |
| TensorDecompositionOps.t_svd(X_2d) |
|
|
| def test_t_svd_type_error_tensor(self): |
| """Test t-SVD with non-tensor input.""" |
| with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"): |
| TensorDecompositionOps.t_svd("not a tensor") |
|
|
| if __name__ == '__main__': |
| unittest.main() |
|
|