| from functools import partial |
| from typing import Sequence |
|
|
| import torch as T |
|
|
|
|
| def svd_truncated(mat: T.Tensor, rank: int): |
| lvecs, svals, rvecs = T.linalg.svd(mat) |
| return lvecs[:, :rank], svals[:rank], rvecs[:rank, :].T |
|
|
|
|
| def ttd(ten: T.Tensor, rank: Sequence[int], noiters: int = 1000, |
| method: str = 'tsvd') -> Sequence[T.Tensor]: |
| """Function ttd implements tensor-train decomposition. |
| """ |
| if ten.ndim + 1 != len(rank): |
| raise ValueError |
| if rank[0] != 1 or rank[-1] != 1: |
| raise ValueError |
|
|
| if method == 'svd': |
| factorize = svd_truncated |
| elif method == 'tsvd': |
| factorize = partial(T.svd_lowrank, niter=noiters) |
| else: |
| raise ValueError(f'Unknown method: {method}.') |
|
|
| cores = [] |
| shape = ten.shape |
|
|
| |
| for core_shape in zip(rank, shape, rank[1:]): |
| |
| |
| mat = ten.reshape(core_shape[0] * core_shape[1], -1) |
| |
| lvecs, svals, rvecs = factorize(mat, core_shape[2]) |
| |
| core = lvecs * svals[None, :] |
| core = core.reshape(core_shape) |
| cores.append(core) |
| |
| ten = rvecs.T |
|
|
| return cores |
|
|