core / tests /test_tensor_ops.py
tensorus's picture
Upload 83 files
edfa748 verified
import unittest
import pytest
pytest.importorskip("torch")
pytest.importorskip("tensorly")
import torch
import numpy as np
import tensorly as tl
from tensorly.cp_tensor import cp_to_tensor
from tensorly.tucker_tensor import tucker_to_tensor
from tensorly.tt_tensor import tt_to_tensor
from tensorly.tr_tensor import tr_to_tensor
from typing import List, Tuple, Union, Dict # Added Dict for HT
import sys
import os
try:
import htensor
HTENSOR_AVAILABLE = True
except ImportError:
HTENSOR_AVAILABLE = False
from scipy.fft import fft, ifft # For t-SVD test helpers
# Add the root directory to sys.path to allow importing tensor_ops
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # No longer needed
from tensorus.tensor_ops import TensorOps
from tensorus.tensor_decompositions import TensorDecompositionOps # Added import
class TestTensorOps(unittest.TestCase):
# --- Test Arithmetic Operations ---
def test_add_tensor_tensor(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
t2 = torch.tensor([[5., 6.], [7., 8.]])
expected = torch.tensor([[6., 8.], [10., 12.]])
result = TensorOps.add(t1, t2)
self.assertTrue(torch.equal(result, expected))
def test_add_tensor_scalar(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
scalar = 10.
expected = torch.tensor([[11., 12.], [13., 14.]])
result = TensorOps.add(t1, scalar)
self.assertTrue(torch.equal(result, expected))
def test_add_type_error(self):
t1 = torch.tensor([1., 2.])
with self.assertRaises(TypeError):
TensorOps.add(t1, "not_a_tensor_or_scalar") # type: ignore
def test_subtract_tensor_tensor(self):
t1 = torch.tensor([[5., 6.], [7., 8.]])
t2 = torch.tensor([[1., 2.], [3., 4.]])
expected = torch.tensor([[4., 4.], [4., 4.]])
result = TensorOps.subtract(t1, t2)
self.assertTrue(torch.equal(result, expected))
def test_subtract_tensor_scalar(self):
t1 = torch.tensor([[5., 6.], [7., 8.]])
expected = torch.tensor([[4., 5.], [6., 7.]])
result = TensorOps.subtract(t1, 1.0)
self.assertTrue(torch.equal(result, expected))
def test_multiply_tensor_tensor(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
t2 = torch.tensor([[2., 2.], [2., 2.]])
expected = torch.tensor([[2., 4.], [6., 8.]])
result = TensorOps.multiply(t1, t2)
self.assertTrue(torch.equal(result, expected))
def test_multiply_tensor_scalar(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
expected = torch.tensor([[2., 4.], [6., 8.]])
result = TensorOps.multiply(t1, 2.0)
self.assertTrue(torch.equal(result, expected))
def test_divide_tensor_tensor(self):
t1 = torch.tensor([[10., 20.], [30., 40.]])
t2 = torch.tensor([[2., 5.], [3., 4.]])
expected = torch.tensor([[5., 4.], [10., 10.]])
result = TensorOps.divide(t1, t2)
self.assertTrue(torch.equal(result, expected))
def test_divide_tensor_scalar(self):
t1 = torch.tensor([[10., 20.], [30., 40.]])
scalar = 10.
expected = torch.tensor([[1., 2.], [3., 4.]])
result = TensorOps.divide(t1, scalar)
self.assertTrue(torch.equal(result, expected))
def test_divide_by_zero_scalar(self):
t1 = torch.tensor([[10., 20.], [30., 40.]])
scalar_zero = 0
with self.assertRaises(ValueError): # As per TensorOps.divide implementation
TensorOps.divide(t1, scalar_zero)
def test_divide_by_zero_tensor(self):
t1 = torch.tensor([[10., 20.], [30., 40.]])
t_zero = torch.tensor([[1., 0.], [3., 1.]])
# TensorOps.divide logs a warning and returns inf/nan from torch.divide
# We expect torch.divide's behavior.
expected_output = torch.divide(t1, t_zero) # This will have inf
result = TensorOps.divide(t1, t_zero)
self.assertTrue(torch.equal(result, expected_output))
# Consider capturing logs here if strict warning check is needed.
# --- Test Matrix and Dot Operations ---
def test_matmul_valid(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
t2 = torch.tensor([[5., 6.], [7., 8.]])
expected = torch.matmul(t1, t2)
result = TensorOps.matmul(t1, t2)
self.assertTrue(torch.equal(result, expected))
def test_matmul_shape_mismatch(self):
t1 = torch.tensor([[1., 2.], [3., 4.]]) # 2x2
t2_wrong_shape = torch.tensor([[5., 6., 7.], [8., 9., 10.]]) # 2x3, but matmul t1@t2 needs t2 to be 2xN
# This specific case is fine, t1.shape[1] == t2.shape[0] is not met for t1@t2_wrong_shape
# if t2_wrong_shape = torch.tensor([[1.,2.],[3.,4.],[5.,6.]]) # 3x2, this would fail
t2_fail = torch.tensor([[1.,2.,3.],[4.,5.,6.],[7.,8.,9.]]) # 3x3, this would fail for 2x2 @ 3x3
with self.assertRaises(ValueError): # As per TensorOps.matmul specific check for 2D
TensorOps.matmul(t1, t2_fail)
def test_matmul_ndim_error(self):
t1 = torch.tensor(1.) # 0-dim
t2 = torch.tensor([1.,2.]) # 1-dim
with self.assertRaises(ValueError): # As per TensorOps.matmul ndim check
TensorOps.matmul(t1,t2)
def test_outer_valid(self):
t1 = torch.tensor([1., 2.])
t2 = torch.tensor([3., 4., 5.])
expected = torch.outer(t1, t2)
result = TensorOps.outer(t1, t2)
self.assertTrue(torch.equal(result, expected))
def test_outer_invalid_rank(self):
t1 = torch.tensor([[1., 2.]])
t2 = torch.tensor([1., 2.])
with self.assertRaises(ValueError):
TensorOps.outer(t1, t2)
def test_cross_valid(self):
t1 = torch.tensor([1., 0., 0.])
t2 = torch.tensor([0., 1., 0.])
expected = torch.cross(t1, t2, dim=0)
result = TensorOps.cross(t1, t2)
self.assertTrue(torch.equal(result, expected))
def test_cross_invalid_dim_size(self):
t1 = torch.tensor([1., 2., 3., 4.])
t2 = torch.tensor([4., 5., 6., 7.])
with self.assertRaises(ValueError):
TensorOps.cross(t1, t2, dim=0)
def test_dot_valid(self):
t1 = torch.tensor([1., 2., 3.])
t2 = torch.tensor([4., 5., 6.])
expected = torch.dot(t1, t2)
result = TensorOps.dot(t1, t2)
self.assertTrue(torch.equal(result, expected))
def test_dot_shape_mismatch(self):
t1 = torch.tensor([1., 2., 3.])
t2 = torch.tensor([1., 2.])
with self.assertRaises(ValueError):
TensorOps.dot(t1, t2)
def test_dot_invalid_rank(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
t2 = torch.tensor([1., 2.])
with self.assertRaises(ValueError):
TensorOps.dot(t1, t2)
# --- Test Reduction Operations ---
def test_sum_all_elements(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
expected = torch.tensor(10.)
result = TensorOps.sum(t1)
self.assertTrue(torch.equal(result, expected))
def test_sum_along_dimension(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
# Sum along dim 0
expected_dim0 = torch.tensor([4., 6.])
result_dim0 = TensorOps.sum(t1, dim=0)
self.assertTrue(torch.equal(result_dim0, expected_dim0))
# Sum along dim 1
expected_dim1 = torch.tensor([3., 7.])
result_dim1 = TensorOps.sum(t1, dim=1)
self.assertTrue(torch.equal(result_dim1, expected_dim1))
def test_sum_keepdim(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
expected_dim0_keepdim = torch.tensor([[4., 6.]])
result_dim0_keepdim = TensorOps.sum(t1, dim=0, keepdim=True)
self.assertTrue(torch.equal(result_dim0_keepdim, expected_dim0_keepdim))
def test_mean_operations(self):
t = torch.tensor([[1., 2.], [3., 4.]])
self.assertTrue(torch.allclose(TensorOps.mean(t), torch.mean(t)))
self.assertTrue(torch.allclose(TensorOps.mean(t, dim=0), torch.mean(t, dim=0)))
def test_min_and_max(self):
t = torch.tensor([[1., 3.], [2., 0.]])
val, idx = TensorOps.min(t, dim=1)
expected_val, expected_idx = torch.min(t, dim=1)
self.assertTrue(torch.equal(val, expected_val))
self.assertTrue(torch.equal(idx, expected_idx))
val, idx = TensorOps.max(t, dim=0)
expected_val, expected_idx = torch.max(t, dim=0)
self.assertTrue(torch.equal(val, expected_val))
self.assertTrue(torch.equal(idx, expected_idx))
# --- Existing Power and Log tests follow ---
def test_power_scalar_exponent(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
exponent = 2.0
expected = torch.tensor([[1., 4.], [9., 16.]])
result = TensorOps.power(t1, exponent)
self.assertTrue(torch.equal(result, expected))
t2 = torch.tensor([1, 2, 3])
exponent_int = 3
expected_int = torch.tensor([1, 8, 27])
result_int = TensorOps.power(t2, exponent_int)
self.assertTrue(torch.equal(result_int, expected_int.float())) # torch.pow promotes to float
def test_power_tensor_exponent(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
t_exponent = torch.tensor([[2., 3.], [1., 2.]])
expected = torch.tensor([[1., 8.], [3., 16.]])
result = TensorOps.power(t1, t_exponent)
self.assertTrue(torch.equal(result, expected))
def test_power_type_error(self):
t1 = torch.tensor([1., 2.])
with self.assertRaises(TypeError):
TensorOps.power(t1, "not_a_number_or_tensor") # type: ignore
with self.assertRaises(TypeError):
TensorOps.power("not_a_tensor", 2.0) # type: ignore
def test_power_runtime_error_shape_mismatch(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
t_exponent_wrong_shape = torch.tensor([2., 3.]) # Shape mismatch for element-wise
expected = torch.pow(t1, t_exponent_wrong_shape)
result = TensorOps.power(t1, t_exponent_wrong_shape)
self.assertTrue(torch.equal(result, expected))
def test_log_valid_inputs(self):
t1 = torch.tensor([[1., 2.], [3., 4.]])
expected = torch.log(t1) # Use torch.log directly for expected value
result = TensorOps.log(t1)
self.assertTrue(torch.equal(result, expected))
t2 = torch.tensor([10., 20., 30.])
expected2 = torch.log(t2)
result2 = TensorOps.log(t2)
self.assertTrue(torch.equal(result2, expected2))
def test_log_non_positive_inputs(self):
t_with_zero = torch.tensor([1., 0., 3.])
# Expect NaN for log(0) and -inf for log(negative)
# torch.log(0) is -inf
# torch.log(-1) is nan
expected_zero = torch.log(t_with_zero) # Let torch.log define the exact output (-inf, nan)
# We are primarily testing that our TensorOps.log runs and produces what torch.log would.
# The warning for non-positive values is logged, not asserted in output here.
# We could capture warnings if needed, but for now, let's check output.
result_zero = TensorOps.log(t_with_zero)
self.assertTrue(torch.allclose(result_zero, expected_zero, equal_nan=True))
t_with_negative = torch.tensor([1., -2., 3.])
expected_negative = torch.log(t_with_negative)
result_negative = TensorOps.log(t_with_negative)
self.assertTrue(torch.allclose(result_negative, expected_negative, equal_nan=True))
def test_log_type_error(self):
with self.assertRaises(TypeError):
TensorOps.log("not_a_tensor") # type: ignore
# --- Additional Operations ---
def test_compute_gradient(self):
x = torch.tensor(2.0, requires_grad=True)
def f(t):
return t * t
grad = TensorOps.compute_gradient(f, x)
self.assertTrue(torch.allclose(grad, torch.tensor(4.0)))
def test_compute_jacobian(self):
x = torch.tensor([1.0, 2.0])
def f(t):
return torch.stack([t[0] + t[1], t[0] * t[1]])
jac = TensorOps.compute_jacobian(f, x)
expected = torch.tensor([[1., 1.], [2.0, 1.0]])
self.assertTrue(torch.allclose(jac, expected))
def test_matrix_eigendecomposition(self):
A = torch.tensor([[2., 0.], [0., 3.]])
vals, vecs = TensorOps.matrix_eigendecomposition(A)
self.assertTrue(torch.allclose(torch.sort(vals.real).values, torch.tensor([2., 3.])))
self.assertTrue(torch.allclose(torch.abs(vecs), torch.eye(2)))
def test_matrix_trace_and_tensor_trace(self):
A = torch.tensor([[1., 2.], [3., 4.]])
self.assertEqual(TensorOps.matrix_trace(A).item(), 5.0)
T = torch.arange(24.).reshape(2, 3, 4)
with self.assertRaises(ValueError):
TensorOps.tensor_trace(T, axis1=0, axis2=1)
def test_tensor_trace_valid(self):
T = torch.arange(27.).reshape(3, 3, 3).float()
diag_sum0 = T.diagonal(dim1=0, dim2=1).sum(-1)
result = TensorOps.tensor_trace(T, axis1=0, axis2=1)
self.assertTrue(torch.equal(result, diag_sum0))
def test_svd_reconstruction(self):
A = torch.tensor([[3., 1.], [1., 3.]], dtype=torch.float32)
U, S, Vh = TensorOps.svd(A)
reconstructed = U @ torch.diag(S) @ Vh
self.assertTrue(torch.allclose(reconstructed, A))
def test_qr_reconstruction(self):
A = torch.randn(4, 3)
Q, R = TensorOps.qr_decomposition(A)
self.assertTrue(torch.allclose(Q @ R, A, atol=1e-5, rtol=1e-5))
def test_lu_decomposition(self):
A = torch.tensor([[4., 3.], [6., 3.]], dtype=torch.float32)
P, L, U = TensorOps.lu_decomposition(A)
self.assertTrue(torch.allclose(P @ A, L @ U))
def test_cholesky_valid(self):
B = torch.tensor([[2., 0.], [1., 1.]], dtype=torch.float32)
A = B @ B.t()
L = TensorOps.cholesky_decomposition(A)
self.assertTrue(torch.allclose(L @ L.t(), A))
def test_cholesky_non_symmetric_error(self):
A = torch.tensor([[1., 2.], [3., 4.]], dtype=torch.float32)
with self.assertRaises(ValueError):
TensorOps.cholesky_decomposition(A)
def test_matrix_inverse(self):
A = torch.tensor([[4., 7.], [2., 6.]], dtype=torch.float32)
inv = TensorOps.matrix_inverse(A)
expected_identity = torch.eye(2, dtype=torch.float32)
actual_result = A @ inv
self.assertEqual(inv.dtype, A.dtype)
self.assertTrue(torch.allclose(actual_result, expected_identity))
def test_matrix_inverse_non_square_error(self):
A = torch.randn(2, 3)
with self.assertRaises(ValueError):
TensorOps.matrix_inverse(A)
def test_matrix_determinant_and_rank(self):
A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32)
det = TensorOps.matrix_determinant(A)
rank = TensorOps.matrix_rank(A)
self.assertEqual(det.item(), 0.0)
self.assertEqual(rank.item(), 1)
def test_convolutions(self):
sig = torch.tensor([1., 2., 3.])
ker = torch.tensor([1., 1.])
conv_valid = TensorOps.convolve_1d(sig, ker, mode="valid")
self.assertTrue(torch.allclose(conv_valid, torch.tensor([3., 5.])))
img = torch.tensor([[1., 2.], [3., 4.]])
k = torch.tensor([[1., 0.], [0., 1.]])
conv2d_same = TensorOps.convolve_2d(img, k, mode="same")
self.assertEqual(conv2d_same.shape, torch.Size([3, 3]))
def test_convolve_3d(self):
vol = torch.arange(27.).reshape(3, 3, 3)
ker = torch.ones((2, 2, 2))
expected = torch.nn.functional.conv3d(
vol.unsqueeze(0).unsqueeze(0),
ker.flip(0, 1, 2).unsqueeze(0).unsqueeze(0),
).squeeze(0).squeeze(0)
result = TensorOps.convolve_3d(vol, ker, mode="valid")
self.assertTrue(torch.allclose(result, expected))
ker_same = torch.ones((3, 3, 3))
conv_same = TensorOps.convolve_3d(vol, ker_same, mode="same")
self.assertEqual(conv_same.shape, vol.shape)
def test_statistics(self):
t = torch.tensor([[1., 2.], [3., 4.]])
self.assertTrue(torch.allclose(TensorOps.variance(t), torch.var(t, unbiased=False)))
cov = TensorOps.covariance(t)
import numpy as np
expected_cov = torch.from_numpy(np.cov(t.numpy(), rowvar=True, bias=False)).float()
self.assertTrue(torch.allclose(cov, expected_cov))
corr = TensorOps.correlation(t)
expected_corr = torch.from_numpy(np.corrcoef(t.numpy(), rowvar=True)).float()
self.assertTrue(torch.allclose(corr, expected_corr))
self.assertTrue(torch.allclose(TensorOps.frobenius_norm(t), torch.linalg.norm(t, "fro")))
self.assertTrue(torch.allclose(TensorOps.l1_norm(t), torch.sum(torch.abs(t))))
self.assertTrue(torch.allclose(TensorOps.l2_norm(t), torch.linalg.norm(t, 2)))
self.assertTrue(torch.allclose(TensorOps.p_norm(t, 2), torch.linalg.norm(t, 2)))
m = torch.tensor([[1., 2.], [3., 4.]])
self.assertTrue(torch.allclose(TensorOps.nuclear_norm(m), torch.linalg.matrix_norm(m, ord="nuc")))
with self.assertRaises(ValueError):
TensorOps.nuclear_norm(torch.tensor([1., 2., 3.]))
def test_std_default(self):
t = torch.tensor([[1., 2.], [3., 4.]])
expected = torch.std(t, unbiased=False)
result = TensorOps.std(t)
self.assertTrue(torch.allclose(result, expected))
def test_std_dim_unbiased_keepdim(self):
t = torch.tensor([[1., 2.], [3., 4.]])
expected = torch.std(t, dim=0, unbiased=True, keepdim=True)
result = TensorOps.std(t, dim=0, unbiased=True, keepdim=True)
self.assertTrue(torch.allclose(result, expected))
def test_std_type_error(self):
with self.assertRaises(TypeError):
TensorOps.std("not_a_tensor") # type: ignore
# --- Test Reshaping Operations ---
def test_flatten_default(self):
t = torch.arange(6).reshape(2, 3)
expected = torch.flatten(t)
result = TensorOps.flatten(t)
self.assertTrue(torch.equal(result, expected))
def test_flatten_start_end(self):
t = torch.arange(24).reshape(2, 3, 4)
expected = torch.flatten(t, start_dim=1, end_dim=2)
result = TensorOps.flatten(t, start_dim=1, end_dim=2)
self.assertTrue(torch.equal(result, expected))
self.assertEqual(result.shape, (2, 12))
def test_squeeze_default(self):
t = torch.zeros(1, 3, 1, 4)
expected = torch.squeeze(t)
result = TensorOps.squeeze(t)
self.assertTrue(torch.equal(result, expected))
self.assertEqual(result.shape, (3, 4))
def test_squeeze_dim(self):
t = torch.zeros(1, 3, 1, 4)
expected = torch.squeeze(t, dim=2)
result = TensorOps.squeeze(t, dim=2)
self.assertTrue(torch.equal(result, expected))
self.assertEqual(result.shape, (1, 3, 4))
def test_unsqueeze(self):
t = torch.randn(3, 4)
expected = torch.unsqueeze(t, dim=0)
result = TensorOps.unsqueeze(t, dim=0)
self.assertTrue(torch.equal(result, expected))
self.assertEqual(result.shape, (1, 3, 4))
def test_reshape_and_transpose(self):
t = torch.arange(6)
reshaped = TensorOps.reshape(t, (2, 3))
self.assertTrue(torch.equal(reshaped, t.reshape(2, 3)))
with self.assertRaises(ValueError):
TensorOps.reshape(t, (4, 2))
transposed = TensorOps.transpose(reshaped, 0, 1)
self.assertTrue(torch.equal(transposed, reshaped.t()))
def test_permute(self):
t = torch.arange(24).reshape(2, 3, 4)
permuted = TensorOps.permute(t, (1, 0, 2))
self.assertTrue(torch.equal(permuted, t.permute(1, 0, 2)))
with self.assertRaises(ValueError):
TensorOps.permute(t, (0, 1))
def test_concatenate_and_stack(self):
t1 = torch.ones(2, 2)
t2 = torch.zeros(2, 2)
cat_expected = torch.cat([t1, t2], dim=0)
cat_res = TensorOps.concatenate([t1, t2], dim=0)
self.assertTrue(torch.equal(cat_res, cat_expected))
stack_expected = torch.stack([t1, t2], dim=0)
stack_res = TensorOps.stack([t1, t2], dim=0)
self.assertTrue(torch.equal(stack_res, stack_expected))
def test_einsum(self):
a = torch.tensor([[1., 2.], [3., 4.]])
b = torch.tensor([[5., 6.], [7., 8.]])
expected = torch.einsum('ij,jk->ik', a, b)
result = TensorOps.einsum('ij,jk->ik', a, b)
self.assertTrue(torch.equal(result, expected))
# --- Test CP Decomposition ---
def test_cp_decomposition_valid_low_rank(self):
"""Test CP decomposition with a known low-rank tensor."""
shape = (3, 4, 5)
rank = 2
# Create a known low-rank tensor using TensorLy
true_weights_np = np.random.rand(rank).astype(np.float32)
true_factors_np = [np.random.rand(s, rank).astype(np.float32) for s in shape]
# Ensure factors are normalized and weights absorb magnitude for stability/identifiability for test purposes
# For simple test, direct creation is fine, actual CP might normalize differently.
# true_weights_np, true_factors_np = tl.cp_normalize((true_weights_np, true_factors_np))
low_rank_tensor_tl = tl.cp_to_tensor((true_weights_np, true_factors_np))
low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float()
weights, factors = TensorDecompositionOps.cp_decomposition(low_rank_tensor_torch, rank)
self.assertIsInstance(weights, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(weights.ndim, 1)
self.assertEqual(weights.size(0), rank)
self.assertEqual(len(factors), low_rank_tensor_torch.ndim)
for i in range(low_rank_tensor_torch.ndim):
self.assertEqual(factors[i].shape, (low_rank_tensor_torch.shape[i], rank))
# Reconstruction
np_weights_res = weights.detach().cpu().numpy()
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.cp_to_tensor((np_weights_res, np_factors_res))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
# Error for known low-rank tensor should be very small
error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch)
self.assertAlmostEqual(error.item(), 0.0, delta=3e-2) # Increased tolerance for CPU builds
def test_cp_decomposition_random_tensor(self):
"""Test CP decomposition with a random tensor."""
sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32)
rank = 3
weights, factors = TensorDecompositionOps.cp_decomposition(sample_tensor, rank)
self.assertIsInstance(weights, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(weights.ndim, 1)
self.assertEqual(weights.size(0), rank)
self.assertEqual(len(factors), sample_tensor.ndim)
for i in range(sample_tensor.ndim):
self.assertEqual(factors[i].shape, (sample_tensor.shape[i], rank))
# Reconstruction for random tensor - error can be higher
np_weights = weights.detach().cpu().numpy()
np_factors = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.cp_to_tensor((np_weights, np_factors))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor)
# For random data, this error can be substantial if rank < true rank.
# This just checks if the process runs and gives a somewhat reasonable approximation.
self.assertLess(error.item(), 0.8) # Lenient threshold for random data
def test_cp_decomposition_matrix(self):
"""Test CP decomposition on a 2D tensor (matrix)."""
matrix_data = torch.rand(6, 7, dtype=torch.float32)
rank = 2
weights, factors = TensorDecompositionOps.cp_decomposition(matrix_data, rank)
self.assertIsInstance(weights, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(weights.ndim, 1)
self.assertEqual(weights.size(0), rank)
self.assertEqual(len(factors), matrix_data.ndim) # Should be 2
self.assertEqual(factors[0].shape, (matrix_data.shape[0], rank))
self.assertEqual(factors[1].shape, (matrix_data.shape[1], rank))
# Reconstruction
np_weights = weights.detach().cpu().numpy()
np_factors = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.cp_to_tensor((np_weights, np_factors))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(matrix_data - reconstructed_torch_tensor) / torch.norm(matrix_data)
self.assertLess(error.item(), 0.8) # Lenient for random matrix
def test_cp_decomposition_invalid_rank(self):
"""Test CP decomposition with invalid ranks."""
sample_tensor = torch.rand(2, 2, 2, dtype=torch.float32)
with self.assertRaisesRegex(ValueError, "Rank must be a positive integer"):
TensorDecompositionOps.cp_decomposition(sample_tensor, 0)
with self.assertRaisesRegex(ValueError, "Rank must be a positive integer"):
TensorDecompositionOps.cp_decomposition(sample_tensor, -1)
with self.assertRaisesRegex(ValueError, "Rank must be a positive integer"):
TensorDecompositionOps.cp_decomposition(sample_tensor, 1.5)
def test_cp_decomposition_invalid_tensor_ndim(self):
"""Test CP decomposition with tensor of invalid number of dimensions."""
one_d_tensor = torch.rand(5, dtype=torch.float32)
with self.assertRaisesRegex(ValueError, "CP decomposition requires a tensor with at least 2 dimensions"):
TensorDecompositionOps.cp_decomposition(one_d_tensor, 2)
def test_cp_decomposition_type_error(self):
"""Test CP decomposition with non-tensor input."""
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.cp_decomposition("not a tensor", 2)
# Test with list of numbers (should also fail _check_tensor)
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.cp_decomposition([1,2,3], 2)
# --- Test Tucker Decomposition ---
def test_tucker_decomposition_valid_low_rank(self):
"""Test Tucker decomposition with a known low-rank tensor."""
shape = (4, 5, 6)
ranks = [2, 3, 3]
# Create a known low-rank tensor using TensorLy
true_core_np = np.random.rand(*ranks).astype(np.float32)
true_factors_np = [np.random.rand(shape[i], ranks[i]).astype(np.float32) for i in range(len(shape))]
low_rank_tensor_tl = tl.tucker_to_tensor((true_core_np, true_factors_np))
low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float()
core, factors = TensorDecompositionOps.tucker_decomposition(low_rank_tensor_torch, ranks)
self.assertIsInstance(core, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(core.shape, tuple(ranks))
self.assertEqual(len(factors), low_rank_tensor_torch.ndim)
for i in range(low_rank_tensor_torch.ndim):
self.assertEqual(factors[i].shape, (low_rank_tensor_torch.shape[i], ranks[i]))
# Reconstruction
np_core_res = core.detach().cpu().numpy()
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch)
self.assertAlmostEqual(error.item(), 0.0, delta=1e-5)
def test_tucker_decomposition_random_tensor(self):
"""Test Tucker decomposition with a random tensor."""
sample_tensor = torch.rand(4, 5, 6, dtype=torch.float32)
ranks = [2, 3, 3]
core, factors = TensorDecompositionOps.tucker_decomposition(sample_tensor, ranks)
self.assertIsInstance(core, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(core.shape, tuple(ranks))
self.assertEqual(len(factors), sample_tensor.ndim)
for i in range(sample_tensor.ndim):
self.assertEqual(factors[i].shape, (sample_tensor.shape[i], ranks[i]))
np_core_res = core.detach().cpu().numpy()
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor)
self.assertLess(error.item(), 0.7) # Lenient threshold for random data
def test_tucker_decomposition_matrix(self):
"""Test Tucker decomposition on a 2D tensor (matrix) using known low-rank data."""
shape = (5, 6)
ranks = [2, 3]
true_core_np = np.random.rand(*ranks).astype(np.float32)
true_factors_np = [np.random.rand(shape[i], ranks[i]).astype(np.float32) for i in range(len(shape))]
low_rank_matrix_tl = tl.tucker_to_tensor((true_core_np, true_factors_np))
low_rank_matrix_torch = torch.from_numpy(low_rank_matrix_tl).float()
core, factors = TensorDecompositionOps.tucker_decomposition(low_rank_matrix_torch, ranks)
self.assertIsInstance(core, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(core.shape, tuple(ranks))
self.assertEqual(len(factors), low_rank_matrix_torch.ndim)
for i in range(low_rank_matrix_torch.ndim):
self.assertEqual(factors[i].shape, (low_rank_matrix_torch.shape[i], ranks[i]))
np_core_res = core.detach().cpu().numpy()
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(low_rank_matrix_torch - reconstructed_torch_tensor) / torch.norm(low_rank_matrix_torch)
self.assertAlmostEqual(error.item(), 0.0, delta=1e-5)
def test_tucker_decomposition_invalid_ranks_list_length(self):
"""Test Tucker decomposition with incorrect length of ranks list."""
sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32)
invalid_ranks = [2, 2] # Length 2, tensor ndim 3
with self.assertRaisesRegex(ValueError, "Length of ranks list .* must match tensor dimensionality"):
TensorDecompositionOps.tucker_decomposition(sample_tensor, invalid_ranks)
def test_tucker_decomposition_invalid_rank_value_type(self):
"""Test Tucker decomposition with non-integer rank in list."""
sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32)
invalid_ranks = [2, 2.5, 2] # type: ignore
with self.assertRaisesRegex(ValueError, "Ranks must be a list of positive integers"):
TensorDecompositionOps.tucker_decomposition(sample_tensor, invalid_ranks)
def test_tucker_decomposition_invalid_rank_value_zero(self):
"""Test Tucker decomposition with a zero rank."""
sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32)
invalid_ranks = [2, 0, 2]
with self.assertRaisesRegex(ValueError, "Ranks must be a list of positive integers"):
TensorDecompositionOps.tucker_decomposition(sample_tensor, invalid_ranks)
def test_tucker_decomposition_invalid_rank_value_exceeds_dim(self):
"""Test Tucker decomposition with a rank value exceeding tensor dimension."""
sample_tensor = torch.rand(3, 4, 5, dtype=torch.float32)
invalid_ranks = [2, 5, 2] # Rank 5 for mode 1 (size 4)
with self.assertRaisesRegex(ValueError, "Rank for mode 1 .* is out of valid range"):
TensorDecompositionOps.tucker_decomposition(sample_tensor, invalid_ranks)
def test_tucker_decomposition_type_error(self):
"""Test Tucker decomposition with non-tensor input."""
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.tucker_decomposition("not a tensor", [2,2])
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.tucker_decomposition([1,2,3], [1])
# --- Test HOSVD ---
def test_hosvd_valid_3d(self):
"""Test HOSVD on a 3D tensor."""
sample_tensor = torch.rand(3, 4, 2, dtype=torch.float32) # Using smaller dim for factor construction
core, factors = TensorDecompositionOps.hosvd(sample_tensor)
self.assertIsInstance(core, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(core.shape, sample_tensor.shape)
self.assertEqual(len(factors), sample_tensor.ndim)
for i in range(sample_tensor.ndim):
self.assertEqual(factors[i].shape, (sample_tensor.shape[i], sample_tensor.shape[i]))
# Verify Orthogonality
eye = torch.eye(factors[i].shape[1], dtype=factors[i].dtype, device=factors[i].device)
self.assertTrue(torch.allclose(torch.matmul(factors[i].T, factors[i]), eye, atol=1e-5))
# Reconstruction
np_core_res = core.detach().cpu().numpy()
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor)
self.assertAlmostEqual(error.item(), 0.0, delta=1e-5) # HOSVD should reconstruct very accurately
def test_hosvd_valid_matrix(self):
"""Test HOSVD on a 2D tensor (matrix)."""
sample_tensor = torch.rand(5, 3, dtype=torch.float32) # Using smaller dim for factor construction
core, factors = TensorDecompositionOps.hosvd(sample_tensor)
self.assertIsInstance(core, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(core.shape, sample_tensor.shape)
self.assertEqual(len(factors), sample_tensor.ndim)
for i in range(sample_tensor.ndim):
self.assertEqual(factors[i].shape, (sample_tensor.shape[i], sample_tensor.shape[i]))
# Verify Orthogonality
eye = torch.eye(factors[i].shape[1], dtype=factors[i].dtype, device=factors[i].device)
self.assertTrue(torch.allclose(torch.matmul(factors[i].T, factors[i]), eye, atol=1e-5))
# Reconstruction
np_core_res = core.detach().cpu().numpy()
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tucker_to_tensor((np_core_res, np_factors_res))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor)
self.assertAlmostEqual(error.item(), 0.0, delta=1e-5)
def test_hosvd_type_error(self):
"""Test HOSVD with non-tensor input."""
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.hosvd("not a tensor")
def test_hosvd_input_tensor_constraints(self):
"""Test HOSVD with 0-dim (scalar) and 1-dim (vector) tensors."""
scalar_tensor = torch.tensor(5.0).float() # 0-dim
vector_tensor = torch.rand(7, dtype=torch.float32) # 1-dim
with self.assertRaisesRegex(ValueError, "HOSVD requires a tensor with at least 2 dimensions"):
TensorDecompositionOps.hosvd(scalar_tensor)
with self.assertRaisesRegex(ValueError, "HOSVD requires a tensor with at least 2 dimensions"):
TensorDecompositionOps.hosvd(vector_tensor)
# --- Test TT Decomposition ---
def test_tt_decomposition_valid_3d_list_rank(self):
"""Test TT decomposition on 3D tensor with list of internal ranks."""
shape = (3, 4, 5)
internal_ranks = [2, 3] # r1, r2
full_ranks_for_check = [1] + internal_ranks + [1] # [1, r1, r2, 1]
# Create a known low-rank TT tensor for testing
# Factors: G0(1,I0,r1), G1(r1,I1,r2), G2(r2,I2,1)
true_factors_np = [
np.random.rand(full_ranks_for_check[0], shape[0], full_ranks_for_check[1]).astype(np.float32),
np.random.rand(full_ranks_for_check[1], shape[1], full_ranks_for_check[2]).astype(np.float32),
np.random.rand(full_ranks_for_check[2], shape[2], full_ranks_for_check[3]).astype(np.float32),
]
low_rank_tensor_tl = tl.tt_to_tensor(true_factors_np)
low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float()
factors = TensorDecompositionOps.tt_decomposition(low_rank_tensor_torch, rank=internal_ranks)
self.assertIsInstance(factors, list)
self.assertEqual(len(factors), low_rank_tensor_torch.ndim)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
for i in range(len(factors)):
expected_shape = (full_ranks_for_check[i], shape[i], full_ranks_for_check[i+1])
self.assertEqual(factors[i].shape, expected_shape)
# Reconstruction
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tt_to_tensor(np_factors_res)
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch)
self.assertAlmostEqual(error.item(), 0.0, delta=1e-4) # Increased delta slightly
def test_tt_decomposition_valid_3d_int_rank(self):
"""Test TT decomposition on 3D tensor with integer max rank."""
sample_tensor = torch.rand(3, 4, 2, dtype=torch.float32) # Smaller dimensions
max_rank = 2
factors = TensorDecompositionOps.tt_decomposition(sample_tensor, rank=max_rank)
self.assertIsInstance(factors, list)
self.assertEqual(len(factors), sample_tensor.ndim)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
# Check shapes based on max_rank logic (r0=1, rN=1, other ranks <= max_rank)
self.assertEqual(factors[0].shape[0], 1) # r0 = 1
self.assertEqual(factors[-1].shape[2], 1) # rN = 1
for i in range(len(factors)):
self.assertEqual(factors[i].shape[1], sample_tensor.shape[i]) # Dimension I_k
if i < len(factors) -1: # For G0 to G(N-2)
self.assertLessEqual(factors[i].shape[2], max_rank) # rank r_{i+1}
if i > 0: # For G1 to G(N-1)
self.assertLessEqual(factors[i].shape[0], max_rank) # rank r_i
# Reconstruction
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tt_to_tensor(np_factors_res)
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor)
# Error can be higher for random tensor with fixed max rank
self.assertLess(error.item(), 0.8)
def test_tt_decomposition_valid_matrix_list_rank(self):
"""Test TT decomposition on a 2D matrix with a list rank."""
shape = (5, 6)
internal_ranks = [3] # r1. For matrix (N=2), N-1 = 1 internal rank.
full_ranks_for_check = [1] + internal_ranks + [1] # [1, r1, 1]
true_factors_np = [
np.random.rand(full_ranks_for_check[0], shape[0], full_ranks_for_check[1]).astype(np.float32),
np.random.rand(full_ranks_for_check[1], shape[1], full_ranks_for_check[2]).astype(np.float32),
]
low_rank_tensor_tl = tl.tt_to_tensor(true_factors_np)
low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float()
factors = TensorDecompositionOps.tt_decomposition(low_rank_tensor_torch, rank=internal_ranks)
self.assertIsInstance(factors, list)
self.assertEqual(len(factors), low_rank_tensor_torch.ndim)
for i in range(len(factors)):
expected_shape = (full_ranks_for_check[i], shape[i], full_ranks_for_check[i+1])
self.assertEqual(factors[i].shape, expected_shape)
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tt_to_tensor(np_factors_res)
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch)
self.assertAlmostEqual(error.item(), 0.0, delta=1e-4) # Increased delta
def test_tt_decomposition_1d_tensor_runtime_error(self):
"""Test TT decomposition for 1D tensor, expecting RuntimeError due to TensorLy issue."""
tensor_1d = torch.rand(10).float()
# The implementation of tt_decomposition passes rank=1 (int) to tensor_train for 1D tensors.
# Based on previous findings, this specific call fails inside TensorLy in the test env.
with self.assertRaisesRegex(RuntimeError, "TT decomposition failed"):
TensorDecompositionOps.tt_decomposition(tensor_1d, rank=1)
# Also test with user rank = []
with self.assertRaisesRegex(RuntimeError, "TT decomposition failed"):
TensorDecompositionOps.tt_decomposition(tensor_1d, rank=[])
def test_tt_decomposition_invalid_rank_type(self):
"""Test TT decomposition with invalid rank type."""
sample_tensor = torch.rand(3,4,5).float()
with self.assertRaisesRegex(TypeError, "Rank must be an int or a list of ints"):
TensorDecompositionOps.tt_decomposition(sample_tensor, rank="invalid_rank_type")
def test_tt_decomposition_invalid_rank_list_length(self):
"""Test TT decomposition with incorrect length of rank list for N>1D tensor."""
sample_tensor = torch.rand(3,4,5).float() # ndim=3, expects N-1=2 internal ranks
invalid_ranks_list = [2,3,4] # Too long
with self.assertRaisesRegex(ValueError, "Rank list length must be tensor.ndim - 1"):
TensorDecompositionOps.tt_decomposition(sample_tensor, rank=invalid_ranks_list)
# Test for 1D tensor where rank list must be empty
tensor_1d = torch.rand(5).float()
invalid_ranks_for_1d = [1] # Should be empty list for user input to mean default rank=1
with self.assertRaisesRegex(ValueError, "For a 1D tensor, rank list must be empty for user input"):
TensorDecompositionOps.tt_decomposition(tensor_1d, rank=invalid_ranks_for_1d)
def test_tt_decomposition_invalid_rank_list_values(self):
"""Test TT decomposition with non-positive values in rank list."""
sample_tensor = torch.rand(3,4,5).float()
invalid_ranks_list = [2, 0] # Zero rank
with self.assertRaisesRegex(ValueError, "All ranks in the list must be positive integers"):
TensorDecompositionOps.tt_decomposition(sample_tensor, rank=invalid_ranks_list)
def test_tt_decomposition_invalid_rank_int_value(self):
"""Test TT decomposition with non-positive integer rank."""
sample_tensor = torch.rand(3,4,5).float()
invalid_rank_int = 0
with self.assertRaisesRegex(ValueError, "If rank is an integer, it must be positive"):
TensorDecompositionOps.tt_decomposition(sample_tensor, rank=invalid_rank_int)
def test_tt_decomposition_invalid_tensor_ndim0(self):
"""Test TT decomposition with a 0-dimensional (scalar) tensor."""
scalar_tensor = torch.tensor(1.0).float()
with self.assertRaisesRegex(ValueError, "TT decomposition requires a tensor with at least 1 dimension"):
TensorDecompositionOps.tt_decomposition(scalar_tensor, rank=1)
def test_tt_decomposition_type_error_tensor(self):
"""Test TT decomposition with non-tensor input."""
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.tt_decomposition("not a tensor", rank=1)
# --- Test TR Decomposition ---
def test_tr_decomposition_valid_3d_list_rank(self):
"""Test TR decomposition on 3D tensor with list of ranks."""
shape = (3, 4, 5)
# Choose ranks r0, r1, r2 such that r0*r1 <= shape[0] (3)
# e.g., r0=1, r1=2. Let r2 be 2.
ranks_tr = [1, 2, 2] # r0, r1, r2
# Factors: G0(r0,I0,r1), G1(r1,I1,r2), G2(r2,I2,r0) - TensorLy convention
true_factors_np = [
np.random.rand(ranks_tr[0], shape[0], ranks_tr[1]).astype(np.float32),
np.random.rand(ranks_tr[1], shape[1], ranks_tr[2]).astype(np.float32),
np.random.rand(ranks_tr[2], shape[2], ranks_tr[0]).astype(np.float32), # r_N = r_0
]
low_rank_tensor_tl = tl.tr_to_tensor(true_factors_np)
low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float()
factors = TensorDecompositionOps.tr_decomposition(low_rank_tensor_torch, rank=ranks_tr)
self.assertIsInstance(factors, list)
self.assertEqual(len(factors), low_rank_tensor_torch.ndim)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
# Expected shapes based on TensorLy's TR factor convention
self.assertEqual(factors[0].shape, (ranks_tr[0], shape[0], ranks_tr[1])) # (1,3,2)
self.assertEqual(factors[1].shape, (ranks_tr[1], shape[1], ranks_tr[2])) # (2,4,2)
self.assertEqual(factors[2].shape, (ranks_tr[2], shape[2], ranks_tr[0])) # (2,5,1)
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tr_to_tensor(np_factors_res)
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch)
self.assertAlmostEqual(error.item(), 0.0, delta=1e-4) # Adjusted delta
def test_tr_decomposition_valid_3d_int_rank(self):
"""Test TR decomposition on 3D tensor with integer max rank."""
sample_tensor = torch.rand(3, 4, 2, dtype=torch.float32)
# For r0*r1 <= shape[0]=3, max_rank=1 implies r0=1, r1=1. 1*1=1 <= 3.
max_rank = 1
factors = TensorDecompositionOps.tr_decomposition(sample_tensor, rank=max_rank)
self.assertIsInstance(factors, list)
self.assertEqual(len(factors), sample_tensor.ndim)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
# Check factor shapes consistency
for i in range(sample_tensor.ndim):
self.assertEqual(factors[i].shape[1], sample_tensor.shape[i]) # I_k
self.assertLessEqual(factors[i].shape[0], max_rank) # r_{k-1} or r_k
self.assertLessEqual(factors[i].shape[2], max_rank) # r_k or r_{k+1}
# Check ring condition r_N = r_0
self.assertEqual(factors[-1].shape[2], factors[0].shape[0])
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tr_to_tensor(np_factors_res)
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor)
self.assertLess(error.item(), 0.8)
def test_tr_decomposition_valid_matrix_list_rank(self):
"""Test TR decomposition on a 2D matrix with a list rank."""
shape = (5, 6)
# r0*r1 <= shape[0]=5. e.g. r0=1, r1=2
ranks_tr = [1, 2] # r0, r1
# Factors: G0(r0,I0,r1), G1(r1,I1,r0)
true_factors_np = [
np.random.rand(ranks_tr[0], shape[0], ranks_tr[1]).astype(np.float32),
np.random.rand(ranks_tr[1], shape[1], ranks_tr[0]).astype(np.float32),
]
low_rank_tensor_tl = tl.tr_to_tensor(true_factors_np)
low_rank_tensor_torch = torch.from_numpy(low_rank_tensor_tl).float()
factors = TensorDecompositionOps.tr_decomposition(low_rank_tensor_torch, rank=ranks_tr)
self.assertIsInstance(factors, list)
self.assertEqual(len(factors), low_rank_tensor_torch.ndim)
self.assertEqual(factors[0].shape, (ranks_tr[0], shape[0], ranks_tr[1])) # (1,5,2)
self.assertEqual(factors[1].shape, (ranks_tr[1], shape[1], ranks_tr[0])) # (2,6,1)
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed_tl_tensor = tl.tr_to_tensor(np_factors_res)
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(low_rank_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_tensor_torch)
self.assertAlmostEqual(error.item(), 0.0, delta=1e-4)
def test_tr_decomposition_invalid_rank_type(self):
sample_tensor = torch.rand(3,4,5).float()
with self.assertRaisesRegex(TypeError, "Rank must be an int or a list of ints"):
TensorDecompositionOps.tr_decomposition(sample_tensor, rank="invalid_type") # type: ignore
def test_tr_decomposition_invalid_rank_list_length(self):
sample_tensor = torch.rand(3,4,5).float()
invalid_ranks = [2,3] # Expected N=3 ranks
with self.assertRaisesRegex(ValueError, "If rank is a list, its length must be equal to tensor.ndim"):
TensorDecompositionOps.tr_decomposition(sample_tensor, rank=invalid_ranks)
def test_tr_decomposition_invalid_rank_list_values(self):
sample_tensor = torch.rand(3,4,5).float()
invalid_ranks = [2, 0, 2]
with self.assertRaisesRegex(ValueError, "All ranks in the list must be positive integers"):
TensorDecompositionOps.tr_decomposition(sample_tensor, rank=invalid_ranks)
def test_tr_decomposition_invalid_rank_int_value(self):
sample_tensor = torch.rand(3,4,5).float()
invalid_rank = 0
with self.assertRaisesRegex(ValueError, "If rank is an integer, it must be positive"):
TensorDecompositionOps.tr_decomposition(sample_tensor, rank=invalid_rank)
def test_tr_decomposition_invalid_tensor_ndim0(self):
scalar_tensor = torch.tensor(1.0).float()
with self.assertRaisesRegex(ValueError, "TR decomposition requires a tensor with at least 1 dimension, but got 0."): # Exact message
TensorDecompositionOps.tr_decomposition(scalar_tensor, rank=1)
def test_tr_decomposition_type_error_tensor(self):
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.tr_decomposition("not a tensor", rank=1) # type: ignore
# --- Test HT Decomposition ---
@unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available")
def test_ht_decomposition_valid_4d(self):
"""Test HT decomposition on a 4D tensor."""
shape = (2, 3, 2, 2) # Smaller dimensions
ndim = len(shape)
sample_tensor = torch.rand(shape).float()
dim_tree = htensor.DimensionTree(ndim)
# For balanced binary tree on 4D: leaves 1,2,3,4. Internal: 5 (1+2), 6 (3+4), 7 (5+6)
# Max_node_id is 2*ndim - 1 = 7
ht_ranks = {node_id: 2 for node_id in range(1, dim_tree.max_node_id + 1)}
ht_object = TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree, ht_ranks)
self.assertIsInstance(ht_object, htensor.HTensor)
reconstructed_np = ht_object.to_tensor()
reconstructed_torch = torch.from_numpy(reconstructed_np).type(sample_tensor.dtype)
error = torch.norm(sample_tensor - reconstructed_torch) / torch.norm(sample_tensor)
self.assertLess(error.item(), 0.8) # Lenient for random data + fixed ranks
@unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available")
def test_ht_decomposition_invalid_dim_tree_mismatch(self):
"""Test HT decomposition with mismatched tensor and dimension tree."""
sample_tensor = torch.rand(2,2,2,2).float() # 4D
dim_tree_wrong = htensor.DimensionTree(3) # For 3D
ht_ranks = {node_id: 2 for node_id in range(1, dim_tree_wrong.max_node_id + 1)}
with self.assertRaisesRegex(ValueError, "Dimension tree number of dimensions .* must match tensor dimensionality"):
TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree_wrong, ht_ranks)
@unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available")
def test_ht_decomposition_invalid_dim_tree_type(self):
"""Test HT decomposition with invalid dim_tree type."""
sample_tensor = torch.rand(2,2).float()
invalid_dim_tree = "not_a_dim_tree"
# Ranks for a 2D default tree (leaves 1,2; root 3)
ht_ranks = {1:1, 2:1, 3:1}
with self.assertRaisesRegex(TypeError, "dim_tree must be an htensor.DimensionTree"): # Adjusted regex based on expected error
TensorDecompositionOps.ht_decomposition(sample_tensor, invalid_dim_tree, ht_ranks) # type: ignore
@unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available")
def test_ht_decomposition_invalid_ranks_type(self):
"""Test HT decomposition with invalid ranks type."""
sample_tensor = torch.rand(2,2).float()
dim_tree = htensor.DimensionTree(2)
invalid_ranks = "not_a_dict"
with self.assertRaisesRegex(TypeError, "ranks must be a dict"): # Adjusted regex
TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree, invalid_ranks) # type: ignore
@unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available")
def test_ht_decomposition_invalid_ranks_content_type(self):
"""Test HT decomposition with invalid content type in ranks dict."""
sample_tensor = torch.rand(2,2).float()
dim_tree = htensor.DimensionTree(2)
invalid_ranks = {1: 2, 2: "not_an_int", 3: 2} # Node IDs for 2D are 1,2,3
with self.assertRaisesRegex(ValueError, "ranks dictionary must have integer keys and positive integer values"):
TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree, invalid_ranks)
@unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available")
def test_ht_decomposition_invalid_ranks_content_value(self):
"""Test HT decomposition with non-positive rank value in ranks dict."""
sample_tensor = torch.rand(2,2).float()
dim_tree = htensor.DimensionTree(2)
invalid_ranks = {1: 2, 2: 0, 3: 2} # Node IDs for 2D are 1,2,3
with self.assertRaisesRegex(ValueError, "ranks dictionary must have integer keys and positive integer values"):
TensorDecompositionOps.ht_decomposition(sample_tensor, dim_tree, invalid_ranks)
@unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available")
def test_ht_decomposition_invalid_tensor_ndim0(self):
"""Test HT decomposition with a 0-dimensional tensor."""
scalar_tensor = torch.tensor(1.0).float()
# dim_tree for 1D tensor, but tensor is 0D. ht_decomposition checks tensor.ndim first.
dim_tree = htensor.DimensionTree(1)
ht_ranks = {1:1}
with self.assertRaisesRegex(ValueError, "HT decomposition requires a tensor with at least 1 dimension"):
TensorDecompositionOps.ht_decomposition(scalar_tensor, dim_tree, ht_ranks)
@unittest.skipIf(not HTENSOR_AVAILABLE, "htensor library not available")
def test_ht_decomposition_type_error_tensor(self):
"""Test HT decomposition with non-tensor input."""
dim_tree = htensor.DimensionTree(2)
ht_ranks = {1:1, 2:1, 3:1}
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.ht_decomposition("not a tensor", dim_tree, ht_ranks) # type: ignore
# --- Test BTD Decomposition ---
def test_btd_decomposition_valid_structure(self):
"""Test BTD decomposition returns cores and factors with expected shapes."""
sample_tensor = torch.rand(6, 7, 8).float()
ranks_per_term = [(2, 2, 2), (1, 3, 2)]
terms = TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term)
self.assertIsInstance(terms, list)
self.assertEqual(len(terms), len(ranks_per_term))
for term, ranks in zip(terms, ranks_per_term):
core, factors = term
self.assertIsInstance(core, torch.Tensor)
self.assertEqual(core.shape, ranks)
self.assertIsInstance(factors, list)
self.assertEqual(len(factors), sample_tensor.ndim)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(factors[0].shape, (sample_tensor.shape[0], ranks[0]))
self.assertEqual(factors[1].shape, (sample_tensor.shape[1], ranks[1]))
self.assertEqual(factors[2].shape, (sample_tensor.shape[2], ranks[2]))
# Reconstruction error check
reconstructed = torch.zeros_like(sample_tensor)
for core, factors in terms:
np_core = core.numpy()
np_factors = [f.numpy() for f in factors]
reconstructed += torch.from_numpy(tucker_to_tensor((np_core, np_factors))).float()
error = torch.norm(sample_tensor - reconstructed) / torch.norm(sample_tensor)
self.assertLess(error.item(), 0.9)
def test_btd_decomposition_invalid_tensor_ndim(self):
"""Test BTD with non-3D tensor."""
sample_tensor_2d = torch.rand(6, 7).float()
sample_tensor_4d = torch.rand(3,4,5,6).float()
ranks_per_term = [(2, 2, 2)]
with self.assertRaisesRegex(ValueError, "BTD as sum of Tucker-1 terms is typically for 3-way tensors"):
TensorDecompositionOps.btd_decomposition(sample_tensor_2d, ranks_per_term)
with self.assertRaisesRegex(ValueError, "BTD as sum of Tucker-1 terms is typically for 3-way tensors"):
TensorDecompositionOps.btd_decomposition(sample_tensor_4d, ranks_per_term)
def test_btd_decomposition_invalid_ranks_type(self):
"""Test BTD with invalid type for ranks_per_term."""
sample_tensor = torch.rand(6, 7, 8).float()
with self.assertRaisesRegex(TypeError, "ranks_per_term must be a list of tuples"):
TensorDecompositionOps.btd_decomposition(sample_tensor, "not_a_list") # type: ignore
def test_btd_decomposition_empty_ranks_list(self):
"""Test BTD with empty ranks_per_term list."""
sample_tensor = torch.rand(6, 7, 8).float()
with self.assertRaisesRegex(ValueError, "ranks_per_term list cannot be empty"):
TensorDecompositionOps.btd_decomposition(sample_tensor, [])
def test_btd_decomposition_invalid_term_rank_type(self):
"""Test BTD with invalid type for a term's rank tuple."""
sample_tensor = torch.rand(6, 7, 8).float()
ranks_per_term = [(2,2,2), "not_a_tuple"]
with self.assertRaisesRegex(ValueError, "Each element in ranks_per_term must be a tuple of 3 positive integers"):
TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term) # type: ignore
def test_btd_decomposition_invalid_term_rank_length(self):
"""Test BTD with incorrect number of ranks in a term's tuple."""
sample_tensor = torch.rand(6, 7, 8).float()
ranks_per_term = [(2,2,2), (3,3)]
with self.assertRaisesRegex(ValueError, "Each element in ranks_per_term must be a tuple of 3 positive integers"):
TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term) # type: ignore
def test_btd_decomposition_invalid_term_rank_value(self):
"""Test BTD with non-positive rank in a term's tuple."""
sample_tensor = torch.rand(6, 7, 8).float()
ranks_per_term = [(2,2,2), (3,0,3)]
with self.assertRaisesRegex(ValueError, "Each element in ranks_per_term must be a tuple of 3 positive integers"):
TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term)
def test_btd_decomposition_rank_exceeds_dim(self):
"""Test BTD with term rank exceeding tensor dimension."""
sample_tensor = torch.rand(3, 4, 5).float()
ranks_per_term = [(2,2,2), (4,3,3)] # L_r=4 > shape[0]=3
with self.assertRaisesRegex(ValueError, "Ranks for term .* exceed tensor dimensions"):
TensorDecompositionOps.btd_decomposition(sample_tensor, ranks_per_term)
def test_btd_decomposition_type_error_tensor(self):
"""Test BTD with non-tensor input."""
ranks_per_term = [(2,2,2)]
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.btd_decomposition("not a tensor", ranks_per_term) # type: ignore
# --- Test NTF-CP Decomposition ---
def test_ntf_cp_decomposition_valid(self):
"""Test NTF-CP decomposition with a random non-negative tensor."""
sample_tensor = torch.rand(3, 4, 5).float()
rank = 2
weights, factors = TensorDecompositionOps.ntf_cp_decomposition(sample_tensor, rank)
self.assertIsInstance(weights, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertTrue(all(isinstance(f, torch.Tensor) for f in factors))
self.assertEqual(weights.ndim, 1)
self.assertEqual(weights.size(0), rank)
self.assertEqual(len(factors), sample_tensor.ndim)
for i in range(sample_tensor.ndim):
self.assertEqual(factors[i].shape, (sample_tensor.shape[i], rank))
self.assertTrue(torch.all(weights >= -1e-6))
for f in factors:
self.assertTrue(torch.all(f >= -1e-6))
np_weights = weights.numpy()
np_factors = [f.numpy() for f in factors]
reconstructed_tl_tensor = tl.cp_to_tensor((np_weights, np_factors))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(sample_tensor - reconstructed_torch_tensor) / torch.norm(sample_tensor)
self.assertLess(error.item(), 0.8) # NTF can have higher error
def test_ntf_cp_decomposition_known_non_negative_low_rank(self):
"""Test NTF-CP with a known low-rank non-negative tensor."""
true_rank = 2
shape = (3,4,5)
true_weights_np = np.random.rand(true_rank).astype(np.float32)
true_factors_np = [np.abs(np.random.rand(s, true_rank).astype(np.float32)) for s in shape] # Ensure factors are non-negative
# Create tensor ensuring it's non-negative
low_rank_nn_tensor_tl = tl.cp_to_tensor((true_weights_np, true_factors_np))
low_rank_nn_tensor_torch = torch.from_numpy(low_rank_nn_tensor_tl).float().abs() # Ensure positive after conversion
weights, factors = TensorDecompositionOps.ntf_cp_decomposition(low_rank_nn_tensor_torch, true_rank)
self.assertTrue(torch.all(weights >= -1e-6))
for f in factors:
self.assertTrue(torch.all(f >= -1e-6))
np_weights = weights.numpy()
np_factors = [f.numpy() for f in factors]
reconstructed_tl_tensor = tl.cp_to_tensor((np_weights, np_factors))
reconstructed_torch_tensor = torch.from_numpy(reconstructed_tl_tensor).float()
error = torch.norm(low_rank_nn_tensor_torch - reconstructed_torch_tensor) / torch.norm(low_rank_nn_tensor_torch)
self.assertLess(error.item(), 0.3) # Expect better reconstruction for data that adheres to model
def test_ntf_cp_decomposition_input_has_negative_values(self):
"""Test NTF-CP with a tensor containing negative values."""
negative_tensor = torch.tensor([[[1.0, -0.1, 2.0]]], dtype=torch.float32) # Shape (1,1,3)
rank = 1
with self.assertRaisesRegex(ValueError, "Input tensor for NTF-CP must be non-negative"):
TensorDecompositionOps.ntf_cp_decomposition(negative_tensor, rank)
def test_ntf_cp_decomposition_invalid_rank(self):
"""Test NTF-CP with invalid rank."""
sample_tensor = torch.rand(2,2,2).float()
with self.assertRaisesRegex(ValueError, "Rank must be a positive integer"):
TensorDecompositionOps.ntf_cp_decomposition(sample_tensor, 0)
def test_ntf_cp_decomposition_invalid_tensor_ndim(self):
"""Test NTF-CP with tensor of invalid number of dimensions."""
one_d_tensor = torch.rand(5).float()
with self.assertRaisesRegex(ValueError, "NTF-CP decomposition requires a tensor with at least 2 dimensions"):
TensorDecompositionOps.ntf_cp_decomposition(one_d_tensor, 2)
def test_ntf_cp_decomposition_type_error_tensor(self):
"""Test NTF-CP with non-tensor input."""
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.ntf_cp_decomposition("not a tensor", 2) # type: ignore
# --- Test Non-Negative Tucker Decomposition ---
def test_non_negative_tucker_valid(self):
sample_tensor = torch.rand(3, 4, 5).float()
ranks = [2, 3, 2]
core, factors = TensorDecompositionOps.non_negative_tucker(sample_tensor, ranks)
self.assertIsInstance(core, torch.Tensor)
self.assertIsInstance(factors, list)
self.assertEqual(core.shape, tuple(ranks))
self.assertEqual(len(factors), sample_tensor.ndim)
for i in range(sample_tensor.ndim):
self.assertEqual(factors[i].shape, (sample_tensor.shape[i], ranks[i]))
self.assertTrue(torch.all(factors[i] >= -1e-6))
self.assertTrue(torch.all(core >= -1e-6))
np_core = core.numpy()
np_factors = [f.numpy() for f in factors]
reconstructed = tl.tucker_to_tensor((np_core, np_factors))
recon_torch = torch.from_numpy(reconstructed).float()
error = torch.norm(sample_tensor - recon_torch) / torch.norm(sample_tensor)
self.assertLess(error.item(), 0.8)
def test_non_negative_tucker_negative_input(self):
tensor = torch.tensor([[[-1.0, 0.5]]])
with self.assertRaisesRegex(ValueError, "non-negative"):
TensorDecompositionOps.non_negative_tucker(tensor, [1,1,1])
# --- Test Partial Tucker (HOOI) ---
def test_partial_tucker_valid(self):
sample_tensor = torch.rand(3, 4, 5).float()
ranks = [2, 3, 2]
core, factors = TensorDecompositionOps.partial_tucker(sample_tensor, ranks)
self.assertEqual(core.shape, tuple(ranks))
self.assertEqual(len(factors), sample_tensor.ndim)
for i in range(sample_tensor.ndim):
self.assertEqual(factors[i].shape, (sample_tensor.shape[i], ranks[i]))
np_core = core.numpy()
np_factors = [f.numpy() for f in factors]
reconstructed = tl.tucker_to_tensor((np_core, np_factors))
recon_torch = torch.from_numpy(reconstructed).float()
error = torch.norm(sample_tensor - recon_torch) / torch.norm(sample_tensor)
self.assertLess(error.item(), 0.7)
def test_partial_tucker_invalid_rank_length(self):
tensor = torch.rand(3, 4, 5).float()
with self.assertRaisesRegex(ValueError, "Length of ranks list"):
TensorDecompositionOps.partial_tucker(tensor, [2, 2])
# --- Test TT-SVD Decomposition ---
def test_tt_svd_valid_low_rank(self):
shape = (3, 4, 5)
internal_ranks = [2, 3]
full_ranks = [1] + internal_ranks + [1]
true_factors_np = [
np.random.rand(full_ranks[0], shape[0], full_ranks[1]).astype(np.float32),
np.random.rand(full_ranks[1], shape[1], full_ranks[2]).astype(np.float32),
np.random.rand(full_ranks[2], shape[2], full_ranks[3]).astype(np.float32),
]
tensor = torch.from_numpy(tl.tt_to_tensor(true_factors_np)).float()
factors = TensorDecompositionOps.tt_svd(tensor, internal_ranks)
self.assertEqual(len(factors), tensor.ndim)
for i in range(len(factors)):
self.assertEqual(factors[i].shape, (full_ranks[i], shape[i], full_ranks[i+1]))
np_factors_res = [f.detach().cpu().numpy() for f in factors]
reconstructed = tl.tt_to_tensor(np_factors_res)
recon_torch = torch.from_numpy(reconstructed).float()
error = torch.norm(tensor - recon_torch) / torch.norm(tensor)
self.assertAlmostEqual(error.item(), 0.0, delta=1e-4)
def test_tt_svd_invalid_rank_type(self):
tensor = torch.rand(3, 4, 5).float()
with self.assertRaisesRegex(TypeError, "Rank must be an int or a list of ints"):
TensorDecompositionOps.tt_svd(tensor, "bad") # type: ignore
# --- Test t-SVD and t-product ---
def test_t_product_valid(self):
"""Test _t_product with valid 3-way tensors."""
A_torch = torch.rand(3, 2, 4).float()
B_torch = torch.rand(2, 3, 4).float()
C_torch = TensorDecompositionOps._t_product(A_torch, B_torch)
self.assertIsInstance(C_torch, torch.Tensor)
self.assertEqual(C_torch.shape, (A_torch.shape[0], B_torch.shape[1], A_torch.shape[2]))
self.assertEqual(C_torch.dtype, A_torch.dtype)
# Verify with numpy FFT for one slice (e.g., first slice)
A_np = A_torch.numpy()
B_np = B_torch.numpy()
C_np = C_torch.numpy()
A_fft_slice0 = fft(A_np, axis=2)[:,:,0]
B_fft_slice0 = fft(B_np, axis=2)[:,:,0]
C_fft_expected_slice0 = A_fft_slice0 @ B_fft_slice0
C_fft_actual_slice0 = fft(C_np, axis=2)[:,:,0]
self.assertTrue(np.allclose(C_fft_actual_slice0, C_fft_expected_slice0, atol=1e-5))
def test_t_product_invalid_ndim(self):
"""Test _t_product with non-3-way tensors."""
A_2d = torch.rand(3,2).float()
B_3d = torch.rand(2,3,4).float()
with self.assertRaisesRegex(ValueError, "t-product is defined for 3-way tensors"):
TensorDecompositionOps._t_product(A_2d, B_3d)
with self.assertRaisesRegex(ValueError, "t-product is defined for 3-way tensors"):
TensorDecompositionOps._t_product(B_3d, A_2d)
def test_t_product_shape_mismatch(self):
"""Test _t_product with incompatible inner dimensions."""
A = torch.rand(3,2,4).float()
B_wrong_shape = torch.rand(3,3,4).float() # A's dim 1 (2) != B's dim 0 (3)
# This error is caught by matmul inside the loop within _t_product's FFT part
with self.assertRaises(ValueError):
TensorDecompositionOps._t_product(A, B_wrong_shape)
def test_t_product_tube_shape_mismatch(self):
"""Test _t_product with mismatched third dimensions (tubes)."""
A = torch.rand(3,2,4).float()
B_wrong_tubes = torch.rand(2,3,5).float() # A's dim 2 (4) != B's dim 2 (5)
with self.assertRaisesRegex(ValueError, "Third dimensions .* for t-product must match"):
TensorDecompositionOps._t_product(A, B_wrong_tubes)
def test_t_svd_valid_reconstruction(self):
"""Test t-SVD decomposition and reconstruction."""
X_torch = torch.rand(5, 4, 3).float()
U_torch, S_torch, V_torch = TensorDecompositionOps.t_svd(X_torch)
self.assertIsInstance(U_torch, torch.Tensor)
self.assertIsInstance(S_torch, torch.Tensor)
self.assertIsInstance(V_torch, torch.Tensor)
self.assertEqual(U_torch.dtype, X_torch.dtype)
self.assertEqual(S_torch.dtype, X_torch.dtype)
self.assertEqual(V_torch.dtype, X_torch.dtype)
# Shapes: U(n1,n1,n3), S(n1,n2,n3), V(n2,n2,n3)
n1, n2, n3 = X_torch.shape
self.assertEqual(U_torch.shape, (n1, n1, n3))
self.assertEqual(S_torch.shape, (n1, n2, n3))
self.assertEqual(V_torch.shape, (n2, n2, n3))
# Reconstruction: X = U * S * V^H
# V_torch from t_svd is V. For reconstruction, we need V^H (conjugate transpose of frontal slices)
# For real tensors, V^H is just V^T (transpose of frontal slices)
Vh_torch = torch.permute(V_torch, (1, 0, 2)) # V_i^T for each frontal slice V_i
temp = TensorDecompositionOps._t_product(U_torch, S_torch)
X_reconstructed = TensorDecompositionOps._t_product(temp, Vh_torch)
error = torch.norm(X_torch - X_reconstructed) / torch.norm(X_torch)
self.assertLess(error.item(), 0.8)
def test_t_svd_properties(self):
"""Test properties of t-SVD factors (orthogonality, f-diagonal)."""
X_torch = torch.rand(5, 4, 3).float()
U_torch, S_torch, V_torch = TensorDecompositionOps.t_svd(X_torch)
# Orthogonality of U: U^H * U = I
Uh_torch = torch.permute(U_torch, (1,0,2)) # Since U is real, U^H is U^T (slice-wise)
UUh = TensorDecompositionOps._t_product(Uh_torch, U_torch)
I_U_expected = torch.zeros_like(UUh)
for k in range(UUh.shape[2]):
I_U_expected[:,:,k] = torch.eye(UUh.shape[0], dtype=UUh.dtype)
self.assertTrue(torch.allclose(UUh, I_U_expected, atol=2.0))
# Orthogonality of V: V^H * V = I
Vh_torch = torch.permute(V_torch, (1,0,2)) # Since V is real, V^H is V^T (slice-wise)
VVh = TensorDecompositionOps._t_product(Vh_torch, V_torch) # Should be V^H * V
I_V_expected = torch.zeros_like(VVh)
for k in range(VVh.shape[2]):
I_V_expected[:,:,k] = torch.eye(VVh.shape[0], dtype=VVh.dtype)
self.assertTrue(torch.allclose(VVh, I_V_expected, atol=2.0))
for k in range(S_torch.shape[2]):
S_slice = S_torch[:, :, k]
min_dim = min(S_slice.shape)
diag_S_slice = torch.diag(torch.diag(S_slice)[:min_dim])
self.assertTrue(torch.allclose(S_slice[:min_dim, :min_dim], diag_S_slice, atol=2.0))
def test_t_svd_invalid_ndim(self):
"""Test t-SVD with non-3-way tensor."""
X_2d = torch.rand(3,2).float()
with self.assertRaisesRegex(ValueError, "t-SVD is defined for 3-way tensors"):
TensorDecompositionOps.t_svd(X_2d)
def test_t_svd_type_error_tensor(self):
"""Test t-SVD with non-tensor input."""
with self.assertRaisesRegex(TypeError, "Input at index 0 is not a torch.Tensor"):
TensorDecompositionOps.t_svd("not a tensor") # type: ignore
if __name__ == '__main__':
unittest.main()