| import primefac |
| import scipy |
| import math |
|
|
|
|
| def butterfly_factors(n): |
| pf = list(primefac.primefac(n)) |
| return (math.prod(pf[0::2]), math.prod(pf[1::2])) |
|
|
| def gen_rand_orthos(m,p): |
| if (p != 2): |
| return torch.tensor(scipy.stats.special_ortho_group.rvs(p, size=m)).to(torch.float32) |
| X = torch.zeros(m,2,2) |
| t = torch.rand(m) * (2 * math.pi) |
| sin_t = torch.sin(t) |
| cos_t = torch.cos(t) |
| X[:,0,0] = cos_t |
| X[:,1,1] = cos_t |
| X[:,0,1] = sin_t |
| X[:,1,0] = -sin_t |
| return X |
|
|
| |
| def gen_rand_ortho_butterfly_noblock(n): |
| return ([gen_rand_orthos(1, p) for p in butterfly_factors(n)], torch.randperm(n), torch.randperm(n)) |
|
|
| |
| def mul_ortho_butterfly(Bpp, x): |
| (B, p_in, p_out) = Bpp |
| assert((len(x.shape) == 1) or (len(x.shape) == 2)) |
| orig_dim = 2 |
| if (len(x.shape) == 1): |
| (n,) = x.shape |
| x = x.reshape(n,1) |
| orig_dim = 1 |
| (n,q) = x.shape |
| x = x[p_in,:] |
| pfn = tuple(butterfly_factors(n)) |
| for i in range(len(pfn)): |
| mpfx = math.prod(pfn[0:i]) |
| p = pfn[i] |
| msfx = math.prod(pfn[(i+1):]) |
| x = x.reshape(mpfx, p, msfx, q).permute(0,2,1,3).reshape(mpfx * msfx, p, q) |
| x = B[i] @ x |
| x = x.reshape(mpfx, msfx, p, q).permute(0,2,1,3).reshape(n,q) |
| x = x[p_out,:] |
| if (orig_dim == 1): |
| x = x.reshape(n) |
| return x |
|
|
| |
| |
| def rand_ortho_butterfly_noblock(n): |
| return mul_ortho_butterfly(gen_rand_ortho_butterfly_noblock(n), torch.eye(n)) |